Skip to content

Commit 94b69d3

Browse files
authored
Complete feature: embeddings batching, retries, managed identity (#1114)
Complete Embedding Generators and Cache feature by implementing remaining Must-Have acceptance criteria. Changes: - Add configurable embeddings batch sizing (default 10) via `EmbeddingsConfig.batchSize` and chunking for batch-capable providers. - Add shared transient HTTP retry/backoff (honors `Retry-After`) and use it in OpenAI/Azure/HF/Ollama providers. - Add Azure OpenAI managed identity auth (`DefaultAzureCredential`) alongside API key auth. - Add HuggingFace `HF_TOKEN` environment variable fallback. - Add/adjust unit + integration tests and stabilize env-var dependent tests.
1 parent 9c6ba4f commit 94b69d3

23 files changed

Lines changed: 739 additions & 86 deletions

src/Core/Config/AppConfig.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ namespace KernelMemory.Core.Config;
1010
/// Root configuration for Kernel Memory application
1111
/// Loaded from ~/.km/config.json or custom path
1212
/// </summary>
13+
#pragma warning disable CA1724 // Conflicts with Microsoft.Identity.Client.AppConfig when Azure.Identity is referenced.
1314
public sealed class AppConfig : IValidatable
15+
#pragma warning restore CA1724
1416
{
1517
/// <summary>
1618
/// Named memory nodes (e.g., "personal", "work")

src/Core/Config/Embeddings/AzureOpenAIEmbeddingsConfig.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,10 @@ public override void Validate(string path)
8181
throw new ConfigException(path,
8282
"Azure OpenAI: specify either ApiKey or UseManagedIdentity, not both");
8383
}
84+
85+
if (this.BatchSize < 1)
86+
{
87+
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
88+
}
8489
}
8590
}

src/Core/Config/Embeddings/EmbeddingsConfig.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ public abstract class EmbeddingsConfig : IValidatable
2121
[JsonIgnore]
2222
public abstract EmbeddingsTypes Type { get; }
2323

24+
/// <summary>
25+
/// Maximum number of texts to send per embeddings API request.
26+
/// Providers that support batch requests should chunk input using this size.
27+
/// Default: 10.
28+
/// </summary>
29+
[JsonPropertyName("batchSize")]
30+
public int BatchSize { get; set; } = Constants.EmbeddingDefaults.DefaultBatchSize;
31+
2432
/// <summary>
2533
/// Validates the embeddings configuration
2634
/// </summary>

src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,25 @@ public override void Validate(string path)
4444
throw new ConfigException($"{path}.Model", "HuggingFace model name is required");
4545
}
4646

47+
// ApiKey can be provided via config or HF_TOKEN environment variable.
48+
// Resolve env var into config so downstream code can rely on configuration only.
4749
if (string.IsNullOrWhiteSpace(this.ApiKey))
4850
{
49-
throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required");
51+
var token = Environment.GetEnvironmentVariable("HF_TOKEN");
52+
if (!string.IsNullOrWhiteSpace(token))
53+
{
54+
this.ApiKey = token;
55+
}
56+
}
57+
58+
if (string.IsNullOrWhiteSpace(this.ApiKey))
59+
{
60+
throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required (set ApiKey or HF_TOKEN)");
61+
}
62+
63+
if (this.BatchSize < 1)
64+
{
65+
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
5066
}
5167

5268
if (string.IsNullOrWhiteSpace(this.BaseUrl))

src/Core/Config/Embeddings/OllamaEmbeddingsConfig.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ public override void Validate(string path)
3939
throw new ConfigException($"{path}.BaseUrl", "Ollama base URL is required");
4040
}
4141

42+
if (this.BatchSize < 1)
43+
{
44+
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
45+
}
46+
4247
if (!Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _))
4348
{
4449
throw new ConfigException($"{path}.BaseUrl",

src/Core/Config/Embeddings/OpenAIEmbeddingsConfig.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ public override void Validate(string path)
4545
throw new ConfigException($"{path}.ApiKey", "OpenAI API key is required");
4646
}
4747

48+
if (this.BatchSize < 1)
49+
{
50+
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
51+
}
52+
4853
if (!string.IsNullOrWhiteSpace(this.BaseUrl) &&
4954
!Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _))
5055
{

src/Core/Constants.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,39 @@ public static bool TryGetDimensions(string modelName, out int dimensions)
210210
}
211211
}
212212

