From 7ce416c6eaf1a012782b2c7906c2461789a380c7 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Fri, 9 May 2025 21:17:41 +0530 Subject: [PATCH 1/2] Wait for test server goroutines to exit during shutdown --- internal/transport/transport_test.go | 87 ++++++++++++++++++---------- 1 file changed, 57 insertions(+), 30 deletions(-) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 8b1219597912..bac0302bc224 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -320,14 +320,15 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream) } type server struct { - lis net.Listener - port string - startedErr chan error // error (or nil) with server start value - mu sync.Mutex - conns map[ServerTransport]net.Conn - h *testStreamHandler - ready chan struct{} - channelz *channelz.Server + lis net.Listener + port string + startedErr chan error // error (or nil) with server start value + mu sync.Mutex + conns map[ServerTransport]net.Conn + h *testStreamHandler + ready chan struct{} + channelz *channelz.Server + servingTasks sync.WaitGroup } func newTestServer() *server { @@ -383,40 +384,65 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.mu.Unlock() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() + s.servingTasks.Add(1) switch ht { case notifyCall: - go transport.HandleStreams(ctx, h.handleStreamAndNotify) + go func() { + transport.HandleStreams(ctx, h.handleStreamAndNotify) + s.servingTasks.Done() + }() case suspended: - go transport.HandleStreams(ctx, func(*ServerStream) {}) + go func() { + transport.HandleStreams(ctx, func(*ServerStream) {}) + s.servingTasks.Done() + }() case misbehaved: - go transport.HandleStreams(ctx, func(s *ServerStream) { - go h.handleStreamMisbehave(t, s) - }) + go func() { + transport.HandleStreams(ctx, func(s *ServerStream) { + go h.handleStreamMisbehave(t, s) + }) + s.servingTasks.Done() + }() case encodingRequiredStatus: - go transport.HandleStreams(ctx, func(s *ServerStream) { - go h.handleStreamEncodingRequiredStatus(s) - }) + go func() { + transport.HandleStreams(ctx, func(s *ServerStream) { + go h.handleStreamEncodingRequiredStatus(s) + }) + s.servingTasks.Done() + }() case invalidHeaderField: - go transport.HandleStreams(ctx, func(s *ServerStream) { - go h.handleStreamInvalidHeaderField(s) - }) + go func() { + transport.HandleStreams(ctx, func(s *ServerStream) { + go h.handleStreamInvalidHeaderField(s) + }) + s.servingTasks.Done() + }() case delayRead: h.notify = make(chan struct{}) h.getNotified = make(chan struct{}) s.mu.Lock() close(s.ready) s.mu.Unlock() - go transport.HandleStreams(ctx, func(s *ServerStream) { - go h.handleStreamDelayRead(t, s) - }) + go func() { + transport.HandleStreams(ctx, func(s *ServerStream) { + go h.handleStreamDelayRead(t, s) + }) + s.servingTasks.Done() + }() case pingpong: - go transport.HandleStreams(ctx, func(s *ServerStream) { - go h.handleStreamPingPong(t, s) - }) + go func() { + transport.HandleStreams(ctx, func(s *ServerStream) { + go h.handleStreamPingPong(t, s) + }) + s.servingTasks.Done() + }() default: - go transport.HandleStreams(ctx, func(s *ServerStream) { - go h.handleStream(t, s) - }) + go func() { + transport.HandleStreams(ctx, func(s *ServerStream) { + go h.handleStream(t, s) + }) + s.servingTasks.Done() + }() } } } @@ -440,6 +466,7 @@ func (s *server) stop() { } s.conns = nil s.mu.Unlock() + s.servingTasks.Wait() } func (s *server) addr() string { @@ -2253,11 +2280,11 @@ func (s) TestPingPong1B(t *testing.T) { runPingPongTest(t, 1) } -func TestPingPong1KB(t *testing.T) { +func (s) TestPingPong1KB(t *testing.T) { runPingPongTest(t, 1024) } -func TestPingPong64KB(t *testing.T) { +func (s) TestPingPong64KB(t *testing.T) { runPingPongTest(t, 65536) } From 6c431008e1b692926dba9d0110452e16ad8bf97e Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Fri, 9 May 2025 22:24:10 +0530 Subject: [PATCH 2/2] Use channel + waitgroup --- internal/transport/transport_test.go | 51 ++++++++++++++++------------ 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index bac0302bc224..bf4d033e93a4 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -320,22 +320,23 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream) } type server struct { - lis net.Listener - port string - startedErr chan error // error (or nil) with server start value - mu sync.Mutex - conns map[ServerTransport]net.Conn - h *testStreamHandler - ready chan struct{} - channelz *channelz.Server - servingTasks sync.WaitGroup + lis net.Listener + port string + startedErr chan error // error (or nil) with server start value + mu sync.Mutex + conns map[ServerTransport]net.Conn + h *testStreamHandler + ready chan struct{} + channelz *channelz.Server + servingTasksDone chan struct{} } func newTestServer() *server { return &server{ - startedErr: make(chan error, 1), - ready: make(chan struct{}), - channelz: channelz.RegisterServer("test server"), + startedErr: make(chan error, 1), + ready: make(chan struct{}), + servingTasksDone: make(chan struct{}), + channelz: channelz.RegisterServer("test server"), } } @@ -359,6 +360,12 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.port = p s.conns = make(map[ServerTransport]net.Conn) s.startedErr <- nil + wg := sync.WaitGroup{} + defer func() { + wg.Wait() + close(s.servingTasksDone) + }() + for { conn, err := s.lis.Accept() if err != nil { @@ -384,38 +391,38 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.mu.Unlock() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - s.servingTasks.Add(1) + wg.Add(1) switch ht { case notifyCall: go func() { transport.HandleStreams(ctx, h.handleStreamAndNotify) - s.servingTasks.Done() + wg.Done() }() case suspended: go func() { transport.HandleStreams(ctx, func(*ServerStream) {}) - s.servingTasks.Done() + wg.Done() }() case misbehaved: go func() { transport.HandleStreams(ctx, func(s *ServerStream) { go h.handleStreamMisbehave(t, s) }) - s.servingTasks.Done() + wg.Done() }() case encodingRequiredStatus: go func() { transport.HandleStreams(ctx, func(s *ServerStream) { go h.handleStreamEncodingRequiredStatus(s) }) - s.servingTasks.Done() + wg.Done() }() case invalidHeaderField: go func() { transport.HandleStreams(ctx, func(s *ServerStream) { go h.handleStreamInvalidHeaderField(s) }) - s.servingTasks.Done() + wg.Done() }() case delayRead: h.notify = make(chan struct{}) @@ -427,21 +434,21 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT transport.HandleStreams(ctx, func(s *ServerStream) { go h.handleStreamDelayRead(t, s) }) - s.servingTasks.Done() + wg.Done() }() case pingpong: go func() { transport.HandleStreams(ctx, func(s *ServerStream) { go h.handleStreamPingPong(t, s) }) - s.servingTasks.Done() + wg.Done() }() default: go func() { transport.HandleStreams(ctx, func(s *ServerStream) { go h.handleStream(t, s) }) - s.servingTasks.Done() + wg.Done() }() } } @@ -466,7 +473,7 @@ func (s *server) stop() { } s.conns = nil s.mu.Unlock() - s.servingTasks.Wait() + <-s.servingTasksDone } func (s *server) addr() string {