@@ -22,13 +22,20 @@ const (
2222
2323 // embedPath is the TEI endpoint path for generating embeddings.
2424 embedPath = "/embed"
25+
26+ // infoPath is the TEI endpoint that returns server metadata including max batch size.
27+ infoPath = "/info"
28+
29+ // defaultMaxBatchSize is used when the TEI /info endpoint does not report a max batch size.
30+ defaultMaxBatchSize = 32
2531)
2632
2733// teiClient implements types.EmbeddingClient by calling the HuggingFace
2834// Text Embeddings Inference (TEI) HTTP API.
2935type teiClient struct {
30- baseURL string
31- httpClient * http.Client
36+ baseURL string
37+ httpClient * http.Client
38+ maxBatchSize int
3239}
3340
3441// NewEmbeddingClient creates an EmbeddingClient from the given optimizer
@@ -42,6 +49,7 @@ func NewEmbeddingClient(cfg *types.OptimizerConfig) (types.EmbeddingClient, erro
4249}
4350
4451// newTEIClient creates a new TEI embedding client that calls the specified endpoint.
52+ // It queries the TEI /info endpoint to discover the server's maximum batch size.
4553func newTEIClient (baseURL string , timeout time.Duration ) (* teiClient , error ) {
4654 if baseURL == "" {
4755 return nil , fmt .Errorf ("TEI BaseURL is required" )
@@ -51,16 +59,54 @@ func newTEIClient(baseURL string, timeout time.Duration) (*teiClient, error) {
5159 timeout = defaultTimeout
5260 }
5361
54- slog .Debug ("TEI embedding client created" , "base_url" , baseURL , "timeout" , timeout )
62+ httpClient := & http.Client {Timeout : timeout }
63+
64+ maxBatch , err := fetchMaxBatchSize (baseURL , httpClient )
65+ if err != nil {
66+ slog .Warn ("failed to query TEI /info, using default max batch size" ,
67+ "error" , err , "default" , defaultMaxBatchSize )
68+ maxBatch = defaultMaxBatchSize
69+ }
70+
71+ slog .Debug ("TEI embedding client created" ,
72+ "base_url" , baseURL , "timeout" , timeout , "max_batch_size" , maxBatch )
5573
5674 return & teiClient {
57- baseURL : baseURL ,
58- httpClient : & http.Client {
59- Timeout : timeout ,
60- },
75+ baseURL : baseURL ,
76+ httpClient : httpClient ,
77+ maxBatchSize : maxBatch ,
6178 }, nil
6279}
6380
81+ // teiInfoResponse is a subset of the TEI /info endpoint response.
82+ type teiInfoResponse struct {
83+ MaxClientBatchSize int `json:"max_client_batch_size"`
84+ }
85+
86+ // fetchMaxBatchSize queries the TEI /info endpoint and returns the max client batch size.
87+ func fetchMaxBatchSize (baseURL string , httpClient * http.Client ) (int , error ) {
88+ resp , err := httpClient .Get (baseURL + infoPath ) // #nosec G107 -- URL is built from the configured TEI base URL
89+ if err != nil {
90+ return 0 , fmt .Errorf ("TEI /info request failed: %w" , err )
91+ }
92+ defer func () { _ = resp .Body .Close () }()
93+
94+ if resp .StatusCode != http .StatusOK {
95+ return 0 , fmt .Errorf ("TEI /info returned status %d" , resp .StatusCode )
96+ }
97+
98+ var info teiInfoResponse
99+ if err := json .NewDecoder (resp .Body ).Decode (& info ); err != nil {
100+ return 0 , fmt .Errorf ("failed to decode TEI /info response: %w" , err )
101+ }
102+
103+ if info .MaxClientBatchSize <= 0 {
104+ return defaultMaxBatchSize , nil
105+ }
106+
107+ return info .MaxClientBatchSize , nil
108+ }
109+
64110// embedRequest is the JSON body sent to the TEI /embed endpoint.
65111type embedRequest struct {
66112 Inputs []string `json:"inputs"`
@@ -83,12 +129,35 @@ func (c *teiClient) Embed(ctx context.Context, text string) ([]float32, error) {
83129 return results [0 ], nil
84130}
85131
86- // EmbedBatch returns vector embeddings for multiple texts in a single request.
132+ // EmbedBatch returns vector embeddings for multiple texts, automatically
133+ // chunking requests to respect the TEI server's maximum batch size.
87134func (c * teiClient ) EmbedBatch (ctx context.Context , texts []string ) ([][]float32 , error ) {
88135 if len (texts ) == 0 {
89136 return nil , nil
90137 }
91138
139+ allEmbeddings := make ([][]float32 , 0 , len (texts ))
140+
141+ for start := 0 ; start < len (texts ); start += c .maxBatchSize {
142+ end := min (start + c .maxBatchSize , len (texts ))
143+ chunk := texts [start :end ]
144+
145+ embeddings , err := c .embedChunk (ctx , chunk )
146+ if err != nil {
147+ return nil , err
148+ }
149+ allEmbeddings = append (allEmbeddings , embeddings ... )
150+ }
151+
152+ slog .Debug ("TEI embedding batch completed" ,
153+ "inputs" , len (texts ), "chunks" , (len (texts )+ c .maxBatchSize - 1 )/ c .maxBatchSize ,
154+ "dimensions" , len (allEmbeddings [0 ]))
155+
156+ return allEmbeddings , nil
157+ }
158+
159+ // embedChunk sends a single batch of texts to the TEI /embed endpoint.
160+ func (c * teiClient ) embedChunk (ctx context.Context , texts []string ) ([][]float32 , error ) {
92161 reqBody := embedRequest {
93162 Inputs : texts ,
94163 Truncate : true ,
@@ -125,8 +194,6 @@ func (c *teiClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32
125194 return nil , fmt .Errorf ("TEI returned %d embeddings for %d inputs" , len (embeddings ), len (texts ))
126195 }
127196
128- slog .Debug ("TEI embedding batch completed" , "inputs" , len (texts ), "dimensions" , len (embeddings [0 ]))
129-
130197 return embeddings , nil
131198}
132199
0 commit comments