diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index 3613d7b64817..968d556a1f1b 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -196,11 +196,14 @@ func decodeTimeout(s string) (time.Duration, error) { if !ok { return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s) } - t, err := strconv.ParseInt(s[:size-1], 10, 64) + t, err := strconv.ParseUint(s[:size-1], 10, 64) if err != nil { return 0, err } - const maxHours = math.MaxInt64 / int64(time.Hour) + if t == 0 { + return 0, fmt.Errorf("transport: timeout must be positive: %q", s) + } + const maxHours = math.MaxInt64 / uint64(time.Hour) if d == time.Hour && t > maxHours { // This timeout would overflow math.MaxInt64; clamp it. return time.Duration(math.MaxInt64), nil diff --git a/internal/transport/http_util_test.go b/internal/transport/http_util_test.go index 5a259d43cdc2..546e8e7f39dc 100644 --- a/internal/transport/http_util_test.go +++ b/internal/transport/http_util_test.go @@ -22,28 +22,49 @@ import ( "errors" "fmt" "io" + "math" "net" "reflect" "testing" "time" ) -func (s) TestTimeoutDecode(t *testing.T) { +func (s) TestDecodeTimeout(t *testing.T) { for _, test := range []struct { // input s string // output - d time.Duration - err error + d time.Duration + wantErr bool }{ - {"1234S", time.Second * 1234, nil}, - {"1234x", 0, fmt.Errorf("transport: timeout unit is not recognized: %q", "1234x")}, - {"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")}, - {"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")}, + + {"00000001n", time.Nanosecond, false}, + {"10u", time.Microsecond * 10, false}, + {"00000010m", time.Millisecond * 10, false}, + {"1234S", time.Second * 1234, false}, + {"00000001M", time.Minute, false}, + {"09999999S", time.Second * 9999999, false}, + {"99999999S", time.Second * 99999999, false}, + {"99999999M", time.Minute * 99999999, false}, + {"2562047H", time.Hour * 2562047, false}, + {"2562048H", time.Duration(math.MaxInt64), false}, + {"99999999H", time.Duration(math.MaxInt64), false}, + {"-1S", 0, true}, + {"1234x", 0, true}, + {"1234s", 0, true}, + {"1234", 0, true}, + {"1", 0, true}, + {"", 0, true}, + {"9a1S", 0, true}, + {"0S", 0, true}, // PROTOCOL-HTTP2.md requires positive integers + {"00000000S", 0, true}, + {"000000000S", 0, true}, } { d, err := decodeTimeout(test.s) - if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) { - t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err) + gotErr := err != nil + if d != test.d || gotErr != test.wantErr { + t.Errorf("timeoutDecode(%q) = %d, %v, want %d, wantErr=%v", + test.s, int64(d), err, int64(test.d), test.wantErr) } } }