Skip to content

Commit 1ba13f3

Browse files
authored
Consolidating down to FULL_DUPLEX_STREAMED supported ext-proc server (#672)
1 parent 4761c71 commit 1ba13f3

File tree

10 files changed

+490
-1336
lines changed

10 files changed

+490
-1336
lines changed

cmd/epp/main.go

-6
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,6 @@ func run() error {
120120
flag.Parse()
121121
initLogging(&opts)
122122

123-
useStreamingServer, err := strconv.ParseBool(os.Getenv("USE_STREAMING"))
124-
if err != nil {
125-
setupLog.Error(err, "Failed to parse env var USE_STREAMING, defaulting to false")
126-
}
127-
128123
// Validate flags
129124
if err := validateFlags(); err != nil {
130125
setupLog.Error(err, "Failed to validate flags")
@@ -178,7 +173,6 @@ func run() error {
178173
Datastore: datastore,
179174
SecureServing: *secureServing,
180175
CertPath: *certPath,
181-
UseStreaming: useStreamingServer,
182176
RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval,
183177
}
184178
if err := serverRunner.SetupWithManager(ctx, mgr); err != nil {

config/charts/inferencepool/templates/epp-deployment.yaml

-3
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ spec:
3535
- "9003"
3636
- -metricsPort
3737
- "9090"
38-
env:
39-
- name: USE_STREAMING
40-
value: "true"
4138
ports:
4239
- name: grpc
4340
containerPort: 9002

config/manifests/inferencepool-resources.yaml

-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ spec:
6262
- "9002"
6363
- -grpcHealthPort
6464
- "9003"
65-
env:
66-
- name: USE_STREAMING
67-
value: "true"
6865
ports:
6966
- containerPort: 9002
7067
- containerPort: 9003

pkg/epp/handlers/request.go

+51-111
Original file line numberDiff line numberDiff line change
@@ -21,190 +21,130 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"strconv"
24+
"time"
2425

25-
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2626
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27-
"google.golang.org/protobuf/types/known/structpb"
2827
"sigs.k8s.io/controller-runtime/pkg/log"
2928
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
3029
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3130
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
3231
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3332
)
3433

35-
// HandleRequestBody handles body of the request to the backend server, such as parsing the "model"
36-
// parameter.
37-
// Envoy sends the request body to ext proc before sending the request to the backend server.
38-
func (s *Server) HandleRequestBody(
34+
// HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling.
35+
func (s *StreamingServer) HandleRequestBody(
3936
ctx context.Context,
4037
reqCtx *RequestContext,
4138
req *extProcPb.ProcessingRequest,
42-
) (*extProcPb.ProcessingResponse, error) {
39+
requestBodyMap map[string]interface{},
40+
) (*RequestContext, error) {
41+
var requestBodyBytes []byte
4342
logger := log.FromContext(ctx)
44-
loggerVerbose := logger.V(logutil.VERBOSE)
45-
loggerVerbose.Info("Handling request body")
46-
47-
// Unmarshal request body (must be JSON).
48-
v := req.Request.(*extProcPb.ProcessingRequest_RequestBody)
49-
var rb map[string]interface{}
50-
if err := json.Unmarshal(v.RequestBody.Body, &rb); err != nil {
51-
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
52-
return nil, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("error unmarshaling request body: %v", err)}
53-
}
54-
loggerVerbose.Info("Request body unmarshalled", "body", rb)
5543

5644
// Resolve target models.
57-
model, ok := rb["model"].(string)
45+
model, ok := requestBodyMap["model"].(string)
5846
if !ok {
59-
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"}
47+
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"}
6048
}
61-
loggerVerbose.Info("Model requested", "model", model)
49+
6250
modelName := model
6351

6452
// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
6553
// This might be a security risk in the future where adapters not registered in the InferenceModel
6654
// are able to be requested by using their distinct name.
6755
modelObj := s.datastore.ModelGet(model)
6856
if modelObj == nil {
69-
return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)}
57+
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)}
7058
}
7159
if len(modelObj.Spec.TargetModels) > 0 {
7260
modelName = RandomWeightedDraw(logger, modelObj, 0)
7361
if modelName == "" {
74-
return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
62+
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
7563
}
7664
}
7765
llmReq := &schedulingtypes.LLMRequest{
7866
Model: model,
7967
ResolvedTargetModel: modelName,
8068
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
8169
}
82-
loggerVerbose.Info("LLM request assembled", "request", llmReq)
70+
logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical)
8371

84-
requestBody := v.RequestBody.Body
8572
var err error
8673
// Update target models in the body.
8774
if llmReq.Model != llmReq.ResolvedTargetModel {
88-
rb["model"] = llmReq.ResolvedTargetModel
89-
requestBody, err = json.Marshal(rb)
90-
if err != nil {
91-
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
92-
return nil, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
93-
}
94-
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBody))
75+
requestBodyMap["model"] = llmReq.ResolvedTargetModel
76+
}
77+
78+
requestBodyBytes, err = json.Marshal(requestBodyMap)
79+
if err != nil {
80+
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
81+
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
9582
}
9683

9784
target, err := s.scheduler.Schedule(ctx, llmReq)
9885
if err != nil {
99-
return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
86+
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
10087
}
10188
targetPod := target.GetPod()
10289

