Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Core/Config/AppConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ namespace KernelMemory.Core.Config;
/// Root configuration for Kernel Memory application
/// Loaded from ~/.km/config.json or custom path
/// </summary>
#pragma warning disable CA1724 // Conflicts with Microsoft.Identity.Client.AppConfig when Azure.Identity is referenced.
public sealed class AppConfig : IValidatable
#pragma warning restore CA1724
{
/// <summary>
/// Named memory nodes (e.g., "personal", "work")
Expand Down
5 changes: 5 additions & 0 deletions src/Core/Config/Embeddings/AzureOpenAIEmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,10 @@ public override void Validate(string path)
throw new ConfigException(path,
"Azure OpenAI: specify either ApiKey or UseManagedIdentity, not both");
}

if (this.BatchSize < 1)
{
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
}
}
}
8 changes: 8 additions & 0 deletions src/Core/Config/Embeddings/EmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ public abstract class EmbeddingsConfig : IValidatable
[JsonIgnore]
public abstract EmbeddingsTypes Type { get; }

/// <summary>
/// Maximum number of texts to send per embeddings API request.
/// Providers that support batch requests should chunk input using this size.
/// Default: 10.
/// </summary>
[JsonPropertyName("batchSize")]
public int BatchSize { get; set; } = Constants.EmbeddingDefaults.DefaultBatchSize;

/// <summary>
/// Validates the embeddings configuration
/// </summary>
Expand Down
11 changes: 9 additions & 2 deletions src/Core/Config/Embeddings/HuggingFaceEmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,16 @@ public override void Validate(string path)
throw new ConfigException($"{path}.Model", "HuggingFace model name is required");
}

if (string.IsNullOrWhiteSpace(this.ApiKey))
// ApiKey can be provided via config or HF_TOKEN environment variable
if (string.IsNullOrWhiteSpace(this.ApiKey) &&
string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("HF_TOKEN")))
{
throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required");
throw new ConfigException($"{path}.ApiKey", "HuggingFace API key is required (set ApiKey or HF_TOKEN)");
}

if (this.BatchSize < 1)
{
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
}

if (string.IsNullOrWhiteSpace(this.BaseUrl))
Expand Down
5 changes: 5 additions & 0 deletions src/Core/Config/Embeddings/OllamaEmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ public override void Validate(string path)
throw new ConfigException($"{path}.BaseUrl", "Ollama base URL is required");
}

if (this.BatchSize < 1)
{
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
}

if (!Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _))
{
throw new ConfigException($"{path}.BaseUrl",
Expand Down
5 changes: 5 additions & 0 deletions src/Core/Config/Embeddings/OpenAIEmbeddingsConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public override void Validate(string path)
throw new ConfigException($"{path}.ApiKey", "OpenAI API key is required");
}

if (this.BatchSize < 1)
{
throw new ConfigException($"{path}.BatchSize", "BatchSize must be >= 1");
}

if (!string.IsNullOrWhiteSpace(this.BaseUrl) &&
!Uri.TryCreate(this.BaseUrl, UriKind.Absolute, out _))
{
Expand Down
33 changes: 33 additions & 0 deletions src/Core/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,39 @@ public static bool TryGetDimensions(string modelName, out int dimensions)
}
}

/// <summary>
/// Constants for HTTP retry/backoff used by external providers (embeddings, etc.).
/// </summary>
public static class HttpRetryDefaults
{
/// <summary>
/// Maximum attempts including the first try.
/// </summary>
public const int MaxAttempts = 5;

/// <summary>
/// Default timeout per attempt (seconds).
/// Applied by <see cref="KernelMemory.Core.Http.HttpRetryPolicy"/> to avoid hanging calls.
/// </summary>
public const int DefaultPerAttemptTimeoutSeconds = 60;

/// <summary>
/// Per-attempt timeout for local Ollama calls (seconds).
/// Keep this low so local development and tests fail fast when Ollama is not running.
/// </summary>
public const int OllamaPerAttemptTimeoutSeconds = 5;

/// <summary>
/// Base delay for exponential backoff.
/// </summary>
public const int BaseDelayMs = 200;

/// <summary>
/// Maximum delay between attempts.
/// </summary>
public const int MaxDelayMs = 5000;
}