213+
/// <summary>
214+
/// Constants for HTTP retry/backoff used by external providers (embeddings, etc.).
215+
/// </summary>
216+
public static class HttpRetryDefaults
217+
{
218+
/// <summary>
219+
/// Maximum attempts including the first try.
220+
/// </summary>
221+
public const int MaxAttempts = 5;
222+
223+
/// <summary>
224+
/// Default timeout per attempt (seconds).
225+
/// Applied by <see cref="KernelMemory.Core.Http.HttpRetryPolicy"/> to avoid hanging calls.
226+
/// </summary>
227+
public const int DefaultPerAttemptTimeoutSeconds = 60;
228+
229+
/// <summary>
230+
/// Per-attempt timeout for local Ollama calls (seconds).
231+
/// Keep this low so local development and tests fail fast when Ollama is not running.
232+
/// </summary>
233+
public const int OllamaPerAttemptTimeoutSeconds = 5;
234+
235+
/// <summary>
236+
/// Base delay for exponential backoff.
237+
/// </summary>
238+
public const int BaseDelayMs = 200;
239+
240+
/// <summary>
241+
/// Maximum delay between attempts.
242+
/// </summary>
243+
public const int MaxDelayMs = 5000;
244+
}
245+
213246
/// <summary>
214247
/// Constants for the logging system including file rotation, log levels,
215248
/// and output formatting.

src/Core/Core.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
</PropertyGroup>
88

99
<ItemGroup>
10+
<PackageReference Include="Azure.Identity" />
1011
<PackageReference Include="cuid.net" />
1112
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" />
1213
<PackageReference Include="Parlot" />

src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
11
// Copyright (c) Microsoft. All rights reserved.
2+
using System.Net.Http.Headers;
23
using System.Net.Http.Json;
34
using System.Text.Json.Serialization;
5+
using Azure.Core;
6+
using Azure.Identity;
47
using KernelMemory.Core.Config.Enums;
8+
using KernelMemory.Core.Http;
59
using Microsoft.Extensions.Logging;
610

711
namespace KernelMemory.Core.Embeddings.Providers;
812

913
/// <summary>
1014
/// Azure OpenAI embedding generator implementation.
1115
/// Communicates with Azure OpenAI Service.
12-
/// Supports API key authentication (managed identity would require Azure.Identity package).
16+
/// Supports API key authentication or managed identity via <see cref="DefaultAzureCredential"/>.
1317
/// </summary>
1418
public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator
1519
{
1620
private readonly HttpClient _httpClient;
1721
private readonly string _endpoint;
1822
private readonly string _deployment;
19-
private readonly string _apiKey;
23+
private readonly string? _apiKey;
24+
private readonly bool _useManagedIdentity;
25+
private readonly TokenCredential? _credential;
26+
private readonly int _batchSize;
2027
private readonly ILogger<AzureOpenAIEmbeddingGenerator> _logger;
28+
private readonly Func<TimeSpan, CancellationToken, Task> _delayAsync;
2129

2230
/// <inheritdoc />
2331
public EmbeddingsTypes ProviderType => EmbeddingsTypes.AzureOpenAI;
@@ -38,35 +46,52 @@ public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator
3846
/// <param name="endpoint">Azure OpenAI endpoint (e.g., https://myservice.openai.azure.com).</param>
3947
/// <param name="deployment">Deployment name in Azure.</param>
4048
/// <param name="model">Model name for identification.</param>
41-
/// <param name="apiKey">Azure OpenAI API key.</param>
49+
/// <param name="apiKey">Azure OpenAI API key (required unless <paramref name="useManagedIdentity"/> is true).</param>
4250
/// <param name="vectorDimensions">Vector dimensions produced by the model.</param>
4351
/// <param name="isNormalized">Whether vectors are normalized.</param>
4452
/// <param name="logger">Logger instance.</param>
53+
/// <param name="batchSize">Maximum number of texts per API request.</param>
54+
/// <param name="useManagedIdentity">Whether to authenticate using managed identity.</param>
55+
/// <param name="credential">Optional token credential (used for testing); defaults to <see cref="DefaultAzureCredential"/>.</param>
56+
/// <param name="delayAsync">Optional delay function for retries (used for fast unit tests).</param>
4557
public AzureOpenAIEmbeddingGenerator(
4658
HttpClient httpClient,
4759
string endpoint,
4860
string deployment,
4961
string model,
50-
string apiKey,
62+
string? apiKey,
5163
int vectorDimensions,
5264
bool isNormalized,
53-
ILogger<AzureOpenAIEmbeddingGenerator> logger)
65+
ILogger<AzureOpenAIEmbeddingGenerator> logger,
66+
int batchSize,
67+
bool useManagedIdentity,
68+
TokenCredential? credential = null,
69+
Func<TimeSpan, CancellationToken, Task>? delayAsync = null)
5470
{
5571
ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient));
5672
ArgumentNullException.ThrowIfNull(endpoint, nameof(endpoint));
5773
ArgumentNullException.ThrowIfNull(deployment, nameof(deployment));
5874
ArgumentNullException.ThrowIfNull(model, nameof(model));
59-
ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey));
6075
ArgumentNullException.ThrowIfNull(logger, nameof(logger));
76+
ArgumentOutOfRangeException.ThrowIfLessThan(batchSize, 1, nameof(batchSize));
6177

