diff --git a/go/ai/gen.go b/go/ai/gen.go index 1f16227c6..828108459 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -181,23 +181,6 @@ type GenerationUsage struct { TotalTokens int `json:"totalTokens,omitempty"` } -type GenkitError struct { - Data *GenkitErrorData `json:"data,omitempty"` - Details any `json:"details,omitempty"` - Message string `json:"message,omitempty"` - Stack string `json:"stack,omitempty"` -} - -type GenkitErrorData struct { - GenkitErrorDetails *GenkitErrorDetails `json:"genkitErrorDetails,omitempty"` - GenkitErrorMessage string `json:"genkitErrorMessage,omitempty"` -} - -type GenkitErrorDetails struct { - Stack string `json:"stack,omitempty"` - TraceID string `json:"traceId,omitempty"` -} - type Media struct { ContentType string `json:"contentType,omitempty"` Url string `json:"url,omitempty"` diff --git a/go/ai/generate.go b/go/ai/generate.go index d6b7573e2..df932b576 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -153,7 +153,7 @@ func LookupModel(r *registry.Registry, provider, name string) Model { // It returns an error if the model was not defined. func LookupModelByName(r *registry.Registry, modelName string) (Model, error) { if modelName == "" { - return nil, errors.New("ai.LookupModelByName: model not specified") + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.LookupModelByName: model not specified") } provider, name, found := strings.Cut(modelName, "/") @@ -165,9 +165,9 @@ func LookupModelByName(r *registry.Registry, modelName string) (Model, error) { model := LookupModel(r, provider, name) if model == nil { if provider == "" { - return nil, fmt.Errorf("ai.LookupModelByName: no model named %q", name) + return nil, core.NewError(core.NOT_FOUND, "ai.LookupModelByName: model %q not found", name) } - return nil, fmt.Errorf("ai.LookupModelByName: no model named %q for provider %q", name, provider) + return nil, core.NewError(core.NOT_FOUND, "ai.LookupModelByName: model %q by provider %q not found", name, provider) } return model, nil @@ -180,7 +180,7 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera opts.Model = defaultModel } if opts.Model == "" { - return nil, errors.New("ai.GenerateWithRequest: model is required") + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: model is required") } } @@ -193,12 +193,12 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera toolDefMap := make(map[string]*ToolDefinition) for _, t := range opts.Tools { if _, ok := toolDefMap[t]; ok { - return nil, fmt.Errorf("ai.GenerateWithRequest: duplicate tool found: %q", t) + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: duplicate tool %q", t) } tool := LookupTool(r, t) if tool == nil { - return nil, fmt.Errorf("ai.GenerateWithRequest: tool not found: %q", t) + return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: tool %q not found", t) } toolDefMap[t] = tool.Definition() @@ -210,7 +210,7 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera maxTurns := opts.MaxTurns if maxTurns < 0 { - return nil, fmt.Errorf("ai.GenerateWithRequest: max turns must be greater than 0, got %d", maxTurns) + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: max turns must be greater than 0, got %d", maxTurns) } if maxTurns == 0 { maxTurns = 5 // Default max turns. @@ -276,7 +276,8 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera resp.Message, err = formatHandler.ParseMessage(resp.Message) if err != nil { logger.FromContext(ctx).Debug("model failed to generate output matching expected schema", "error", err.Error()) - return nil, fmt.Errorf("model failed to generate output matching expected schema: %w", err) + return nil, core.NewError(core.INTERNAL, "model failed to generate output matching expected schema: %v", err) + } } @@ -291,7 +292,7 @@ func GenerateWithRequest(ctx context.Context, r *registry.Registry, opts *Genera } if currentTurn+1 > maxTurns { - return nil, fmt.Errorf("exceeded maximum tool call iterations (%d)", maxTurns) + return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, cb) @@ -318,7 +319,7 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) genOpts := &generateOptions{} for _, opt := range opts { if err := opt.applyGenerate(genOpts); err != nil { - return nil, fmt.Errorf("ai.Generate: error applying options: %w", err) + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: error applying options: %v", err) } } @@ -421,7 +422,7 @@ func (m *model) Name() string { // Generate applies the [Action] to provided request. func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { if m == nil { - return nil, errors.New("Model.Generate: generate called on a nil model; check that all models are defined") + return nil, core.NewError(core.INVALID_ARGUMENT, "Model.Generate: generate called on a nil model; check that all models are defined") } return (*core.ActionDef[*ModelRequest, *ModelResponse, *ModelResponseChunk])(m).Run(ctx, req, cb) @@ -478,12 +479,12 @@ func cloneMessage(m *Message) *Message { panic(fmt.Sprintf("failed to marshal message: %v", err)) } - var copy Message - if err := json.Unmarshal(bytes, ©); err != nil { + var msgCopy Message + if err := json.Unmarshal(bytes, &msgCopy); err != nil { panic(fmt.Sprintf("failed to unmarshal message: %v", err)) } - return © + return &msgCopy } // handleToolRequests processes any tool requests in the response, returning @@ -520,7 +521,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq toolReq := p.ToolRequest tool := LookupTool(r, toolReq.Name) if tool == nil { - resultChan <- toolResult{idx, nil, fmt.Errorf("tool %q not found", toolReq.Name)} + resultChan <- toolResult{idx, nil, core.NewError(core.NOT_FOUND, "tool %q not found", toolReq.Name)} return } @@ -538,7 +539,7 @@ func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelReq resultChan <- toolResult{idx, nil, interruptErr} return } - resultChan <- toolResult{idx, nil, fmt.Errorf("tool %q failed: %w", toolReq.Name, err)} + resultChan <- toolResult{idx, nil, core.NewError(core.INTERNAL, "tool %q failed: %v", toolReq.Name, err)} return } diff --git a/go/ai/model_middleware.go b/go/ai/model_middleware.go index 7afe13eb0..76f4610f1 100644 --- a/go/ai/model_middleware.go +++ b/go/ai/model_middleware.go @@ -27,6 +27,7 @@ import ( "strconv" "strings" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" ) @@ -101,28 +102,28 @@ func validateSupport(model string, info *ModelInfo) ModelMiddleware { for _, msg := range input.Messages { for _, part := range msg.Content { if part.IsMedia() { - return nil, fmt.Errorf("model %q does not support media, but media was provided. Request: %+v", model, input) + return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support media, but media was provided. Request: %+v", model, input) } } } } if !info.Supports.Tools && len(input.Tools) > 0 { - return nil, fmt.Errorf("model %q does not support tool use, but tools were provided. Request: %+v", model, input) + return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support tool use, but tools were provided. Request: %+v", model, input) } if !info.Supports.Multiturn && len(input.Messages) > 1 { - return nil, fmt.Errorf("model %q does not support multiple messages, but %d were provided. Request: %+v", model, len(input.Messages), input) + return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support multiple messages, but %d were provided. Request: %+v", model, len(input.Messages), input) } if !info.Supports.ToolChoice && input.ToolChoice != "" && input.ToolChoice != ToolChoiceAuto { - return nil, fmt.Errorf("model %q does not support tool choice, but tool choice was provided. Request: %+v", model, input) + return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support tool choice, but tool choice was provided. Request: %+v", model, input) } if !info.Supports.SystemRole { for _, msg := range input.Messages { if msg.Role == RoleSystem { - return nil, fmt.Errorf("model %q does not support system role, but system role was provided. Request: %+v", model, input) + return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support system role, but system role was provided. Request: %+v", model, input) } } } @@ -140,7 +141,7 @@ func validateSupport(model string, info *ModelInfo) ModelMiddleware { info.Supports.Constrained == ConstrainedSupportNone || (info.Supports.Constrained == ConstrainedSupportNoTools && len(input.Tools) > 0)) && input.Output != nil && input.Output.Constrained { - return nil, fmt.Errorf("model %q does not support native constrained output, but constrained output was requested. Request: %+v", model, input) + return nil, core.NewError(core.INVALID_ARGUMENT, "model %q does not support native constrained output, but constrained output was requested. Request: %+v", model, input) } if err := validateVersion(model, info.Versions, input.Config); err != nil { @@ -176,14 +177,14 @@ func validateVersion(model string, versions []string, config any) error { version, ok := versionVal.(string) if !ok { - return fmt.Errorf("version must be a string, got %T", versionVal) + return core.NewError(core.INVALID_ARGUMENT, "version must be a string, got %T", versionVal) } if slices.Contains(versions, version) { return nil } - return fmt.Errorf("model %q does not support version %q, supported versions: %v", model, version, versions) + return core.NewError(core.INVALID_ARGUMENT, "model %q does not support version %q, supported versions: %v", model, version, versions) } // ContextItemTemplate is the default item template for context augmentation. @@ -302,13 +303,13 @@ func DownloadRequestMedia(options *DownloadMediaOptions) ModelMiddleware { resp, err := client.Get(mediaUrl) if err != nil { - return nil, fmt.Errorf("HTTP error downloading media %q: %w", mediaUrl, err) + return nil, core.NewError(core.INVALID_ARGUMENT, "HTTP error downloading media %q: %v", mediaUrl, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("HTTP error downloading media %q: %s", mediaUrl, string(body)) + return nil, core.NewError(core.UNKNOWN, "HTTP error downloading media %q: %s", mediaUrl, string(body)) } contentType := part.ContentType @@ -324,7 +325,7 @@ func DownloadRequestMedia(options *DownloadMediaOptions) ModelMiddleware { data, err = io.ReadAll(resp.Body) } if err != nil { - return nil, fmt.Errorf("error reading media %q: %v", mediaUrl, err) + return nil, core.NewError(core.UNKNOWN, "error reading media %q: %v", mediaUrl, err) } message.Content[j] = NewMediaPart(contentType, fmt.Sprintf("data:%s;base64,%s", contentType, base64.StdEncoding.EncodeToString(data))) diff --git a/go/core/action.go b/go/core/action.go index 076a2a750..a5cd85f85 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -20,7 +20,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "reflect" "time" @@ -213,7 +212,7 @@ func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb Strea func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. if err := base.ValidateJSON(input, a.inputSchema); err != nil { - return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err} + return nil, NewError(INVALID_ARGUMENT, err.Error()) } var in In if input != nil { diff --git a/go/core/error.go b/go/core/error.go new file mode 100644 index 000000000..e9bfa0bb1 --- /dev/null +++ b/go/core/error.go @@ -0,0 +1,138 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package core provides base error types and utilities for Genkit. +package core + +import ( + "fmt" + "runtime/debug" +) + +type ReflectionErrorDetails struct { + Stack *string `json:"stack,omitempty"` // Use pointer for optional + TraceID *string `json:"traceId,omitempty"` +} + +// ReflectionError is the wire format for HTTP errors for Reflection API responses. +type ReflectionError struct { + Details *ReflectionErrorDetails `json:"details,omitempty"` + Message string `json:"message"` + Code int `json:"code"` +} + +// GenkitError is the base error type for Genkit errors. +type GenkitError struct { + Message string `json:"message"` // Exclude from default JSON if embedded elsewhere + Status StatusName `json:"status"` + HTTPCode int `json:"-"` // Exclude from default JSON + Details map[string]any `json:"details"` // Use map for arbitrary details + Source *string `json:"source,omitempty"` // Pointer for optional +} + +// UserFacingError is the base error type for user facing errors. +type UserFacingError struct { + Message string `json:"message"` // Exclude from default JSON if embedded elsewhere + Status StatusName `json:"status"` + Details map[string]any `json:"details"` // Use map for arbitrary details +} + +// NewPublicError allows a web framework handler to know it +// is safe to return the message in a request. Other kinds of errors will +// result in a generic 500 message to avoid the possibility of internal +// exceptions being leaked to attackers. +func NewPublicError(status StatusName, message string, details map[string]any) *UserFacingError { + return &UserFacingError{ + Status: status, + Details: details, + Message: message, + } +} + +// Error implements the standard error interface for UserFacingError. +func (e *UserFacingError) Error() string { + return fmt.Sprintf("%s: %s", e.Status, e.Message) +} + +// NewError creates a new GenkitError with a stack trace. +func NewError(status StatusName, message string, args ...any) *GenkitError { + // Prevents a compile-time warning about non-constant message. + msg := message + + ge := &GenkitError{ + Status: status, + Message: fmt.Sprintf(msg, args...), + } + + errStack := getErrorStack(ge) + if errStack != "" { + ge.Details = make(map[string]any) + ge.Details["stack"] = errStack + } + return ge +} + +// Error implements the standard error interface. +func (e *GenkitError) Error() string { + if e == nil { + return "" + } + return e.Message +} + +// ToReflectionError returns a JSON-serializable representation for reflection API responses. +func (e *GenkitError) ToReflectionError() ReflectionError { + errDetails := &ReflectionErrorDetails{} + if stackVal, ok := e.Details["stack"].(string); ok { + errDetails.Stack = &stackVal + } + if traceVal, ok := e.Details["traceId"].(string); ok { + errDetails.TraceID = &traceVal + } + return ReflectionError{ + Details: errDetails, + Code: HTTPStatusCode(e.Status), + Message: e.Message, + } +} + +// ToReflectionError gets the JSON representation for reflection API Error responses. +func ToReflectionError(err error) ReflectionError { + if ge, ok := err.(*GenkitError); ok { + return ge.ToReflectionError() + } + + stack := getErrorStack(err) + detailsWire := &ReflectionErrorDetails{} + if stack != "" { + detailsWire.Stack = &stack + } + + return ReflectionError{ + Message: err.Error(), + Code: HTTPStatusCode(INTERNAL), + Details: detailsWire, + } +} + +// getErrorStack extracts stack trace from an error object. +// This captures the stack trace of the current goroutine when called. +func getErrorStack(err error) string { + if err == nil { + return "" + } + return string(debug.Stack()) +} diff --git a/go/core/schemas.config b/go/core/schemas.config index f40e52e7d..8b6bf69fa 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -272,4 +272,6 @@ Score omit Embedding.embedding type []float32 -GenkitErrorDataGenkitErrorDetails name GenkitErrorDetails \ No newline at end of file +GenkitError omit +GenkitErrorData omit +GenkitErrorDataGenkitErrorDetails omit \ No newline at end of file diff --git a/go/core/status_types.go b/go/core/status_types.go new file mode 100644 index 000000000..97a9592a7 --- /dev/null +++ b/go/core/status_types.go @@ -0,0 +1,150 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package status defines canonical status codes, names, and related types +// inspired by gRPC status codes. +package core + +import "net/http" // Import standard http package for status codes + +// StatusName defines the set of canonical status names. +type StatusName string + +// Constants for canonical status names. +const ( + OK StatusName = "OK" + CANCELLED StatusName = "CANCELLED" + UNKNOWN StatusName = "UNKNOWN" + INVALID_ARGUMENT StatusName = "INVALID_ARGUMENT" + DEADLINE_EXCEEDED StatusName = "DEADLINE_EXCEEDED" + NOT_FOUND StatusName = "NOT_FOUND" + ALREADY_EXISTS StatusName = "ALREADY_EXISTS" + PERMISSION_DENIED StatusName = "PERMISSION_DENIED" + UNAUTHENTICATED StatusName = "UNAUTHENTICATED" + RESOURCE_EXHAUSTED StatusName = "RESOURCE_EXHAUSTED" + FAILED_PRECONDITION StatusName = "FAILED_PRECONDITION" + ABORTED StatusName = "ABORTED" + OUT_OF_RANGE StatusName = "OUT_OF_RANGE" + UNIMPLEMENTED StatusName = "UNIMPLEMENTED" + INTERNAL StatusName = "INTERNAL_SERVER_ERROR" + UNAVAILABLE StatusName = "UNAVAILABLE" + DATA_LOSS StatusName = "DATA_LOSS" +) + +// Constants for canonical status codes (integer values). +const ( + // CodeOK means not an error; returned on success. + CodeOK = 0 + // CodeCancelled means the operation was cancelled, typically by the caller. + CodeCancelled = 1 + // CodeUnknown means an unknown error occurred. + CodeUnknown = 2 + // CodeInvalidArgument means the client specified an invalid argument. + CodeInvalidArgument = 3 + // CodeDeadlineExceeded means the deadline expired before the operation could complete. + CodeDeadlineExceeded = 4 + // CodeNotFound means some requested entity (e.g., file or directory) was not found. + CodeNotFound = 5 + // CodeAlreadyExists means the entity that a client attempted to create already exists. + CodeAlreadyExists = 6 + // CodePermissionDenied means the caller does not have permission to execute the operation. + CodePermissionDenied = 7 + // CodeUnauthenticated means the request does not have valid authentication credentials. + CodeUnauthenticated = 16 + // CodeResourceExhausted means some resource has been exhausted. + CodeResourceExhausted = 8 + // CodeFailedPrecondition means the operation was rejected because the system is not in a state required. + CodeFailedPrecondition = 9 + // CodeAborted means the operation was aborted, typically due to some issue. + CodeAborted = 10 + // CodeOutOfRange means the operation was attempted past the valid range. + CodeOutOfRange = 11 + // CodeUnimplemented means the operation is not implemented or not supported/enabled. + CodeUnimplemented = 12 + // CodeInternal means internal errors. Some invariants expected by the underlying system were broken. + CodeInternal = 13 + // CodeUnavailable means the service is currently unavailable. + CodeUnavailable = 14 + // CodeDataLoss means unrecoverable data loss or corruption. + CodeDataLoss = 15 +) + +// StatusNameToCode maps status names to their integer code values. +// Exported for potential use elsewhere if needed. +var StatusNameToCode = map[StatusName]int{ + OK: CodeOK, + CANCELLED: CodeCancelled, + UNKNOWN: CodeUnknown, + INVALID_ARGUMENT: CodeInvalidArgument, + DEADLINE_EXCEEDED: CodeDeadlineExceeded, + NOT_FOUND: CodeNotFound, + ALREADY_EXISTS: CodeAlreadyExists, + PERMISSION_DENIED: CodePermissionDenied, + UNAUTHENTICATED: CodeUnauthenticated, + RESOURCE_EXHAUSTED: CodeResourceExhausted, + FAILED_PRECONDITION: CodeFailedPrecondition, + ABORTED: CodeAborted, + OUT_OF_RANGE: CodeOutOfRange, + UNIMPLEMENTED: CodeUnimplemented, + INTERNAL: CodeInternal, + UNAVAILABLE: CodeUnavailable, + DATA_LOSS: CodeDataLoss, +} + +// statusNameToHTTPCode maps status names to HTTP status codes. +// Kept unexported as it's primarily used by the HTTPStatusCode function. +var statusNameToHTTPCode = map[StatusName]int{ + OK: http.StatusOK, // 200 + CANCELLED: 499, // Client Closed Request (non-standard but common) + UNKNOWN: http.StatusInternalServerError, // 500 + INVALID_ARGUMENT: http.StatusBadRequest, // 400 + DEADLINE_EXCEEDED: http.StatusGatewayTimeout, // 504 + NOT_FOUND: http.StatusNotFound, // 404 + ALREADY_EXISTS: http.StatusConflict, // 409 + PERMISSION_DENIED: http.StatusForbidden, // 403 + UNAUTHENTICATED: http.StatusUnauthorized, // 401 + RESOURCE_EXHAUSTED: http.StatusTooManyRequests, // 429 + FAILED_PRECONDITION: http.StatusBadRequest, // 400 + ABORTED: http.StatusConflict, // 409 + OUT_OF_RANGE: http.StatusBadRequest, // 400 + UNIMPLEMENTED: http.StatusNotImplemented, // 501 + INTERNAL: http.StatusInternalServerError, // 500 + UNAVAILABLE: http.StatusServiceUnavailable, // 503 + DATA_LOSS: http.StatusInternalServerError, // 500 +} + +// HTTPStatusCode gets the corresponding HTTP status code for a given Genkit status name. +func HTTPStatusCode(name StatusName) int { + if code, ok := statusNameToHTTPCode[name]; ok { + return code + } + + return http.StatusInternalServerError +} + +// Status represents a status condition, typically used in responses or errors. +type Status struct { + Name StatusName `json:"name"` + Message string `json:"message,omitempty"` +} + +// NewStatus creates a new Status object. +func NewStatus(name StatusName, message string) *Status { + return &Status{ + Name: name, + Message: message, + } +} diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index eba5be0be..69248b792 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -19,7 +19,6 @@ package genkit import ( "context" "encoding/json" - "errors" "fmt" "log/slog" "net" @@ -29,13 +28,11 @@ import ( "strconv" "time" - "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" "github.com/firebase/genkit/go/internal/action" - "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/internal/registry" "go.opentelemetry.io/otel/trace" ) @@ -250,26 +247,9 @@ func wrapReflectionHandler(h func(w http.ResponseWriter, r *http.Request) error) w.Header().Set("x-genkit-version", "go/"+internal.Version) if err = h(w, r); err != nil { - var traceID string - statusCode := http.StatusInternalServerError - if herr, ok := err.(*base.HTTPError); ok { - traceID = herr.TraceID - statusCode = herr.Code - } - - genkitErr := &ai.GenkitError{ - Message: err.Error(), - Details: struct { - TraceID string `json:"traceId"` - Stack string `json:"stack"` - }{ - TraceID: traceID, - Stack: "", // TODO: Propagate stack trace from local error. - }, - } - - w.WriteHeader(statusCode) - writeJSON(ctx, w, genkitErr) + errorResponse := core.ToReflectionError(err) + w.WriteHeader(errorResponse.Code) + writeJSON(ctx, w, errorResponse) } } } @@ -287,7 +267,7 @@ func handleRunAction(reg *registry.Registry) func(w http.ResponseWriter, r *http } defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - return &base.HTTPError{Code: http.StatusBadRequest, Err: err} + return core.NewError(core.INVALID_ARGUMENT, err.Error()) } stream, err := parseBoolQueryParam(r, "stream") @@ -322,26 +302,16 @@ func handleRunAction(reg *registry.Registry) func(w http.ResponseWriter, r *http resp, err := runAction(ctx, reg, body.Key, body.Input, cb, contextMap) if err != nil { if stream { - var traceID string - if herr, ok := err.(*base.HTTPError); ok { - traceID = herr.TraceID + reflectErr, err := json.Marshal(core.ToReflectionError(err)) + if err != nil { + return err } - genkitErr := &ai.GenkitError{ - Message: err.Error(), - Data: &ai.GenkitErrorData{ - GenkitErrorMessage: err.Error(), - GenkitErrorDetails: &ai.GenkitErrorDetails{ - TraceID: traceID, - }, - }, + _, err = fmt.Fprintf(w, "%s\n\n", reflectErr) + if err != nil { + return err } - errorJSON, _ := json.Marshal(genkitErr) - _, writeErr := fmt.Fprintf(w, "%s\n\n", errorJSON) - if writeErr != nil { - return writeErr - } if f, ok := w.(http.Flusher); ok { f.Flush() } @@ -364,7 +334,7 @@ func handleNotify(reg *registry.Registry) func(w http.ResponseWriter, r *http.Re defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - return &base.HTTPError{Code: http.StatusBadRequest, Err: err} + return core.NewError(core.INVALID_ARGUMENT, err.Error()) } if os.Getenv("GENKIT_TELEMETRY_SERVER") == "" && body.TelemetryServerURL != "" { @@ -408,7 +378,7 @@ type telemetry struct { func runAction(ctx context.Context, reg *registry.Registry, key string, input json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { action := reg.LookupAction(key) if action == nil { - return nil, &base.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no action with key %q", key)} + return nil, core.NewError(core.NOT_FOUND, "action %q not found", key) } if runtimeContext != nil { ctx = core.WithActionContext(ctx, runtimeContext) @@ -421,12 +391,7 @@ func runAction(ctx context.Context, reg *registry.Registry, key string, input js return action.RunJSON(ctx, input, cb) }) if err != nil { - var herr *base.HTTPError - if errors.As(err, &herr) { - herr.TraceID = traceID - return nil, herr - } - return nil, &base.HTTPError{Code: http.StatusInternalServerError, Err: err, TraceID: traceID} + return nil, err } return &runActionResponse{ diff --git a/go/genkit/servers.go b/go/genkit/servers.go index 06db1100f..178e8d442 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -29,7 +29,6 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" - "github.com/firebase/genkit/go/internal/base" ) type HandlerOption interface { @@ -90,9 +89,9 @@ func wrapHandler(h func(http.ResponseWriter, *http.Request) error) http.HandlerF }() if err = h(w, r); err != nil { - var herr *base.HTTPError + var herr *core.GenkitError if errors.As(err, &herr) { - http.Error(w, herr.Error(), herr.Code) + http.Error(w, herr.Error(), core.HTTPStatusCode(herr.Status)) } else { http.Error(w, err.Error(), http.StatusInternalServerError) } @@ -113,7 +112,7 @@ func handler(a core.Action, params *handlerParams) func(http.ResponseWriter, *ht if r.Body != nil && r.ContentLength > 0 { defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - return &base.HTTPError{Code: http.StatusBadRequest, Err: err} + return core.NewPublicError(core.INVALID_ARGUMENT, err.Error(), nil) } } @@ -158,7 +157,7 @@ func handler(a core.Action, params *handlerParams) func(http.ResponseWriter, *ht }) if err != nil { logger.FromContext(ctx).Error("error providing action context from request", "err", err) - return &base.HTTPError{Code: http.StatusUnauthorized, Err: err} + return err } if existing := core.FromContext(ctx); existing != nil { @@ -195,7 +194,7 @@ func parseBoolQueryParam(r *http.Request, name string) (bool, error) { var err error b, err = strconv.ParseBool(s) if err != nil { - return false, &base.HTTPError{Code: http.StatusBadRequest, Err: err} + return false, core.NewPublicError(core.INVALID_ARGUMENT, err.Error(), nil) } } return b, nil diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index cd9cc6072..d203bbb99 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -132,8 +132,8 @@ func TestHandler(t *testing.T) { resp := w.Result() body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("want status code %d, got %d", http.StatusBadRequest, resp.StatusCode) + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("want status code %d, got %d", http.StatusInternalServerError, resp.StatusCode) } if !strings.Contains(string(body), "invalid character") { diff --git a/go/internal/base/misc.go b/go/internal/base/misc.go index ef7e5db6a..9e3afa1d9 100644 --- a/go/internal/base/misc.go +++ b/go/internal/base/misc.go @@ -17,8 +17,6 @@ package base import ( - "fmt" - "net/http" "net/url" ) @@ -40,14 +38,3 @@ func Zero[T any]() T { func Clean(id string) string { return url.PathEscape(id) } - -// HTTPError is an error that includes an HTTP status code. -type HTTPError struct { - Code int - Err error - TraceID string -} - -func (e *HTTPError) Error() string { - return fmt.Sprintf("%s: %s", http.StatusText(e.Code), e.Err) -} diff --git a/go/plugins/firebase/auth.go b/go/plugins/firebase/auth.go index bbd3c6f16..bb1856f97 100644 --- a/go/plugins/firebase/auth.go +++ b/go/plugins/firebase/auth.go @@ -19,7 +19,6 @@ package firebase import ( "context" "encoding/json" - "errors" "fmt" "strings" @@ -43,7 +42,7 @@ type AuthClient interface { func ContextProvider(ctx context.Context, g *genkit.Genkit, policy AuthPolicy) (core.ContextProvider, error) { f, ok := genkit.LookupPlugin(g, provider).(*Firebase) if !ok { - return nil, errors.New("firebase plugin not initialized; did you pass the plugin to genkit.Init()") + return nil, core.NewError(core.NOT_FOUND, "firebase plugin not initialized; did you pass the plugin to genkit.Init()") } client, err := f.App.Auth(ctx) if err != nil { @@ -53,19 +52,19 @@ func ContextProvider(ctx context.Context, g *genkit.Genkit, policy AuthPolicy) ( return func(ctx context.Context, input core.RequestData) (core.ActionContext, error) { authHeader, ok := input.Headers["authorization"] if !ok { - return nil, errors.New("authorization header is required but not provided") + return nil, core.NewPublicError(core.UNAUTHENTICATED, "authorization header is required but not provided", nil) } const bearerPrefix = "bearer " if !strings.HasPrefix(strings.ToLower(authHeader), bearerPrefix) { - return nil, errors.New("invalid authorization header format") + return nil, core.NewPublicError(core.UNAUTHENTICATED, "invalid authorization header format", nil) } token := authHeader[len(bearerPrefix):] authCtx, err := client.VerifyIDToken(ctx, token) if err != nil { - return nil, fmt.Errorf("error verifying ID token: %v", err) + return nil, core.NewPublicError(core.UNAUTHENTICATED, fmt.Sprintf("error verifying ID token: %v", err), nil) } if policy != nil { diff --git a/go/samples/basic-gemini/main.go b/go/samples/basic-gemini/main.go index dceb064dc..f84368f16 100644 --- a/go/samples/basic-gemini/main.go +++ b/go/samples/basic-gemini/main.go @@ -16,10 +16,10 @@ package main import ( "context" - "errors" "log" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/googlegenai" ) @@ -40,7 +40,7 @@ func main() { genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { m := googlegenai.GoogleAIModel(g, "gemini-2.5-pro-preview-03-25") if m == nil { - return "", errors.New("jokesFlow: failed to find model") + return "", core.NewError(core.INVALID_ARGUMENT, "jokesFlow: failed to find model") } resp, err := genkit.Generate(ctx, g,