/// <summary>
/// Constants for the logging system including file rotation, log levels,
/// and output formatting.
Expand Down
1 change: 1 addition & 0 deletions src/Core/Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Identity" />
<PackageReference Include="cuid.net" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Sqlite" />
<PackageReference Include="Parlot" />
Expand Down
97 changes: 87 additions & 10 deletions src/Core/Embeddings/Providers/AzureOpenAIEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text.Json.Serialization;
using Azure.Core;
using Azure.Identity;
using KernelMemory.Core.Config.Enums;
using KernelMemory.Core.Http;
using Microsoft.Extensions.Logging;

namespace KernelMemory.Core.Embeddings.Providers;

/// <summary>
/// Azure OpenAI embedding generator implementation.
/// Communicates with Azure OpenAI Service.
/// Supports API key authentication (managed identity would require Azure.Identity package).
/// Supports API key authentication or managed identity via <see cref="DefaultAzureCredential"/>.
/// </summary>
public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator
{
private readonly HttpClient _httpClient;
private readonly string _endpoint;
private readonly string _deployment;
private readonly string _apiKey;
private readonly string? _apiKey;
private readonly bool _useManagedIdentity;
private readonly TokenCredential? _credential;
private readonly int _batchSize;
private readonly ILogger<AzureOpenAIEmbeddingGenerator> _logger;
private readonly Func<TimeSpan, CancellationToken, Task> _delayAsync;

/// <inheritdoc />
public EmbeddingsTypes ProviderType => EmbeddingsTypes.AzureOpenAI;
Expand All @@ -38,35 +46,52 @@ public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator
/// <param name="endpoint">Azure OpenAI endpoint (e.g., https://myservice.openai.azure.com).</param>
/// <param name="deployment">Deployment name in Azure.</param>
/// <param name="model">Model name for identification.</param>
/// <param name="apiKey">Azure OpenAI API key.</param>
/// <param name="apiKey">Azure OpenAI API key (required unless <paramref name="useManagedIdentity"/> is true).</param>
/// <param name="vectorDimensions">Vector dimensions produced by the model.</param>
/// <param name="isNormalized">Whether vectors are normalized.</param>
/// <param name="logger">Logger instance.</param>
/// <param name="batchSize">Maximum number of texts per API request.</param>
/// <param name="useManagedIdentity">Whether to authenticate using managed identity.</param>
/// <param name="credential">Optional token credential (used for testing); defaults to <see cref="DefaultAzureCredential"/>.</param>
/// <param name="delayAsync">Optional delay function for retries (used for fast unit tests).</param>
public AzureOpenAIEmbeddingGenerator(
HttpClient httpClient,
string endpoint,
string deployment,
string model,
string apiKey,
string? apiKey,
int vectorDimensions,
bool isNormalized,
ILogger<AzureOpenAIEmbeddingGenerator> logger)
ILogger<AzureOpenAIEmbeddingGenerator> logger,
int batchSize,
bool useManagedIdentity,
TokenCredential? credential = null,
Func<TimeSpan, CancellationToken, Task>? delayAsync = null)
{
ArgumentNullException.ThrowIfNull(httpClient, nameof(httpClient));
ArgumentNullException.ThrowIfNull(endpoint, nameof(endpoint));
ArgumentNullException.ThrowIfNull(deployment, nameof(deployment));
ArgumentNullException.ThrowIfNull(model, nameof(model));
ArgumentNullException.ThrowIfNull(apiKey, nameof(apiKey));
ArgumentNullException.ThrowIfNull(logger, nameof(logger));
ArgumentOutOfRangeException.ThrowIfLessThan(batchSize, 1, nameof(batchSize));

this._httpClient = httpClient;
this._endpoint = endpoint.TrimEnd('/');
this._deployment = deployment;
this._apiKey = apiKey;
this._useManagedIdentity = useManagedIdentity;
this._credential = credential;
this._batchSize = batchSize;
this.ModelName = model;
this.VectorDimensions = vectorDimensions;
this.IsNormalized = isNormalized;
this._logger = logger;
this._delayAsync = delayAsync ?? Task.Delay;

if (!this._useManagedIdentity && string.IsNullOrWhiteSpace(this._apiKey))
{
throw new ArgumentException("Azure OpenAI API key is required when not using managed identity", nameof(apiKey));
}

this._logger.LogDebug("AzureOpenAIEmbeddingGenerator initialized: {Endpoint}, deployment: {Deployment}, model: {Model}",
this._endpoint, this._deployment, this.ModelName);
Expand All @@ -88,21 +113,53 @@ public async Task<EmbeddingResult[]> GenerateAsync(IEnumerable<string> texts, Ca
return [];
}

var allResults = new List<EmbeddingResult>(textArray.Length);
foreach (var chunk in Chunk(textArray, this._batchSize))
{
var chunkResults = await this.GenerateBatchAsync(chunk, ct).ConfigureAwait(false);
allResults.AddRange(chunkResults);
}

return allResults.ToArray();
}

