Skip to content

Commit 85dac47

Browse files
committed
In the prefix plugin, keep both block size parameters: the legacy one defined in chars and the new one defined in tokens, update tests accordingly
Signed-off-by: Maya Barnea <mayab@il.ibm.com>
1 parent 9b30592 commit 85dac47

3 files changed

Lines changed: 77 additions & 17 deletions

File tree

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"encoding/binary"
2222
"encoding/json"
23+
"errors"
2324
"fmt"
2425
"sync"
2526
"time"
@@ -68,7 +69,8 @@ const (
6869

6970
var DefaultConfig = Config{
7071
AutoTune: true,
71-
BlockSizeTokens: DefaultBlockSizeTokens,
72+
BlockSize: 0,
73+
BlockSizeTokens: 0,
7274
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
7375
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
7476
}
@@ -77,9 +79,13 @@ type Config struct {
7779
// If set to true, the plugin will automatically adjust the configuration based on various
7880
// metrics from the model servers.
7981
AutoTune bool `json:"autoTune"`
80-
// The input prompt is broken into sizes of BlockSizeTokens to calculate block hashes . Requests
82+
// The input prompt is broken into sizes of BlockSizeTokens to calculate block hashes. Requests
8183
// with length shorter than the block size will be ignored.
82-
BlockSizeTokens int `json:"blockSize"`
84+
BlockSizeTokens int `json:"blockSizeTokens"`
85+
// Depricated: Legacy block size defined in number of characters.
86+
// In case only BlockSize is defined in the configuration - plugin initialization will fail.
87+
// In case both BlockSize and BlockSizeTokens are defined - BlockSizeTokens is used.
88+
BlockSize int `json:"blockSize"`
8389
// MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will
8490
// be ignored.
8591
MaxPrefixBlocksToMatch int `json:"maxPrefixBlocksToMatch"`
@@ -163,13 +169,25 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle
163169
}
164170
}
165171

166-
p := New(handle.Context(), parameters).WithName(name)
172+
p, err := New(handle.Context(), parameters)
173+
if err != nil {
174+
return nil, err
175+
}
176+
177+
p.WithName(name)
167178
go p.CleanUpInactivePods(handle.Context(), handle)
168179
return p, nil
169180
}
170181

171182
// New initializes a new prefix Plugin and returns its pointer.
172-
func New(ctx context.Context, config Config) *Plugin {
183+
func New(ctx context.Context, config Config) (*Plugin, error) {
184+
// invalid configuration: only BlockSize is defined
185+
if config.BlockSize > 0 && config.BlockSizeTokens <= 0 {
186+
err := errors.New("BlockSize is depricated, use BlockSizeTokens instead, the value should be defined in tokens")
187+
log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "invalid prefix plugin configuration")
188+
return nil, err
189+
}
190+
173191
if config.LRUCapacityPerServer <= 0 {
174192
config.LRUCapacityPerServer = DefaultLRUCapacityPerServer
175193
log.FromContext(ctx).V(logutil.DEFAULT).Info(
@@ -196,7 +214,7 @@ func New(ctx context.Context, config Config) *Plugin {
196214
config: config,
197215
pluginState: plugins.NewPluginState(ctx),
198216
indexer: newIndexer(ctx, config.LRUCapacityPerServer),
199-
}
217+
}, nil
200218
}
201219

202220
// TypedName returns the type and name tuple of this plugin instance.

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,46 @@ import (
3939
// static check to ensure Plugin implements the PrepareDataPlugin interface.
4040
var _ requestcontrol.PrepareDataPlugin = &Plugin{}
4141

42+
func TestPrefixPluginValidation(t *testing.T) {
43+
validConfigs := []Config{{
44+
AutoTune: false,
45+
BlockSizeTokens: 1,
46+
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
47+
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
48+
}, {
49+
AutoTune: false,
50+
BlockSize: 1,
51+
BlockSizeTokens: 1,
52+
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
53+
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
54+
}}
55+
invalidConfigs := []Config{{
56+
AutoTune: false,
57+
BlockSize: 1,
58+
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
59+
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
60+
}}
61+
62+
for _, config := range validConfigs {
63+
_, err := New(context.Background(), config)
64+
assert.NoError(t, err)
65+
}
66+
67+
for _, config := range invalidConfigs {
68+
_, err := New(context.Background(), config)
69+
assert.Error(t, err)
70+
}
71+
}
72+
4273
func TestPrefixPluginCompletion(t *testing.T) {
4374
config := Config{
4475
AutoTune: false,
4576
BlockSizeTokens: 1,
4677
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
4778
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
4879
}
49-
plugin := New(context.Background(), config)
80+
plugin, err := New(context.Background(), config)
81+
assert.NoError(t, err)
5082

5183
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState()}
5284
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState()}
@@ -212,7 +244,8 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
212244
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
213245
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
214246
}
215-
plugin := New(context.Background(), config)
247+
plugin, err := New(context.Background(), config)
248+
assert.NoError(t, err)
216249

