Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 77 additions & 10 deletions pkg/vmcp/optimizer/internal/similarity/tei_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ const (

// embedPath is the TEI endpoint path for generating embeddings.
embedPath = "/embed"

// infoPath is the TEI endpoint that returns server metadata including max batch size.
infoPath = "/info"

// defaultMaxBatchSize is used when the TEI /info endpoint does not report a max batch size.
defaultMaxBatchSize = 32
)

// teiClient implements types.EmbeddingClient by calling the HuggingFace
// Text Embeddings Inference (TEI) HTTP API.
type teiClient struct {
baseURL string
httpClient *http.Client
baseURL string
httpClient *http.Client
maxBatchSize int
}

// NewEmbeddingClient creates an EmbeddingClient from the given optimizer
Expand All @@ -42,6 +49,7 @@ func NewEmbeddingClient(cfg *types.OptimizerConfig) (types.EmbeddingClient, erro
}

// newTEIClient creates a new TEI embedding client that calls the specified endpoint.
// It queries the TEI /info endpoint to discover the server's maximum batch size.
func newTEIClient(baseURL string, timeout time.Duration) (*teiClient, error) {
if baseURL == "" {
return nil, fmt.Errorf("TEI BaseURL is required")
Expand All @@ -51,16 +59,54 @@ func newTEIClient(baseURL string, timeout time.Duration) (*teiClient, error) {
timeout = defaultTimeout
}

slog.Debug("TEI embedding client created", "base_url", baseURL, "timeout", timeout)
httpClient := &http.Client{Timeout: timeout}

maxBatch, err := fetchMaxBatchSize(baseURL, httpClient)
if err != nil {
slog.Warn("failed to query TEI /info, using default max batch size",
"error", err, "default", defaultMaxBatchSize)
maxBatch = defaultMaxBatchSize
}

slog.Debug("TEI embedding client created",
"base_url", baseURL, "timeout", timeout, "max_batch_size", maxBatch)

return &teiClient{
baseURL: baseURL,
httpClient: &http.Client{
Timeout: timeout,
},
baseURL: baseURL,
httpClient: httpClient,
maxBatchSize: maxBatch,
}, nil
}

// teiInfoResponse is a subset of the TEI /info endpoint response.
type teiInfoResponse struct {
MaxClientBatchSize int `json:"max_client_batch_size"`
}

// fetchMaxBatchSize queries the TEI /info endpoint and returns the max client batch size.
func fetchMaxBatchSize(baseURL string, httpClient *http.Client) (int, error) {
resp, err := httpClient.Get(baseURL + infoPath) // #nosec G107 -- URL is built from the configured TEI base URL
if err != nil {
return 0, fmt.Errorf("TEI /info request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("TEI /info returned status %d", resp.StatusCode)
}

var info teiInfoResponse
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
return 0, fmt.Errorf("failed to decode TEI /info response: %w", err)
}

if info.MaxClientBatchSize <= 0 {
return defaultMaxBatchSize, nil
}

return info.MaxClientBatchSize, nil
}

// embedRequest is the JSON body sent to the TEI /embed endpoint.
type embedRequest struct {
Inputs []string `json:"inputs"`
Expand All @@ -83,12 +129,35 @@ func (c *teiClient) Embed(ctx context.Context, text string) ([]float32, error) {
return results[0], nil
}

// EmbedBatch returns vector embeddings for multiple texts in a single request.
// EmbedBatch returns vector embeddings for multiple texts, automatically
// chunking requests to respect the TEI server's maximum batch size.
func (c *teiClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}

allEmbeddings := make([][]float32, 0, len(texts))

for start := 0; start < len(texts); start += c.maxBatchSize {
end := min(start+c.maxBatchSize, len(texts))
chunk := texts[start:end]

embeddings, err := c.embedChunk(ctx, chunk)
if err != nil {
return nil, err
}
allEmbeddings = append(allEmbeddings, embeddings...)
}

slog.Debug("TEI embedding batch completed",
"inputs", len(texts), "chunks", (len(texts)+c.maxBatchSize-1)/c.maxBatchSize,
"dimensions", len(allEmbeddings[0]))

return allEmbeddings, nil
}

// embedChunk sends a single batch of texts to the TEI /embed endpoint.
func (c *teiClient) embedChunk(ctx context.Context, texts []string) ([][]float32, error) {
reqBody := embedRequest{
Inputs: texts,
Truncate: true,
Expand Down Expand Up @@ -125,8 +194,6 @@ func (c *teiClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32
return nil, fmt.Errorf("TEI returned %d embeddings for %d inputs", len(embeddings), len(texts))
}

slog.Debug("TEI embedding batch completed", "inputs", len(texts), "dimensions", len(embeddings[0]))

return embeddings, nil
}

Expand Down
Loading
Loading