diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index ed2d93af67e5..ef56592b944e 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -592,6 +592,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) // Send out timeout regardless its value. The server can detect timeout context by itself. // TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire. timeout := time.Until(dl) + if timeout <= 0 { + return nil, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) + } headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: grpcutil.EncodeDuration(timeout)}) } for k, v := range authData { diff --git a/test/end2end_test.go b/test/end2end_test.go index fd5e29ea8c96..75b27f4c224d 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5351,6 +5351,31 @@ func testRPCTimeout(t *testing.T, e env) { } } +// Tests that the client doesn't send a negative timeout to the server. If the +// server receives a negative timeout, it would return an internal status. The +// client checks the context error before starting a stream, however the context +// may expire after this check and before the timeout is calculated. +func (s) TestNegativeRPCTimeout(t *testing.T) { + server := stubserver.StartTestService(t, nil) + defer server.Stop() + + if err := server.StartClient(); err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Try increasingly larger timeout values to trigger the condition when the + // context has expired while creating the grpc-timeout header. + for i := range 10 { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(i*100)*time.Nanosecond) + defer cancel() + + client := server.Client + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v; want , error code: %s", err, codes.DeadlineExceeded) + } + } +} + func (s) TestDisabledIOBuffers(t *testing.T) { payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(60000)) if err != nil {