6278
this._httpClient = httpClient;
6379
this._endpoint = endpoint.TrimEnd('/');
6480
this._deployment = deployment;
6581
this._apiKey = apiKey;
82+
this._useManagedIdentity = useManagedIdentity;
83+
this._credential = credential;
84+
this._batchSize = batchSize;
6685
this.ModelName = model;
6786
this.VectorDimensions = vectorDimensions;
6887
this.IsNormalized = isNormalized;
6988
this._logger = logger;
89+
this._delayAsync = delayAsync ?? Task.Delay;
90+
91+
if (!this._useManagedIdentity && string.IsNullOrWhiteSpace(this._apiKey))
92+
{
93+
throw new ArgumentException("Azure OpenAI API key is required when not using managed identity", nameof(apiKey));
94+
}
7095

7196
this._logger.LogDebug("AzureOpenAIEmbeddingGenerator initialized: {Endpoint}, deployment: {Deployment}, model: {Model}",
7297
this._endpoint, this._deployment, this.ModelName);
@@ -88,21 +113,53 @@ public async Task<EmbeddingResult[]> GenerateAsync(IEnumerable<string> texts, Ca
88113
return [];
89114
}
90115

116+
var allResults = new List<EmbeddingResult>(textArray.Length);
117+
foreach (var chunk in Chunk(textArray, this._batchSize))
118+
{
119+
var chunkResults = await this.GenerateBatchAsync(chunk, ct).ConfigureAwait(false);
120+
allResults.AddRange(chunkResults);
121+
}
122+
123+
return allResults.ToArray();
124+
}
125+
126+
private async Task<EmbeddingResult[]> GenerateBatchAsync(string[] textArray, CancellationToken ct)
127+
{
91128
var url = $"{this._endpoint}/openai/deployments/{this._deployment}/embeddings?api-version={Constants.EmbeddingDefaults.AzureOpenAIApiVersion}";
92129

93130
var request = new AzureEmbeddingRequest
94131
{
95132
Input = textArray
96133
};
97134

98-
using var httpRequest = new HttpRequestMessage(HttpMethod.Post, url);
99-
httpRequest.Headers.Add("api-key", this._apiKey);
100-
httpRequest.Content = JsonContent.Create(request);
135+
var bearerToken = this._useManagedIdentity
136+
? await this.GetManagedIdentityTokenAsync(ct).ConfigureAwait(false)
137+
: null;
101138

102139
this._logger.LogTrace("Calling Azure OpenAI embeddings API: deployment={Deployment}, batch size: {BatchSize}",
103140
this._deployment, textArray.Length);
104141

105-
var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false);
142+
using var response = await HttpRetryPolicy.SendAsync(
143+
this._httpClient,
144+
requestFactory: () =>
145+
{
146+
var httpRequest = new HttpRequestMessage(HttpMethod.Post, url);
147+
if (bearerToken != null)
148+
{
149+
httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", bearerToken);
150+
}
151+
else
152+
{
153+
httpRequest.Headers.Add("api-key", this._apiKey);
154+
}
155+
156+
httpRequest.Content = JsonContent.Create(request);
157+
return httpRequest;
158+
},
159+
this._logger,
160+
ct,
161+
delayAsync: this._delayAsync).ConfigureAwait(false);
162+
106163
response.EnsureSuccessStatusCode();
107164

108165
var result = await response.Content.ReadFromJsonAsync<AzureEmbeddingResponse>(ct).ConfigureAwait(false);
@@ -141,6 +198,26 @@ public async Task<EmbeddingResult[]> GenerateAsync(IEnumerable<string> texts, Ca
141198
return results;
142199
}
143200

201+
private async Task<string> GetManagedIdentityTokenAsync(CancellationToken ct)
202+
{
203+
var credential = this._credential ?? new DefaultAzureCredential();
204+
var token = await credential.GetTokenAsync(
205+
new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]),
206+
ct).ConfigureAwait(false);
207+
return token.Token;
208+
}
209+
210+
private static IEnumerable<string[]> Chunk(string[] items, int chunkSize)
211+
{
212+
for (int i = 0; i < items.Length; i += chunkSize)
213+
{
214+
var length = Math.Min(chunkSize, items.Length - i);
215+
var chunk = new string[length];
216+
Array.Copy(items, i, chunk, 0, length);
217+
yield return chunk;
218+
}
219+
}
220+
144221
/// <summary>
145222
/// Request body for Azure OpenAI embeddings API.
146223
/// </summary>

0 commit comments

Comments
 (0)