@@ -21,190 +21,130 @@ import (
21
21
"encoding/json"
22
22
"fmt"
23
23
"strconv"
24
+ "time"
24
25
25
- configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
26
26
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27
- "google.golang.org/protobuf/types/known/structpb"
28
27
"sigs.k8s.io/controller-runtime/pkg/log"
29
28
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
30
29
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
31
30
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
32
31
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
33
32
)
34
33
35
- // HandleRequestBody handles body of the request to the backend server, such as parsing the "model"
36
- // parameter.
37
- // Envoy sends the request body to ext proc before sending the request to the backend server.
38
- func (s * Server ) HandleRequestBody (
34
+ // HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling.
35
+ func (s * StreamingServer ) HandleRequestBody (
39
36
ctx context.Context ,
40
37
reqCtx * RequestContext ,
41
38
req * extProcPb.ProcessingRequest ,
42
- ) (* extProcPb.ProcessingResponse , error ) {
39
+ requestBodyMap map [string ]interface {},
40
+ ) (* RequestContext , error ) {
41
+ var requestBodyBytes []byte
43
42
logger := log .FromContext (ctx )
44
- loggerVerbose := logger .V (logutil .VERBOSE )
45
- loggerVerbose .Info ("Handling request body" )
46
-
47
- // Unmarshal request body (must be JSON).
48
- v := req .Request .(* extProcPb.ProcessingRequest_RequestBody )
49
- var rb map [string ]interface {}
50
- if err := json .Unmarshal (v .RequestBody .Body , & rb ); err != nil {
51
- logger .V (logutil .DEFAULT ).Error (err , "Error unmarshaling request body" )
52
- return nil , errutil.Error {Code : errutil .BadRequest , Msg : fmt .Sprintf ("error unmarshaling request body: %v" , err )}
53
- }
54
- loggerVerbose .Info ("Request body unmarshalled" , "body" , rb )
55
43
56
44
// Resolve target models.
57
- model , ok := rb ["model" ].(string )
45
+ model , ok := requestBodyMap ["model" ].(string )
58
46
if ! ok {
59
- return nil , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request" }
47
+ return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request" }
60
48
}
61
- loggerVerbose . Info ( "Model requested" , "model" , model )
49
+
62
50
modelName := model
63
51
64
52
// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
65
53
// This might be a security risk in the future where adapters not registered in the InferenceModel
66
54
// are able to be requested by using their distinct name.
67
55
modelObj := s .datastore .ModelGet (model )
68
56
if modelObj == nil {
69
- return nil , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error finding a model object in InferenceModel for input %v" , model )}
57
+ return reqCtx , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error finding a model object in InferenceModel for input %v" , model )}
70
58
}
71
59
if len (modelObj .Spec .TargetModels ) > 0 {
72
60
modelName = RandomWeightedDraw (logger , modelObj , 0 )
73
61
if modelName == "" {
74
- return nil , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error getting target model name for model %v" , modelObj .Name )}
62
+ return reqCtx , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error getting target model name for model %v" , modelObj .Name )}
75
63
}
76
64
}
77
65
llmReq := & schedulingtypes.LLMRequest {
78
66
Model : model ,
79
67
ResolvedTargetModel : modelName ,
80
68
Critical : modelObj .Spec .Criticality != nil && * modelObj .Spec .Criticality == v1alpha2 .Critical ,
81
69
}
82
- loggerVerbose . Info ("LLM request assembled" , "request " , llmReq )
70
+ logger . V ( logutil . DEBUG ). Info ("LLM request assembled" , "model " , llmReq . Model , "targetModel" , llmReq . ResolvedTargetModel , "critical" , llmReq . Critical )
83
71
84
- requestBody := v .RequestBody .Body
85
72
var err error
86
73
// Update target models in the body.
87
74
if llmReq .Model != llmReq .ResolvedTargetModel {
88
- rb ["model" ] = llmReq .ResolvedTargetModel
89
- requestBody , err = json . Marshal ( rb )
90
- if err != nil {
91
- logger . V ( logutil . DEFAULT ). Error ( err , "Error marshaling request body" )
92
- return nil , errutil. Error { Code : errutil . Internal , Msg : fmt . Sprintf ( "error marshaling request body: %v" , err )}
93
- }
94
- loggerVerbose . Info ( "Updated request body marshalled " , "body" , string ( requestBody ))
75
+ requestBodyMap ["model" ] = llmReq .ResolvedTargetModel
76
+ }
77
+
78
+ requestBodyBytes , err = json . Marshal ( requestBodyMap )
79
+ if err != nil {
80
+ logger . V ( logutil . DEFAULT ). Error ( err , "Error marshaling request body" )
81
+ return reqCtx , errutil. Error { Code : errutil . Internal , Msg : fmt . Sprintf ( "error marshaling request body: %v " , err )}
95
82
}
96
83
97
84
target , err := s .scheduler .Schedule (ctx , llmReq )
98
85
if err != nil {
99
- return nil , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : fmt .Errorf ("failed to find target pod: %w" , err ).Error ()}
86
+ return reqCtx , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : fmt .Errorf ("failed to find target pod: %w" , err ).Error ()}
100
87
}
101
88
targetPod := target .GetPod ()
102
89
103
- logger .V (logutil .DEFAULT ).Info ("Request handled" ,
104
- "model" , llmReq .Model , "targetModel" , llmReq .ResolvedTargetModel , "endpoint" , targetPod )
105
-
106
90
// Insert target endpoint to instruct Envoy to route requests to the specified target pod.
107
91
// Attach the port number
108
92
pool , err := s .datastore .PoolGet ()
109
93
if err != nil {
110
- return nil , err
94
+ return reqCtx , err
111
95
}
112
96
endpoint := targetPod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
113
97
98
+ logger .V (logutil .DEFAULT ).Info ("Request handled" ,
99
+ "model" , llmReq .Model , "targetModel" , llmReq .ResolvedTargetModel , "endpoint" , targetPod , "endpoint metrics" ,
100
+ fmt .Sprintf ("%+v" , target ))
101
+
114
102
reqCtx .Model = llmReq .Model
115
103
reqCtx .ResolvedTargetModel = llmReq .ResolvedTargetModel
116
- reqCtx .RequestSize = len (v . RequestBody . Body )
104
+ reqCtx .RequestSize = len (requestBodyBytes )
117
105
reqCtx .TargetPod = targetPod .NamespacedName .String ()
118
106
reqCtx .TargetEndpoint = endpoint
119
107
120
- headers := []* configPb.HeaderValueOption {
121
- {
122
- Header : & configPb.HeaderValue {
123
- Key : s .destinationEndpointHintKey ,
124
- RawValue : []byte (endpoint ),
125
- },
126
- },
127
- // We need to update the content length header if the body is mutated, see Envoy doc:
128
- // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
129
- {
130
- Header : & configPb.HeaderValue {
131
- Key : "Content-Length" ,
132
- RawValue : []byte (strconv .Itoa (len (requestBody ))),
133
- },
134
- },
135
- }
136
- // Print headers for debugging
137
- for _ , header := range headers {
138
- logger .V (logutil .DEBUG ).Info ("Request body header" , "key" , header .Header .Key , "value" , header .Header .RawValue )
139
- }
140
-
141
- targetEndpointValue := & structpb.Struct {
142
- Fields : map [string ]* structpb.Value {
143
- s .destinationEndpointHintKey : {
144
- Kind : & structpb.Value_StringValue {
145
- StringValue : endpoint ,
146
- },
147
- },
148
- },
149
- }
150
- dynamicMetadata := targetEndpointValue
151
- if s .destinationEndpointHintMetadataNamespace != "" {
152
- // If a namespace is defined, wrap the selected endpoint with that.
153
- dynamicMetadata = & structpb.Struct {
154
- Fields : map [string ]* structpb.Value {
155
- s .destinationEndpointHintMetadataNamespace : {
156
- Kind : & structpb.Value_StructValue {
157
- StructValue : targetEndpointValue ,
158
- },
159
- },
160
- },
161
- }
162
- }
108
+ s .populateRequestHeaderResponse (reqCtx , endpoint , len (requestBodyBytes ))
163
109
164
- resp : = & extProcPb.ProcessingResponse {
110
+ reqCtx . reqBodyResp = & extProcPb.ProcessingResponse {
165
111
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
166
112
// and as an unstructure ext-proc response metadata key/value pair. This enables different integration
167
113
// options for gateway providers.
168
114
Response : & extProcPb.ProcessingResponse_RequestBody {
169
115
RequestBody : & extProcPb.BodyResponse {
170
116
Response : & extProcPb.CommonResponse {
171
- HeaderMutation : & extProcPb.HeaderMutation {
172
- SetHeaders : headers ,
173
- },
174
117
BodyMutation : & extProcPb.BodyMutation {
175
- Mutation : & extProcPb.BodyMutation_Body {
176
- Body : requestBody ,
118
+ Mutation : & extProcPb.BodyMutation_StreamedResponse {
119
+ StreamedResponse : & extProcPb.StreamedBodyResponse {
120
+ Body : requestBodyBytes ,
121
+ EndOfStream : true ,
122
+ },
177
123
},
178
124
},
179
125
},
180
126
},
181
127
},
182
- DynamicMetadata : dynamicMetadata ,
183
128
}
184
- return resp , nil
129
+ return reqCtx , nil
185
130
}
186
131
187
- func HandleRequestHeaders (
188
- ctx context.Context ,
189
- reqCtx * RequestContext ,
190
- req * extProcPb.ProcessingRequest ,
191
- ) * extProcPb.ProcessingResponse {
192
- r := req .Request
193
- h := r .(* extProcPb.ProcessingRequest_RequestHeaders )
194
- log .FromContext (ctx ).V (logutil .VERBOSE ).Info ("Handling request headers" , "headers" , h )
195
-
196
- resp := & extProcPb.ProcessingResponse {
197
- Response : & extProcPb.ProcessingResponse_RequestHeaders {
198
- RequestHeaders : & extProcPb.HeadersResponse {
199
- Response : & extProcPb.CommonResponse {
200
- // Set `clear_route_cache = true` to force Envoy to recompute the target cluster
201
- // based on the new "target-pod" header.
202
- // See https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto#service-ext-proc-v3-commonresponse.
203
- ClearRouteCache : true ,
204
- },
205
- },
206
- },
132
+ func (s * StreamingServer ) HandleRequestHeaders (ctx context.Context , reqCtx * RequestContext , req * extProcPb.ProcessingRequest_RequestHeaders ) error {
133
+ reqCtx .RequestReceivedTimestamp = time .Now ()
134
+
135
+ // an EoS in the request headers means this request has no body or trailers.
136
+ if req .RequestHeaders .EndOfStream {
137
+ // We will route this request to a random pod as this is assumed to just be a GET
138
+ // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
139
+ // The above PR will address endpoint admission, but currently any request without a body will be
140
+ // routed to a random upstream pod.
141
+ pod := GetRandomPod (s .datastore )
142
+ pool , err := s .datastore .PoolGet ()
143
+ if err != nil {
144
+ return err
145
+ }
146
+ endpoint := pod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
147
+ s .populateRequestHeaderResponse (reqCtx , endpoint , 0 )
207
148
}
208
-
209
- return resp
149
+ return nil
210
150
}
0 commit comments