private async Task<EmbeddingResult[]> GenerateBatchAsync(string[] textArray, CancellationToken ct)
{
var url = $"{this._endpoint}/openai/deployments/{this._deployment}/embeddings?api-version={Constants.EmbeddingDefaults.AzureOpenAIApiVersion}";

var request = new AzureEmbeddingRequest
{
Input = textArray
};

using var httpRequest = new HttpRequestMessage(HttpMethod.Post, url);
httpRequest.Headers.Add("api-key", this._apiKey);
httpRequest.Content = JsonContent.Create(request);
var bearerToken = this._useManagedIdentity
? await this.GetManagedIdentityTokenAsync(ct).ConfigureAwait(false)
: null;
Comment thread
dluc marked this conversation as resolved.

this._logger.LogTrace("Calling Azure OpenAI embeddings API: deployment={Deployment}, batch size: {BatchSize}",
this._deployment, textArray.Length);

var response = await this._httpClient.SendAsync(httpRequest, ct).ConfigureAwait(false);
using var response = await HttpRetryPolicy.SendAsync(
this._httpClient,
requestFactory: () =>
{
var httpRequest = new HttpRequestMessage(HttpMethod.Post, url);
if (bearerToken != null)
{
httpRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", bearerToken);
}
else
{
httpRequest.Headers.Add("api-key", this._apiKey);
}

httpRequest.Content = JsonContent.Create(request);
return httpRequest;
},
this._logger,
ct,
delayAsync: this._delayAsync).ConfigureAwait(false);

response.EnsureSuccessStatusCode();

var result = await response.Content.ReadFromJsonAsync<AzureEmbeddingResponse>(ct).ConfigureAwait(false);
Expand Down Expand Up @@ -141,6 +198,26 @@ public async Task<EmbeddingResult[]> GenerateAsync(IEnumerable<string> texts, Ca
return results;
}

private async Task<string> GetManagedIdentityTokenAsync(CancellationToken ct)
{
var credential = this._credential ?? new DefaultAzureCredential();
var token = await credential.GetTokenAsync(
new TokenRequestContext(["https://cognitiveservices.azure.com/.default"]),
ct).ConfigureAwait(false);
return token.Token;
}
Comment thread
dluc marked this conversation as resolved.

private static IEnumerable<string[]> Chunk(string[] items, int chunkSize)
{
for (int i = 0; i < items.Length; i += chunkSize)
{
var length = Math.Min(chunkSize, items.Length - i);
var chunk = new string[length];
Array.Copy(items, i, chunk, 0, length);
yield return chunk;
}
}

/// <summary>
/// Request body for Azure OpenAI embeddings API.
/// </summary>
Expand Down
Loading
Loading