217250
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{}}
218251
pods := []types.Pod{pod1}
@@ -247,7 +280,8 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
247280
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
248281
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
249282
}
250-
plugin := New(context.Background(), config)
283+
plugin, err := New(context.Background(), config)
284+
assert.NoError(t, err)
251285

252286
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{}}
253287
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: &backendmetrics.MetricsState{}}
@@ -361,7 +395,8 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
361395
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
362396
}
363397

364-
plugin := New(context.Background(), config)
398+
plugin, err := New(context.Background(), config)
399+
assert.NoError(b, err)
365400
types.NewCycleState()
366401
var promptLen []int
367402
for i := 1; i <= 1024; {
@@ -462,8 +497,9 @@ func TestNew_InvalidConfigFallbacks(t *testing.T) {
462497
for _, tt := range tests {
463498
t.Run(tt.name, func(t *testing.T) {
464499

465-
plugin := New(context.Background(), tt.config)
500+
plugin, err := New(context.Background(), tt.config)
466501

502+
assert.NoError(t, err)
467503
assert.NotEmpty(t, plugin)
468504
assert.NotEmpty(t, plugin.indexer)
469505
assert.Equal(t, tt.expectBlock, plugin.config.BlockSizeTokens)
@@ -506,7 +542,8 @@ func TestPrefixPluginAutoTune(t *testing.T) {
506542
// Should be ignored in favor of pod metrics (1000)
507543
LRUCapacityPerServer: 1,
508544
}
509-
plugin := New(context.Background(), config)
545+
plugin, err := New(context.Background(), config)
546+
assert.NoError(t, err)
510547

511548
// 1. Verify Score uses pod metrics for block size
512549
scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
@@ -540,7 +577,8 @@ func TestPrefixPluginAutoTune(t *testing.T) {
540577
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
541578
LRUCapacityPerServer: 1, // Should be used, and the first hash should be evicted due to the small
542579
}
543-
plugin := New(context.Background(), config)
580+
plugin, err := New(context.Background(), config)
581+
assert.NoError(t, err)
544582

545583
// 1. Verify Score uses config BlockSize
546584
req.RequestId = uuid.NewString() // New request ID
@@ -584,7 +622,8 @@ func TestPrepareRequestData(t *testing.T) {
584622
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
585623
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
586624
}
587-
plugin := New(context.Background(), config)
625+
plugin, err := New(context.Background(), config)
626+
assert.NoError(t, err)
588627

589628
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()}
590629
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState(), AttributeMap: datalayer.NewAttributes()}
@@ -621,7 +660,7 @@ func TestPrepareRequestData(t *testing.T) {
621660
},
622661
}
623662

624-
err := plugin.PrepareRequestData(context.Background(), req2, pods)
663+
err = plugin.PrepareRequestData(context.Background(), req2, pods)
625664
assert.NoError(t, err)
626665

627666
// Verify pod1 has the correct prefix match info
@@ -647,7 +686,8 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
647686
MaxPrefixBlocksToMatch: maxPrefixBlocks,
648687
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
649688
}
650-
plugin := New(context.Background(), config)
689+
plugin, err := New(context.Background(), config)
690+
assert.NoError(b, err)
651691

652692
// Test scenarios: varying number of messages and message lengths
653693
scenarios := []struct {

pkg/epp/scheduling/scheduler_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222

2323
"github.com/google/go-cmp/cmp"
2424
"github.com/google/uuid"
25+
"github.com/stretchr/testify/assert"
2526
k8stypes "k8s.io/apimachinery/pkg/types"
2627

2728
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
@@ -38,7 +39,8 @@ import (
3839
func TestSchedule(t *testing.T) {
3940
kvCacheUtilizationScorer := scorer.NewKVCacheUtilizationScorer()
4041
queueingScorer := scorer.NewQueueScorer()
41-
prefixCacheScorer := prefix.New(context.Background(), prefix.DefaultConfig)
42+
prefixCacheScorer, err := prefix.New(context.Background(), prefix.DefaultConfig)
43+
assert.NoError(t, err)
4244
loraAffinityScorer := scorer.NewLoraAffinityScorer()
4345

4446
defaultProfile := framework.NewSchedulerProfile().

0 commit comments

Comments
 (0)