Skip to content

internal/transport: Wait for server goroutines to exit during shutdown in test #8306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 67 additions & 33 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,21 +320,23 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream)
}

type server struct {
lis net.Listener
port string
startedErr chan error // error (or nil) with server start value
mu sync.Mutex
conns map[ServerTransport]net.Conn
h *testStreamHandler
ready chan struct{}
channelz *channelz.Server
lis net.Listener
port string
startedErr chan error // error (or nil) with server start value
mu sync.Mutex
conns map[ServerTransport]net.Conn
h *testStreamHandler
ready chan struct{}
channelz *channelz.Server
servingTasksDone chan struct{}
}

func newTestServer() *server {
return &server{
startedErr: make(chan error, 1),
ready: make(chan struct{}),
channelz: channelz.RegisterServer("test server"),
startedErr: make(chan error, 1),
ready: make(chan struct{}),
servingTasksDone: make(chan struct{}),
channelz: channelz.RegisterServer("test server"),
}
}

Expand All @@ -358,6 +360,12 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.port = p
s.conns = make(map[ServerTransport]net.Conn)
s.startedErr <- nil
wg := sync.WaitGroup{}
defer func() {
wg.Wait()
close(s.servingTasksDone)
}()

for {
conn, err := s.lis.Accept()
if err != nil {
Expand All @@ -383,40 +391,65 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
wg.Add(1)
switch ht {
case notifyCall:
go transport.HandleStreams(ctx, h.handleStreamAndNotify)
go func() {
transport.HandleStreams(ctx, h.handleStreamAndNotify)
wg.Done()
}()
case suspended:
go transport.HandleStreams(ctx, func(*ServerStream) {})
go func() {
transport.HandleStreams(ctx, func(*ServerStream) {})
wg.Done()
}()
case misbehaved:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamMisbehave(t, s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamMisbehave(t, s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this (and similar) be outside the other goroutine now? Or does HandleStreams track this and not return until all streams it's handling are done?

})
wg.Done()
}()
case encodingRequiredStatus:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamEncodingRequiredStatus(s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamEncodingRequiredStatus(s)
})
wg.Done()
}()
case invalidHeaderField:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamInvalidHeaderField(s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamInvalidHeaderField(s)
})
wg.Done()
}()
case delayRead:
h.notify = make(chan struct{})
h.getNotified = make(chan struct{})
s.mu.Lock()
close(s.ready)
s.mu.Unlock()
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamDelayRead(t, s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamDelayRead(t, s)
})
wg.Done()
}()
case pingpong:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamPingPong(t, s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamPingPong(t, s)
})
wg.Done()
}()
default:
go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStream(t, s)
})
go func() {
transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStream(t, s)
})
wg.Done()
}()
}
}
}
Expand All @@ -440,6 +473,7 @@ func (s *server) stop() {
}
s.conns = nil
s.mu.Unlock()
<-s.servingTasksDone
}

func (s *server) addr() string {
Expand Down Expand Up @@ -2253,11 +2287,11 @@ func (s) TestPingPong1B(t *testing.T) {
runPingPongTest(t, 1)
}

func TestPingPong1KB(t *testing.T) {
func (s) TestPingPong1KB(t *testing.T) {
runPingPongTest(t, 1024)
}

func TestPingPong64KB(t *testing.T) {
func (s) TestPingPong64KB(t *testing.T) {
runPingPongTest(t, 65536)
}

Expand Down