Skip to content

Commit b33a55c

Browse files
Auto-chunk TEI embedding requests to respect server batch size limit (#4029)
TEI (Text Embeddings Inference) does not automatically batch embedding requests — it rejects requests that exceed max_client_batch_size with a 422 error. This wasn't caught earlier because the Optimizer was indexing fewer than 32 tools (the default limit). Once we exceeded that threshold, upserts started failing. Query the TEI /info endpoint at client creation to discover max_client_batch_size, then split EmbedBatch requests into chunks that fit within the limit. Falls back to a default of 32 when /info is unavailable. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bc9b534 commit b33a55c

2 files changed

Lines changed: 311 additions & 53 deletions

File tree

pkg/vmcp/optimizer/internal/similarity/tei_client.go

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,20 @@ const (
2222

2323
// embedPath is the TEI endpoint path for generating embeddings.
2424
embedPath = "/embed"
25+
26+
// infoPath is the TEI endpoint that returns server metadata including max batch size.
27+
infoPath = "/info"
28+
29+
// defaultMaxBatchSize is used when the TEI /info endpoint does not report a max batch size.
30+
defaultMaxBatchSize = 32
2531
)
2632

2733
// teiClient implements types.EmbeddingClient by calling the HuggingFace
2834
// Text Embeddings Inference (TEI) HTTP API.
2935
type teiClient struct {
30-
baseURL string
31-
httpClient *http.Client
36+
baseURL string
37+
httpClient *http.Client
38+
maxBatchSize int
3239
}
3340

3441
// NewEmbeddingClient creates an EmbeddingClient from the given optimizer
@@ -42,6 +49,7 @@ func NewEmbeddingClient(cfg *types.OptimizerConfig) (types.EmbeddingClient, erro
4249
}
4350

4451
// newTEIClient creates a new TEI embedding client that calls the specified endpoint.
52+
// It queries the TEI /info endpoint to discover the server's maximum batch size.
4553
func newTEIClient(baseURL string, timeout time.Duration) (*teiClient, error) {
4654
if baseURL == "" {
4755
return nil, fmt.Errorf("TEI BaseURL is required")
@@ -51,16 +59,54 @@ func newTEIClient(baseURL string, timeout time.Duration) (*teiClient, error) {
5159
timeout = defaultTimeout
5260
}
5361

54-
slog.Debug("TEI embedding client created", "base_url", baseURL, "timeout", timeout)
62+
httpClient := &http.Client{Timeout: timeout}
63+
64+
maxBatch, err := fetchMaxBatchSize(baseURL, httpClient)
65+
if err != nil {
66+
slog.Warn("failed to query TEI /info, using default max batch size",
67+
"error", err, "default", defaultMaxBatchSize)
68+
maxBatch = defaultMaxBatchSize
69+
}
70+
71+
slog.Debug("TEI embedding client created",
72+
"base_url", baseURL, "timeout", timeout, "max_batch_size", maxBatch)
5573

5674
return &teiClient{
57-
baseURL: baseURL,
58-
httpClient: &http.Client{
59-
Timeout: timeout,
60-
},
75+
baseURL: baseURL,
76+
httpClient: httpClient,
77+
maxBatchSize: maxBatch,
6178
}, nil
6279
}
6380

81+
// teiInfoResponse is a subset of the TEI /info endpoint response.
82+
type teiInfoResponse struct {
83+
MaxClientBatchSize int `json:"max_client_batch_size"`
84+
}
85+
86+
// fetchMaxBatchSize queries the TEI /info endpoint and returns the max client batch size.
87+
func fetchMaxBatchSize(baseURL string, httpClient *http.Client) (int, error) {
88+
resp, err := httpClient.Get(baseURL + infoPath) // #nosec G107 -- URL is built from the configured TEI base URL
89+
if err != nil {
90+
return 0, fmt.Errorf("TEI /info request failed: %w", err)
91+
}
92+
defer func() { _ = resp.Body.Close() }()
93+
94+
if resp.StatusCode != http.StatusOK {
95+
return 0, fmt.Errorf("TEI /info returned status %d", resp.StatusCode)
96+
}
97+
98+
var info teiInfoResponse
99+
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
100+
return 0, fmt.Errorf("failed to decode TEI /info response: %w", err)
101+
}
102+
103+
if info.MaxClientBatchSize <= 0 {
104+
return defaultMaxBatchSize, nil
105+
}
106+
107+
return info.MaxClientBatchSize, nil
108+
}
109+
64110
// embedRequest is the JSON body sent to the TEI /embed endpoint.
65111
type embedRequest struct {
66112
Inputs []string `json:"inputs"`
@@ -83,12 +129,35 @@ func (c *teiClient) Embed(ctx context.Context, text string) ([]float32, error) {
83129
return results[0], nil
84130
}
85131

86-
// EmbedBatch returns vector embeddings for multiple texts in a single request.
132+
// EmbedBatch returns vector embeddings for multiple texts, automatically
133+
// chunking requests to respect the TEI server's maximum batch size.
87134
func (c *teiClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
88135
if len(texts) == 0 {
89136
return nil, nil
90137
}
91138

139+
allEmbeddings := make([][]float32, 0, len(texts))
140+
141+
for start := 0; start < len(texts); start += c.maxBatchSize {
142+
end := min(start+c.maxBatchSize, len(texts))
143+
chunk := texts[start:end]
144+
145+
embeddings, err := c.embedChunk(ctx, chunk)
146+
if err != nil {
147+
return nil, err
148+
}
149+
allEmbeddings = append(allEmbeddings, embeddings...)
150+
}
151+
152+
slog.Debug("TEI embedding batch completed",
153+
"inputs", len(texts), "chunks", (len(texts)+c.maxBatchSize-1)/c.maxBatchSize,
154+
"dimensions", len(allEmbeddings[0]))
155+
156+
return allEmbeddings, nil
157+
}
158+
159+
// embedChunk sends a single batch of texts to the TEI /embed endpoint.
160+
func (c *teiClient) embedChunk(ctx context.Context, texts []string) ([][]float32, error) {
92161
reqBody := embedRequest{
93162
Inputs: texts,
94163
Truncate: true,
@@ -125,8 +194,6 @@ func (c *teiClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32
125194
return nil, fmt.Errorf("TEI returned %d embeddings for %d inputs", len(embeddings), len(texts))
126195
}
127196

128-
slog.Debug("TEI embedding batch completed", "inputs", len(texts), "dimensions", len(embeddings[0]))
129-
130197
return embeddings, nil
131198
}
132199

0 commit comments

Comments
 (0)