103-
logger.V(logutil.DEFAULT).Info("Request handled",
104-
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod)
105-
10690
// Insert target endpoint to instruct Envoy to route requests to the specified target pod.
10791
// Attach the port number
10892
pool, err := s.datastore.PoolGet()
10993
if err != nil {
110-
return nil, err
94+
return reqCtx, err
11195
}
11296
endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
11397

98+
logger.V(logutil.DEFAULT).Info("Request handled",
99+
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics",
100+
fmt.Sprintf("%+v", target))
101+
114102
reqCtx.Model = llmReq.Model
115103
reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel
116-
reqCtx.RequestSize = len(v.RequestBody.Body)
104+
reqCtx.RequestSize = len(requestBodyBytes)
117105
reqCtx.TargetPod = targetPod.NamespacedName.String()
118106
reqCtx.TargetEndpoint = endpoint
119107

120-
headers := []*configPb.HeaderValueOption{
121-
{
122-
Header: &configPb.HeaderValue{
123-
Key: s.destinationEndpointHintKey,
124-
RawValue: []byte(endpoint),
125-
},
126-
},
127-
// We need to update the content length header if the body is mutated, see Envoy doc:
128-
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
129-
{
130-
Header: &configPb.HeaderValue{
131-
Key: "Content-Length",
132-
RawValue: []byte(strconv.Itoa(len(requestBody))),
133-
},
134-
},
135-
}
136-
// Print headers for debugging
137-
for _, header := range headers {
138-
logger.V(logutil.DEBUG).Info("Request body header", "key", header.Header.Key, "value", header.Header.RawValue)
139-
}
140-
141-
targetEndpointValue := &structpb.Struct{
142-
Fields: map[string]*structpb.Value{
143-
s.destinationEndpointHintKey: {
144-
Kind: &structpb.Value_StringValue{
145-
StringValue: endpoint,
146-
},
147-
},
148-
},
149-
}
150-
dynamicMetadata := targetEndpointValue
151-
if s.destinationEndpointHintMetadataNamespace != "" {
152-
// If a namespace is defined, wrap the selected endpoint with that.
153-
dynamicMetadata = &structpb.Struct{
154-
Fields: map[string]*structpb.Value{
155-
s.destinationEndpointHintMetadataNamespace: {
156-
Kind: &structpb.Value_StructValue{
157-
StructValue: targetEndpointValue,
158-
},
159-
},
160-
},
161-
}
162-
}
108+
s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes))
163109

164-
resp := &extProcPb.ProcessingResponse{
110+
reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{
165111
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
166112
// and as an unstructure ext-proc response metadata key/value pair. This enables different integration
167113
// options for gateway providers.
168114
Response: &extProcPb.ProcessingResponse_RequestBody{
169115
RequestBody: &extProcPb.BodyResponse{
170116
Response: &extProcPb.CommonResponse{
171-
HeaderMutation: &extProcPb.HeaderMutation{
172-
SetHeaders: headers,
173-
},
174117
BodyMutation: &extProcPb.BodyMutation{
175-
Mutation: &extProcPb.BodyMutation_Body{
176-
Body: requestBody,
118+
Mutation: &extProcPb.BodyMutation_StreamedResponse{
119+
StreamedResponse: &extProcPb.StreamedBodyResponse{
120+
Body: requestBodyBytes,
121+
EndOfStream: true,
122+
},
177123
},
178124
},
179125
},
180126
},
181127
},
182-
DynamicMetadata: dynamicMetadata,
183128
}
184-
return resp, nil
129+
return reqCtx, nil
185130
}
186131

187-
func HandleRequestHeaders(
188-
ctx context.Context,
189-
reqCtx *RequestContext,
190-
req *extProcPb.ProcessingRequest,
191-
) *extProcPb.ProcessingResponse {
192-
r := req.Request
193-
h := r.(*extProcPb.ProcessingRequest_RequestHeaders)
194-
log.FromContext(ctx).V(logutil.VERBOSE).Info("Handling request headers", "headers", h)
195-
196-
resp := &extProcPb.ProcessingResponse{
197-
Response: &extProcPb.ProcessingResponse_RequestHeaders{
198-
RequestHeaders: &extProcPb.HeadersResponse{
199-
Response: &extProcPb.CommonResponse{
200-
// Set `clear_route_cache = true` to force Envoy to recompute the target cluster
201-
// based on the new "target-pod" header.
202-
// See https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto#service-ext-proc-v3-commonresponse.
203-
ClearRouteCache: true,
204-
},
205-
},
206-
},
132+
func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error {
133+
reqCtx.RequestReceivedTimestamp = time.Now()
134+
135+
// an EoS in the request headers means this request has no body or trailers.
136+
if req.RequestHeaders.EndOfStream {
137+
// We will route this request to a random pod as this is assumed to just be a GET
138+
// More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
139+
// The above PR will address endpoint admission, but currently any request without a body will be
140+
// routed to a random upstream pod.
141+
pod := GetRandomPod(s.datastore)
142+
pool, err := s.datastore.PoolGet()
143+
if err != nil {
144+
return err
145+
}
146+
endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
147+
s.populateRequestHeaderResponse(reqCtx, endpoint, 0)
207148
}
208-
209-
return resp
149+
return nil
210150
}

0 commit comments

Comments
 (0)