diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go index 925a0cc5..21c0f401 100644 --- a/pkg/epp/backend/metrics/types.go +++ b/pkg/epp/backend/metrics/types.go @@ -79,6 +79,9 @@ func (p *Pod) String() string { } func (p *Pod) Clone() *Pod { + if p == nil { + return nil + } return &Pod{ NamespacedName: types.NamespacedName{ Name: p.NamespacedName.Name, @@ -118,6 +121,9 @@ func (m *Metrics) String() string { } func (m *Metrics) Clone() *Metrics { + if m == nil { + return nil + } cm := make(map[string]int, len(m.ActiveModels)) for k, v := range m.ActiveModels { cm[k] = v diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 44537923..9121b59a 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -67,7 +67,7 @@ func (s *StreamingServer) HandleRequestBody( ResolvedTargetModel: modelName, Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, } - logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical) + logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) var err error // Update target models in the body. @@ -81,11 +81,11 @@ func (s *StreamingServer) HandleRequestBody( return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} } - target, err := s.scheduler.Schedule(ctx, llmReq) + res, err := s.scheduler.Schedule(ctx, llmReq) if err != nil { return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} } - targetPod := target.GetPod() + targetPod := res.TargetPod.GetPod() // Insert target endpoint to instruct Envoy to route requests to the specified target pod. // Attach the port number @@ -96,8 +96,7 @@ func (s *StreamingServer) HandleRequestBody( endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) logger.V(logutil.DEFAULT).Info("Request handled", - "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics", - fmt.Sprintf("%+v", target)) + "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod) reqCtx.Model = llmReq.Model reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 7bb0fcb1..e7ecc26d 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -65,7 +65,7 @@ type StreamingServer struct { } type Scheduler interface { - Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (targetPod schedulingtypes.Pod, err error) + Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) } // RequestContext stores context information during the life time of an HTTP request. diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index b474df36..56dcfca8 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -30,6 +30,7 @@ import ( const ( InferenceModelComponent = "inference_model" InferencePoolComponent = "inference_pool" + EPPComponent = "endpoint_picker" ) var ( @@ -176,6 +177,20 @@ var ( }, []string{"name"}, ) + + // Scheduler Plugin Metrics + SchedulerPluginProcessingLatencies = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: EPPComponent, + Name: "scheduler_plugin_duration_seconds", + Help: "Scheduler plugin processing latency distribution in seconds for each plugin type and plugin name.", + Buckets: []float64{ + 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, + }, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{"plugin_type", "plugin_name"}, + ) ) var registerMetrics sync.Once @@ -196,6 +211,8 @@ func Register() { legacyregistry.MustRegister(inferencePoolAvgKVCache) legacyregistry.MustRegister(inferencePoolAvgQueueSize) legacyregistry.MustRegister(inferencePoolReadyPods) + + legacyregistry.MustRegister(SchedulerPluginProcessingLatencies) }) } @@ -293,3 +310,8 @@ func RecordInferencePoolAvgQueueSize(name string, queueSize float64) { func RecordinferencePoolReadyPods(name string, runningPods float64) { inferencePoolReadyPods.WithLabelValues(name).Set(runningPods) } + +// RecordSchedulerPluginProcessingLatency records the processing latency for a scheduler plugin. +func RecordSchedulerPluginProcessingLatency(pluginType, pluginName string, duration time.Duration) { + SchedulerPluginProcessingLatencies.WithLabelValues(pluginType, pluginName).Observe(duration.Seconds()) +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index b5f19e6d..81797e6d 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -556,3 +556,67 @@ func TestInferencePoolMetrics(t *testing.T) { }) } } + +func TestSchedulerPluginProcessingLatencies(t *testing.T) { + type pluginLatency struct { + pluginType string + pluginName string + duration time.Duration + } + scenarios := []struct { + name string + latencies []pluginLatency + }{ + { + name: "multiple plugins", + latencies: []pluginLatency{ + { + pluginType: "PreSchedule", + pluginName: "PluginA", + duration: 100 * time.Millisecond, + }, + { + pluginType: "PostSchedule", + pluginName: "PluginB", + duration: 200 * time.Millisecond, + }, + { + pluginType: "Filter", + pluginName: "PluginC", + duration: 50 * time.Millisecond, + }, + { + pluginType: "Scorer", + pluginName: "PluginD", + duration: 10 * time.Millisecond, + }, + { + pluginType: "Picker", + pluginName: "PluginE", + duration: 10 * time.Microsecond, + }, + }, + }, + } + Register() + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + for _, latency := range scenario.latencies { + RecordSchedulerPluginProcessingLatency(latency.pluginType, latency.pluginName, latency.duration) + } + + wantPluginLatencies, err := os.Open("testdata/scheduler_plugin_processing_latencies_metric") + defer func() { + if err := wantPluginLatencies.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantPluginLatencies, "endpoint_picker_scheduler_plugin_processing_latencies"); err != nil { + t.Error(err) + } + }) + } +} diff --git a/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric b/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric new file mode 100644 index 00000000..8c11757f --- /dev/null +++ b/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric @@ -0,0 +1,67 @@ +# HELP endpoint_picker_scheduler_plugin_duration_seconds [ALPHA] Scheduler plugin processing latency distribution in seconds for each plugin type and plugin name. +# TYPE endpoint_picker_scheduler_plugin_duration_seconds histogram +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.05"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginA",plugin_type="PreSchedule"} 0.1 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginA",plugin_type="PreSchedule"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.05"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.1"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginB",plugin_type="PostSchedule"} 0.2 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginB",plugin_type="PostSchedule"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginC",plugin_type="Filter"} 0.05 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginC",plugin_type="Filter"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.01"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.02"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginD",plugin_type="Scorer"} 0.01 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginD",plugin_type="Scorer"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0001"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0002"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0005"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.001"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.002"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.005"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.01"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.02"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginE",plugin_type="Picker"} 1e-05 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginE",plugin_type="Picker"} 1 diff --git a/pkg/epp/scheduling/config/config.go b/pkg/epp/scheduling/config/config.go new file mode 100644 index 00000000..e00b82ae --- /dev/null +++ b/pkg/epp/scheduling/config/config.go @@ -0,0 +1,58 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config + +import ( + "sigs.k8s.io/controller-runtime/pkg/log" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// Config holds all the configuration values for the scheduler +type Config struct { + KVCacheThreshold float64 + QueueThresholdCritical int + QueueingThresholdLoRA int + LoraAffinityThreshold float64 +} + +const ( + // Default values to use if environment variables are not set + defaultKVCacheThreshold = 0.8 + defaultQueueThresholdCritical = 5 + defaultQueueingThresholdLoRA = 128 + defaultLoraAffinityThreshold = 0.999 +) + +// LoadConfig loads configuration from environment variables +func LoadConfig() Config { + // Use a default logger for initial configuration loading + baseLogger := log.Log.WithName("scheduling-config") + + config := Config{ + KVCacheThreshold: envutil.GetEnvFloat("KV_CACHE_THRESHOLD", defaultKVCacheThreshold, baseLogger), + QueueThresholdCritical: envutil.GetEnvInt("QUEUE_THRESHOLD_CRITICAL", defaultQueueThresholdCritical, baseLogger), + QueueingThresholdLoRA: envutil.GetEnvInt("QUEUING_THRESHOLD_LORA", defaultQueueingThresholdLoRA, baseLogger), + LoraAffinityThreshold: envutil.GetEnvFloat("LORA_AFFINITY_THRESHOLD", defaultLoraAffinityThreshold, baseLogger), + } + + baseLogger.V(logutil.DEFAULT).Info("Scheduler configuration loaded", "config", config) + + return config +} + +var Conf = LoadConfig() diff --git a/pkg/epp/scheduling/filter.go b/pkg/epp/scheduling/plugins/filter.go similarity index 60% rename from pkg/epp/scheduling/filter.go rename to pkg/epp/scheduling/plugins/filter.go index 99044e97..efcb6be1 100644 --- a/pkg/epp/scheduling/filter.go +++ b/pkg/epp/scheduling/plugins/filter.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package scheduling +package plugins import ( "errors" @@ -22,83 +22,80 @@ import ( "math/rand" "time" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -type Filter interface { - Name() string - Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) -} - -type basicFilter struct { +type Filter struct { name string filter filterFunc } -func (bf *basicFilter) Name() string { +func (bf *Filter) Name() string { if bf == nil { return "nil" } return bf.name } -func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func (bf *Filter) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { loggerTrace := ctx.Logger.V(logutil.TRACE) loggerTrace.Info("Running a filter", "name", bf.Name(), "podCount", len(pods)) return bf.filter(ctx, pods) } -// decisionTreeFilter applies current filterFunc, and then recursively applies next filters +// DecisionTreeFilter applies current filterFunc, and then recursively applies next filters // depending success or failure of the current filter. // It can be used to construct a flow chart algorithm. -type decisionTreeFilter struct { - current Filter - // nextOnSuccess filter will be applied after successfully applying the current filter. +type DecisionTreeFilter struct { + Current types.Filter + // NextOnSuccess filter will be applied after successfully applying the current filter. // The filtered results will be passed to the next filter. - nextOnSuccess Filter - // nextOnFailure filter will be applied if current filter fails. + NextOnSuccess types.Filter + // NextOnFailure filter will be applied if current filter fails. // The original input will be passed to the next filter. - nextOnFailure Filter - // nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the + NextOnFailure types.Filter + // NextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the // success or failure of the current filter. - // NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. + // NOTE: When using NextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. // However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of - // nextOnSuccessOrFailure, in the success and failure scenarios, respectively. - nextOnSuccessOrFailure Filter + // NextOnSuccessOrFailure, in the success and failure scenarios, respectively. + NextOnSuccessOrFailure types.Filter } -func (f *decisionTreeFilter) Name() string { +func (f *DecisionTreeFilter) Name() string { if f == nil { return "nil" } - return f.current.Name() + return f.Current.Name() } -func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func (f *DecisionTreeFilter) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { loggerTrace := ctx.Logger.V(logutil.TRACE) - filtered, err := f.current.Filter(ctx, pods) + filtered, err := f.Current.Filter(ctx, pods) - next := f.nextOnSuccessOrFailure + next := f.NextOnSuccessOrFailure if err == nil && len(filtered) > 0 { - if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil { + if f.NextOnSuccess == nil && f.NextOnSuccessOrFailure == nil { // No succeeding filters to run, return. return filtered, err } - if f.nextOnSuccess != nil { - next = f.nextOnSuccess + if f.NextOnSuccess != nil { + next = f.NextOnSuccess } loggerTrace.Info("Filter succeeded", "filter", f.Name(), "next", next.Name(), "filteredPodCount", len(filtered)) // On success, pass the filtered result to the next filter. return next.Filter(ctx, filtered) } else { - if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil { + if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil { // No succeeding filters to run, return. return filtered, err } - if f.nextOnFailure != nil { - next = f.nextOnFailure + if f.NextOnFailure != nil { + next = f.NextOnFailure } loggerTrace.Info("Filter failed", "filter", f.Name(), "next", next.Name()) // On failure, pass the initial set of pods to the next filter. @@ -107,12 +104,12 @@ func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics } // filterFunc filters a set of input pods to a subset. -type filterFunc func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) +type filterFunc func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) // toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. func toFilterFunc(pp podPredicate) filterFunc { - return func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { - filtered := []*types.PodMetrics{} + return func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + filtered := []types.Pod{} for _, pod := range pods { pass := pp(ctx.Req, pod) if pass { @@ -126,7 +123,7 @@ func toFilterFunc(pp podPredicate) filterFunc { } } -var leastQueueFilter = &basicFilter{ +var LeastQueueFilter = &Filter{ name: "least queuing", filter: leastQueuingFilterFunc, } @@ -138,34 +135,34 @@ var leastQueueFilter = &basicFilter{ // the least one as it gives more choices for the next filter, which on aggregate gave better // results. // TODO: Compare this strategy with other strategies such as top K. -func leastQueuingFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func leastQueuingFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { min := math.MaxInt max := 0 - filtered := []*types.PodMetrics{} + filtered := []types.Pod{} for _, pod := range pods { - if pod.WaitingQueueSize <= min { - min = pod.WaitingQueueSize + if pod.GetMetrics().WaitingQueueSize <= min { + min = pod.GetMetrics().WaitingQueueSize } - if pod.WaitingQueueSize >= max { - max = pod.WaitingQueueSize + if pod.GetMetrics().WaitingQueueSize >= max { + max = pod.GetMetrics().WaitingQueueSize } } for _, pod := range pods { - if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) { + if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) { filtered = append(filtered, pod) } } return filtered, nil } -var lowQueueFilter = &basicFilter{ +var LowQueueFilter = &Filter{ name: "low queueing filter", - filter: toFilterFunc((queueThresholdPredicate(config.QueueingThresholdLoRA))), + filter: toFilterFunc((queueThresholdPredicate(config.Conf.QueueingThresholdLoRA))), } -var leastKVCacheFilter = &basicFilter{ +var LeastKVCacheFilter = &Filter{ name: "least KV cache percent", filter: leastKVCacheFilterFunc, } @@ -176,29 +173,29 @@ var leastKVCacheFilter = &basicFilter{ // should consider them all instead of the absolute minimum one. This worked better than picking the // least one as it gives more choices for the next filter, which on aggregate gave better results. // TODO: Compare this strategy with other strategies such as top K. -func leastKVCacheFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func leastKVCacheFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { min := math.MaxFloat64 var max float64 = 0 - filtered := []*types.PodMetrics{} + filtered := []types.Pod{} for _, pod := range pods { - if pod.KVCacheUsagePercent <= min { - min = pod.KVCacheUsagePercent + if pod.GetMetrics().KVCacheUsagePercent <= min { + min = pod.GetMetrics().KVCacheUsagePercent } - if pod.KVCacheUsagePercent >= max { - max = pod.KVCacheUsagePercent + if pod.GetMetrics().KVCacheUsagePercent >= max { + max = pod.GetMetrics().KVCacheUsagePercent } } for _, pod := range pods { - if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { + if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { filtered = append(filtered, pod) } } return filtered, nil } -var loRAAffinityFilter = &basicFilter{ +var LoRAAffinityFilter = &Filter{ name: "affinity LoRA", filter: loRASoftAffinityFilterFunc, } @@ -219,20 +216,20 @@ var loRAAffinityFilter = &basicFilter{ // Returns: // - Filtered slice of pod metrics based on affinity and availability // - Error if any issues occur during filtering -func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func loRASoftAffinityFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { // Pre-allocate slices with estimated capacity - filtered_affinity := make([]*types.PodMetrics, 0, len(pods)) - filtered_available := make([]*types.PodMetrics, 0, len(pods)) + filtered_affinity := make([]types.Pod, 0, len(pods)) + filtered_available := make([]types.Pod, 0, len(pods)) // Categorize pods based on affinity and availability for _, pod := range pods { - _, active := pod.ActiveModels[ctx.Req.ResolvedTargetModel] - _, waiting := pod.WaitingModels[ctx.Req.ResolvedTargetModel] + _, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel] + _, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel] if active || waiting { filtered_affinity = append(filtered_affinity, pod) - } else if len(pod.ActiveModels)+len(pod.WaitingModels) < pod.MaxActiveModels { + } else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels { filtered_available = append(filtered_available, pod) } } @@ -243,7 +240,7 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([ // If both groups have pods, use probability to select which group to return if len(filtered_affinity) > 0 && len(filtered_available) > 0 { - if randGen.Float64() < config.LoraAffinityThreshold { + if randGen.Float64() < config.Conf.LoraAffinityThreshold { return filtered_affinity, nil } return filtered_available, nil @@ -257,23 +254,38 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([ return filtered_available, nil } +var HasCapacityFilter = &Filter{ + name: "has capacity for sheddable requests", + filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))), +} + +var DropRequestFilter = &Filter{ + name: "drop request", + filter: func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + ctx.Logger.V(logutil.DEFAULT).Info("Request dropped", "request", ctx.Req) + return []types.Pod{}, errutil.Error{ + Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", + } + }, +} + // podPredicate is a filter function to check whether a pod is desired. -type podPredicate func(req *types.LLMRequest, pod *types.PodMetrics) bool +type podPredicate func(req *types.LLMRequest, pod types.Pod) bool func queueThresholdPredicate(queueThreshold int) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { - return pod.WaitingQueueSize <= queueThreshold + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().WaitingQueueSize <= queueThreshold } } func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { - return pod.KVCacheUsagePercent <= kvCacheThreshold + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold } } func (pp podPredicate) and(another podPredicate) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { + return func(req *types.LLMRequest, pod types.Pod) bool { return pp(req, pod) && another(req, pod) } } diff --git a/pkg/epp/scheduling/filter_test.go b/pkg/epp/scheduling/plugins/filter_test.go similarity index 82% rename from pkg/epp/scheduling/filter_test.go rename to pkg/epp/scheduling/plugins/filter_test.go index 543826d0..107b423f 100644 --- a/pkg/epp/scheduling/filter_test.go +++ b/pkg/epp/scheduling/plugins/filter_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package scheduling +package plugins import ( "context" @@ -24,6 +24,7 @@ import ( "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -31,17 +32,17 @@ func TestFilter(t *testing.T) { tests := []struct { name string req *types.LLMRequest - input []*types.PodMetrics - output []*types.PodMetrics + input []types.Pod + output []types.Pod err bool - filter *decisionTreeFilter + filter *DecisionTreeFilter }{ { name: "simple filter without successor, failure", - filter: &decisionTreeFilter{ - current: &basicFilter{ + filter: &DecisionTreeFilter{ + Current: &Filter{ name: "error", - filter: func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { + filter: func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { return nil, errors.New("filter error") }, }, @@ -58,7 +59,8 @@ func TestFilter(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + opt := cmp.AllowUnexported(types.PodMetrics{}) + if diff := cmp.Diff(test.output, got, opt); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -70,43 +72,43 @@ func TestFilterFunc(t *testing.T) { name string f filterFunc req *types.LLMRequest - input []*types.PodMetrics - output []*types.PodMetrics + input []types.Pod + output []types.Pod err bool }{ { name: "least queuing empty input", f: leastQueuingFilterFunc, - input: []*types.PodMetrics{}, - output: []*types.PodMetrics{}, + input: []types.Pod{}, + output: []types.Pod{}, }, { name: "least queuing", f: leastQueuingFilterFunc, - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 10, }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, }, @@ -116,36 +118,36 @@ func TestFilterFunc(t *testing.T) { { name: "least kv cache empty input", f: leastKVCacheFilterFunc, - input: []*types.PodMetrics{}, - output: []*types.PodMetrics{}, + input: []types.Pod{}, + output: []types.Pod{}, }, { name: "least kv cache", f: leastKVCacheFilterFunc, - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0.3, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 1.0, }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0.3, }, @@ -155,22 +157,22 @@ func TestFilterFunc(t *testing.T) { { name: "lowQueueAndLessThanKVCacheThresholdPredicate", f: toFilterFunc(queueThresholdPredicate(0).and(kvCacheThresholdPredicate(0.8))), - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ // This pod should be returned. Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ // Queue is non zero, despite low kv cache, should not return. Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 1, KVCacheUsagePercent: 0.3, }, }, - { + &types.PodMetrics{ // High kv cache despite zero queue, should not return Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, @@ -178,8 +180,8 @@ func TestFilterFunc(t *testing.T) { }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, @@ -197,7 +199,8 @@ func TestFilterFunc(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + opt := cmp.AllowUnexported(types.PodMetrics{}) + if diff := cmp.Diff(test.output, got, opt); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -215,15 +218,15 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { ) // Save original config value to restore later - originalThreshold := config.LoraAffinityThreshold + originalThreshold := config.Conf.LoraAffinityThreshold // Set a specific test value for this test testThreshold := 0.75 // 75% - config.LoraAffinityThreshold = testThreshold + config.Conf.LoraAffinityThreshold = testThreshold // Ensure we restore the original threshold when test completes defer func() { - config.LoraAffinityThreshold = originalThreshold + config.Conf.LoraAffinityThreshold = originalThreshold }() // Create a test request and pods @@ -233,8 +236,8 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { } // Test setup: One affinity pod and one available pod - pods := []*types.PodMetrics{ - { + pods := []types.Pod{ + &types.PodMetrics{ Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "affinity-pod"}}, Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, @@ -243,7 +246,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, }, - { + &types.PodMetrics{ Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "available-pod"}}, Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, @@ -258,7 +261,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { availableCount := 0 // Use the test threshold value - expectedAffinityPercent := config.LoraAffinityThreshold * 100 + expectedAffinityPercent := config.Conf.LoraAffinityThreshold * 100 expectedAvailabilityPercent := 100 - expectedAffinityPercent for i := 0; i < numIterations; i++ { @@ -292,8 +295,8 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { availableUpperBound := expectedAvailabilityPercent + tolerancePercent t.Logf("Distribution results over %d iterations:", numIterations) - t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, config.LoraAffinityThreshold) - t.Logf("Expected availability percent: %.2f%% (threshold: %.2f)", expectedAvailabilityPercent, config.LoraAffinityThreshold) + t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, config.Conf.LoraAffinityThreshold) + t.Logf("Expected availability percent: %.2f%% (threshold: %.2f)", expectedAvailabilityPercent, config.Conf.LoraAffinityThreshold) t.Logf("Actual affinity percent: %.2f%% (%d out of %d)", actualAffinityPercent, affinityCount, numIterations) t.Logf("Actual available percent: %.2f%% (%d out of %d)", actualAvailablePercent, availableCount, numIterations) diff --git a/pkg/epp/scheduling/plugins/noop.go b/pkg/epp/scheduling/plugins/noop.go new file mode 100644 index 00000000..1abcb95b --- /dev/null +++ b/pkg/epp/scheduling/plugins/noop.go @@ -0,0 +1,38 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +// NoopPlugin provides a default, no-operation implementation of the Plugin interface. +// It can be embedded in other plugin implementations to avoid boilerplate code for +// unused methods. +type NoopPlugin struct{} + +func (p *NoopPlugin) Name() string { return "NoopPlugin" } + +func (p *NoopPlugin) Score(ctx *types.Context, pod types.Pod) (float64, error) { return 0.0, nil } + +func (p *NoopPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + return pods, nil +} + +func (p *NoopPlugin) PreSchedule(ctx *types.Context) {} + +func (p *NoopPlugin) PostSchedule(ctx *types.Context, res *types.Result) {} diff --git a/pkg/epp/scheduling/plugins/picker.go b/pkg/epp/scheduling/plugins/picker.go new file mode 100644 index 00000000..569e4e86 --- /dev/null +++ b/pkg/epp/scheduling/plugins/picker.go @@ -0,0 +1,37 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "fmt" + "math/rand" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +type RandomPicker struct{} + +func (rp *RandomPicker) Name() string { + return "random" +} + +func (rp *RandomPicker) Pick(ctx *types.Context, pods []types.Pod) (*types.Result, error) { + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) + i := rand.Intn(len(pods)) + return &types.Result{TargetPod: pods[i]}, nil +} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 8679ffba..7cc2bd96 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -20,113 +20,71 @@ package scheduling import ( "context" "fmt" - "math/rand" + "time" "sigs.k8s.io/controller-runtime/pkg/log" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// Config holds all the configuration values for the scheduler -type Config struct { - KVCacheThreshold float64 - QueueThresholdCritical int - QueueingThresholdLoRA int - LoraAffinityThreshold float64 -} - -const ( - // Default values to use if environment variables are not set - defaultKVCacheThreshold = 0.8 - defaultQueueThresholdCritical = 5 - defaultQueueingThresholdLoRA = 128 - defaultLoraAffinityThreshold = 0.999 -) - -// LoadConfig loads configuration from environment variables -func LoadConfig() Config { - // Use a default logger for initial configuration loading - baseLogger := log.Log.WithName("scheduling-config") - - config := Config{ - KVCacheThreshold: envutil.GetEnvFloat("KV_CACHE_THRESHOLD", defaultKVCacheThreshold, baseLogger), - QueueThresholdCritical: envutil.GetEnvInt("QUEUE_THRESHOLD_CRITICAL", defaultQueueThresholdCritical, baseLogger), - QueueingThresholdLoRA: envutil.GetEnvInt("QUEUING_THRESHOLD_LORA", defaultQueueingThresholdLoRA, baseLogger), - LoraAffinityThreshold: envutil.GetEnvFloat("LORA_AFFINITY_THRESHOLD", defaultLoraAffinityThreshold, baseLogger), - } - - baseLogger.V(logutil.DEFAULT).Info("Scheduler configuration loaded", "config", config) - - return config -} - -var config = LoadConfig() - var ( - lowLatencyFilter = &decisionTreeFilter{ - current: lowQueueFilter, - nextOnSuccess: &decisionTreeFilter{ - current: loRAAffinityFilter, - nextOnSuccessOrFailure: &decisionTreeFilter{ - current: leastQueueFilter, - nextOnSuccessOrFailure: &decisionTreeFilter{ - current: leastKVCacheFilter, + lowLatencyFilter = &plugins.DecisionTreeFilter{ + Current: plugins.LowQueueFilter, + NextOnSuccess: &plugins.DecisionTreeFilter{ + Current: plugins.LoRAAffinityFilter, + NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LeastQueueFilter, + NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LeastKVCacheFilter, }, }, }, - nextOnFailure: &decisionTreeFilter{ - current: leastQueueFilter, - nextOnSuccessOrFailure: &decisionTreeFilter{ - current: loRAAffinityFilter, - nextOnSuccessOrFailure: &decisionTreeFilter{ - current: leastKVCacheFilter, + NextOnFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LeastQueueFilter, + NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LoRAAffinityFilter, + NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LeastKVCacheFilter, }, }, }, } - sheddableRequestFilter = &decisionTreeFilter{ + sheddableRequestFilter = &plugins.DecisionTreeFilter{ // When there is at least one model server that's not queuing requests, and still has KV // cache below a certain threshold, we consider this model server has capacity to handle // a sheddable request without impacting critical requests. - current: hasCapacityFilter, - nextOnSuccess: lowLatencyFilter, + Current: plugins.HasCapacityFilter, + NextOnSuccess: lowLatencyFilter, // If all pods are queuing or running above the KVCache threshold, we drop the sheddable // request to make room for critical requests. - nextOnFailure: dropRequestFilter, - } - - hasCapacityFilter = &basicFilter{ - name: "has capacity for sheddable requests", - filter: toFilterFunc(queueThresholdPredicate(config.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.KVCacheThreshold))), - } - - dropRequestFilter = &basicFilter{ - name: "drop request", - filter: func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { - ctx.Logger.V(logutil.DEFAULT).Info("Request dropped", "request", ctx.Req) - return []*types.PodMetrics{}, errutil.Error{ - Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", - } - }, + NextOnFailure: plugins.DropRequestFilter, } ) func NewScheduler(datastore Datastore) *Scheduler { + defaultPlugin := &defaultPlugin{} + return &Scheduler{ - datastore: datastore, - criticalRequestFilter: lowLatencyFilter, - sheddableRequestFilter: sheddableRequestFilter, + datastore: datastore, + preSchedulePlugins: []types.PreSchedule{}, + postSchedulePlugins: []types.PostSchedule{}, + scorers: []types.Scorer{}, + filters: []types.Filter{defaultPlugin}, + picker: defaultPlugin, } } type Scheduler struct { - datastore Datastore - criticalRequestFilter Filter - sheddableRequestFilter Filter + datastore Datastore + preSchedulePlugins []types.PreSchedule + postSchedulePlugins []types.PostSchedule + filters []types.Filter + scorers []types.Scorer + picker types.Picker } type Datastore interface { @@ -134,27 +92,125 @@ type Datastore interface { } // Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (targetPod types.Pod, err error) { +func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { logger := log.FromContext(ctx).WithValues("request", req) + loggerDebug := logger.V(logutil.DEBUG) // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. sCtx := types.NewContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) - logger.V(logutil.DEBUG).Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot)) + loggerDebug.Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot)) - var filter Filter - if req.Critical { - filter = s.criticalRequestFilter - } else { - filter = s.sheddableRequestFilter + s.runPreSchedulePlugins(sCtx) + + pods, err := s.runFilterPlugins(sCtx) + if err != nil { + return nil, err + } + + if err := s.runScorerPlugins(sCtx, pods); err != nil { + return nil, err + } + + before := time.Now() + res, err := s.picker.Pick(sCtx, pods) + metrics.RecordSchedulerPluginProcessingLatency(types.PickerPluginType, s.picker.Name(), time.Since(before)) + if err != nil { + return nil, err } + loggerDebug.Info("After running picker plugins", "result", res) - pods, err := filter.Filter(sCtx, sCtx.PodsSnapshot) - if err != nil || len(pods) == 0 { - return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) + s.runPostSchedulePlugins(sCtx, res) + + return res, nil +} + +func (s *Scheduler) runPreSchedulePlugins(ctx *types.Context) { + for _, plugin := range s.preSchedulePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running pre-schedule plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PreSchedule(ctx) + metrics.RecordSchedulerPluginProcessingLatency(types.PreSchedulerPluginType, plugin.Name(), time.Since(before)) + } +} + +func (s *Scheduler) runPostSchedulePlugins(ctx *types.Context, res *types.Result) { + for _, plugin := range s.postSchedulePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PostSchedule(ctx, res) + metrics.RecordSchedulerPluginProcessingLatency(types.PostSchedulePluginType, plugin.Name(), time.Since(before)) + } +} + +func (s *Scheduler) runFilterPlugins(ctx *types.Context) ([]types.Pod, error) { + loggerDebug := ctx.Logger.V(logutil.DEBUG) + pods := ctx.PodsSnapshot + loggerDebug.Info("Before running filter plugins", "pods", pods) + for _, filter := range s.filters { + loggerDebug.Info("Running filter plugin", "plugin", filter.Name()) + before := time.Now() + filteredPods, err := filter.Filter(ctx, pods) + metrics.RecordSchedulerPluginProcessingLatency(types.FilterPluginType, filter.Name(), time.Since(before)) + if err != nil || len(filteredPods) == 0 { + return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(filteredPods), err) + } + pods = filteredPods + loggerDebug.Info("Filter plugin result", "plugin", filter.Name(), "pods", pods) + } + loggerDebug.Info("After running filter plugins", "pods", pods) + return pods, nil +} + +func (s *Scheduler) runScorerPlugins(ctx *types.Context, pods []types.Pod) error { + loggerDebug := ctx.Logger.V(logutil.DEBUG) + loggerDebug.Info("Before running score plugins", "pods", pods) + for _, pod := range pods { + score, err := runScorersForPod(ctx, s.scorers, pod) + if err != nil { + return err + } + pod.SetScore(score) + } + loggerDebug.Info("After running score plugins", "pods", pods) + return nil +} + +// Iterate through each scorer in the chain and accumulate the scores. +func runScorersForPod(ctx *types.Context, scorers []types.Scorer, pod types.Pod) (float64, error) { + logger := ctx.Logger.WithValues("pod", pod.GetPod().NamespacedName).V(logutil.DEBUG) + score := float64(0) + for _, scorer := range scorers { + logger.Info("Running scorer", "scorer", scorer.Name()) + before := time.Now() + oneScore, err := scorer.Score(ctx, pod) + metrics.RecordSchedulerPluginProcessingLatency(types.ScorerPluginType, scorer.Name(), time.Since(before)) + if err != nil { + logger.Error(err, "Failed to calculate score for scorer", "scorer", scorer.Name()) + return 0, err + } + score += oneScore + logger.Info("After scorer", "scorer", scorer.Name(), "score", oneScore, "total score", score) + } + return score, nil +} + +type defaultPlugin struct { + plugins.RandomPicker +} + +func (p *defaultPlugin) Name() string { + return "DefaultPlugin" +} + +func (p *defaultPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + req := ctx.Req + var filter types.Filter + if req.Critical { + filter = lowLatencyFilter + } else { + filter = sheddableRequestFilter } - logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) - i := rand.Intn(len(pods)) - return pods[i], nil + return filter.Filter(ctx, pods) } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 3fd3fb24..5a2265bf 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -18,22 +18,34 @@ package scheduling import ( "context" + "errors" "testing" "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +// Tests the default scheduler configuration and expected behavior. func TestSchedule(t *testing.T) { tests := []struct { - name string - req *types.LLMRequest - input []*backendmetrics.FakePodMetrics - output types.Pod - err bool + name string + req *types.LLMRequest + input []*backendmetrics.FakePodMetrics + wantRes *types.Result + err bool }{ + { + name: "no pods in datastore", + req: &types.LLMRequest{ + Model: "any-model", + ResolvedTargetModel: "any-model", + Critical: true, + }, + input: []*backendmetrics.FakePodMetrics{}, + err: true, + }, { name: "critical request", req: &types.LLMRequest{ @@ -80,17 +92,19 @@ func TestSchedule(t *testing.T) { }, }, }, - output: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, + wantRes: &types.Result{ + TargetPod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, }, }, @@ -139,17 +153,19 @@ func TestSchedule(t *testing.T) { }, }, }, - output: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, + wantRes: &types.Result{ + TargetPod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, }, }, @@ -199,8 +215,8 @@ func TestSchedule(t *testing.T) { }, }, }, - output: nil, - err: true, + wantRes: nil, + err: true, }, } @@ -212,13 +228,205 @@ func TestSchedule(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + opt := cmp.AllowUnexported(types.PodMetrics{}) + if diff := cmp.Diff(test.wantRes, got, opt); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) } } +func TestSchedulePlugins(t *testing.T) { + tp1 := &TestPlugin{ + NameRes: "test1", + ScoreRes: 0.3, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + } + tp2 := &TestPlugin{ + NameRes: "test2", + ScoreRes: 0.8, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + } + tpFilterErr := &TestPlugin{ + NameRes: "filter err", + FilterErr: errors.New("filter error"), + } + tpScorerErr := &TestPlugin{ + NameRes: "score err", + ScoreErr: errors.New("score err"), + } + pickerPlugin := &TestPlugin{ + NameRes: "picker", + PickRes: k8stypes.NamespacedName{Name: "pod1"}, + } + pickerErr := &TestPlugin{ + NameRes: "picker err", + PickErr: errors.New("picker err"), + } + + tests := []struct { + name string + preSchedulePlugins []types.PreSchedule + postSchedulePlugins []types.PostSchedule + filters []types.Filter + scorers []types.Scorer + picker types.Picker + input []*backendmetrics.FakePodMetrics + wantTargetPod k8stypes.NamespacedName + targetPodScore float64 + // Number of expected pods to score (after filter) + numPodsToScore int + err bool + }{ + { + name: "all plugins executed successfully", + preSchedulePlugins: []types.PreSchedule{tp1, tp2}, + postSchedulePlugins: []types.PostSchedule{tp1, tp2}, + filters: []types.Filter{tp1, tp2}, + scorers: []types.Scorer{tp1, tp2}, + picker: pickerPlugin, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, + }, + { + name: "filter error", + preSchedulePlugins: []types.PreSchedule{tp1, tp2}, + postSchedulePlugins: []types.PostSchedule{tp1, tp2}, + filters: []types.Filter{tp1, tpFilterErr}, + scorers: []types.Scorer{tp1, tp2}, + picker: pickerPlugin, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + err: true, + }, + { + name: "scorer error", + preSchedulePlugins: []types.PreSchedule{tp1, tp2}, + postSchedulePlugins: []types.PostSchedule{tp1, tp2}, + filters: []types.Filter{tp1, tp2}, + scorers: []types.Scorer{tp1, tpScorerErr}, + picker: pickerPlugin, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + err: true, + }, + { + name: "picker error", + preSchedulePlugins: []types.PreSchedule{tp1, tp2}, + postSchedulePlugins: []types.PostSchedule{tp1, tp2}, + filters: []types.Filter{tp1, tp2}, + scorers: []types.Scorer{tp1, tp2}, + picker: pickerErr, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + err: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Reset all plugins before each new test case. + for _, plugin := range test.preSchedulePlugins { + plugin.(*TestPlugin).Reset() + } + for _, plugin := range test.postSchedulePlugins { + plugin.(*TestPlugin).Reset() + } + for _, plugin := range test.filters { + plugin.(*TestPlugin).Reset() + } + for _, plugin := range test.scorers { + plugin.(*TestPlugin).Reset() + } + test.picker.(*TestPlugin).Reset() + + // Initialize the scheduler + scheduler := &Scheduler{ + datastore: &fakeDataStore{pods: test.input}, + preSchedulePlugins: test.preSchedulePlugins, + postSchedulePlugins: test.postSchedulePlugins, + filters: test.filters, + scorers: test.scorers, + picker: test.picker, + } + + req := &types.LLMRequest{Model: "test-model"} + got, err := scheduler.Schedule(context.Background(), req) + + // Validate error state + if test.err != (err != nil) { + t.Fatalf("Unexpected error, got %v, want %v", err, test.err) + } + + if err != nil { + return + } + + // Validate output + opt := cmp.AllowUnexported(types.PodMetrics{}) + wantPod := &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: test.wantTargetPod}, + } + wantPod.SetScore(test.targetPodScore) + wantRes := &types.Result{TargetPod: wantPod} + if diff := cmp.Diff(wantRes, got, opt); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + + // Validate plugin execution counts dynamically + for _, plugin := range test.preSchedulePlugins { + tp, _ := plugin.(*TestPlugin) + if tp.PreScheduleCallCount != 1 { + t.Errorf("Plugin %s PreSchedule() called %d times, expected 1", tp.NameRes, tp.PreScheduleCallCount) + } + } + + for _, plugin := range test.postSchedulePlugins { + tp, _ := plugin.(*TestPlugin) + if tp.PostScheduleCallCount != 1 { + t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", tp.NameRes, tp.PostScheduleCallCount) + } + } + + for _, plugin := range test.filters { + tp, _ := plugin.(*TestPlugin) + if tp.FilterCallCount != 1 { + t.Errorf("Plugin %s Filter() called %d times, expected 1", tp.NameRes, tp.FilterCallCount) + } + } + + for _, plugin := range test.scorers { + tp, _ := plugin.(*TestPlugin) + if tp.ScoreCallCount != test.numPodsToScore { + t.Errorf("Plugin %s Score() called %d times, expected 1", tp.NameRes, tp.ScoreCallCount) + } + } + + tp, _ := test.picker.(*TestPlugin) + if tp.PickCallCount != 1 { + t.Errorf("Picker plugin %s Pick() called %d times, expected 1", tp.NameRes, tp.PickCallCount) + } + + }) + } +} + type fakeDataStore struct { pods []*backendmetrics.FakePodMetrics } @@ -230,3 +438,68 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { } return pm } + +// TestPlugin is an implementation useful in unit tests. +type TestPlugin struct { + NameRes string + ScoreCallCount int + ScoreRes float64 + ScoreErr error + FilterCallCount int + FilterRes []k8stypes.NamespacedName + FilterErr error + PreScheduleCallCount int + PostScheduleCallCount int + PickCallCount int + PickRes k8stypes.NamespacedName + PickErr error +} + +func (tp *TestPlugin) Name() string { return tp.NameRes } + +func (tp *TestPlugin) Score(ctx *types.Context, pod types.Pod) (float64, error) { + tp.ScoreCallCount++ + return tp.ScoreRes, tp.ScoreErr +} + +func (tp *TestPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + tp.FilterCallCount++ + return findPods(ctx, tp.FilterRes...), tp.FilterErr +} + +func (tp *TestPlugin) PreSchedule(ctx *types.Context) { + tp.PreScheduleCallCount++ +} + +func (tp *TestPlugin) PostSchedule(ctx *types.Context, res *types.Result) { + tp.PostScheduleCallCount++ +} + +func (tp *TestPlugin) Pick(ctx *types.Context, pods []types.Pod) (*types.Result, error) { + tp.PickCallCount++ + if tp.PickErr != nil { + return nil, tp.PickErr + } + pod := findPods(ctx, tp.PickRes)[0] + return &types.Result{TargetPod: pod}, nil +} + +func (tp *TestPlugin) Reset() { + tp.PreScheduleCallCount = 0 + tp.PostScheduleCallCount = 0 + tp.FilterCallCount = 0 + tp.ScoreCallCount = 0 + tp.PickCallCount = 0 +} + +func findPods(ctx *types.Context, names ...k8stypes.NamespacedName) []types.Pod { + res := []types.Pod{} + for _, pod := range ctx.PodsSnapshot { + for _, name := range names { + if pod.GetPod().NamespacedName.String() == name.String() { + res = append(res, pod) + } + } + } + return res +} diff --git a/pkg/epp/scheduling/types/interfaces.go b/pkg/epp/scheduling/types/interfaces.go new file mode 100644 index 00000000..6e954cef --- /dev/null +++ b/pkg/epp/scheduling/types/interfaces.go @@ -0,0 +1,75 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" +) + +const ( + PreSchedulerPluginType = "PreSchedule" + PostSchedulePluginType = "PostSchedule" + FilterPluginType = "Filter" + ScorerPluginType = "Scorer" + PickerPluginType = "Picker" +) + +type Pod interface { + GetPod() *backendmetrics.Pod + GetMetrics() *backendmetrics.Metrics + SetScore(float64) + Score() float64 + String() string +} + +// Plugin defines the interface for scheduler plugins, combining scoring, filtering, +// and event handling capabilities. +type Plugin interface { + // Name returns the name of the plugin. + Name() string +} + +// PreSchedule is called when the scheduler receives a new request. It can be used for various +// initialization work. +type PreSchedule interface { + Plugin + PreSchedule(ctx *Context) +} + +// PostSchedule is called by the scheduler after it selects a targetPod for the request. +type PostSchedule interface { + Plugin + PostSchedule(ctx *Context, res *Result) +} + +// Filter defines the interface for filtering a list of pods based on context. +type Filter interface { + Plugin + Filter(ctx *Context, pods []Pod) ([]Pod, error) +} + +// Scorer defines the interface for scoring pods based on context. +type Scorer interface { + Plugin + Score(ctx *Context, pod Pod) (float64, error) +} + +// Picker picks the final pod(s) to send the request to. +type Picker interface { + Plugin + Pick(ctx *Context, pods []Pod) (*Result, error) +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 9450652e..e52e9047 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -30,23 +30,22 @@ type LLMRequest struct { Model string // Target models is a map of target model name to weight. TargetModels map[string]int + Prompt string // Resolved target model is the final target model after traffic split. ResolvedTargetModel string Critical bool } +func (r *LLMRequest) String() string { + return fmt.Sprintf("Model: %s, TargetModels: %v, ResolvedTargetModel: %s, Critical: %t, PromptLength: %v", r.Model, r.TargetModels, r.ResolvedTargetModel, r.Critical, len(r.Prompt)) +} + // Context holds contextual information during a scheduling operation. type Context struct { context.Context Logger logr.Logger Req *LLMRequest - PodsSnapshot []*PodMetrics -} - -type Pod interface { - GetPod() *backendmetrics.Pod - GetMetrics() *backendmetrics.Metrics - String() string + PodsSnapshot []Pod } func (pm *PodMetrics) String() string { @@ -64,12 +63,21 @@ func (pm *PodMetrics) GetMetrics() *backendmetrics.Metrics { return pm.Metrics } +func (pm *PodMetrics) SetScore(score float64) { + pm.score = score +} + +func (pm *PodMetrics) Score() float64 { + return pm.score +} + type PodMetrics struct { + score float64 *backendmetrics.Pod *backendmetrics.Metrics } -func NewContext(ctx context.Context, req *LLMRequest, pods []*PodMetrics) *Context { +func NewContext(ctx context.Context, req *LLMRequest, pods []Pod) *Context { logger := log.FromContext(ctx).WithValues("request", req) return &Context{ Context: ctx, @@ -79,10 +87,15 @@ func NewContext(ctx context.Context, req *LLMRequest, pods []*PodMetrics) *Conte } } -func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []*PodMetrics { - pm := make([]*PodMetrics, 0, len(pods)) +func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { + pm := make([]Pod, 0, len(pods)) for _, pod := range pods { - pm = append(pm, &PodMetrics{pod.GetPod().Clone(), pod.GetMetrics().Clone()}) + pm = append(pm, &PodMetrics{Pod: pod.GetPod().Clone(), Metrics: pod.GetMetrics().Clone()}) } return pm } + +// Result captures the scheduler result. +type Result struct { + TargetPod Pod +}