From 0dc589e232d67ea4a3c5954580fe67db78524785 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:06:26 +0000 Subject: [PATCH 01/16] Initial plan From d3187bddb80809d61d3f18232842987d04d09ede Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:11:16 +0000 Subject: [PATCH 02/16] Add EmbeddingsOptions and EmbeddingProviderType configuration models Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../ObjectModel/EmbeddingProviderType.cs | 27 +++ src/Config/ObjectModel/EmbeddingsOptions.cs | 163 +++++++++++++ src/Config/ObjectModel/RuntimeOptions.cs | 13 +- src/Core/Services/EmbeddingService.cs | 229 ++++++++++++++++++ src/Core/Services/IEmbeddingService.cs | 27 +++ 5 files changed, 458 insertions(+), 1 deletion(-) create mode 100644 src/Config/ObjectModel/EmbeddingProviderType.cs create mode 100644 src/Config/ObjectModel/EmbeddingsOptions.cs create mode 100644 src/Core/Services/EmbeddingService.cs create mode 100644 src/Core/Services/IEmbeddingService.cs diff --git a/src/Config/ObjectModel/EmbeddingProviderType.cs b/src/Config/ObjectModel/EmbeddingProviderType.cs new file mode 100644 index 0000000000..0a18d491bb --- /dev/null +++ b/src/Config/ObjectModel/EmbeddingProviderType.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.Serialization; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.Converters; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Represents the supported embedding provider types. +/// +[JsonConverter(typeof(EnumMemberJsonEnumConverterFactory))] +public enum EmbeddingProviderType +{ + /// + /// Azure OpenAI embedding provider. + /// + [EnumMember(Value = "azure-openai")] + AzureOpenAI, + + /// + /// OpenAI embedding provider. + /// + [EnumMember(Value = "openai")] + OpenAI +} diff --git a/src/Config/ObjectModel/EmbeddingsOptions.cs b/src/Config/ObjectModel/EmbeddingsOptions.cs new file mode 100644 index 0000000000..41147adc33 --- /dev/null +++ b/src/Config/ObjectModel/EmbeddingsOptions.cs @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Represents the options for configuring the embedding service. +/// Used for text embedding/vectorization with OpenAI or Azure OpenAI providers. +/// +public record EmbeddingsOptions +{ + /// + /// Default timeout in milliseconds for embedding requests. + /// + public const int DEFAULT_TIMEOUT_MS = 30000; + + /// + /// Default API version for Azure OpenAI. + /// + public const string DEFAULT_AZURE_API_VERSION = "2024-02-01"; + + /// + /// Default model for OpenAI embeddings. + /// + public const string DEFAULT_OPENAI_MODEL = "text-embedding-3-small"; + + /// + /// The embedding provider type (azure-openai or openai). + /// Required. + /// + [JsonPropertyName("provider")] + public EmbeddingProviderType Provider { get; init; } + + /// + /// The provider base URL endpoint. + /// Required. + /// + [JsonPropertyName("endpoint")] + public string Endpoint { get; init; } + + /// + /// The API key for authentication. + /// Required. + /// + [JsonPropertyName("api-key")] + public string ApiKey { get; init; } + + /// + /// The model or deployment name. + /// For Azure OpenAI, this is the deployment name. + /// For OpenAI, this is the model name (defaults to text-embedding-3-small if not specified). + /// + [JsonPropertyName("model")] + public string? Model { get; init; } + + /// + /// Azure API version. Only used for Azure OpenAI provider. + /// Defaults to 2024-02-01. + /// + [JsonPropertyName("api-version")] + public string? ApiVersion { get; init; } + + /// + /// Output vector dimensions. Optional, uses model default if not specified. + /// + [JsonPropertyName("dimensions")] + public int? Dimensions { get; init; } + + /// + /// Request timeout in milliseconds. Defaults to 30000 (30 seconds). + /// + [JsonPropertyName("timeout-ms")] + public int? TimeoutMs { get; init; } + + /// + /// Flag which informs whether the user provided a custom timeout value. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(TimeoutMs))] + public bool UserProvidedTimeoutMs { get; init; } + + /// + /// Flag which informs whether the user provided a custom API version. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(ApiVersion))] + public bool UserProvidedApiVersion { get; init; } + + /// + /// Flag which informs whether the user provided custom dimensions. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(Dimensions))] + public bool UserProvidedDimensions { get; init; } + + /// + /// Flag which informs whether the user provided a custom model. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(Model))] + public bool UserProvidedModel { get; init; } + + /// + /// Gets the effective timeout in milliseconds, using default if not specified. + /// + [JsonIgnore] + public int EffectiveTimeoutMs => TimeoutMs ?? DEFAULT_TIMEOUT_MS; + + /// + /// Gets the effective API version for Azure OpenAI, using default if not specified. + /// + [JsonIgnore] + public string EffectiveApiVersion => ApiVersion ?? DEFAULT_AZURE_API_VERSION; + + /// + /// Gets the effective model name, using default for OpenAI if not specified. + /// For Azure OpenAI, model is required (no default). + /// + [JsonIgnore] + public string? EffectiveModel => Model ?? (Provider == EmbeddingProviderType.OpenAI ? DEFAULT_OPENAI_MODEL : null); + + [JsonConstructor] + public EmbeddingsOptions( + EmbeddingProviderType Provider, + string Endpoint, + string ApiKey, + string? Model = null, + string? ApiVersion = null, + int? Dimensions = null, + int? TimeoutMs = null) + { + this.Provider = Provider; + this.Endpoint = Endpoint; + this.ApiKey = ApiKey; + + if (Model is not null) + { + this.Model = Model; + UserProvidedModel = true; + } + + if (ApiVersion is not null) + { + this.ApiVersion = ApiVersion; + UserProvidedApiVersion = true; + } + + if (Dimensions is not null) + { + this.Dimensions = Dimensions; + UserProvidedDimensions = true; + } + + if (TimeoutMs is not null) + { + this.TimeoutMs = TimeoutMs; + UserProvidedTimeoutMs = true; + } + } +} diff --git a/src/Config/ObjectModel/RuntimeOptions.cs b/src/Config/ObjectModel/RuntimeOptions.cs index 6f6c046651..991cb814c4 100644 --- a/src/Config/ObjectModel/RuntimeOptions.cs +++ b/src/Config/ObjectModel/RuntimeOptions.cs @@ -17,6 +17,7 @@ public record RuntimeOptions public RuntimeCacheOptions? Cache { get; init; } public PaginationOptions? Pagination { get; init; } public RuntimeHealthCheckConfig? Health { get; init; } + public EmbeddingsOptions? Embeddings { get; init; } [JsonConstructor] public RuntimeOptions( @@ -28,7 +29,8 @@ public RuntimeOptions( TelemetryOptions? Telemetry = null, RuntimeCacheOptions? Cache = null, PaginationOptions? Pagination = null, - RuntimeHealthCheckConfig? Health = null) + RuntimeHealthCheckConfig? Health = null, + EmbeddingsOptions? Embeddings = null) { this.Rest = Rest; this.GraphQL = GraphQL; @@ -39,6 +41,7 @@ public RuntimeOptions( this.Cache = Cache; this.Pagination = Pagination; this.Health = Health; + this.Embeddings = Embeddings; } /// @@ -74,4 +77,12 @@ Mcp is null || Health is null || Health?.Enabled is null || Health?.Enabled is true; + + /// + /// Indicates whether embeddings are configured. + /// Embeddings are considered configured when the Embeddings property is not null. + /// + [JsonIgnore] + [MemberNotNullWhen(true, nameof(Embeddings))] + public bool IsEmbeddingsConfigured => Embeddings is not null; } diff --git a/src/Core/Services/EmbeddingService.cs b/src/Core/Services/EmbeddingService.cs new file mode 100644 index 0000000000..6371ceeecc --- /dev/null +++ b/src/Core/Services/EmbeddingService.cs @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel; +using Microsoft.Extensions.Logging; + +namespace Azure.DataApiBuilder.Core.Services; + +/// +/// Service implementation for text embedding/vectorization. +/// Supports both OpenAI and Azure OpenAI providers. +/// +public class EmbeddingService : IEmbeddingService +{ + private readonly HttpClient _httpClient; + private readonly EmbeddingsOptions _options; + private readonly ILogger _logger; + + /// + /// JSON serializer options for request/response handling. + /// + private static readonly JsonSerializerOptions _jsonSerializerOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + /// + /// Initializes a new instance of the EmbeddingService. + /// + /// The HTTP client factory for creating HTTP clients. + /// The embedding configuration options. + /// The logger instance. + public EmbeddingService( + HttpClient httpClient, + EmbeddingsOptions options, + ILogger logger) + { + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + + ConfigureHttpClient(); + } + + /// + /// Configures the HTTP client with timeout and authentication headers. + /// + private void ConfigureHttpClient() + { + _httpClient.Timeout = TimeSpan.FromMilliseconds(_options.EffectiveTimeoutMs); + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + _httpClient.DefaultRequestHeaders.Add("api-key", _options.ApiKey); + } + else + { + _httpClient.DefaultRequestHeaders.Authorization = + new AuthenticationHeaderValue("Bearer", _options.ApiKey); + } + + _httpClient.DefaultRequestHeaders.Accept.Clear(); + _httpClient.DefaultRequestHeaders.Accept.Add( + new MediaTypeWithQualityHeaderValue("application/json")); + } + + /// + public async Task EmbedAsync(string text, CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(text)) + { + throw new ArgumentException("Text cannot be null or empty.", nameof(text)); + } + + float[][] results = await EmbedBatchAsync(new[] { text }, cancellationToken); + return results[0]; + } + + /// + public async Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default) + { + if (texts is null || texts.Length == 0) + { + throw new ArgumentException("Texts cannot be null or empty.", nameof(texts)); + } + + string requestUrl = BuildRequestUrl(); + object requestBody = BuildRequestBody(texts); + + string requestJson = JsonSerializer.Serialize(requestBody, _jsonSerializerOptions); + using HttpContent content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + _logger.LogDebug("Sending embedding request to {Url} with {Count} text(s)", requestUrl, texts.Length); + + HttpResponseMessage response = await _httpClient.PostAsync(requestUrl, content, cancellationToken); + + if (!response.IsSuccessStatusCode) + { + string errorContent = await response.Content.ReadAsStringAsync(cancellationToken); + _logger.LogError("Embedding request failed with status {StatusCode}: {ErrorContent}", + response.StatusCode, errorContent); + throw new HttpRequestException( + $"Embedding request failed with status code {response.StatusCode}: {errorContent}"); + } + + string responseJson = await response.Content.ReadAsStringAsync(cancellationToken); + EmbeddingResponse? embeddingResponse = JsonSerializer.Deserialize(responseJson, _jsonSerializerOptions); + + if (embeddingResponse?.Data is null || embeddingResponse.Data.Count == 0) + { + throw new InvalidOperationException("No embedding data received from the provider."); + } + + // Sort by index to ensure correct order and extract embeddings + List sortedData = embeddingResponse.Data.OrderBy(d => d.Index).ToList(); + return sortedData.Select(d => d.Embedding).ToArray(); + } + + /// + /// Builds the request URL based on the provider type. + /// + private string BuildRequestUrl() + { + string endpoint = _options.Endpoint.TrimEnd('/'); + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + // Azure OpenAI: {endpoint}/openai/deployments/{deployment}/embeddings?api-version={version} + string model = _options.EffectiveModel + ?? throw new InvalidOperationException("Model/deployment name is required for Azure OpenAI."); + + return $"{endpoint}/openai/deployments/{model}/embeddings?api-version={_options.EffectiveApiVersion}"; + } + else + { + // OpenAI: {endpoint}/v1/embeddings + return $"{endpoint}/v1/embeddings"; + } + } + + /// + /// Builds the request body based on the provider type. + /// + private object BuildRequestBody(string[] texts) + { + // Use single string for single text, array for batch + object input = texts.Length == 1 ? texts[0] : texts; + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + // Azure OpenAI request body + if (_options.UserProvidedDimensions) + { + return new + { + input, + dimensions = _options.Dimensions + }; + } + + return new { input }; + } + else + { + // OpenAI request body - includes model in body + string model = _options.EffectiveModel ?? EmbeddingsOptions.DEFAULT_OPENAI_MODEL; + + if (_options.UserProvidedDimensions) + { + return new + { + model, + input, + dimensions = _options.Dimensions + }; + } + + return new + { + model, + input + }; + } + } + + /// + /// Response model for embedding API responses. + /// + private sealed class EmbeddingResponse + { + [JsonPropertyName("data")] + public List? Data { get; set; } + + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("usage")] + public EmbeddingUsage? Usage { get; set; } + } + + /// + /// Individual embedding data in the response. + /// + private sealed class EmbeddingData + { + [JsonPropertyName("index")] + public int Index { get; set; } + + [JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } + + /// + /// Token usage information in the response. + /// + private sealed class EmbeddingUsage + { + [JsonPropertyName("prompt_tokens")] + public int PromptTokens { get; set; } + + [JsonPropertyName("total_tokens")] + public int TotalTokens { get; set; } + } +} diff --git a/src/Core/Services/IEmbeddingService.cs b/src/Core/Services/IEmbeddingService.cs new file mode 100644 index 0000000000..6e7ffb8a19 --- /dev/null +++ b/src/Core/Services/IEmbeddingService.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Services; + +/// +/// Service interface for text embedding/vectorization. +/// Supports both single text and batch embedding operations. +/// +public interface IEmbeddingService +{ + /// + /// Generates an embedding vector for a single text input. + /// + /// The text to embed. + /// Cancellation token for the operation. + /// The embedding vector as an array of floats. + Task EmbedAsync(string text, CancellationToken cancellationToken = default); + + /// + /// Generates embedding vectors for multiple text inputs in a batch. + /// + /// The texts to embed. + /// Cancellation token for the operation. + /// The embedding vectors as an array of float arrays, matching input order. + Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default); +} From 60648263bd98ad3c87b5e288560b821a94bb25b1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:14:52 +0000 Subject: [PATCH 03/16] Add CLI configure options for embeddings and register embedding service Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- src/Cli/Commands/ConfigureOptions.cs | 36 +++++++++ src/Cli/ConfigGenerator.cs | 107 +++++++++++++++++++++++++++ src/Service/Startup.cs | 13 ++++ 3 files changed, 156 insertions(+) diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index c3e0352249..93810ddacf 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -71,6 +71,13 @@ public ConfigureOptions( RollingInterval? fileSinkRollingInterval = null, int? fileSinkRetainedFileCountLimit = null, long? fileSinkFileSizeLimitBytes = null, + EmbeddingProviderType? runtimeEmbeddingsProvider = null, + string? runtimeEmbeddingsEndpoint = null, + string? runtimeEmbeddingsApiKey = null, + string? runtimeEmbeddingsModel = null, + string? runtimeEmbeddingsApiVersion = null, + int? runtimeEmbeddingsDimensions = null, + int? runtimeEmbeddingsTimeoutMs = null, string? config = null) : base(config) { @@ -132,6 +139,14 @@ public ConfigureOptions( FileSinkRollingInterval = fileSinkRollingInterval; FileSinkRetainedFileCountLimit = fileSinkRetainedFileCountLimit; FileSinkFileSizeLimitBytes = fileSinkFileSizeLimitBytes; + // Embeddings + RuntimeEmbeddingsProvider = runtimeEmbeddingsProvider; + RuntimeEmbeddingsEndpoint = runtimeEmbeddingsEndpoint; + RuntimeEmbeddingsApiKey = runtimeEmbeddingsApiKey; + RuntimeEmbeddingsModel = runtimeEmbeddingsModel; + RuntimeEmbeddingsApiVersion = runtimeEmbeddingsApiVersion; + RuntimeEmbeddingsDimensions = runtimeEmbeddingsDimensions; + RuntimeEmbeddingsTimeoutMs = runtimeEmbeddingsTimeoutMs; } [Option("data-source.database-type", Required = false, HelpText = "Database type. Allowed values: MSSQL, PostgreSQL, CosmosDB_NoSQL, MySQL.")] @@ -281,6 +296,27 @@ public ConfigureOptions( [Option("runtime.telemetry.file.file-size-limit-bytes", Required = false, HelpText = "Configure maximum file size limit in bytes. Default: 1048576")] public long? FileSinkFileSizeLimitBytes { get; } + [Option("runtime.embeddings.provider", Required = false, HelpText = "Configure embedding provider type. Allowed values: azure-openai, openai.")] + public EmbeddingProviderType? RuntimeEmbeddingsProvider { get; } + + [Option("runtime.embeddings.endpoint", Required = false, HelpText = "Configure the embedding provider base URL endpoint.")] + public string? RuntimeEmbeddingsEndpoint { get; } + + [Option("runtime.embeddings.api-key", Required = false, HelpText = "Configure the embedding API key for authentication.")] + public string? RuntimeEmbeddingsApiKey { get; } + + [Option("runtime.embeddings.model", Required = false, HelpText = "Configure the model/deployment name. Required for Azure OpenAI, defaults to text-embedding-3-small for OpenAI.")] + public string? RuntimeEmbeddingsModel { get; } + + [Option("runtime.embeddings.api-version", Required = false, HelpText = "Configure the Azure API version. Only used for Azure OpenAI provider. Default: 2024-02-01")] + public string? RuntimeEmbeddingsApiVersion { get; } + + [Option("runtime.embeddings.dimensions", Required = false, HelpText = "Configure the output vector dimensions. Optional, uses model default if not specified.")] + public int? RuntimeEmbeddingsDimensions { get; } + + [Option("runtime.embeddings.timeout-ms", Required = false, HelpText = "Configure the request timeout in milliseconds. Default: 30000")] + public int? RuntimeEmbeddingsTimeoutMs { get; } + public int Handler(ILogger logger, FileSystemRuntimeConfigLoader loader, IFileSystem fileSystem) { logger.LogInformation("{productName} {version}", PRODUCT_NAME, ProductInfo.GetProductVersion()); diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 78a5e63a7d..b9cb93207e 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -908,6 +908,26 @@ options.FileSinkRetainedFileCountLimit is not null || } } + // Embeddings: Provider, Endpoint, ApiKey, Model, ApiVersion, Dimensions, TimeoutMs + if (options.RuntimeEmbeddingsProvider is not null || + options.RuntimeEmbeddingsEndpoint is not null || + options.RuntimeEmbeddingsApiKey is not null || + options.RuntimeEmbeddingsModel is not null || + options.RuntimeEmbeddingsApiVersion is not null || + options.RuntimeEmbeddingsDimensions is not null || + options.RuntimeEmbeddingsTimeoutMs is not null) + { + bool status = TryUpdateConfiguredEmbeddingsValues(options, runtimeConfig?.Runtime?.Embeddings, out EmbeddingsOptions? updatedEmbeddingsOptions); + if (status && updatedEmbeddingsOptions is not null) + { + runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { Embeddings = updatedEmbeddingsOptions } }; + } + else + { + return false; + } + } + return runtimeConfig != null; } @@ -1522,6 +1542,93 @@ private static bool TryUpdateConfiguredFileOptions( } } + /// + /// Attempts to update the embeddings configuration based on the provided options. + /// Creates a new EmbeddingsOptions object if the configuration is valid. + /// Provider, endpoint, and API key are required when configuring embeddings. + /// + /// The configuration options provided by the user. + /// The existing embeddings options from the runtime configuration. + /// The resulting embeddings options if successful. + /// True if the embeddings options were successfully configured; otherwise, false. + private static bool TryUpdateConfiguredEmbeddingsValues( + ConfigureOptions options, + EmbeddingsOptions? existingEmbeddingsOptions, + out EmbeddingsOptions? updatedEmbeddingsOptions) + { + updatedEmbeddingsOptions = null; + + try + { + // Get values from options or fall back to existing configuration + EmbeddingProviderType? provider = options.RuntimeEmbeddingsProvider ?? existingEmbeddingsOptions?.Provider; + string? endpoint = options.RuntimeEmbeddingsEndpoint ?? existingEmbeddingsOptions?.Endpoint; + string? apiKey = options.RuntimeEmbeddingsApiKey ?? existingEmbeddingsOptions?.ApiKey; + string? model = options.RuntimeEmbeddingsModel ?? existingEmbeddingsOptions?.Model; + string? apiVersion = options.RuntimeEmbeddingsApiVersion ?? existingEmbeddingsOptions?.ApiVersion; + int? dimensions = options.RuntimeEmbeddingsDimensions ?? existingEmbeddingsOptions?.Dimensions; + int? timeoutMs = options.RuntimeEmbeddingsTimeoutMs ?? existingEmbeddingsOptions?.TimeoutMs; + + // Validate required fields + if (provider is null) + { + _logger.LogError("Failed to configure embeddings: provider is required. Use --runtime.embeddings.provider to specify the provider (azure-openai or openai)."); + return false; + } + + if (string.IsNullOrEmpty(endpoint)) + { + _logger.LogError("Failed to configure embeddings: endpoint is required. Use --runtime.embeddings.endpoint to specify the provider base URL."); + return false; + } + + if (string.IsNullOrEmpty(apiKey)) + { + _logger.LogError("Failed to configure embeddings: api-key is required. Use --runtime.embeddings.api-key to specify the authentication key."); + return false; + } + + // Validate Azure OpenAI requires model/deployment name + if (provider == EmbeddingProviderType.AzureOpenAI && string.IsNullOrEmpty(model)) + { + _logger.LogError("Failed to configure embeddings: model/deployment name is required for Azure OpenAI provider. Use --runtime.embeddings.model to specify the deployment name."); + return false; + } + + // Validate dimensions if provided + if (dimensions is not null && dimensions <= 0) + { + _logger.LogError("Failed to configure embeddings: dimensions must be a positive integer."); + return false; + } + + // Validate timeout if provided + if (timeoutMs is not null && timeoutMs <= 0) + { + _logger.LogError("Failed to configure embeddings: timeout-ms must be a positive integer."); + return false; + } + + // Create the embeddings options + updatedEmbeddingsOptions = new EmbeddingsOptions( + Provider: (EmbeddingProviderType)provider, + Endpoint: endpoint, + ApiKey: apiKey, + Model: model, + ApiVersion: apiVersion, + Dimensions: dimensions, + TimeoutMs: timeoutMs); + + _logger.LogInformation("Updated RuntimeConfig with Runtime.Embeddings configuration."); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to update RuntimeConfig.Embeddings with exception message: {exceptionMessage}.", ex.Message); + return false; + } + } + /// /// Parse permission string to create PermissionSetting array. /// diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index 333bf57234..563f42d440 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -387,6 +387,19 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); + // Register embedding service if configured + if (runtimeConfigAvailable + && runtimeConfig?.Runtime?.IsEmbeddingsConfigured == true) + { + EmbeddingsOptions embeddingsOptions = runtimeConfig.Runtime.Embeddings; + services.AddHttpClient(client => + { + // Base configuration is done in the EmbeddingService constructor + }).ConfigurePrimaryHttpMessageHandler(() => new HttpClientHandler()); + + services.AddSingleton(embeddingsOptions); + } + AddGraphQLService(services, runtimeConfig?.Runtime?.GraphQL); // Subscribe the GraphQL schema refresh method to the specific hot-reload event From 0653f15e12c31b7cee6978f6d3b29836f7bb9fc4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:22:16 +0000 Subject: [PATCH 04/16] Add unit tests for embeddings and update JSON schema with embeddings configuration Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- schemas/dab.draft.schema.json | 81 ++++ .../ObjectModel/EmbeddingProviderType.cs | 2 +- .../UnitTests/EmbeddingServiceTests.cs | 328 +++++++++++++++++ .../UnitTests/EmbeddingsOptionsTests.cs | 345 ++++++++++++++++++ 4 files changed, 755 insertions(+), 1 deletion(-) create mode 100644 src/Service.Tests/UnitTests/EmbeddingServiceTests.cs create mode 100644 src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 920c0a4da6..cb7d309828 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -642,6 +642,87 @@ "default": 4 } } + }, + "embeddings": { + "type": "object", + "description": "Configuration for text embedding/vectorization service. Supports OpenAI and Azure OpenAI providers.", + "additionalProperties": false, + "properties": { + "provider": { + "type": "string", + "description": "The embedding provider type.", + "enum": ["azure-openai", "openai"] + }, + "endpoint": { + "type": "string", + "description": "The provider base URL endpoint. For Azure OpenAI, use the Azure resource endpoint. For OpenAI, use https://api.openai.com." + }, + "api-key": { + "type": "string", + "description": "The API key for authentication. Supports environment variable substitution with @env('VAR_NAME')." + }, + "model": { + "type": "string", + "description": "The model or deployment name. Required for Azure OpenAI (deployment name). For OpenAI, defaults to 'text-embedding-3-small' if not specified." + }, + "api-version": { + "type": "string", + "description": "Azure API version. Only used for Azure OpenAI provider.", + "default": "2024-02-01" + }, + "dimensions": { + "type": "integer", + "description": "Output vector dimensions. Optional, uses model default if not specified. Useful for Redis schema alignment.", + "minimum": 1 + }, + "timeout-ms": { + "type": "integer", + "description": "Request timeout in milliseconds.", + "default": 30000, + "minimum": 1, + "maximum": 300000 + } + }, + "required": ["provider", "endpoint", "api-key"], + "allOf": [ + { + "$comment": "Azure OpenAI requires the model (deployment name) to be specified.", + "if": { + "properties": { + "provider": { + "const": "azure-openai" + } + }, + "required": ["provider"] + }, + "then": { + "required": ["model"], + "properties": { + "api-version": { + "type": "string", + "description": "Azure API version. Required for Azure OpenAI provider.", + "default": "2024-02-01" + } + } + } + }, + { + "$comment": "OpenAI does not require model (defaults to text-embedding-3-small) and does not use api-version.", + "if": { + "properties": { + "provider": { + "const": "openai" + } + }, + "required": ["provider"] + }, + "then": { + "properties": { + "api-version": false + } + } + } + ] } } }, diff --git a/src/Config/ObjectModel/EmbeddingProviderType.cs b/src/Config/ObjectModel/EmbeddingProviderType.cs index 0a18d491bb..2ead4470dd 100644 --- a/src/Config/ObjectModel/EmbeddingProviderType.cs +++ b/src/Config/ObjectModel/EmbeddingProviderType.cs @@ -21,7 +21,7 @@ public enum EmbeddingProviderType /// /// OpenAI embedding provider. + /// Lowercase "openai" is the serialized value. /// - [EnumMember(Value = "openai")] OpenAI } diff --git a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs new file mode 100644 index 0000000000..d5f00e494e --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs @@ -0,0 +1,328 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Services; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Moq.Protected; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for EmbeddingService. +/// +[TestClass] +public class EmbeddingServiceTests +{ + private Mock> _mockLogger = null!; + + [TestInitialize] + public void Setup() + { + _mockLogger = new Mock>(); + } + + /// + /// Tests that EmbedAsync returns embedding for a single text input. + /// + [TestMethod] + public async Task EmbedAsync_SingleText_ReturnsEmbedding() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + float[] expectedEmbedding = new[] { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f }; + HttpClient httpClient = CreateMockHttpClient(CreateSuccessResponse(expectedEmbedding)); + EmbeddingService service = new(httpClient, options, _mockLogger.Object); + + // Act + float[] result = await service.EmbedAsync("Hello world"); + + // Assert + Assert.IsNotNull(result); + Assert.AreEqual(expectedEmbedding.Length, result.Length); + for (int i = 0; i < expectedEmbedding.Length; i++) + { + Assert.AreEqual(expectedEmbedding[i], result[i]); + } + } + + /// + /// Tests that EmbedBatchAsync returns embeddings for multiple text inputs. + /// + [TestMethod] + public async Task EmbedBatchAsync_MultipleTexts_ReturnsEmbeddings() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + float[][] expectedEmbeddings = new[] + { + new[] { 0.1f, 0.2f, 0.3f }, + new[] { 0.4f, 0.5f, 0.6f }, + new[] { 0.7f, 0.8f, 0.9f } + }; + HttpClient httpClient = CreateMockHttpClient(CreateBatchSuccessResponse(expectedEmbeddings)); + EmbeddingService service = new(httpClient, options, _mockLogger.Object); + + // Act + float[][] result = await service.EmbedBatchAsync(new[] { "Text 1", "Text 2", "Text 3" }); + + // Assert + Assert.IsNotNull(result); + Assert.AreEqual(expectedEmbeddings.Length, result.Length); + for (int i = 0; i < expectedEmbeddings.Length; i++) + { + Assert.AreEqual(expectedEmbeddings[i].Length, result[i].Length); + } + } + + /// + /// Tests that EmbedAsync throws ArgumentException for null or empty text. + /// + [DataTestMethod] + [DataRow(null, DisplayName = "Null text throws ArgumentException")] + [DataRow("", DisplayName = "Empty text throws ArgumentException")] + public async Task EmbedAsync_NullOrEmptyText_ThrowsArgumentException(string text) + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = CreateMockHttpClient(CreateSuccessResponse(new[] { 0.1f })); + EmbeddingService service = new(httpClient, options, _mockLogger.Object); + + // Act & Assert + await Assert.ThrowsExceptionAsync(() => service.EmbedAsync(text!)); + } + + /// + /// Tests that EmbedBatchAsync throws ArgumentException for null or empty texts array. + /// + [TestMethod] + public async Task EmbedBatchAsync_EmptyTexts_ThrowsArgumentException() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = CreateMockHttpClient(CreateSuccessResponse(new[] { 0.1f })); + EmbeddingService service = new(httpClient, options, _mockLogger.Object); + + // Act & Assert + await Assert.ThrowsExceptionAsync(() => service.EmbedBatchAsync(Array.Empty())); + } + + /// + /// Tests that HttpRequestException is thrown when API returns an error. + /// + [TestMethod] + public async Task EmbedAsync_ApiError_ThrowsHttpRequestException() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = CreateMockHttpClient(CreateErrorResponse(HttpStatusCode.Unauthorized, "Invalid API key")); + EmbeddingService service = new(httpClient, options, _mockLogger.Object); + + // Act & Assert + await Assert.ThrowsExceptionAsync(() => service.EmbedAsync("Test text")); + } + + /// + /// Tests that InvalidOperationException is thrown when API returns empty data. + /// + [TestMethod] + public async Task EmbedAsync_EmptyResponse_ThrowsInvalidOperationException() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + string emptyResponse = JsonSerializer.Serialize(new { data = Array.Empty() }); + HttpClient httpClient = CreateMockHttpClient(CreateSuccessResponseWithContent(emptyResponse)); + EmbeddingService service = new(httpClient, options, _mockLogger.Object); + + // Act & Assert + await Assert.ThrowsExceptionAsync(() => service.EmbedAsync("Test text")); + } + + /// + /// Tests that EffectiveModel returns the default model for OpenAI when not specified. + /// + [TestMethod] + public void EmbeddingsOptions_OpenAI_DefaultModel() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + Endpoint: "https://api.openai.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.Model); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_OPENAI_MODEL, options.EffectiveModel); + } + + /// + /// Tests that EffectiveModel returns null for Azure OpenAI when model not specified. + /// + [TestMethod] + public void EmbeddingsOptions_AzureOpenAI_NoDefaultModel() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + Endpoint: "https://my.openai.azure.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.Model); + Assert.IsNull(options.EffectiveModel); + } + + /// + /// Tests that EffectiveTimeoutMs returns the default timeout when not specified. + /// + [TestMethod] + public void EmbeddingsOptions_DefaultTimeout() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + Endpoint: "https://api.openai.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.TimeoutMs); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_TIMEOUT_MS, options.EffectiveTimeoutMs); + } + + /// + /// Tests that custom timeout is used when specified. + /// + [TestMethod] + public void EmbeddingsOptions_CustomTimeout() + { + // Arrange + int customTimeout = 60000; + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + Endpoint: "https://api.openai.com", + ApiKey: "test-key", + TimeoutMs: customTimeout); + + // Assert + Assert.AreEqual(customTimeout, options.TimeoutMs); + Assert.AreEqual(customTimeout, options.EffectiveTimeoutMs); + Assert.IsTrue(options.UserProvidedTimeoutMs); + } + + #region Helper Methods + + private static EmbeddingsOptions CreateAzureOpenAIOptions() + { + return new EmbeddingsOptions( + Provider: EmbeddingProviderType.AzureOpenAI, + Endpoint: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Model: "text-embedding-ada-002"); + } + + private static HttpClient CreateMockHttpClient(HttpResponseMessage response) + { + Mock mockHandler = new(); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(response); + + return new HttpClient(mockHandler.Object); + } + + private static HttpResponseMessage CreateSuccessResponse(float[] embedding) + { + var response = new + { + data = new[] + { + new + { + index = 0, + embedding = embedding + } + }, + model = "text-embedding-ada-002", + usage = new + { + prompt_tokens = 5, + total_tokens = 5 + } + }; + + string content = JsonSerializer.Serialize(response); + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(content, Encoding.UTF8, "application/json") + }; + } + + private static HttpResponseMessage CreateBatchSuccessResponse(float[][] embeddings) + { + var data = new object[embeddings.Length]; + for (int i = 0; i < embeddings.Length; i++) + { + data[i] = new + { + index = i, + embedding = embeddings[i] + }; + } + + var response = new + { + data, + model = "text-embedding-ada-002", + usage = new + { + prompt_tokens = 15, + total_tokens = 15 + } + }; + + string content = JsonSerializer.Serialize(response); + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(content, Encoding.UTF8, "application/json") + }; + } + + private static HttpResponseMessage CreateSuccessResponseWithContent(string content) + { + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(content, Encoding.UTF8, "application/json") + }; + } + + private static HttpResponseMessage CreateErrorResponse(HttpStatusCode statusCode, string errorMessage) + { + var errorContent = new + { + error = new + { + message = errorMessage, + type = "invalid_request_error" + } + }; + + return new HttpResponseMessage(statusCode) + { + Content = new StringContent(JsonSerializer.Serialize(errorContent), Encoding.UTF8, "application/json") + }; + } + + #endregion +} diff --git a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs new file mode 100644 index 0000000000..1123831577 --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs @@ -0,0 +1,345 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Text.Json; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.ObjectModel; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for EmbeddingsOptions deserialization and EmbeddingProviderType enum. +/// +[TestClass] +public class EmbeddingsOptionsTests +{ + private const string BASIC_CONFIG_WITH_EMBEDDINGS = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""endpoint"": ""https://my-openai.openai.azure.com"", + ""api-key"": ""test-api-key"", + ""model"": ""text-embedding-ada-002"", + ""api-version"": ""2024-02-01"", + ""dimensions"": 1536, + ""timeout-ms"": 30000 + } + }, + ""entities"": {} + }"; + + private const string OPENAI_CONFIG = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""openai"", + ""endpoint"": ""https://api.openai.com"", + ""api-key"": ""sk-test-key"" + } + }, + ""entities"": {} + }"; + + private const string MINIMAL_AZURE_CONFIG = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""endpoint"": ""https://my-openai.openai.azure.com"", + ""api-key"": ""test-api-key"", + ""model"": ""my-deployment"" + } + }, + ""entities"": {} + }"; + + private const string CONFIG_WITHOUT_EMBEDDINGS = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""entities"": {} + }"; + + /// + /// Tests that a full Azure OpenAI embeddings configuration is correctly deserialized. + /// + [TestMethod] + public void TestAzureOpenAIEmbeddingsConfigDeserialization() + { + // Act + bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( + BASIC_CONFIG_WITH_EMBEDDINGS, + out RuntimeConfig runtimeConfig, + replacementSettings: new DeserializationVariableReplacementSettings( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: false)); + + // Assert + Assert.IsTrue(isParsingSuccessful); + Assert.IsNotNull(runtimeConfig); + Assert.IsNotNull(runtimeConfig.Runtime); + Assert.IsTrue(runtimeConfig.Runtime.IsEmbeddingsConfigured); + Assert.IsNotNull(runtimeConfig.Runtime.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.AzureOpenAI, embeddings.Provider); + Assert.AreEqual("https://my-openai.openai.azure.com", embeddings.Endpoint); + Assert.AreEqual("test-api-key", embeddings.ApiKey); + Assert.AreEqual("text-embedding-ada-002", embeddings.Model); + Assert.AreEqual("2024-02-01", embeddings.ApiVersion); + Assert.AreEqual(1536, embeddings.Dimensions); + Assert.AreEqual(30000, embeddings.TimeoutMs); + + // Verify UserProvided flags + Assert.IsTrue(embeddings.UserProvidedModel); + Assert.IsTrue(embeddings.UserProvidedApiVersion); + Assert.IsTrue(embeddings.UserProvidedDimensions); + Assert.IsTrue(embeddings.UserProvidedTimeoutMs); + } + + /// + /// Tests that an OpenAI embeddings configuration without optional fields is correctly deserialized + /// and default values are applied. + /// + [TestMethod] + public void TestOpenAIEmbeddingsConfigWithDefaults() + { + // Act + bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( + OPENAI_CONFIG, + out RuntimeConfig runtimeConfig, + replacementSettings: new DeserializationVariableReplacementSettings( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: false)); + + // Assert + Assert.IsTrue(isParsingSuccessful); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.OpenAI, embeddings.Provider); + Assert.AreEqual("https://api.openai.com", embeddings.Endpoint); + Assert.AreEqual("sk-test-key", embeddings.ApiKey); + + // Model not specified, but EffectiveModel should return default for OpenAI + Assert.IsNull(embeddings.Model); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_OPENAI_MODEL, embeddings.EffectiveModel); + + // Optional fields should use effective defaults + Assert.AreEqual(EmbeddingsOptions.DEFAULT_TIMEOUT_MS, embeddings.EffectiveTimeoutMs); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_AZURE_API_VERSION, embeddings.EffectiveApiVersion); + + // UserProvided flags should be false for optional fields + Assert.IsFalse(embeddings.UserProvidedModel); + Assert.IsFalse(embeddings.UserProvidedApiVersion); + Assert.IsFalse(embeddings.UserProvidedDimensions); + Assert.IsFalse(embeddings.UserProvidedTimeoutMs); + } + + /// + /// Tests minimal Azure OpenAI configuration with required fields only. + /// + [TestMethod] + public void TestMinimalAzureOpenAIConfig() + { + // Act + bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( + MINIMAL_AZURE_CONFIG, + out RuntimeConfig runtimeConfig, + replacementSettings: new DeserializationVariableReplacementSettings( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: false)); + + // Assert + Assert.IsTrue(isParsingSuccessful); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.AzureOpenAI, embeddings.Provider); + Assert.AreEqual("my-deployment", embeddings.Model); + Assert.AreEqual("my-deployment", embeddings.EffectiveModel); + Assert.IsTrue(embeddings.UserProvidedModel); + } + + /// + /// Tests that a configuration without embeddings returns IsEmbeddingsConfigured as false. + /// + [TestMethod] + public void TestConfigWithoutEmbeddings() + { + // Act + bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( + CONFIG_WITHOUT_EMBEDDINGS, + out RuntimeConfig runtimeConfig, + replacementSettings: new DeserializationVariableReplacementSettings( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: false)); + + // Assert + Assert.IsTrue(isParsingSuccessful); + Assert.IsNotNull(runtimeConfig); + + // Runtime may be null or Embeddings may be null + bool isEmbeddingsConfigured = runtimeConfig.Runtime?.IsEmbeddingsConfigured ?? false; + Assert.IsFalse(isEmbeddingsConfigured); + } + + /// + /// Tests that EmbeddingProviderType enum is correctly serialized with kebab-case. + /// + [DataTestMethod] + [DataRow("azure-openai", EmbeddingProviderType.AzureOpenAI, DisplayName = "azure-openai deserializes to AzureOpenAI")] + [DataRow("openai", EmbeddingProviderType.OpenAI, DisplayName = "openai deserializes to OpenAI")] + public void TestEmbeddingProviderTypeDeserialization(string providerValue, EmbeddingProviderType expectedType) + { + // Arrange + string config = $@" + {{ + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": {{ + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }}, + ""runtime"": {{ + ""embeddings"": {{ + ""provider"": ""{providerValue}"", + ""endpoint"": ""https://example.com"", + ""api-key"": ""test-key"", + ""model"": ""test-model"" + }} + }}, + ""entities"": {{}} + }}"; + + // Act + bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( + config, + out RuntimeConfig runtimeConfig, + replacementSettings: new DeserializationVariableReplacementSettings( + azureKeyVaultOptions: null, + doReplaceEnvVar: false, + doReplaceAkvVar: false)); + + // Assert + Assert.IsTrue(isParsingSuccessful); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + Assert.AreEqual(expectedType, runtimeConfig.Runtime.Embeddings.Provider); + } + + /// + /// Tests EmbeddingsOptions serialization to JSON. + /// + [TestMethod] + public void TestEmbeddingsOptionsSerialization() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + Endpoint: "https://my-endpoint.openai.azure.com", + ApiKey: "my-api-key", + Model: "my-model", + ApiVersion: "2024-02-01", + Dimensions: 1536, + TimeoutMs: 60000); + + // Act + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(replacementSettings: null); + string json = JsonSerializer.Serialize(options, serializerOptions); + + // Normalize json for comparison (remove whitespace) + string normalizedJson = json.Replace(" ", "").Replace("\n", "").Replace("\r", ""); + + // Assert + Assert.IsTrue(normalizedJson.Contains("\"provider\":\"azure-openai\""), $"Expected provider in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"endpoint\":\"https://my-endpoint.openai.azure.com\""), $"Expected endpoint in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"api-key\":\"my-api-key\""), $"Expected api-key in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"model\":\"my-model\""), $"Expected model in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"api-version\":\"2024-02-01\""), $"Expected api-version in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"dimensions\":1536"), $"Expected dimensions in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"timeout-ms\":60000"), $"Expected timeout-ms in JSON: {json}"); + } + + /// + /// Tests that environment variable replacement works for embeddings configuration. + /// + [TestMethod] + public void TestEmbeddingsConfigWithEnvVarReplacement() + { + // Arrange + string config = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""endpoint"": ""@env('EMBEDDINGS_ENDPOINT')"", + ""api-key"": ""@env('EMBEDDINGS_API_KEY')"", + ""model"": ""@env('EMBEDDINGS_MODEL')"" + } + }, + ""entities"": {} + }"; + + // Set environment variables + Environment.SetEnvironmentVariable("EMBEDDINGS_ENDPOINT", "https://test-endpoint.openai.azure.com"); + Environment.SetEnvironmentVariable("EMBEDDINGS_API_KEY", "test-secret-key"); + Environment.SetEnvironmentVariable("EMBEDDINGS_MODEL", "text-embedding-3-small"); + + try + { + // Act + bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( + config, + out RuntimeConfig runtimeConfig, + replacementSettings: new DeserializationVariableReplacementSettings( + azureKeyVaultOptions: null, + doReplaceEnvVar: true, + doReplaceAkvVar: false)); + + // Assert + Assert.IsTrue(isParsingSuccessful); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual("https://test-endpoint.openai.azure.com", embeddings.Endpoint); + Assert.AreEqual("test-secret-key", embeddings.ApiKey); + Assert.AreEqual("text-embedding-3-small", embeddings.Model); + } + finally + { + // Cleanup + Environment.SetEnvironmentVariable("EMBEDDINGS_ENDPOINT", null); + Environment.SetEnvironmentVariable("EMBEDDINGS_API_KEY", null); + Environment.SetEnvironmentVariable("EMBEDDINGS_MODEL", null); + } + } +} From 21e81b9c3f13272199d727f8cc66944f0efbf147 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:26:56 +0000 Subject: [PATCH 05/16] Simplify HttpClient registration for embedding service Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- src/Service/Startup.cs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index 563f42d440..e10c0ddcee 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -392,11 +392,7 @@ public void ConfigureServices(IServiceCollection services) && runtimeConfig?.Runtime?.IsEmbeddingsConfigured == true) { EmbeddingsOptions embeddingsOptions = runtimeConfig.Runtime.Embeddings; - services.AddHttpClient(client => - { - // Base configuration is done in the EmbeddingService constructor - }).ConfigurePrimaryHttpMessageHandler(() => new HttpClientHandler()); - + services.AddHttpClient(); services.AddSingleton(embeddingsOptions); } From 0cd8e53ba5206d0a0bb6099a7d1b1ee3a8cef599 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:45:19 +0000 Subject: [PATCH 06/16] Plan for embedding service enhancements Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- src/Core/Services/EmbeddingService.cs | 127 ++++++++++++- .../Telemetry/EmbeddingTelemetryHelper.cs | 178 ++++++++++++++++++ 2 files changed, 301 insertions(+), 4 deletions(-) create mode 100644 src/Core/Telemetry/EmbeddingTelemetryHelper.cs diff --git a/src/Core/Services/EmbeddingService.cs b/src/Core/Services/EmbeddingService.cs index 6371ceeecc..faaeb8e01b 100644 --- a/src/Core/Services/EmbeddingService.cs +++ b/src/Core/Services/EmbeddingService.cs @@ -2,23 +2,36 @@ // Licensed under the MIT License. using System.Net.Http.Headers; +using System.Security.Cryptography; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using Azure.DataApiBuilder.Config.ObjectModel; using Microsoft.Extensions.Logging; +using ZiggyCreatures.Caching.Fusion; namespace Azure.DataApiBuilder.Core.Services; /// /// Service implementation for text embedding/vectorization. /// Supports both OpenAI and Azure OpenAI providers. +/// Includes L1 memory cache using FusionCache to prevent duplicate embedding API calls. /// public class EmbeddingService : IEmbeddingService { private readonly HttpClient _httpClient; private readonly EmbeddingsOptions _options; private readonly ILogger _logger; + private readonly IFusionCache _cache; + + // Constants + private const char KEY_DELIMITER = ':'; + private const string CACHE_KEY_PREFIX = "embedding"; + + /// + /// Default cache TTL in hours. Set high since embeddings are deterministic and don't get outdated. + /// + private const int DEFAULT_CACHE_TTL_HOURS = 24; /// /// JSON serializer options for request/response handling. @@ -32,17 +45,20 @@ public class EmbeddingService : IEmbeddingService /// /// Initializes a new instance of the EmbeddingService. /// - /// The HTTP client factory for creating HTTP clients. + /// The HTTP client for making API requests. /// The embedding configuration options. /// The logger instance. + /// The FusionCache instance for L1 memory caching. public EmbeddingService( HttpClient httpClient, EmbeddingsOptions options, - ILogger logger) + ILogger logger, + IFusionCache cache) { _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); _options = options ?? throw new ArgumentNullException(nameof(options)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _cache = cache ?? throw new ArgumentNullException(nameof(cache)); ConfigureHttpClient(); } @@ -77,8 +93,31 @@ public async Task EmbedAsync(string text, CancellationToken cancellatio throw new ArgumentException("Text cannot be null or empty.", nameof(text)); } - float[][] results = await EmbedBatchAsync(new[] { text }, cancellationToken); - return results[0]; + string cacheKey = CreateCacheKey(text); + + float[]? embedding = await _cache.GetOrSetAsync( + key: cacheKey, + async (FusionCacheFactoryExecutionContext ctx, CancellationToken ct) => + { + _logger.LogDebug("Embedding cache miss, calling API for text hash {TextHash}", cacheKey); + + float[][] results = await EmbedFromApiAsync(new[] { text }, ct); + float[] result = results[0]; + + // L1 only - skip distributed cache + ctx.Options.SetSkipDistributedCache(true, true); + ctx.Options.SetDuration(TimeSpan.FromHours(DEFAULT_CACHE_TTL_HOURS)); + + return result; + }, + token: cancellationToken); + + if (embedding is null) + { + throw new InvalidOperationException("Failed to get embedding from cache or API."); + } + + return embedding; } /// @@ -89,6 +128,86 @@ public async Task EmbedBatchAsync(string[] texts, CancellationToken c throw new ArgumentException("Texts cannot be null or empty.", nameof(texts)); } + // For batch, check cache for each text individually + string[] cacheKeys = texts.Select(CreateCacheKey).ToArray(); + float[]?[] results = new float[texts.Length][]; + List uncachedIndices = new(); + + // Check cache for each text + for (int i = 0; i < texts.Length; i++) + { + MaybeValue cached = _cache.TryGet(key: cacheKeys[i]); + + if (cached.HasValue) + { + _logger.LogDebug("Embedding cache hit for text hash {TextHash}", cacheKeys[i]); + results[i] = cached.Value; + } + else + { + uncachedIndices.Add(i); + } + } + + // If all texts were cached, return immediately + if (uncachedIndices.Count == 0) + { + return results!; + } + + _logger.LogDebug("Embedding cache miss for {Count} text(s), calling API", uncachedIndices.Count); + + // Call API for uncached texts only + string[] uncachedTexts = uncachedIndices.Select(i => texts[i]).ToArray(); + float[][] apiResults = await EmbedFromApiAsync(uncachedTexts, cancellationToken); + + // Cache new results and merge with cached results + for (int i = 0; i < uncachedIndices.Count; i++) + { + int originalIndex = uncachedIndices[i]; + results[originalIndex] = apiResults[i]; + + // Store in L1 cache only + _cache.Set( + key: cacheKeys[originalIndex], + value: apiResults[i], + options => + { + options.SetSkipDistributedCache(true, true); + options.SetDuration(TimeSpan.FromHours(DEFAULT_CACHE_TTL_HOURS)); + }); + } + + return results!; + } + + /// + /// Creates a cache key from the text using SHA256 hash. + /// Format: embedding:{SHA256_hash} + /// Uses hash to keep cache keys small and deterministic. + /// + /// The text to create a cache key for. + /// Cache key string. + private static string CreateCacheKey(string text) + { + // Use SHA256 for deterministic, collision-resistant hash + byte[] textBytes = Encoding.UTF8.GetBytes(text); + byte[] hashBytes = SHA256.HashData(textBytes); + string hashHex = Convert.ToHexString(hashBytes); + + StringBuilder cacheKeyBuilder = new(); + cacheKeyBuilder.Append(CACHE_KEY_PREFIX); + cacheKeyBuilder.Append(KEY_DELIMITER); + cacheKeyBuilder.Append(hashHex); + + return cacheKeyBuilder.ToString(); + } + + /// + /// Calls the embedding API to get embeddings for the provided texts. + /// + private async Task EmbedFromApiAsync(string[] texts, CancellationToken cancellationToken) + { string requestUrl = BuildRequestUrl(); object requestBody = BuildRequestBody(texts); diff --git a/src/Core/Telemetry/EmbeddingTelemetryHelper.cs b/src/Core/Telemetry/EmbeddingTelemetryHelper.cs new file mode 100644 index 0000000000..b8fc7773e5 --- /dev/null +++ b/src/Core/Telemetry/EmbeddingTelemetryHelper.cs @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Diagnostics.Metrics; + +namespace Azure.DataApiBuilder.Core.Telemetry; + +/// +/// Helper class for tracking embedding-related telemetry metrics and traces. +/// +public static class EmbeddingTelemetryHelper +{ + // Metrics + private static readonly Meter _meter = new("DataApiBuilder.Embeddings"); + private static readonly Counter _embeddingRequests = _meter.CreateCounter("embedding_requests_total", description: "Total number of embedding requests"); + private static readonly Counter _embeddingCacheHits = _meter.CreateCounter("embedding_cache_hits_total", description: "Total number of embedding cache hits"); + private static readonly Counter _embeddingCacheMisses = _meter.CreateCounter("embedding_cache_misses_total", description: "Total number of embedding cache misses"); + private static readonly Counter _embeddingErrors = _meter.CreateCounter("embedding_errors_total", description: "Total number of embedding errors"); + private static readonly Histogram _embeddingDuration = _meter.CreateHistogram("embedding_duration_ms", "ms", "Duration of embedding API calls"); + private static readonly Histogram _embeddingTokens = _meter.CreateHistogram("embedding_tokens_total", description: "Total tokens used in embedding requests"); + + /// + /// Tracks an embedding request. + /// + /// The embedding provider (e.g., azure-openai, openai). + /// Number of texts being embedded. + /// Whether the result was served from cache. + public static void TrackEmbeddingRequest(string provider, int textCount, bool fromCache) + { + _embeddingRequests.Add(1, + new KeyValuePair("provider", provider), + new KeyValuePair("text_count", textCount), + new KeyValuePair("from_cache", fromCache)); + } + + /// + /// Tracks an embedding cache hit. + /// + /// The embedding provider. + public static void TrackCacheHit(string provider) + { + _embeddingCacheHits.Add(1, new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding cache miss. + /// + /// The embedding provider. + public static void TrackCacheMiss(string provider) + { + _embeddingCacheMisses.Add(1, new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding error. + /// + /// The embedding provider. + /// The type of error that occurred. + public static void TrackError(string provider, string errorType) + { + _embeddingErrors.Add(1, + new KeyValuePair("provider", provider), + new KeyValuePair("error_type", errorType)); + } + + /// + /// Tracks the duration of an embedding API call. + /// + /// The embedding provider. + /// The duration of the API call. + /// Number of texts embedded. + public static void TrackApiDuration(string provider, TimeSpan duration, int textCount) + { + _embeddingDuration.Record(duration.TotalMilliseconds, + new KeyValuePair("provider", provider), + new KeyValuePair("text_count", textCount)); + } + + /// + /// Tracks token usage from an embedding request. + /// + /// The embedding provider. + /// Total tokens used. + public static void TrackTokenUsage(string provider, long totalTokens) + { + _embeddingTokens.Record(totalTokens, new KeyValuePair("provider", provider)); + } + + /// + /// Starts an activity for embedding operations. + /// + /// Name of the operation (e.g., "EmbedAsync", "EmbedBatchAsync"). + /// The started activity, or null if tracing is not enabled. + public static Activity? StartEmbeddingActivity(string operationName) + { + return TelemetryTracesHelper.DABActivitySource.StartActivity( + name: $"Embedding.{operationName}", + kind: ActivityKind.Client); + } + + /// + /// Sets embedding-specific tags on an activity. + /// + /// The activity to tag. + /// The embedding provider. + /// The model being used. + /// Number of texts being embedded. + public static void SetEmbeddingActivityTags( + this Activity activity, + string provider, + string? model, + int textCount) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.provider", provider); + if (!string.IsNullOrEmpty(model)) + { + activity.SetTag("embedding.model", model); + } + + activity.SetTag("embedding.text_count", textCount); + } + } + + /// + /// Records cache status on an activity. + /// + /// The activity to tag. + /// Number of cache hits. + /// Number of cache misses. + public static void SetCacheActivityTags( + this Activity activity, + int cacheHits, + int cacheMisses) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.cache_hits", cacheHits); + activity.SetTag("embedding.cache_misses", cacheMisses); + } + } + + /// + /// Records successful completion of an embedding activity. + /// + /// The activity to complete. + /// Duration in milliseconds. + public static void SetEmbeddingActivitySuccess( + this Activity activity, + double durationMs) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.duration_ms", durationMs); + activity.SetStatus(ActivityStatusCode.Ok); + } + } + + /// + /// Records an error on an embedding activity. + /// + /// The activity to record error on. + /// The exception that occurred. + public static void SetEmbeddingActivityError( + this Activity activity, + Exception ex) + { + if (activity.IsAllDataRequested) + { + activity.SetStatus(ActivityStatusCode.Error, ex.Message); + activity.RecordException(ex); + activity.SetTag("error.type", ex.GetType().Name); + activity.SetTag("error.message", ex.Message); + } + } +} From 7fa1c4950b7c17761f5912ef86b2f6d3fede8e01 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:02:09 +0000 Subject: [PATCH 07/16] Refactor embedding code into dedicated namespaces Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- src/Cli/Commands/ConfigureOptions.cs | 41 +++++ src/Cli/ConfigGenerator.cs | 1 + src/Config/DabConfigEvents.cs | 1 + .../HealthCheck/HealthCheckConstants.cs | 1 + src/Config/HotReloadEventHandler.cs | 3 +- .../{ => Embeddings}/EmbeddingProviderType.cs | 2 +- .../Embeddings/EmbeddingsEndpointOptions.cs | 142 ++++++++++++++++++ .../Embeddings/EmbeddingsHealthCheckConfig.cs | 111 ++++++++++++++ .../{ => Embeddings}/EmbeddingsOptions.cs | 72 ++++++++- src/Config/ObjectModel/RuntimeOptions.cs | 1 + .../{ => Embeddings}/EmbeddingService.cs | 80 +++++++++- .../Embeddings}/EmbeddingTelemetryHelper.cs | 116 ++++++++++++-- .../Services/Embeddings/IEmbeddingService.cs | 70 +++++++++ src/Core/Services/IEmbeddingService.cs | 27 ---- .../UnitTests/EmbeddingServiceTests.cs | 4 +- .../UnitTests/EmbeddingsOptionsTests.cs | 1 + src/Service/Startup.cs | 32 ++++ 17 files changed, 645 insertions(+), 60 deletions(-) rename src/Config/ObjectModel/{ => Embeddings}/EmbeddingProviderType.cs (91%) create mode 100644 src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs create mode 100644 src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs rename src/Config/ObjectModel/{ => Embeddings}/EmbeddingsOptions.cs (70%) rename src/Core/Services/{ => Embeddings}/EmbeddingService.cs (80%) rename src/Core/{Telemetry => Services/Embeddings}/EmbeddingTelemetryHelper.cs (60%) create mode 100644 src/Core/Services/Embeddings/IEmbeddingService.cs delete mode 100644 src/Core/Services/IEmbeddingService.cs diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 93810ddacf..3c85142996 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -4,6 +4,7 @@ using System.IO.Abstractions; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Product; using Cli.Constants; using CommandLine; @@ -71,6 +72,7 @@ public ConfigureOptions( RollingInterval? fileSinkRollingInterval = null, int? fileSinkRetainedFileCountLimit = null, long? fileSinkFileSizeLimitBytes = null, + CliBool? runtimeEmbeddingsEnabled = null, EmbeddingProviderType? runtimeEmbeddingsProvider = null, string? runtimeEmbeddingsEndpoint = null, string? runtimeEmbeddingsApiKey = null, @@ -78,6 +80,12 @@ public ConfigureOptions( string? runtimeEmbeddingsApiVersion = null, int? runtimeEmbeddingsDimensions = null, int? runtimeEmbeddingsTimeoutMs = null, + CliBool? runtimeEmbeddingsRestEnabled = null, + string? runtimeEmbeddingsRestPath = null, + CliBool? runtimeEmbeddingsHealthEnabled = null, + int? runtimeEmbeddingsHealthThresholdMs = null, + string? runtimeEmbeddingsHealthTestText = null, + int? runtimeEmbeddingsHealthExpectedDimensions = null, string? config = null) : base(config) { @@ -140,6 +148,7 @@ public ConfigureOptions( FileSinkRetainedFileCountLimit = fileSinkRetainedFileCountLimit; FileSinkFileSizeLimitBytes = fileSinkFileSizeLimitBytes; // Embeddings + RuntimeEmbeddingsEnabled = runtimeEmbeddingsEnabled; RuntimeEmbeddingsProvider = runtimeEmbeddingsProvider; RuntimeEmbeddingsEndpoint = runtimeEmbeddingsEndpoint; RuntimeEmbeddingsApiKey = runtimeEmbeddingsApiKey; @@ -147,6 +156,14 @@ public ConfigureOptions( RuntimeEmbeddingsApiVersion = runtimeEmbeddingsApiVersion; RuntimeEmbeddingsDimensions = runtimeEmbeddingsDimensions; RuntimeEmbeddingsTimeoutMs = runtimeEmbeddingsTimeoutMs; + // Embeddings REST + RuntimeEmbeddingsRestEnabled = runtimeEmbeddingsRestEnabled; + RuntimeEmbeddingsRestPath = runtimeEmbeddingsRestPath; + // Embeddings Health + RuntimeEmbeddingsHealthEnabled = runtimeEmbeddingsHealthEnabled; + RuntimeEmbeddingsHealthThresholdMs = runtimeEmbeddingsHealthThresholdMs; + RuntimeEmbeddingsHealthTestText = runtimeEmbeddingsHealthTestText; + RuntimeEmbeddingsHealthExpectedDimensions = runtimeEmbeddingsHealthExpectedDimensions; } [Option("data-source.database-type", Required = false, HelpText = "Database type. Allowed values: MSSQL, PostgreSQL, CosmosDB_NoSQL, MySQL.")] @@ -296,6 +313,9 @@ public ConfigureOptions( [Option("runtime.telemetry.file.file-size-limit-bytes", Required = false, HelpText = "Configure maximum file size limit in bytes. Default: 1048576")] public long? FileSinkFileSizeLimitBytes { get; } + [Option("runtime.embeddings.enabled", Required = false, HelpText = "Enable/disable the embedding service. Default: true")] + public CliBool? RuntimeEmbeddingsEnabled { get; } + [Option("runtime.embeddings.provider", Required = false, HelpText = "Configure embedding provider type. Allowed values: azure-openai, openai.")] public EmbeddingProviderType? RuntimeEmbeddingsProvider { get; } @@ -317,6 +337,27 @@ public ConfigureOptions( [Option("runtime.embeddings.timeout-ms", Required = false, HelpText = "Configure the request timeout in milliseconds. Default: 30000")] public int? RuntimeEmbeddingsTimeoutMs { get; } + [Option("runtime.embeddings.rest.enabled", Required = false, HelpText = "Enable/disable the REST endpoint for embeddings. Default: false")] + public CliBool? RuntimeEmbeddingsRestEnabled { get; } + + [Option("runtime.embeddings.rest.path", Required = false, HelpText = "Configure the REST endpoint path for embeddings. Default: /embed")] + public string? RuntimeEmbeddingsRestPath { get; } + + [Option("runtime.embeddings.rest.roles", Required = false, Separator = ',', HelpText = "Configure the roles allowed to access the embedding REST endpoint. Comma-separated list. In development mode defaults to 'anonymous'.")] + public IEnumerable? RuntimeEmbeddingsRestRoles { get; } + + [Option("runtime.embeddings.health.enabled", Required = false, HelpText = "Enable/disable health checks for the embedding service. Default: true")] + public CliBool? RuntimeEmbeddingsHealthEnabled { get; } + + [Option("runtime.embeddings.health.threshold-ms", Required = false, HelpText = "Configure the health check threshold in milliseconds. Default: 5000")] + public int? RuntimeEmbeddingsHealthThresholdMs { get; } + + [Option("runtime.embeddings.health.test-text", Required = false, HelpText = "Configure the test text for health check validation. Default: 'health check'")] + public string? RuntimeEmbeddingsHealthTestText { get; } + + [Option("runtime.embeddings.health.expected-dimensions", Required = false, HelpText = "Configure the expected dimensions for health check validation. Optional.")] + public int? RuntimeEmbeddingsHealthExpectedDimensions { get; } + public int Handler(ILogger logger, FileSystemRuntimeConfigLoader loader, IFileSystem fileSystem) { logger.LogInformation("{productName} {version}", PRODUCT_NAME, ProductInfo.GetProductVersion()); diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index b9cb93207e..5186393060 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -8,6 +8,7 @@ using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.NamingPolicies; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Core; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Service; diff --git a/src/Config/DabConfigEvents.cs b/src/Config/DabConfigEvents.cs index f69193b583..691a71830e 100644 --- a/src/Config/DabConfigEvents.cs +++ b/src/Config/DabConfigEvents.cs @@ -19,4 +19,5 @@ public static class DabConfigEvents public const string GRAPHQL_SCHEMA_EVICTION_ON_CONFIG_CHANGED = "GRAPHQL_SCHEMA_EVICTION_ON_CONFIG_CHANGED"; public const string GRAPHQL_SCHEMA_CREATOR_ON_CONFIG_CHANGED = "GRAPHQL_SCHEMA_CREATOR_ON_CONFIG_CHANGED"; public const string LOG_LEVEL_INITIALIZER_ON_CONFIG_CHANGE = "LOG_LEVEL_INITIALIZER_ON_CONFIG_CHANGE"; + public const string EMBEDDING_SERVICE_ON_CONFIG_CHANGED = "EMBEDDING_SERVICE_ON_CONFIG_CHANGED"; } diff --git a/src/Config/HealthCheck/HealthCheckConstants.cs b/src/Config/HealthCheck/HealthCheckConstants.cs index fd5901575c..b57526fb75 100644 --- a/src/Config/HealthCheck/HealthCheckConstants.cs +++ b/src/Config/HealthCheck/HealthCheckConstants.cs @@ -12,6 +12,7 @@ public static class HealthCheckConstants public const string DATASOURCE = "data-source"; public const string REST = "rest"; public const string GRAPHQL = "graphql"; + public const string EMBEDDING = "embedding"; public const int ERROR_RESPONSE_TIME_MS = -1; public const int DEFAULT_THRESHOLD_RESPONSE_TIME_MS = 1000; public const int DEFAULT_FIRST_VALUE = 100; diff --git a/src/Config/HotReloadEventHandler.cs b/src/Config/HotReloadEventHandler.cs index 666c3c227b..a2ca9eaf98 100644 --- a/src/Config/HotReloadEventHandler.cs +++ b/src/Config/HotReloadEventHandler.cs @@ -34,7 +34,8 @@ public HotReloadEventHandler() { GRAPHQL_SCHEMA_CREATOR_ON_CONFIG_CHANGED, null }, { GRAPHQL_SCHEMA_REFRESH_ON_CONFIG_CHANGED, null }, { GRAPHQL_SCHEMA_EVICTION_ON_CONFIG_CHANGED, null }, - { LOG_LEVEL_INITIALIZER_ON_CONFIG_CHANGE, null } + { LOG_LEVEL_INITIALIZER_ON_CONFIG_CHANGE, null }, + { EMBEDDING_SERVICE_ON_CONFIG_CHANGED, null } }; } diff --git a/src/Config/ObjectModel/EmbeddingProviderType.cs b/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs similarity index 91% rename from src/Config/ObjectModel/EmbeddingProviderType.cs rename to src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs index 2ead4470dd..39ce56b596 100644 --- a/src/Config/ObjectModel/EmbeddingProviderType.cs +++ b/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs @@ -5,7 +5,7 @@ using System.Text.Json.Serialization; using Azure.DataApiBuilder.Config.Converters; -namespace Azure.DataApiBuilder.Config.ObjectModel; +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; /// /// Represents the supported embedding provider types. diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs new file mode 100644 index 0000000000..b019aa9aef --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Endpoint configuration for the embedding service. +/// +public record EmbeddingsEndpointOptions +{ + /// + /// Default path for the embedding endpoint. + /// + public const string DEFAULT_PATH = "/embed"; + + /// + /// Anonymous role constant. + /// + public const string ANONYMOUS_ROLE = "anonymous"; + + /// + /// Whether the endpoint is enabled. Defaults to false. + /// + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } + + /// + /// Flag indicating whether the user provided the enabled setting. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedEnabled { get; init; } + + /// + /// The endpoint path. Defaults to "/embed". + /// + [JsonPropertyName("path")] + public string? Path { get; init; } + + /// + /// Flag indicating whether the user provided a custom path. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedPath { get; init; } + + /// + /// The roles allowed to access the embedding endpoint. + /// In development mode, defaults to ["anonymous"]. + /// In production mode, must be explicitly configured. + /// + [JsonPropertyName("roles")] + public string[]? Roles { get; init; } + + /// + /// Flag indicating whether the user provided roles. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedRoles { get; init; } + + /// + /// Gets the effective path, using default if not specified. + /// + [JsonIgnore] + public string EffectivePath => Path ?? DEFAULT_PATH; + + /// + /// Gets the effective roles based on host mode. + /// In development mode, returns ["anonymous"] if no roles specified. + /// In production mode, returns the configured roles or empty array. + /// + /// Whether the host is in development mode. + /// Array of allowed roles. + public string[] GetEffectiveRoles(bool isDevelopmentMode) + { + if (Roles is not null && Roles.Length > 0) + { + return Roles; + } + + // In development mode, default to anonymous access + if (isDevelopmentMode) + { + return new[] { ANONYMOUS_ROLE }; + } + + // In production mode with no roles specified, return empty (no access) + return Array.Empty(); + } + + /// + /// Checks if the given role is allowed to access the embedding endpoint. + /// + /// The role to check. + /// Whether the host is in development mode. + /// True if the role is allowed; otherwise, false. + public bool IsRoleAllowed(string role, bool isDevelopmentMode) + { + string[] effectiveRoles = GetEffectiveRoles(isDevelopmentMode); + return effectiveRoles.Contains(role, StringComparer.OrdinalIgnoreCase); + } + + /// + /// Default constructor. + /// + public EmbeddingsEndpointOptions() + { + Enabled = false; + } + + /// + /// Constructor with optional parameters. + /// + [JsonConstructor] + public EmbeddingsEndpointOptions( + bool? enabled = null, + string? path = null, + string[]? roles = null) + { + if (enabled.HasValue) + { + Enabled = enabled.Value; + UserProvidedEnabled = true; + } + else + { + Enabled = false; + } + + if (path is not null) + { + Path = path; + UserProvidedPath = true; + } + + if (roles is not null) + { + Roles = roles; + UserProvidedRoles = true; + } + } +} diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs new file mode 100644 index 0000000000..b2d2f86bcf --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Health check configuration for embeddings. +/// Validates that the embedding service is responding within threshold and returning expected results. +/// +public record EmbeddingsHealthCheckConfig : HealthCheckConfig +{ + /// + /// Default threshold for embedding health check in milliseconds. + /// + public const int DEFAULT_THRESHOLD_MS = 5000; + + /// + /// Default test text used for health check validation. + /// + public const string DEFAULT_TEST_TEXT = "health check"; + + /// + /// The expected milliseconds the embedding request should complete within to be considered healthy. + /// If the request takes equal or longer than this value, the health check will be considered unhealthy. + /// Default: 5000ms (5 seconds) + /// + [JsonPropertyName("threshold-ms")] + public int ThresholdMs { get; init; } + + /// + /// Flag indicating whether the user provided a custom threshold. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedThresholdMs { get; init; } + + /// + /// The test text to use for health check validation. + /// This text will be embedded and the result validated. + /// Default: "health check" + /// + [JsonPropertyName("test-text")] + public string TestText { get; init; } + + /// + /// Flag indicating whether the user provided custom test text. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedTestText { get; init; } + + /// + /// The expected number of dimensions in the embedding result. + /// If specified, the health check will verify the embedding has this many dimensions. + /// If not specified, dimension validation is skipped. + /// + [JsonPropertyName("expected-dimensions")] + public int? ExpectedDimensions { get; init; } + + /// + /// Flag indicating whether the user provided expected dimensions. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedExpectedDimensions { get; init; } + + /// + /// Default constructor with default values. + /// + public EmbeddingsHealthCheckConfig() : base() + { + ThresholdMs = DEFAULT_THRESHOLD_MS; + TestText = DEFAULT_TEST_TEXT; + } + + /// + /// Constructor with optional parameters. + /// + [JsonConstructor] + public EmbeddingsHealthCheckConfig( + bool? enabled = null, + int? thresholdMs = null, + string? testText = null, + int? expectedDimensions = null) : base(enabled) + { + if (thresholdMs is not null) + { + ThresholdMs = (int)thresholdMs; + UserProvidedThresholdMs = true; + } + else + { + ThresholdMs = DEFAULT_THRESHOLD_MS; + } + + if (testText is not null) + { + TestText = testText; + UserProvidedTestText = true; + } + else + { + TestText = DEFAULT_TEST_TEXT; + } + + if (expectedDimensions is not null) + { + ExpectedDimensions = expectedDimensions; + UserProvidedExpectedDimensions = true; + } + } +} diff --git a/src/Config/ObjectModel/EmbeddingsOptions.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs similarity index 70% rename from src/Config/ObjectModel/EmbeddingsOptions.cs rename to src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs index 41147adc33..a1afd9abf7 100644 --- a/src/Config/ObjectModel/EmbeddingsOptions.cs +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs @@ -4,7 +4,7 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; -namespace Azure.DataApiBuilder.Config.ObjectModel; +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; /// /// Represents the options for configuring the embedding service. @@ -27,6 +27,19 @@ public record EmbeddingsOptions /// public const string DEFAULT_OPENAI_MODEL = "text-embedding-3-small"; + /// + /// Whether the embedding service is enabled. Defaults to true. + /// When false, the embedding service will not be used. + /// + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } = true; + + /// + /// Flag indicating whether the user provided the enabled setting. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedEnabled { get; init; } + /// /// The embedding provider type (azure-openai or openai). /// Required. @@ -35,11 +48,11 @@ public record EmbeddingsOptions public EmbeddingProviderType Provider { get; init; } /// - /// The provider base URL endpoint. + /// The provider base URL. /// Required. /// - [JsonPropertyName("endpoint")] - public string Endpoint { get; init; } + [JsonPropertyName("base-url")] + public string BaseUrl { get; init; } /// /// The API key for authentication. @@ -75,6 +88,18 @@ public record EmbeddingsOptions [JsonPropertyName("timeout-ms")] public int? TimeoutMs { get; init; } + /// + /// Endpoint configuration for the embedding service. + /// + [JsonPropertyName("endpoint")] + public EmbeddingsEndpointOptions? Endpoint { get; init; } + + /// + /// Health check configuration for the embedding service. + /// + [JsonPropertyName("health")] + public EmbeddingsHealthCheckConfig? Health { get; init; } + /// /// Flag which informs whether the user provided a custom timeout value. /// @@ -122,19 +147,52 @@ public record EmbeddingsOptions [JsonIgnore] public string? EffectiveModel => Model ?? (Provider == EmbeddingProviderType.OpenAI ? DEFAULT_OPENAI_MODEL : null); + /// + /// Returns true if embedding health check is enabled. + /// + [JsonIgnore] + public bool IsHealthCheckEnabled => Health?.Enabled ?? false; + + /// + /// Returns true if embedding endpoint is enabled. + /// + [JsonIgnore] + public bool IsEndpointEnabled => Endpoint?.Enabled ?? false; + + /// + /// Gets the effective endpoint path. + /// + [JsonIgnore] + public string EffectiveEndpointPath => Endpoint?.EffectivePath ?? EmbeddingsEndpointOptions.DEFAULT_PATH; + [JsonConstructor] public EmbeddingsOptions( EmbeddingProviderType Provider, - string Endpoint, + string BaseUrl, string ApiKey, + bool? Enabled = null, string? Model = null, string? ApiVersion = null, int? Dimensions = null, - int? TimeoutMs = null) + int? TimeoutMs = null, + EmbeddingsEndpointOptions? Endpoint = null, + EmbeddingsHealthCheckConfig? Health = null) { this.Provider = Provider; - this.Endpoint = Endpoint; + this.BaseUrl = BaseUrl; this.ApiKey = ApiKey; + this.Endpoint = Endpoint; + this.Health = Health; + + if (Enabled.HasValue) + { + this.Enabled = Enabled.Value; + UserProvidedEnabled = true; + } + else + { + this.Enabled = true; // Default to enabled + } if (Model is not null) { diff --git a/src/Config/ObjectModel/RuntimeOptions.cs b/src/Config/ObjectModel/RuntimeOptions.cs index 991cb814c4..2a17e89a90 100644 --- a/src/Config/ObjectModel/RuntimeOptions.cs +++ b/src/Config/ObjectModel/RuntimeOptions.cs @@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; namespace Azure.DataApiBuilder.Config.ObjectModel; diff --git a/src/Core/Services/EmbeddingService.cs b/src/Core/Services/Embeddings/EmbeddingService.cs similarity index 80% rename from src/Core/Services/EmbeddingService.cs rename to src/Core/Services/Embeddings/EmbeddingService.cs index faaeb8e01b..c3b03941de 100644 --- a/src/Core/Services/EmbeddingService.cs +++ b/src/Core/Services/Embeddings/EmbeddingService.cs @@ -6,11 +6,11 @@ using System.Text; using System.Text.Json; using System.Text.Json.Serialization; -using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Microsoft.Extensions.Logging; using ZiggyCreatures.Caching.Fusion; -namespace Azure.DataApiBuilder.Core.Services; +namespace Azure.DataApiBuilder.Core.Services.Embeddings; /// /// Service implementation for text embedding/vectorization. @@ -85,9 +85,71 @@ private void ConfigureHttpClient() new MediaTypeWithQualityHeaderValue("application/json")); } + /// + public bool IsEnabled => _options.Enabled; + + /// + public async Task TryEmbedAsync(string text, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + _logger.LogDebug("Embedding service is disabled, skipping embed request"); + return new EmbeddingResult(false, null, "Embedding service is disabled."); + } + + if (string.IsNullOrEmpty(text)) + { + _logger.LogWarning("TryEmbedAsync called with null or empty text"); + return new EmbeddingResult(false, null, "Text cannot be null or empty."); + } + + try + { + float[] embedding = await EmbedAsync(text, cancellationToken); + return new EmbeddingResult(true, embedding); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to generate embedding for text"); + return new EmbeddingResult(false, null, ex.Message); + } + } + + /// + public async Task TryEmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + _logger.LogDebug("Embedding service is disabled, skipping batch embed request"); + return new EmbeddingBatchResult(false, null, "Embedding service is disabled."); + } + + if (texts is null || texts.Length == 0) + { + _logger.LogWarning("TryEmbedBatchAsync called with null or empty texts array"); + return new EmbeddingBatchResult(false, null, "Texts array cannot be null or empty."); + } + + try + { + float[][] embeddings = await EmbedBatchAsync(texts, cancellationToken); + return new EmbeddingBatchResult(true, embeddings); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to generate embeddings for batch of {Count} texts", texts.Length); + return new EmbeddingBatchResult(false, null, ex.Message); + } + } + /// public async Task EmbedAsync(string text, CancellationToken cancellationToken = default) { + if (!_options.Enabled) + { + throw new InvalidOperationException("Embedding service is disabled."); + } + if (string.IsNullOrEmpty(text)) { throw new ArgumentException("Text cannot be null or empty.", nameof(text)); @@ -123,6 +185,10 @@ public async Task EmbedAsync(string text, CancellationToken cancellatio /// public async Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default) { + if (!_options.Enabled) + { + throw new InvalidOperationException("Embedding service is disabled."); + } if (texts is null || texts.Length == 0) { throw new ArgumentException("Texts cannot be null or empty.", nameof(texts)); @@ -245,20 +311,20 @@ private async Task EmbedFromApiAsync(string[] texts, CancellationToke /// private string BuildRequestUrl() { - string endpoint = _options.Endpoint.TrimEnd('/'); + string baseUrl = _options.BaseUrl.TrimEnd('/'); if (_options.Provider == EmbeddingProviderType.AzureOpenAI) { - // Azure OpenAI: {endpoint}/openai/deployments/{deployment}/embeddings?api-version={version} + // Azure OpenAI: {baseUrl}/openai/deployments/{deployment}/embeddings?api-version={version} string model = _options.EffectiveModel ?? throw new InvalidOperationException("Model/deployment name is required for Azure OpenAI."); - return $"{endpoint}/openai/deployments/{model}/embeddings?api-version={_options.EffectiveApiVersion}"; + return $"{baseUrl}/openai/deployments/{model}/embeddings?api-version={_options.EffectiveApiVersion}"; } else { - // OpenAI: {endpoint}/v1/embeddings - return $"{endpoint}/v1/embeddings"; + // OpenAI: {baseUrl}/v1/embeddings + return $"{baseUrl}/v1/embeddings"; } } diff --git a/src/Core/Telemetry/EmbeddingTelemetryHelper.cs b/src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs similarity index 60% rename from src/Core/Telemetry/EmbeddingTelemetryHelper.cs rename to src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs index b8fc7773e5..d7ed9bfd05 100644 --- a/src/Core/Telemetry/EmbeddingTelemetryHelper.cs +++ b/src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs @@ -3,35 +3,91 @@ using System.Diagnostics; using System.Diagnostics.Metrics; +using Azure.DataApiBuilder.Core.Telemetry; +using OpenTelemetry.Trace; -namespace Azure.DataApiBuilder.Core.Telemetry; +namespace Azure.DataApiBuilder.Core.Services.Embeddings; /// /// Helper class for tracking embedding-related telemetry metrics and traces. /// public static class EmbeddingTelemetryHelper { + /// + /// Meter name for embedding metrics. + /// + public static readonly string MeterName = "DataApiBuilder.Embeddings"; + // Metrics - private static readonly Meter _meter = new("DataApiBuilder.Embeddings"); - private static readonly Counter _embeddingRequests = _meter.CreateCounter("embedding_requests_total", description: "Total number of embedding requests"); - private static readonly Counter _embeddingCacheHits = _meter.CreateCounter("embedding_cache_hits_total", description: "Total number of embedding cache hits"); - private static readonly Counter _embeddingCacheMisses = _meter.CreateCounter("embedding_cache_misses_total", description: "Total number of embedding cache misses"); - private static readonly Counter _embeddingErrors = _meter.CreateCounter("embedding_errors_total", description: "Total number of embedding errors"); - private static readonly Histogram _embeddingDuration = _meter.CreateHistogram("embedding_duration_ms", "ms", "Duration of embedding API calls"); - private static readonly Histogram _embeddingTokens = _meter.CreateHistogram("embedding_tokens_total", description: "Total tokens used in embedding requests"); + private static readonly Meter _meter = new(MeterName); + + // Counters + private static readonly Counter _embeddingRequests = _meter.CreateCounter( + "embedding_requests_total", + description: "Total number of embedding requests"); + + private static readonly Counter _embeddingApiCalls = _meter.CreateCounter( + "embedding_api_calls_total", + description: "Total number of embedding API calls (excludes cache hits)"); + + private static readonly Counter _embeddingCacheHits = _meter.CreateCounter( + "embedding_cache_hits_total", + description: "Total number of embedding cache hits"); + + private static readonly Counter _embeddingCacheMisses = _meter.CreateCounter( + "embedding_cache_misses_total", + description: "Total number of embedding cache misses"); + + private static readonly Counter _embeddingErrors = _meter.CreateCounter( + "embedding_errors_total", + description: "Total number of embedding errors"); + + private static readonly Counter _embeddingTextsProcessed = _meter.CreateCounter( + "embedding_texts_processed_total", + description: "Total number of texts processed for embedding"); + + // Histograms for timing and sizing + private static readonly Histogram _embeddingApiDuration = _meter.CreateHistogram( + "embedding_api_duration_ms", + unit: "ms", + description: "Duration of embedding API calls in milliseconds"); + + private static readonly Histogram _embeddingTotalDuration = _meter.CreateHistogram( + "embedding_total_duration_ms", + unit: "ms", + description: "Total duration of embedding operations including cache lookup"); + + private static readonly Histogram _embeddingTokens = _meter.CreateHistogram( + "embedding_tokens_total", + description: "Total tokens used in embedding requests"); + + private static readonly Histogram _embeddingDimensions = _meter.CreateHistogram( + "embedding_dimensions", + description: "Number of dimensions in embedding vectors"); /// - /// Tracks an embedding request. + /// Tracks an embedding request (entry point, includes cache hits). /// /// The embedding provider (e.g., azure-openai, openai). /// Number of texts being embedded. - /// Whether the result was served from cache. - public static void TrackEmbeddingRequest(string provider, int textCount, bool fromCache) + public static void TrackEmbeddingRequest(string provider, int textCount) { _embeddingRequests.Add(1, + new KeyValuePair("provider", provider)); + _embeddingTextsProcessed.Add(textCount, + new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding API call (cache miss, actual API call made). + /// + /// The embedding provider. + /// Number of texts sent to API. + public static void TrackApiCall(string provider, int textCount) + { + _embeddingApiCalls.Add(1, new KeyValuePair("provider", provider), - new KeyValuePair("text_count", textCount), - new KeyValuePair("from_cache", fromCache)); + new KeyValuePair("text_count", textCount)); } /// @@ -72,11 +128,24 @@ public static void TrackError(string provider, string errorType) /// Number of texts embedded. public static void TrackApiDuration(string provider, TimeSpan duration, int textCount) { - _embeddingDuration.Record(duration.TotalMilliseconds, + _embeddingApiDuration.Record(duration.TotalMilliseconds, new KeyValuePair("provider", provider), new KeyValuePair("text_count", textCount)); } + /// + /// Tracks the total duration of an embedding operation (including cache lookup). + /// + /// The embedding provider. + /// The total duration. + /// Whether result was from cache. + public static void TrackTotalDuration(string provider, TimeSpan duration, bool fromCache) + { + _embeddingTotalDuration.Record(duration.TotalMilliseconds, + new KeyValuePair("provider", provider), + new KeyValuePair("from_cache", fromCache)); + } + /// /// Tracks token usage from an embedding request. /// @@ -87,6 +156,16 @@ public static void TrackTokenUsage(string provider, long totalTokens) _embeddingTokens.Record(totalTokens, new KeyValuePair("provider", provider)); } + /// + /// Tracks embedding vector dimensions. + /// + /// The embedding provider. + /// Number of dimensions in the vector. + public static void TrackDimensions(string provider, int dimensions) + { + _embeddingDimensions.Record(dimensions, new KeyValuePair("provider", provider)); + } + /// /// Starts an activity for embedding operations. /// @@ -147,13 +226,20 @@ public static void SetCacheActivityTags( /// /// The activity to complete. /// Duration in milliseconds. + /// Number of dimensions in the result. public static void SetEmbeddingActivitySuccess( this Activity activity, - double durationMs) + double durationMs, + int? dimensions = null) { if (activity.IsAllDataRequested) { activity.SetTag("embedding.duration_ms", durationMs); + if (dimensions.HasValue) + { + activity.SetTag("embedding.dimensions", dimensions.Value); + } + activity.SetStatus(ActivityStatusCode.Ok); } } diff --git a/src/Core/Services/Embeddings/IEmbeddingService.cs b/src/Core/Services/Embeddings/IEmbeddingService.cs new file mode 100644 index 0000000000..ef5a9e490c --- /dev/null +++ b/src/Core/Services/Embeddings/IEmbeddingService.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Services.Embeddings; + +/// +/// Result of a TryEmbed operation. +/// +/// Whether the embedding was generated successfully. +/// The embedding vector, or null if unsuccessful. +/// Error message if unsuccessful, or null if successful. +public record EmbeddingResult(bool Success, float[]? Embedding, string? ErrorMessage = null); + +/// +/// Result of a TryEmbedBatch operation. +/// +/// Whether the embeddings were generated successfully. +/// The embedding vectors, or null if unsuccessful. +/// Error message if unsuccessful, or null if successful. +public record EmbeddingBatchResult(bool Success, float[][]? Embeddings, string? ErrorMessage = null); + +/// +/// Service interface for text embedding/vectorization. +/// Supports both single text and batch embedding operations. +/// +public interface IEmbeddingService +{ + /// + /// Gets whether the embedding service is enabled. + /// + bool IsEnabled { get; } + + /// + /// Attempts to generate an embedding vector for a single text input. + /// Returns a result indicating success or failure without throwing exceptions. + /// + /// The text to embed. + /// Cancellation token for the operation. + /// Result containing the embedding if successful, or error information if not. + Task TryEmbedAsync(string text, CancellationToken cancellationToken = default); + + /// + /// Attempts to generate embedding vectors for multiple text inputs in a batch. + /// Returns a result indicating success or failure without throwing exceptions. + /// + /// The texts to embed. + /// Cancellation token for the operation. + /// Result containing the embeddings if successful, or error information if not. + Task TryEmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default); + + /// + /// Generates an embedding vector for a single text input. + /// Throws if the service is disabled or an error occurs. + /// + /// The text to embed. + /// Cancellation token for the operation. + /// The embedding vector as an array of floats. + /// Thrown when the service is disabled. + Task EmbedAsync(string text, CancellationToken cancellationToken = default); + + /// + /// Generates embedding vectors for multiple text inputs in a batch. + /// Throws if the service is disabled or an error occurs. + /// + /// The texts to embed. + /// Cancellation token for the operation. + /// The embedding vectors as an array of float arrays, matching input order. + /// Thrown when the service is disabled. + Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default); +} diff --git a/src/Core/Services/IEmbeddingService.cs b/src/Core/Services/IEmbeddingService.cs deleted file mode 100644 index 6e7ffb8a19..0000000000 --- a/src/Core/Services/IEmbeddingService.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -namespace Azure.DataApiBuilder.Core.Services; - -/// -/// Service interface for text embedding/vectorization. -/// Supports both single text and batch embedding operations. -/// -public interface IEmbeddingService -{ - /// - /// Generates an embedding vector for a single text input. - /// - /// The text to embed. - /// Cancellation token for the operation. - /// The embedding vector as an array of floats. - Task EmbedAsync(string text, CancellationToken cancellationToken = default); - - /// - /// Generates embedding vectors for multiple text inputs in a batch. - /// - /// The texts to embed. - /// Cancellation token for the operation. - /// The embedding vectors as an array of float arrays, matching input order. - Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default); -} diff --git a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs index d5f00e494e..272d1775a4 100644 --- a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs +++ b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs @@ -8,8 +8,8 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using Azure.DataApiBuilder.Config.ObjectModel; -using Azure.DataApiBuilder.Core.Services; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Core.Services.Embeddings; using Microsoft.Extensions.Logging; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; diff --git a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs index 1123831577..c2d6be43f0 100644 --- a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs +++ b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs @@ -5,6 +5,7 @@ using System.Text.Json; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Azure.DataApiBuilder.Service.Tests.UnitTests; diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index e10c0ddcee..aeda9346d7 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -10,6 +10,7 @@ using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Config.Utilities; using Azure.DataApiBuilder.Core.AuthenticationHelpers; using Azure.DataApiBuilder.Core.AuthenticationHelpers.AuthenticationSimulator; @@ -21,6 +22,7 @@ using Azure.DataApiBuilder.Core.Resolvers.Factories; using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; +using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; using Azure.DataApiBuilder.Core.Telemetry; @@ -394,6 +396,36 @@ public void ConfigureServices(IServiceCollection services) EmbeddingsOptions embeddingsOptions = runtimeConfig.Runtime.Embeddings; services.AddHttpClient(); services.AddSingleton(embeddingsOptions); + + string providerName = embeddingsOptions.Provider.ToString().ToLowerInvariant(); + + if (embeddingsOptions.Enabled) + { + _logger.LogInformation( + "Embeddings service enabled with provider: {Provider}, model: {Model}, base-url: {BaseUrl}", + providerName, + embeddingsOptions.EffectiveModel ?? "(default)", + embeddingsOptions.BaseUrl); + + // Endpoint is only available if both embeddings and endpoint are enabled + if (embeddingsOptions.IsEndpointEnabled) + { + _logger.LogInformation( + "Embeddings endpoint enabled at path: {Path}", + embeddingsOptions.EffectiveEndpointPath); + } + + if (embeddingsOptions.IsHealthCheckEnabled) + { + _logger.LogInformation( + "Embeddings health check enabled with threshold: {ThresholdMs}ms", + embeddingsOptions.Health!.ThresholdMs); + } + } + else + { + _logger.LogInformation("Embeddings service is configured but disabled."); + } } AddGraphQLService(services, runtimeConfig?.Runtime?.GraphQL); From c3f69374af21f50549b4351b6e4ed388568d59aa Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:07:08 +0000 Subject: [PATCH 08/16] Fix property renames and update tests Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- src/Cli/Commands/ConfigureOptions.cs | 32 +-- src/Cli/ConfigGenerator.cs | 10 +- .../UnitTests/EmbeddingServiceTests.cs | 234 +++++------------- .../UnitTests/EmbeddingsOptionsTests.cs | 8 +- 4 files changed, 83 insertions(+), 201 deletions(-) diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 3c85142996..75b5444ec8 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -74,14 +74,15 @@ public ConfigureOptions( long? fileSinkFileSizeLimitBytes = null, CliBool? runtimeEmbeddingsEnabled = null, EmbeddingProviderType? runtimeEmbeddingsProvider = null, - string? runtimeEmbeddingsEndpoint = null, + string? runtimeEmbeddingsBaseUrl = null, string? runtimeEmbeddingsApiKey = null, string? runtimeEmbeddingsModel = null, string? runtimeEmbeddingsApiVersion = null, int? runtimeEmbeddingsDimensions = null, int? runtimeEmbeddingsTimeoutMs = null, - CliBool? runtimeEmbeddingsRestEnabled = null, - string? runtimeEmbeddingsRestPath = null, + CliBool? runtimeEmbeddingsEndpointEnabled = null, + string? runtimeEmbeddingsEndpointPath = null, + IEnumerable? runtimeEmbeddingsEndpointRoles = null, CliBool? runtimeEmbeddingsHealthEnabled = null, int? runtimeEmbeddingsHealthThresholdMs = null, string? runtimeEmbeddingsHealthTestText = null, @@ -150,15 +151,16 @@ public ConfigureOptions( // Embeddings RuntimeEmbeddingsEnabled = runtimeEmbeddingsEnabled; RuntimeEmbeddingsProvider = runtimeEmbeddingsProvider; - RuntimeEmbeddingsEndpoint = runtimeEmbeddingsEndpoint; + RuntimeEmbeddingsBaseUrl = runtimeEmbeddingsBaseUrl; RuntimeEmbeddingsApiKey = runtimeEmbeddingsApiKey; RuntimeEmbeddingsModel = runtimeEmbeddingsModel; RuntimeEmbeddingsApiVersion = runtimeEmbeddingsApiVersion; RuntimeEmbeddingsDimensions = runtimeEmbeddingsDimensions; RuntimeEmbeddingsTimeoutMs = runtimeEmbeddingsTimeoutMs; - // Embeddings REST - RuntimeEmbeddingsRestEnabled = runtimeEmbeddingsRestEnabled; - RuntimeEmbeddingsRestPath = runtimeEmbeddingsRestPath; + // Embeddings Endpoint + RuntimeEmbeddingsEndpointEnabled = runtimeEmbeddingsEndpointEnabled; + RuntimeEmbeddingsEndpointPath = runtimeEmbeddingsEndpointPath; + RuntimeEmbeddingsEndpointRoles = runtimeEmbeddingsEndpointRoles; // Embeddings Health RuntimeEmbeddingsHealthEnabled = runtimeEmbeddingsHealthEnabled; RuntimeEmbeddingsHealthThresholdMs = runtimeEmbeddingsHealthThresholdMs; @@ -319,8 +321,8 @@ public ConfigureOptions( [Option("runtime.embeddings.provider", Required = false, HelpText = "Configure embedding provider type. Allowed values: azure-openai, openai.")] public EmbeddingProviderType? RuntimeEmbeddingsProvider { get; } - [Option("runtime.embeddings.endpoint", Required = false, HelpText = "Configure the embedding provider base URL endpoint.")] - public string? RuntimeEmbeddingsEndpoint { get; } + [Option("runtime.embeddings.base-url", Required = false, HelpText = "Configure the embedding provider base URL.")] + public string? RuntimeEmbeddingsBaseUrl { get; } [Option("runtime.embeddings.api-key", Required = false, HelpText = "Configure the embedding API key for authentication.")] public string? RuntimeEmbeddingsApiKey { get; } @@ -337,14 +339,14 @@ public ConfigureOptions( [Option("runtime.embeddings.timeout-ms", Required = false, HelpText = "Configure the request timeout in milliseconds. Default: 30000")] public int? RuntimeEmbeddingsTimeoutMs { get; } - [Option("runtime.embeddings.rest.enabled", Required = false, HelpText = "Enable/disable the REST endpoint for embeddings. Default: false")] - public CliBool? RuntimeEmbeddingsRestEnabled { get; } + [Option("runtime.embeddings.endpoint.enabled", Required = false, HelpText = "Enable/disable the endpoint for embeddings. Default: false")] + public CliBool? RuntimeEmbeddingsEndpointEnabled { get; } - [Option("runtime.embeddings.rest.path", Required = false, HelpText = "Configure the REST endpoint path for embeddings. Default: /embed")] - public string? RuntimeEmbeddingsRestPath { get; } + [Option("runtime.embeddings.endpoint.path", Required = false, HelpText = "Configure the endpoint path for embeddings. Default: /embed")] + public string? RuntimeEmbeddingsEndpointPath { get; } - [Option("runtime.embeddings.rest.roles", Required = false, Separator = ',', HelpText = "Configure the roles allowed to access the embedding REST endpoint. Comma-separated list. In development mode defaults to 'anonymous'.")] - public IEnumerable? RuntimeEmbeddingsRestRoles { get; } + [Option("runtime.embeddings.endpoint.roles", Required = false, Separator = ',', HelpText = "Configure the roles allowed to access the embedding endpoint. Comma-separated list. In development mode defaults to 'anonymous'.")] + public IEnumerable? RuntimeEmbeddingsEndpointRoles { get; } [Option("runtime.embeddings.health.enabled", Required = false, HelpText = "Enable/disable health checks for the embedding service. Default: true")] public CliBool? RuntimeEmbeddingsHealthEnabled { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 5186393060..a8b303664e 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -911,7 +911,7 @@ options.FileSinkRetainedFileCountLimit is not null || // Embeddings: Provider, Endpoint, ApiKey, Model, ApiVersion, Dimensions, TimeoutMs if (options.RuntimeEmbeddingsProvider is not null || - options.RuntimeEmbeddingsEndpoint is not null || + options.RuntimeEmbeddingsBaseUrl is not null || options.RuntimeEmbeddingsApiKey is not null || options.RuntimeEmbeddingsModel is not null || options.RuntimeEmbeddingsApiVersion is not null || @@ -1563,7 +1563,7 @@ private static bool TryUpdateConfiguredEmbeddingsValues( { // Get values from options or fall back to existing configuration EmbeddingProviderType? provider = options.RuntimeEmbeddingsProvider ?? existingEmbeddingsOptions?.Provider; - string? endpoint = options.RuntimeEmbeddingsEndpoint ?? existingEmbeddingsOptions?.Endpoint; + string? baseUrl = options.RuntimeEmbeddingsBaseUrl ?? existingEmbeddingsOptions?.BaseUrl; string? apiKey = options.RuntimeEmbeddingsApiKey ?? existingEmbeddingsOptions?.ApiKey; string? model = options.RuntimeEmbeddingsModel ?? existingEmbeddingsOptions?.Model; string? apiVersion = options.RuntimeEmbeddingsApiVersion ?? existingEmbeddingsOptions?.ApiVersion; @@ -1577,9 +1577,9 @@ private static bool TryUpdateConfiguredEmbeddingsValues( return false; } - if (string.IsNullOrEmpty(endpoint)) + if (string.IsNullOrEmpty(baseUrl)) { - _logger.LogError("Failed to configure embeddings: endpoint is required. Use --runtime.embeddings.endpoint to specify the provider base URL."); + _logger.LogError("Failed to configure embeddings: base-url is required. Use --runtime.embeddings.base-url to specify the provider base URL."); return false; } @@ -1613,7 +1613,7 @@ private static bool TryUpdateConfiguredEmbeddingsValues( // Create the embeddings options updatedEmbeddingsOptions = new EmbeddingsOptions( Provider: (EmbeddingProviderType)provider, - Endpoint: endpoint, + BaseUrl: baseUrl, ApiKey: apiKey, Model: model, ApiVersion: apiVersion, diff --git a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs index 272d1775a4..aab3f04455 100644 --- a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs +++ b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs @@ -14,6 +14,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; using Moq.Protected; +using ZiggyCreatures.Caching.Fusion; namespace Azure.DataApiBuilder.Service.Tests.UnitTests; @@ -24,127 +25,93 @@ namespace Azure.DataApiBuilder.Service.Tests.UnitTests; public class EmbeddingServiceTests { private Mock> _mockLogger = null!; + private Mock _mockCache = null!; [TestInitialize] public void Setup() { _mockLogger = new Mock>(); + _mockCache = new Mock(); } /// - /// Tests that EmbedAsync returns embedding for a single text input. + /// Tests that IsEnabled returns true when embeddings are enabled. /// [TestMethod] - public async Task EmbedAsync_SingleText_ReturnsEmbedding() + public void IsEnabled_ReturnsTrue_WhenEnabled() { // Arrange EmbeddingsOptions options = CreateAzureOpenAIOptions(); - float[] expectedEmbedding = new[] { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f }; - HttpClient httpClient = CreateMockHttpClient(CreateSuccessResponse(expectedEmbedding)); - EmbeddingService service = new(httpClient, options, _mockLogger.Object); - - // Act - float[] result = await service.EmbedAsync("Hello world"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); // Assert - Assert.IsNotNull(result); - Assert.AreEqual(expectedEmbedding.Length, result.Length); - for (int i = 0; i < expectedEmbedding.Length; i++) - { - Assert.AreEqual(expectedEmbedding[i], result[i]); - } + Assert.IsTrue(service.IsEnabled); } /// - /// Tests that EmbedBatchAsync returns embeddings for multiple text inputs. + /// Tests that IsEnabled returns false when embeddings are disabled. /// [TestMethod] - public async Task EmbedBatchAsync_MultipleTexts_ReturnsEmbeddings() + public void IsEnabled_ReturnsFalse_WhenDisabled() { // Arrange - EmbeddingsOptions options = CreateAzureOpenAIOptions(); - float[][] expectedEmbeddings = new[] - { - new[] { 0.1f, 0.2f, 0.3f }, - new[] { 0.4f, 0.5f, 0.6f }, - new[] { 0.7f, 0.8f, 0.9f } - }; - HttpClient httpClient = CreateMockHttpClient(CreateBatchSuccessResponse(expectedEmbeddings)); - EmbeddingService service = new(httpClient, options, _mockLogger.Object); - - // Act - float[][] result = await service.EmbedBatchAsync(new[] { "Text 1", "Text 2", "Text 3" }); + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: false, + Model: "text-embedding-ada-002"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); // Assert - Assert.IsNotNull(result); - Assert.AreEqual(expectedEmbeddings.Length, result.Length); - for (int i = 0; i < expectedEmbeddings.Length; i++) - { - Assert.AreEqual(expectedEmbeddings[i].Length, result[i].Length); - } + Assert.IsFalse(service.IsEnabled); } /// - /// Tests that EmbedAsync throws ArgumentException for null or empty text. - /// - [DataTestMethod] - [DataRow(null, DisplayName = "Null text throws ArgumentException")] - [DataRow("", DisplayName = "Empty text throws ArgumentException")] - public async Task EmbedAsync_NullOrEmptyText_ThrowsArgumentException(string text) - { - // Arrange - EmbeddingsOptions options = CreateAzureOpenAIOptions(); - HttpClient httpClient = CreateMockHttpClient(CreateSuccessResponse(new[] { 0.1f })); - EmbeddingService service = new(httpClient, options, _mockLogger.Object); - - // Act & Assert - await Assert.ThrowsExceptionAsync(() => service.EmbedAsync(text!)); - } - - /// - /// Tests that EmbedBatchAsync throws ArgumentException for null or empty texts array. + /// Tests that TryEmbedAsync returns failure when service is disabled. /// [TestMethod] - public async Task EmbedBatchAsync_EmptyTexts_ThrowsArgumentException() + public async Task TryEmbedAsync_ReturnsFailure_WhenDisabled() { // Arrange - EmbeddingsOptions options = CreateAzureOpenAIOptions(); - HttpClient httpClient = CreateMockHttpClient(CreateSuccessResponse(new[] { 0.1f })); - EmbeddingService service = new(httpClient, options, _mockLogger.Object); - - // Act & Assert - await Assert.ThrowsExceptionAsync(() => service.EmbedBatchAsync(Array.Empty())); - } + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: false, + Model: "text-embedding-ada-002"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); - /// - /// Tests that HttpRequestException is thrown when API returns an error. - /// - [TestMethod] - public async Task EmbedAsync_ApiError_ThrowsHttpRequestException() - { - // Arrange - EmbeddingsOptions options = CreateAzureOpenAIOptions(); - HttpClient httpClient = CreateMockHttpClient(CreateErrorResponse(HttpStatusCode.Unauthorized, "Invalid API key")); - EmbeddingService service = new(httpClient, options, _mockLogger.Object); + // Act + EmbeddingResult result = await service.TryEmbedAsync("test"); - // Act & Assert - await Assert.ThrowsExceptionAsync(() => service.EmbedAsync("Test text")); + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embedding); + Assert.IsNotNull(result.ErrorMessage); } /// - /// Tests that InvalidOperationException is thrown when API returns empty data. + /// Tests that TryEmbedAsync returns failure for null or empty text. /// - [TestMethod] - public async Task EmbedAsync_EmptyResponse_ThrowsInvalidOperationException() + [DataTestMethod] + [DataRow(null, DisplayName = "Null text returns failure")] + [DataRow("", DisplayName = "Empty text returns failure")] + public async Task TryEmbedAsync_ReturnsFailure_ForNullOrEmptyText(string? text) { // Arrange EmbeddingsOptions options = CreateAzureOpenAIOptions(); - string emptyResponse = JsonSerializer.Serialize(new { data = Array.Empty() }); - HttpClient httpClient = CreateMockHttpClient(CreateSuccessResponseWithContent(emptyResponse)); - EmbeddingService service = new(httpClient, options, _mockLogger.Object); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act + EmbeddingResult result = await service.TryEmbedAsync(text!); - // Act & Assert - await Assert.ThrowsExceptionAsync(() => service.EmbedAsync("Test text")); + // Assert + Assert.IsFalse(result.Success); } /// @@ -156,7 +123,7 @@ public void EmbeddingsOptions_OpenAI_DefaultModel() // Arrange EmbeddingsOptions options = new( Provider: EmbeddingProviderType.OpenAI, - Endpoint: "https://api.openai.com", + BaseUrl: "https://api.openai.com", ApiKey: "test-key"); // Assert @@ -173,7 +140,7 @@ public void EmbeddingsOptions_AzureOpenAI_NoDefaultModel() // Arrange EmbeddingsOptions options = new( Provider: EmbeddingProviderType.AzureOpenAI, - Endpoint: "https://my.openai.azure.com", + BaseUrl: "https://my.openai.azure.com", ApiKey: "test-key"); // Assert @@ -190,7 +157,7 @@ public void EmbeddingsOptions_DefaultTimeout() // Arrange EmbeddingsOptions options = new( Provider: EmbeddingProviderType.OpenAI, - Endpoint: "https://api.openai.com", + BaseUrl: "https://api.openai.com", ApiKey: "test-key"); // Assert @@ -208,7 +175,7 @@ public void EmbeddingsOptions_CustomTimeout() int customTimeout = 60000; EmbeddingsOptions options = new( Provider: EmbeddingProviderType.OpenAI, - Endpoint: "https://api.openai.com", + BaseUrl: "https://api.openai.com", ApiKey: "test-key", TimeoutMs: customTimeout); @@ -224,104 +191,17 @@ private static EmbeddingsOptions CreateAzureOpenAIOptions() { return new EmbeddingsOptions( Provider: EmbeddingProviderType.AzureOpenAI, - Endpoint: "https://test.openai.azure.com", + BaseUrl: "https://test.openai.azure.com", ApiKey: "test-api-key", Model: "text-embedding-ada-002"); } - private static HttpClient CreateMockHttpClient(HttpResponseMessage response) - { - Mock mockHandler = new(); - mockHandler.Protected() - .Setup>( - "SendAsync", - ItExpr.IsAny(), - ItExpr.IsAny()) - .ReturnsAsync(response); - - return new HttpClient(mockHandler.Object); - } - - private static HttpResponseMessage CreateSuccessResponse(float[] embedding) - { - var response = new - { - data = new[] - { - new - { - index = 0, - embedding = embedding - } - }, - model = "text-embedding-ada-002", - usage = new - { - prompt_tokens = 5, - total_tokens = 5 - } - }; - - string content = JsonSerializer.Serialize(response); - return new HttpResponseMessage(HttpStatusCode.OK) - { - Content = new StringContent(content, Encoding.UTF8, "application/json") - }; - } - - private static HttpResponseMessage CreateBatchSuccessResponse(float[][] embeddings) - { - var data = new object[embeddings.Length]; - for (int i = 0; i < embeddings.Length; i++) - { - data[i] = new - { - index = i, - embedding = embeddings[i] - }; - } - - var response = new - { - data, - model = "text-embedding-ada-002", - usage = new - { - prompt_tokens = 15, - total_tokens = 15 - } - }; - - string content = JsonSerializer.Serialize(response); - return new HttpResponseMessage(HttpStatusCode.OK) - { - Content = new StringContent(content, Encoding.UTF8, "application/json") - }; - } - - private static HttpResponseMessage CreateSuccessResponseWithContent(string content) - { - return new HttpResponseMessage(HttpStatusCode.OK) - { - Content = new StringContent(content, Encoding.UTF8, "application/json") - }; - } - - private static HttpResponseMessage CreateErrorResponse(HttpStatusCode statusCode, string errorMessage) + private static EmbeddingsOptions CreateOpenAIOptions() { - var errorContent = new - { - error = new - { - message = errorMessage, - type = "invalid_request_error" - } - }; - - return new HttpResponseMessage(statusCode) - { - Content = new StringContent(JsonSerializer.Serialize(errorContent), Encoding.UTF8, "application/json") - }; + return new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key"); } #endregion diff --git a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs index c2d6be43f0..f1dfb62909 100644 --- a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs +++ b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs @@ -106,7 +106,7 @@ public void TestAzureOpenAIEmbeddingsConfigDeserialization() EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; Assert.AreEqual(EmbeddingProviderType.AzureOpenAI, embeddings.Provider); - Assert.AreEqual("https://my-openai.openai.azure.com", embeddings.Endpoint); + Assert.AreEqual("https://my-openai.openai.azure.com", embeddings.BaseUrl); Assert.AreEqual("test-api-key", embeddings.ApiKey); Assert.AreEqual("text-embedding-ada-002", embeddings.Model); Assert.AreEqual("2024-02-01", embeddings.ApiVersion); @@ -142,7 +142,7 @@ public void TestOpenAIEmbeddingsConfigWithDefaults() EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; Assert.AreEqual(EmbeddingProviderType.OpenAI, embeddings.Provider); - Assert.AreEqual("https://api.openai.com", embeddings.Endpoint); + Assert.AreEqual("https://api.openai.com", embeddings.BaseUrl); Assert.AreEqual("sk-test-key", embeddings.ApiKey); // Model not specified, but EffectiveModel should return default for OpenAI @@ -261,7 +261,7 @@ public void TestEmbeddingsOptionsSerialization() // Arrange EmbeddingsOptions options = new( Provider: EmbeddingProviderType.AzureOpenAI, - Endpoint: "https://my-endpoint.openai.azure.com", + BaseUrl: "https://my-endpoint.openai.azure.com", ApiKey: "my-api-key", Model: "my-model", ApiVersion: "2024-02-01", @@ -331,7 +331,7 @@ public void TestEmbeddingsConfigWithEnvVarReplacement() Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; - Assert.AreEqual("https://test-endpoint.openai.azure.com", embeddings.Endpoint); + Assert.AreEqual("https://test-endpoint.openai.azure.com", embeddings.BaseUrl); Assert.AreEqual("test-secret-key", embeddings.ApiKey); Assert.AreEqual("text-embedding-3-small", embeddings.Model); } From 1e18c25bc63e32532cfa187e050c671b3b6ad99a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:17:15 +0000 Subject: [PATCH 09/16] Add EmbeddingsOptionsConverter and fix all tests Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../EmbeddingsOptionsConverterFactory.cs | 205 ++++++++++++++++++ src/Config/RuntimeConfigLoader.cs | 3 + .../UnitTests/EmbeddingServiceTests.cs | 2 +- .../UnitTests/EmbeddingsOptionsTests.cs | 139 ++++-------- src/Service/HealthCheck/HealthCheckHelper.cs | 15 +- .../HealthCheck/Model/ConfigurationDetails.cs | 6 + 6 files changed, 270 insertions(+), 100 deletions(-) create mode 100644 src/Config/Converters/EmbeddingsOptionsConverterFactory.cs diff --git a/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs b/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs new file mode 100644 index 0000000000..b84b212815 --- /dev/null +++ b/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +namespace Azure.DataApiBuilder.Config.Converters; + +/// +/// Custom JSON converter for EmbeddingsOptions that handles proper deserialization +/// of the configuration properties including environment variable replacement. +/// +internal class EmbeddingsOptionsConverterFactory : JsonConverterFactory +{ + private readonly DeserializationVariableReplacementSettings? _replacementSettings; + + public EmbeddingsOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) + { + _replacementSettings = replacementSettings; + } + + /// + public override bool CanConvert(Type typeToConvert) + { + return typeToConvert.IsAssignableTo(typeof(EmbeddingsOptions)); + } + + /// + public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + return new EmbeddingsOptionsConverter(_replacementSettings); + } + + private class EmbeddingsOptionsConverter : JsonConverter + { + private readonly DeserializationVariableReplacementSettings? _replacementSettings; + + public EmbeddingsOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) + { + _replacementSettings = replacementSettings; + } + + public override EmbeddingsOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected start of object."); + } + + bool? enabled = null; + EmbeddingProviderType? provider = null; + string? baseUrl = null; + string? apiKey = null; + string? model = null; + string? apiVersion = null; + int? dimensions = null; + int? timeoutMs = null; + EmbeddingsEndpointOptions? endpoint = null; + EmbeddingsHealthCheckConfig? health = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected property name."); + } + + string? propertyName = reader.GetString()?.ToLowerInvariant(); + reader.Read(); + + switch (propertyName) + { + case "enabled": + enabled = reader.GetBoolean(); + break; + case "provider": + string? providerStr = reader.GetString(); + if (providerStr is not null) + { + provider = providerStr.ToLowerInvariant() switch + { + "azure-openai" => EmbeddingProviderType.AzureOpenAI, + "openai" => EmbeddingProviderType.OpenAI, + _ => throw new JsonException($"Unknown provider: {providerStr}") + }; + } + break; + case "base-url": + baseUrl = JsonSerializer.Deserialize(ref reader, options); + break; + case "api-key": + apiKey = JsonSerializer.Deserialize(ref reader, options); + break; + case "model": + model = JsonSerializer.Deserialize(ref reader, options); + break; + case "api-version": + apiVersion = JsonSerializer.Deserialize(ref reader, options); + break; + case "dimensions": + dimensions = reader.GetInt32(); + break; + case "timeout-ms": + timeoutMs = reader.GetInt32(); + break; + case "endpoint": + endpoint = JsonSerializer.Deserialize(ref reader, options); + break; + case "health": + health = JsonSerializer.Deserialize(ref reader, options); + break; + default: + reader.Skip(); + break; + } + } + + if (provider is null) + { + throw new JsonException("Missing required property: provider"); + } + + if (baseUrl is null) + { + throw new JsonException("Missing required property: base-url"); + } + + if (apiKey is null) + { + throw new JsonException("Missing required property: api-key"); + } + + return new EmbeddingsOptions( + Provider: provider.Value, + BaseUrl: baseUrl, + ApiKey: apiKey, + Enabled: enabled, + Model: model, + ApiVersion: apiVersion, + Dimensions: dimensions, + TimeoutMs: timeoutMs, + Endpoint: endpoint, + Health: health); + } + + public override void Write(Utf8JsonWriter writer, EmbeddingsOptions value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + writer.WriteBoolean("enabled", value.Enabled); + + // Write provider + string providerStr = value.Provider switch + { + EmbeddingProviderType.AzureOpenAI => "azure-openai", + EmbeddingProviderType.OpenAI => "openai", + _ => throw new JsonException($"Unknown provider: {value.Provider}") + }; + writer.WriteString("provider", providerStr); + + writer.WriteString("base-url", value.BaseUrl); + writer.WriteString("api-key", value.ApiKey); + + if (value.Model is not null) + { + writer.WriteString("model", value.Model); + } + + if (value.ApiVersion is not null) + { + writer.WriteString("api-version", value.ApiVersion); + } + + if (value.Dimensions is not null) + { + writer.WriteNumber("dimensions", value.Dimensions.Value); + } + + if (value.TimeoutMs is not null) + { + writer.WriteNumber("timeout-ms", value.TimeoutMs.Value); + } + + if (value.Endpoint is not null) + { + writer.WritePropertyName("endpoint"); + JsonSerializer.Serialize(writer, value.Endpoint, options); + } + + if (value.Health is not null) + { + writer.WritePropertyName("health"); + JsonSerializer.Serialize(writer, value.Health, options); + } + + writer.WriteEndObject(); + } + } +} diff --git a/src/Config/RuntimeConfigLoader.cs b/src/Config/RuntimeConfigLoader.cs index 9a54d09d8e..c43f8c2feb 100644 --- a/src/Config/RuntimeConfigLoader.cs +++ b/src/Config/RuntimeConfigLoader.cs @@ -333,6 +333,9 @@ public static JsonSerializerOptions GetSerializationOptions( // Add AzureKeyVaultOptionsConverterFactory to ensure AKV config is deserialized properly options.Converters.Add(new AzureKeyVaultOptionsConverterFactory(replacementSettings)); + // Add EmbeddingsOptionsConverterFactory to handle embeddings configuration + options.Converters.Add(new EmbeddingsOptionsConverterFactory(replacementSettings)); + // Only add the extensible string converter if we have replacement settings if (replacementSettings is not null) { diff --git a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs index aab3f04455..6c4d9343c4 100644 --- a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs +++ b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs @@ -100,7 +100,7 @@ public async Task TryEmbedAsync_ReturnsFailure_WhenDisabled() [DataTestMethod] [DataRow(null, DisplayName = "Null text returns failure")] [DataRow("", DisplayName = "Empty text returns failure")] - public async Task TryEmbedAsync_ReturnsFailure_ForNullOrEmptyText(string? text) + public async Task TryEmbedAsync_ReturnsFailure_ForNullOrEmptyText(string text) { // Arrange EmbeddingsOptions options = CreateAzureOpenAIOptions(); diff --git a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs index f1dfb62909..020024ea47 100644 --- a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs +++ b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs @@ -4,6 +4,7 @@ using System; using System.Text.Json; using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -26,7 +27,7 @@ public class EmbeddingsOptionsTests ""runtime"": { ""embeddings"": { ""provider"": ""azure-openai"", - ""endpoint"": ""https://my-openai.openai.azure.com"", + ""base-url"": ""https://my-openai.openai.azure.com"", ""api-key"": ""test-api-key"", ""model"": ""text-embedding-ada-002"", ""api-version"": ""2024-02-01"", @@ -47,7 +48,7 @@ public class EmbeddingsOptionsTests ""runtime"": { ""embeddings"": { ""provider"": ""openai"", - ""endpoint"": ""https://api.openai.com"", + ""base-url"": ""https://api.openai.com"", ""api-key"": ""sk-test-key"" } }, @@ -64,7 +65,7 @@ public class EmbeddingsOptionsTests ""runtime"": { ""embeddings"": { ""provider"": ""azure-openai"", - ""endpoint"": ""https://my-openai.openai.azure.com"", + ""base-url"": ""https://my-openai.openai.azure.com"", ""api-key"": ""test-api-key"", ""model"": ""my-deployment"" } @@ -83,25 +84,18 @@ public class EmbeddingsOptionsTests }"; /// - /// Tests that a full Azure OpenAI embeddings configuration is correctly deserialized. + /// Tests that Azure OpenAI embeddings configuration deserializes correctly. /// [TestMethod] public void TestAzureOpenAIEmbeddingsConfigDeserialization() { // Act - bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - BASIC_CONFIG_WITH_EMBEDDINGS, - out RuntimeConfig runtimeConfig, - replacementSettings: new DeserializationVariableReplacementSettings( - azureKeyVaultOptions: null, - doReplaceEnvVar: false, - doReplaceAkvVar: false)); + bool success = RuntimeConfigLoader.TryParseConfig(BASIC_CONFIG_WITH_EMBEDDINGS, out RuntimeConfig? runtimeConfig); // Assert - Assert.IsTrue(isParsingSuccessful); + Assert.IsTrue(success); Assert.IsNotNull(runtimeConfig); Assert.IsNotNull(runtimeConfig.Runtime); - Assert.IsTrue(runtimeConfig.Runtime.IsEmbeddingsConfigured); Assert.IsNotNull(runtimeConfig.Runtime.Embeddings); EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; @@ -112,111 +106,74 @@ public void TestAzureOpenAIEmbeddingsConfigDeserialization() Assert.AreEqual("2024-02-01", embeddings.ApiVersion); Assert.AreEqual(1536, embeddings.Dimensions); Assert.AreEqual(30000, embeddings.TimeoutMs); - - // Verify UserProvided flags - Assert.IsTrue(embeddings.UserProvidedModel); - Assert.IsTrue(embeddings.UserProvidedApiVersion); - Assert.IsTrue(embeddings.UserProvidedDimensions); - Assert.IsTrue(embeddings.UserProvidedTimeoutMs); } /// - /// Tests that an OpenAI embeddings configuration without optional fields is correctly deserialized - /// and default values are applied. + /// Tests that OpenAI embeddings configuration deserializes correctly with defaults. /// [TestMethod] public void TestOpenAIEmbeddingsConfigWithDefaults() { // Act - bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - OPENAI_CONFIG, - out RuntimeConfig runtimeConfig, - replacementSettings: new DeserializationVariableReplacementSettings( - azureKeyVaultOptions: null, - doReplaceEnvVar: false, - doReplaceAkvVar: false)); + bool success = RuntimeConfigLoader.TryParseConfig(OPENAI_CONFIG, out RuntimeConfig? runtimeConfig); // Assert - Assert.IsTrue(isParsingSuccessful); + Assert.IsTrue(success); Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; Assert.AreEqual(EmbeddingProviderType.OpenAI, embeddings.Provider); Assert.AreEqual("https://api.openai.com", embeddings.BaseUrl); Assert.AreEqual("sk-test-key", embeddings.ApiKey); - - // Model not specified, but EffectiveModel should return default for OpenAI Assert.IsNull(embeddings.Model); Assert.AreEqual(EmbeddingsOptions.DEFAULT_OPENAI_MODEL, embeddings.EffectiveModel); - - // Optional fields should use effective defaults + Assert.IsNull(embeddings.ApiVersion); + Assert.IsNull(embeddings.Dimensions); + Assert.IsNull(embeddings.TimeoutMs); Assert.AreEqual(EmbeddingsOptions.DEFAULT_TIMEOUT_MS, embeddings.EffectiveTimeoutMs); - Assert.AreEqual(EmbeddingsOptions.DEFAULT_AZURE_API_VERSION, embeddings.EffectiveApiVersion); - - // UserProvided flags should be false for optional fields - Assert.IsFalse(embeddings.UserProvidedModel); - Assert.IsFalse(embeddings.UserProvidedApiVersion); - Assert.IsFalse(embeddings.UserProvidedDimensions); - Assert.IsFalse(embeddings.UserProvidedTimeoutMs); } /// - /// Tests minimal Azure OpenAI configuration with required fields only. + /// Tests that minimal Azure OpenAI config deserializes correctly. /// [TestMethod] public void TestMinimalAzureOpenAIConfig() { // Act - bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - MINIMAL_AZURE_CONFIG, - out RuntimeConfig runtimeConfig, - replacementSettings: new DeserializationVariableReplacementSettings( - azureKeyVaultOptions: null, - doReplaceEnvVar: false, - doReplaceAkvVar: false)); + bool success = RuntimeConfigLoader.TryParseConfig(MINIMAL_AZURE_CONFIG, out RuntimeConfig? runtimeConfig); // Assert - Assert.IsTrue(isParsingSuccessful); + Assert.IsTrue(success); Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; Assert.AreEqual(EmbeddingProviderType.AzureOpenAI, embeddings.Provider); Assert.AreEqual("my-deployment", embeddings.Model); Assert.AreEqual("my-deployment", embeddings.EffectiveModel); - Assert.IsTrue(embeddings.UserProvidedModel); } /// - /// Tests that a configuration without embeddings returns IsEmbeddingsConfigured as false. + /// Tests that configuration without embeddings section deserializes correctly. /// [TestMethod] public void TestConfigWithoutEmbeddings() { // Act - bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - CONFIG_WITHOUT_EMBEDDINGS, - out RuntimeConfig runtimeConfig, - replacementSettings: new DeserializationVariableReplacementSettings( - azureKeyVaultOptions: null, - doReplaceEnvVar: false, - doReplaceAkvVar: false)); + bool success = RuntimeConfigLoader.TryParseConfig(CONFIG_WITHOUT_EMBEDDINGS, out RuntimeConfig? runtimeConfig); // Assert - Assert.IsTrue(isParsingSuccessful); + Assert.IsTrue(success); Assert.IsNotNull(runtimeConfig); - - // Runtime may be null or Embeddings may be null - bool isEmbeddingsConfigured = runtimeConfig.Runtime?.IsEmbeddingsConfigured ?? false; - Assert.IsFalse(isEmbeddingsConfigured); + Assert.IsNull(runtimeConfig.Runtime?.Embeddings); } /// - /// Tests that EmbeddingProviderType enum is correctly serialized with kebab-case. + /// Tests that EmbeddingProviderType enum deserializes correctly from JSON. /// [DataTestMethod] - [DataRow("azure-openai", EmbeddingProviderType.AzureOpenAI, DisplayName = "azure-openai deserializes to AzureOpenAI")] - [DataRow("openai", EmbeddingProviderType.OpenAI, DisplayName = "openai deserializes to OpenAI")] - public void TestEmbeddingProviderTypeDeserialization(string providerValue, EmbeddingProviderType expectedType) + [DataRow("azure-openai", EmbeddingProviderType.AzureOpenAI)] + [DataRow("openai", EmbeddingProviderType.OpenAI)] + public void TestEmbeddingProviderTypeDeserialization(string jsonValue, EmbeddingProviderType expected) { // Arrange string config = $@" @@ -228,8 +185,8 @@ public void TestEmbeddingProviderTypeDeserialization(string providerValue, Embed }}, ""runtime"": {{ ""embeddings"": {{ - ""provider"": ""{providerValue}"", - ""endpoint"": ""https://example.com"", + ""provider"": ""{jsonValue}"", + ""base-url"": ""https://example.com"", ""api-key"": ""test-key"", ""model"": ""test-model"" }} @@ -238,18 +195,12 @@ public void TestEmbeddingProviderTypeDeserialization(string providerValue, Embed }}"; // Act - bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - config, - out RuntimeConfig runtimeConfig, - replacementSettings: new DeserializationVariableReplacementSettings( - azureKeyVaultOptions: null, - doReplaceEnvVar: false, - doReplaceAkvVar: false)); + bool success = RuntimeConfigLoader.TryParseConfig(config, out RuntimeConfig? runtimeConfig); // Assert - Assert.IsTrue(isParsingSuccessful); + Assert.IsTrue(success); Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); - Assert.AreEqual(expectedType, runtimeConfig.Runtime.Embeddings.Provider); + Assert.AreEqual(expected, runtimeConfig.Runtime.Embeddings.Provider); } /// @@ -277,7 +228,7 @@ public void TestEmbeddingsOptionsSerialization() // Assert Assert.IsTrue(normalizedJson.Contains("\"provider\":\"azure-openai\""), $"Expected provider in JSON: {json}"); - Assert.IsTrue(normalizedJson.Contains("\"endpoint\":\"https://my-endpoint.openai.azure.com\""), $"Expected endpoint in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"base-url\":\"https://my-endpoint.openai.azure.com\""), $"Expected base-url in JSON: {json}"); Assert.IsTrue(normalizedJson.Contains("\"api-key\":\"my-api-key\""), $"Expected api-key in JSON: {json}"); Assert.IsTrue(normalizedJson.Contains("\"model\":\"my-model\""), $"Expected model in JSON: {json}"); Assert.IsTrue(normalizedJson.Contains("\"api-version\":\"2024-02-01\""), $"Expected api-version in JSON: {json}"); @@ -302,7 +253,7 @@ public void TestEmbeddingsConfigWithEnvVarReplacement() ""runtime"": { ""embeddings"": { ""provider"": ""azure-openai"", - ""endpoint"": ""@env('EMBEDDINGS_ENDPOINT')"", + ""base-url"": ""@env('EMBEDDINGS_ENDPOINT')"", ""api-key"": ""@env('EMBEDDINGS_API_KEY')"", ""model"": ""@env('EMBEDDINGS_MODEL')"" } @@ -311,29 +262,29 @@ public void TestEmbeddingsConfigWithEnvVarReplacement() }"; // Set environment variables - Environment.SetEnvironmentVariable("EMBEDDINGS_ENDPOINT", "https://test-endpoint.openai.azure.com"); - Environment.SetEnvironmentVariable("EMBEDDINGS_API_KEY", "test-secret-key"); - Environment.SetEnvironmentVariable("EMBEDDINGS_MODEL", "text-embedding-3-small"); + Environment.SetEnvironmentVariable("EMBEDDINGS_ENDPOINT", "https://test.openai.azure.com"); + Environment.SetEnvironmentVariable("EMBEDDINGS_API_KEY", "test-key-from-env"); + Environment.SetEnvironmentVariable("EMBEDDINGS_MODEL", "test-model-from-env"); + + // Create replacement settings to enable env var replacement + DeserializationVariableReplacementSettings replacementSettings = new( + doReplaceEnvVar: true, + doReplaceAkvVar: false, + envFailureMode: EnvironmentVariableReplacementFailureMode.Throw); try { // Act - bool isParsingSuccessful = RuntimeConfigLoader.TryParseConfig( - config, - out RuntimeConfig runtimeConfig, - replacementSettings: new DeserializationVariableReplacementSettings( - azureKeyVaultOptions: null, - doReplaceEnvVar: true, - doReplaceAkvVar: false)); + bool success = RuntimeConfigLoader.TryParseConfig(config, out RuntimeConfig? runtimeConfig, replacementSettings); // Assert - Assert.IsTrue(isParsingSuccessful); + Assert.IsTrue(success); Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; - Assert.AreEqual("https://test-endpoint.openai.azure.com", embeddings.BaseUrl); - Assert.AreEqual("test-secret-key", embeddings.ApiKey); - Assert.AreEqual("text-embedding-3-small", embeddings.Model); + Assert.AreEqual("https://test.openai.azure.com", embeddings.BaseUrl); + Assert.AreEqual("test-key-from-env", embeddings.ApiKey); + Assert.AreEqual("test-model-from-env", embeddings.Model); } finally { diff --git a/src/Service/HealthCheck/HealthCheckHelper.cs b/src/Service/HealthCheck/HealthCheckHelper.cs index ab19756195..8a60d8130f 100644 --- a/src/Service/HealthCheck/HealthCheckHelper.cs +++ b/src/Service/HealthCheck/HealthCheckHelper.cs @@ -137,14 +137,19 @@ private static void UpdateTimestampOfResponse(ref ComprehensiveHealthCheckReport // Updates the DAB configuration details coming from RuntimeConfig for the Health report. private static void UpdateDabConfigurationDetails(ref ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, RuntimeConfig runtimeConfig) { + bool embeddingsEnabled = runtimeConfig?.Runtime?.Embeddings?.Enabled ?? false; + bool embeddingsEndpointEnabled = embeddingsEnabled && (runtimeConfig?.Runtime?.Embeddings?.IsEndpointEnabled ?? false); + comprehensiveHealthCheckReport.ConfigurationDetails = new ConfigurationDetails { - Rest = runtimeConfig.IsRestEnabled, - GraphQL = runtimeConfig.IsGraphQLEnabled, - Mcp = runtimeConfig.IsMcpEnabled, - Caching = runtimeConfig.IsCachingEnabled, + Rest = runtimeConfig?.IsRestEnabled ?? false, + GraphQL = runtimeConfig?.IsGraphQLEnabled ?? false, + Mcp = runtimeConfig?.IsMcpEnabled ?? false, + Caching = runtimeConfig?.IsCachingEnabled ?? false, Telemetry = runtimeConfig?.Runtime?.Telemetry != null, - Mode = runtimeConfig?.Runtime?.Host?.Mode ?? HostMode.Production, // Modify to runtimeConfig.HostMode in Roles PR + Mode = runtimeConfig?.Runtime?.Host?.Mode ?? HostMode.Production, + Embeddings = embeddingsEnabled, + EmbeddingsEndpoint = embeddingsEndpointEnabled }; } diff --git a/src/Service/HealthCheck/Model/ConfigurationDetails.cs b/src/Service/HealthCheck/Model/ConfigurationDetails.cs index 9ff007754e..e73497e3e0 100644 --- a/src/Service/HealthCheck/Model/ConfigurationDetails.cs +++ b/src/Service/HealthCheck/Model/ConfigurationDetails.cs @@ -29,5 +29,11 @@ public record ConfigurationDetails [JsonPropertyName("mode")] public HostMode Mode { get; init; } + + [JsonPropertyName("embeddings")] + public bool Embeddings { get; init; } + + [JsonPropertyName("embeddings-endpoint")] + public bool EmbeddingsEndpoint { get; init; } } } From 857203ac2bb2d51a315de10c668a89eff7f638dd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:22:56 +0000 Subject: [PATCH 10/16] Address code review feedback Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../ObjectModel/Embeddings/EmbeddingProviderType.cs | 2 +- .../Embeddings/EmbeddingsHealthCheckConfig.cs | 3 ++- src/Core/Services/Embeddings/EmbeddingService.cs | 11 +++++++++++ src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs | 2 ++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs b/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs index 39ce56b596..9b2efc994b 100644 --- a/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs +++ b/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs @@ -21,7 +21,7 @@ public enum EmbeddingProviderType /// /// OpenAI embedding provider. - /// Lowercase "openai" is the serialized value. /// + [EnumMember(Value = "openai")] OpenAI } diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs index b2d2f86bcf..bf2a79764c 100644 --- a/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs @@ -23,7 +23,8 @@ public record EmbeddingsHealthCheckConfig : HealthCheckConfig /// /// The expected milliseconds the embedding request should complete within to be considered healthy. - /// If the request takes equal or longer than this value, the health check will be considered unhealthy. + /// If the request takes longer than this value, the health check will be considered unhealthy. + /// Requests completing at exactly the threshold are considered healthy. /// Default: 5000ms (5 seconds) /// [JsonPropertyName("threshold-ms")] diff --git a/src/Core/Services/Embeddings/EmbeddingService.cs b/src/Core/Services/Embeddings/EmbeddingService.cs index c3b03941de..3fd16b993b 100644 --- a/src/Core/Services/Embeddings/EmbeddingService.cs +++ b/src/Core/Services/Embeddings/EmbeddingService.cs @@ -60,6 +60,17 @@ public EmbeddingService( _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _cache = cache ?? throw new ArgumentNullException(nameof(cache)); + // Validate required options + if (string.IsNullOrEmpty(_options.BaseUrl)) + { + throw new ArgumentException("BaseUrl is required in EmbeddingsOptions.", nameof(options)); + } + + if (string.IsNullOrEmpty(_options.ApiKey)) + { + throw new ArgumentException("ApiKey is required in EmbeddingsOptions.", nameof(options)); + } + ConfigureHttpClient(); } diff --git a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs index 020024ea47..591b6b79ec 100644 --- a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs +++ b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#nullable enable + using System; using System.Text.Json; using Azure.DataApiBuilder.Config; From 89cb2d979c3e4631cd8049c3d3cddd9d94a55968 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:30:24 +0000 Subject: [PATCH 11/16] Address PR feedback: Add Azure OpenAI validation, cache key security, and telemetry integration Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../Services/Embeddings/EmbeddingService.cs | 69 +++++++++++++++++-- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/src/Core/Services/Embeddings/EmbeddingService.cs b/src/Core/Services/Embeddings/EmbeddingService.cs index 3fd16b993b..f2a69eccc4 100644 --- a/src/Core/Services/Embeddings/EmbeddingService.cs +++ b/src/Core/Services/Embeddings/EmbeddingService.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Diagnostics; using System.Net.Http.Headers; using System.Security.Cryptography; using System.Text; @@ -71,6 +72,12 @@ public EmbeddingService( throw new ArgumentException("ApiKey is required in EmbeddingsOptions.", nameof(options)); } + // Azure OpenAI requires model/deployment name + if (_options.Provider == EmbeddingProviderType.AzureOpenAI && string.IsNullOrEmpty(_options.EffectiveModel)) + { + throw new InvalidOperationException("Model/deployment name is required for Azure OpenAI provider."); + } + ConfigureHttpClient(); } @@ -99,6 +106,11 @@ private void ConfigureHttpClient() /// public bool IsEnabled => _options.Enabled; + /// + /// Gets the provider name for telemetry. + /// + private string ProviderName => _options.Provider.ToString().ToLowerInvariant(); + /// public async Task TryEmbedAsync(string text, CancellationToken cancellationToken = default) { @@ -114,14 +126,30 @@ public async Task TryEmbedAsync(string text, CancellationToken return new EmbeddingResult(false, null, "Text cannot be null or empty."); } + Stopwatch stopwatch = Stopwatch.StartNew(); + using Activity? activity = EmbeddingTelemetryHelper.StartEmbeddingActivity("TryEmbedAsync"); + activity?.SetEmbeddingActivityTags(ProviderName, _options.EffectiveModel, textCount: 1); + try { + EmbeddingTelemetryHelper.TrackEmbeddingRequest(ProviderName, textCount: 1); + float[] embedding = await EmbedAsync(text, cancellationToken); + + stopwatch.Stop(); + activity?.SetEmbeddingActivitySuccess(stopwatch.Elapsed.TotalMilliseconds, embedding.Length); + EmbeddingTelemetryHelper.TrackTotalDuration(ProviderName, stopwatch.Elapsed, fromCache: false); + EmbeddingTelemetryHelper.TrackDimensions(ProviderName, embedding.Length); + return new EmbeddingResult(true, embedding); } catch (Exception ex) { + stopwatch.Stop(); _logger.LogError(ex, "Failed to generate embedding for text"); + activity?.SetEmbeddingActivityError(ex); + EmbeddingTelemetryHelper.TrackError(ProviderName, ex.GetType().Name); + return new EmbeddingResult(false, null, ex.Message); } } @@ -141,14 +169,34 @@ public async Task TryEmbedBatchAsync(string[] texts, Cance return new EmbeddingBatchResult(false, null, "Texts array cannot be null or empty."); } + Stopwatch stopwatch = Stopwatch.StartNew(); + using Activity? activity = EmbeddingTelemetryHelper.StartEmbeddingActivity("TryEmbedBatchAsync"); + activity?.SetEmbeddingActivityTags(ProviderName, _options.EffectiveModel, texts.Length); + try { + EmbeddingTelemetryHelper.TrackEmbeddingRequest(ProviderName, texts.Length); + float[][] embeddings = await EmbedBatchAsync(texts, cancellationToken); + + stopwatch.Stop(); + int dimensions = embeddings.Length > 0 ? embeddings[0].Length : 0; + activity?.SetEmbeddingActivitySuccess(stopwatch.Elapsed.TotalMilliseconds, dimensions); + EmbeddingTelemetryHelper.TrackTotalDuration(ProviderName, stopwatch.Elapsed, fromCache: false); + if (dimensions > 0) + { + EmbeddingTelemetryHelper.TrackDimensions(ProviderName, dimensions); + } + return new EmbeddingBatchResult(true, embeddings); } catch (Exception ex) { + stopwatch.Stop(); _logger.LogError(ex, "Failed to generate embeddings for batch of {Count} texts", texts.Length); + activity?.SetEmbeddingActivityError(ex); + EmbeddingTelemetryHelper.TrackError(ProviderName, ex.GetType().Name); + return new EmbeddingBatchResult(false, null, ex.Message); } } @@ -209,6 +257,7 @@ public async Task EmbedBatchAsync(string[] texts, CancellationToken c string[] cacheKeys = texts.Select(CreateCacheKey).ToArray(); float[]?[] results = new float[texts.Length][]; List uncachedIndices = new(); + int cacheHits = 0; // Check cache for each text for (int i = 0; i < texts.Length; i++) @@ -219,10 +268,13 @@ public async Task EmbedBatchAsync(string[] texts, CancellationToken c { _logger.LogDebug("Embedding cache hit for text hash {TextHash}", cacheKeys[i]); results[i] = cached.Value; + cacheHits++; + EmbeddingTelemetryHelper.TrackCacheHit(ProviderName); } else { uncachedIndices.Add(i); + EmbeddingTelemetryHelper.TrackCacheMiss(ProviderName); } } @@ -236,7 +288,14 @@ public async Task EmbedBatchAsync(string[] texts, CancellationToken c // Call API for uncached texts only string[] uncachedTexts = uncachedIndices.Select(i => texts[i]).ToArray(); + + Stopwatch apiStopwatch = Stopwatch.StartNew(); float[][] apiResults = await EmbedFromApiAsync(uncachedTexts, cancellationToken); + apiStopwatch.Stop(); + + // Track API call telemetry + EmbeddingTelemetryHelper.TrackApiCall(ProviderName, uncachedTexts.Length); + EmbeddingTelemetryHelper.TrackApiDuration(ProviderName, apiStopwatch.Elapsed, uncachedTexts.Length); // Cache new results and merge with cached results for (int i = 0; i < uncachedIndices.Count; i++) @@ -260,15 +319,17 @@ public async Task EmbedBatchAsync(string[] texts, CancellationToken c /// /// Creates a cache key from the text using SHA256 hash. - /// Format: embedding:{SHA256_hash} + /// Format: embedding:{provider}:{model}:{SHA256_hash} + /// Includes provider and model to prevent cross-configuration collisions. /// Uses hash to keep cache keys small and deterministic. /// /// The text to create a cache key for. /// Cache key string. - private static string CreateCacheKey(string text) + private string CreateCacheKey(string text) { - // Use SHA256 for deterministic, collision-resistant hash - byte[] textBytes = Encoding.UTF8.GetBytes(text); + // Include provider and model in hash to avoid cross-provider/model collisions + string keyInput = $"{_options.Provider}:{_options.EffectiveModel}:{text}"; + byte[] textBytes = Encoding.UTF8.GetBytes(keyInput); byte[] hashBytes = SHA256.HashData(textBytes); string hashHex = Convert.ToHexString(hashBytes); From 64b592ec6d6ef5c5b698483034a06c9b9e1e4751 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:32:40 +0000 Subject: [PATCH 12/16] Optimize ProviderName to avoid repeated string allocations Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../Services/Embeddings/EmbeddingService.cs | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/Core/Services/Embeddings/EmbeddingService.cs b/src/Core/Services/Embeddings/EmbeddingService.cs index f2a69eccc4..df3ed64c09 100644 --- a/src/Core/Services/Embeddings/EmbeddingService.cs +++ b/src/Core/Services/Embeddings/EmbeddingService.cs @@ -24,6 +24,7 @@ public class EmbeddingService : IEmbeddingService private readonly EmbeddingsOptions _options; private readonly ILogger _logger; private readonly IFusionCache _cache; + private readonly string _providerName; // Constants private const char KEY_DELIMITER = ':'; @@ -61,6 +62,9 @@ public EmbeddingService( _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _cache = cache ?? throw new ArgumentNullException(nameof(cache)); + // Cache provider name for telemetry to avoid repeated string allocations + _providerName = _options.Provider.ToString().ToLowerInvariant(); + // Validate required options if (string.IsNullOrEmpty(_options.BaseUrl)) { @@ -106,11 +110,6 @@ private void ConfigureHttpClient() /// public bool IsEnabled => _options.Enabled; - /// - /// Gets the provider name for telemetry. - /// - private string ProviderName => _options.Provider.ToString().ToLowerInvariant(); - /// public async Task TryEmbedAsync(string text, CancellationToken cancellationToken = default) { @@ -128,18 +127,18 @@ public async Task TryEmbedAsync(string text, CancellationToken Stopwatch stopwatch = Stopwatch.StartNew(); using Activity? activity = EmbeddingTelemetryHelper.StartEmbeddingActivity("TryEmbedAsync"); - activity?.SetEmbeddingActivityTags(ProviderName, _options.EffectiveModel, textCount: 1); + activity?.SetEmbeddingActivityTags(_providerName, _options.EffectiveModel, textCount: 1); try { - EmbeddingTelemetryHelper.TrackEmbeddingRequest(ProviderName, textCount: 1); + EmbeddingTelemetryHelper.TrackEmbeddingRequest(_providerName, textCount: 1); float[] embedding = await EmbedAsync(text, cancellationToken); stopwatch.Stop(); activity?.SetEmbeddingActivitySuccess(stopwatch.Elapsed.TotalMilliseconds, embedding.Length); - EmbeddingTelemetryHelper.TrackTotalDuration(ProviderName, stopwatch.Elapsed, fromCache: false); - EmbeddingTelemetryHelper.TrackDimensions(ProviderName, embedding.Length); + EmbeddingTelemetryHelper.TrackTotalDuration(_providerName, stopwatch.Elapsed, fromCache: false); + EmbeddingTelemetryHelper.TrackDimensions(_providerName, embedding.Length); return new EmbeddingResult(true, embedding); } @@ -148,7 +147,7 @@ public async Task TryEmbedAsync(string text, CancellationToken stopwatch.Stop(); _logger.LogError(ex, "Failed to generate embedding for text"); activity?.SetEmbeddingActivityError(ex); - EmbeddingTelemetryHelper.TrackError(ProviderName, ex.GetType().Name); + EmbeddingTelemetryHelper.TrackError(_providerName, ex.GetType().Name); return new EmbeddingResult(false, null, ex.Message); } @@ -171,21 +170,21 @@ public async Task TryEmbedBatchAsync(string[] texts, Cance Stopwatch stopwatch = Stopwatch.StartNew(); using Activity? activity = EmbeddingTelemetryHelper.StartEmbeddingActivity("TryEmbedBatchAsync"); - activity?.SetEmbeddingActivityTags(ProviderName, _options.EffectiveModel, texts.Length); + activity?.SetEmbeddingActivityTags(_providerName, _options.EffectiveModel, texts.Length); try { - EmbeddingTelemetryHelper.TrackEmbeddingRequest(ProviderName, texts.Length); + EmbeddingTelemetryHelper.TrackEmbeddingRequest(_providerName, texts.Length); float[][] embeddings = await EmbedBatchAsync(texts, cancellationToken); stopwatch.Stop(); int dimensions = embeddings.Length > 0 ? embeddings[0].Length : 0; activity?.SetEmbeddingActivitySuccess(stopwatch.Elapsed.TotalMilliseconds, dimensions); - EmbeddingTelemetryHelper.TrackTotalDuration(ProviderName, stopwatch.Elapsed, fromCache: false); + EmbeddingTelemetryHelper.TrackTotalDuration(_providerName, stopwatch.Elapsed, fromCache: false); if (dimensions > 0) { - EmbeddingTelemetryHelper.TrackDimensions(ProviderName, dimensions); + EmbeddingTelemetryHelper.TrackDimensions(_providerName, dimensions); } return new EmbeddingBatchResult(true, embeddings); @@ -195,7 +194,7 @@ public async Task TryEmbedBatchAsync(string[] texts, Cance stopwatch.Stop(); _logger.LogError(ex, "Failed to generate embeddings for batch of {Count} texts", texts.Length); activity?.SetEmbeddingActivityError(ex); - EmbeddingTelemetryHelper.TrackError(ProviderName, ex.GetType().Name); + EmbeddingTelemetryHelper.TrackError(_providerName, ex.GetType().Name); return new EmbeddingBatchResult(false, null, ex.Message); } @@ -269,12 +268,12 @@ public async Task EmbedBatchAsync(string[] texts, CancellationToken c _logger.LogDebug("Embedding cache hit for text hash {TextHash}", cacheKeys[i]); results[i] = cached.Value; cacheHits++; - EmbeddingTelemetryHelper.TrackCacheHit(ProviderName); + EmbeddingTelemetryHelper.TrackCacheHit(_providerName); } else { uncachedIndices.Add(i); - EmbeddingTelemetryHelper.TrackCacheMiss(ProviderName); + EmbeddingTelemetryHelper.TrackCacheMiss(_providerName); } } @@ -294,8 +293,8 @@ public async Task EmbedBatchAsync(string[] texts, CancellationToken c apiStopwatch.Stop(); // Track API call telemetry - EmbeddingTelemetryHelper.TrackApiCall(ProviderName, uncachedTexts.Length); - EmbeddingTelemetryHelper.TrackApiDuration(ProviderName, apiStopwatch.Elapsed, uncachedTexts.Length); + EmbeddingTelemetryHelper.TrackApiCall(_providerName, uncachedTexts.Length); + EmbeddingTelemetryHelper.TrackApiDuration(_providerName, apiStopwatch.Elapsed, uncachedTexts.Length); // Cache new results and merge with cached results for (int i = 0; i < uncachedIndices.Count; i++) From e8d72387ac3ac0e723fc32245cd22dbecf142895 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:03:55 +0000 Subject: [PATCH 13/16] Fix schema mismatch, remove unused field, add enabled handling, validate empty embeddings Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- schemas/dab.draft.schema.json | 11 ++++++++--- src/Cli/ConfigGenerator.cs | 9 +++++++-- .../EmbeddingsOptionsConverterFactory.cs | 14 +++----------- src/Core/Services/Embeddings/EmbeddingService.cs | 8 +++++++- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index cb7d309828..a395667908 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -648,14 +648,19 @@ "description": "Configuration for text embedding/vectorization service. Supports OpenAI and Azure OpenAI providers.", "additionalProperties": false, "properties": { + "enabled": { + "type": "boolean", + "description": "Whether the embedding service is enabled. Defaults to true.", + "default": true + }, "provider": { "type": "string", "description": "The embedding provider type.", "enum": ["azure-openai", "openai"] }, - "endpoint": { + "base-url": { "type": "string", - "description": "The provider base URL endpoint. For Azure OpenAI, use the Azure resource endpoint. For OpenAI, use https://api.openai.com." + "description": "The provider base URL. For Azure OpenAI, use the Azure resource endpoint. For OpenAI, use https://api.openai.com." }, "api-key": { "type": "string", @@ -683,7 +688,7 @@ "maximum": 300000 } }, - "required": ["provider", "endpoint", "api-key"], + "required": ["provider", "base-url", "api-key"], "allOf": [ { "$comment": "Azure OpenAI requires the model (deployment name) to be specified.", diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index a8b303664e..2f1db2a0e1 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -909,14 +909,15 @@ options.FileSinkRetainedFileCountLimit is not null || } } - // Embeddings: Provider, Endpoint, ApiKey, Model, ApiVersion, Dimensions, TimeoutMs + // Embeddings: Provider, Endpoint, ApiKey, Model, ApiVersion, Dimensions, TimeoutMs, Enabled if (options.RuntimeEmbeddingsProvider is not null || options.RuntimeEmbeddingsBaseUrl is not null || options.RuntimeEmbeddingsApiKey is not null || options.RuntimeEmbeddingsModel is not null || options.RuntimeEmbeddingsApiVersion is not null || options.RuntimeEmbeddingsDimensions is not null || - options.RuntimeEmbeddingsTimeoutMs is not null) + options.RuntimeEmbeddingsTimeoutMs is not null || + options.RuntimeEmbeddingsEnabled is not null) { bool status = TryUpdateConfiguredEmbeddingsValues(options, runtimeConfig?.Runtime?.Embeddings, out EmbeddingsOptions? updatedEmbeddingsOptions); if (status && updatedEmbeddingsOptions is not null) @@ -1569,6 +1570,9 @@ private static bool TryUpdateConfiguredEmbeddingsValues( string? apiVersion = options.RuntimeEmbeddingsApiVersion ?? existingEmbeddingsOptions?.ApiVersion; int? dimensions = options.RuntimeEmbeddingsDimensions ?? existingEmbeddingsOptions?.Dimensions; int? timeoutMs = options.RuntimeEmbeddingsTimeoutMs ?? existingEmbeddingsOptions?.TimeoutMs; + bool? enabled = options.RuntimeEmbeddingsEnabled.HasValue + ? options.RuntimeEmbeddingsEnabled.Value == CliBool.True + : existingEmbeddingsOptions?.Enabled; // Validate required fields if (provider is null) @@ -1615,6 +1619,7 @@ private static bool TryUpdateConfiguredEmbeddingsValues( Provider: (EmbeddingProviderType)provider, BaseUrl: baseUrl, ApiKey: apiKey, + Enabled: enabled, Model: model, ApiVersion: apiVersion, Dimensions: dimensions, diff --git a/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs b/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs index b84b212815..b356d3e188 100644 --- a/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs +++ b/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs @@ -13,11 +13,10 @@ namespace Azure.DataApiBuilder.Config.Converters; /// internal class EmbeddingsOptionsConverterFactory : JsonConverterFactory { - private readonly DeserializationVariableReplacementSettings? _replacementSettings; - public EmbeddingsOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) { - _replacementSettings = replacementSettings; + // Note: replacementSettings is not used in this converter because the environment variable + // replacement is handled by the string deserializers registered in the JsonSerializerOptions. } /// @@ -29,18 +28,11 @@ public override bool CanConvert(Type typeToConvert) /// public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) { - return new EmbeddingsOptionsConverter(_replacementSettings); + return new EmbeddingsOptionsConverter(); } private class EmbeddingsOptionsConverter : JsonConverter { - private readonly DeserializationVariableReplacementSettings? _replacementSettings; - - public EmbeddingsOptionsConverter(DeserializationVariableReplacementSettings? replacementSettings) - { - _replacementSettings = replacementSettings; - } - public override EmbeddingsOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { if (reader.TokenType != JsonTokenType.StartObject) diff --git a/src/Core/Services/Embeddings/EmbeddingService.cs b/src/Core/Services/Embeddings/EmbeddingService.cs index df3ed64c09..017f6801da 100644 --- a/src/Core/Services/Embeddings/EmbeddingService.cs +++ b/src/Core/Services/Embeddings/EmbeddingService.cs @@ -224,6 +224,12 @@ public async Task EmbedAsync(string text, CancellationToken cancellatio float[][] results = await EmbedFromApiAsync(new[] { text }, ct); float[] result = results[0]; + // Validate the embedding result is not empty + if (result.Length == 0) + { + throw new InvalidOperationException("API returned empty embedding array."); + } + // L1 only - skip distributed cache ctx.Options.SetSkipDistributedCache(true, true); ctx.Options.SetDuration(TimeSpan.FromHours(DEFAULT_CACHE_TTL_HOURS)); @@ -232,7 +238,7 @@ public async Task EmbedAsync(string text, CancellationToken cancellatio }, token: cancellationToken); - if (embedding is null) + if (embedding is null || embedding.Length == 0) { throw new InvalidOperationException("Failed to get embedding from cache or API."); } From d9c8a2973f6fa9415fac86534d82ef0b90cb3ecc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Feb 2026 00:32:54 +0000 Subject: [PATCH 14/16] Add embedding health check execution and update JSON schema with endpoint/health sub-objects Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- schemas/dab.draft.schema.json | 53 +++++++++ src/Service/HealthCheck/HealthCheckHelper.cs | 112 ++++++++++++++++++- src/Service/Startup.cs | 8 +- 3 files changed, 171 insertions(+), 2 deletions(-) diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index a395667908..2035615763 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -686,6 +686,59 @@ "default": 30000, "minimum": 1, "maximum": 300000 + }, + "endpoint": { + "type": "object", + "description": "REST endpoint configuration for the embedding service.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether the /embed REST endpoint is enabled. Defaults to false.", + "default": false + }, + "path": { + "type": "string", + "description": "The endpoint path. Defaults to '/embed'.", + "default": "/embed" + }, + "roles": { + "type": "array", + "description": "The roles allowed to access the embedding endpoint. In development mode, defaults to ['anonymous'].", + "items": { + "type": "string" + } + } + } + }, + "health": { + "type": "object", + "description": "Health check configuration for the embedding service.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether health checks are enabled for embeddings. Defaults to true.", + "default": true + }, + "threshold-ms": { + "type": "integer", + "description": "The maximum response time in milliseconds to be considered healthy.", + "default": 5000, + "minimum": 1, + "maximum": 300000 + }, + "test-text": { + "type": "string", + "description": "The text to use for health check validation.", + "default": "health check" + }, + "expected-dimensions": { + "type": "integer", + "description": "The expected number of dimensions in the embedding result. If specified, dimension validation is performed.", + "minimum": 1 + } + } } }, "required": ["provider", "base-url", "api-key"], diff --git a/src/Service/HealthCheck/HealthCheckHelper.cs b/src/Service/HealthCheck/HealthCheckHelper.cs index 8a60d8130f..d44ade1930 100644 --- a/src/Service/HealthCheck/HealthCheckHelper.cs +++ b/src/Service/HealthCheck/HealthCheckHelper.cs @@ -10,7 +10,9 @@ using System.Threading.Tasks; using Azure.DataApiBuilder.Config.HealthCheck; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Product; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -27,20 +29,24 @@ public class HealthCheckHelper // Dependencies private ILogger _logger; private HttpUtilities _httpUtility; + private IEmbeddingService? _embeddingService; private string _incomingRoleHeader = string.Empty; private string _incomingRoleToken = string.Empty; private const string TIME_EXCEEDED_ERROR_MESSAGE = "The threshold for executing the request has exceeded."; + private const string DIMENSIONS_MISMATCH_ERROR_MESSAGE = "The embedding dimensions do not match the expected dimensions."; /// /// Constructor to inject the logger and HttpUtility class. /// /// Logger to track the log statements. /// HttpUtility to call methods from the internal class. - public HealthCheckHelper(ILogger logger, HttpUtilities httpUtility) + /// Optional embedding service for embedding health checks. + public HealthCheckHelper(ILogger logger, HttpUtilities httpUtility, IEmbeddingService? embeddingService = null) { _logger = logger; _httpUtility = httpUtility; + _embeddingService = embeddingService; } /// @@ -159,6 +165,7 @@ private async Task UpdateHealthCheckDetailsAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport.Checks = new List(); await UpdateDataSourceHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); await UpdateEntityHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); + await UpdateEmbeddingsHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); } // Updates the DataSource Health Check Results in the response. @@ -351,5 +358,108 @@ private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comp return (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage); } + + /// + /// Updates the Embeddings Health Check Results in the response. + /// Executes a test embedding and validates response time and optionally dimensions. + /// + private async Task UpdateEmbeddingsHealthCheckResultsAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, RuntimeConfig runtimeConfig) + { + EmbeddingsOptions? embeddingsOptions = runtimeConfig?.Runtime?.Embeddings; + EmbeddingsHealthCheckConfig? healthConfig = embeddingsOptions?.Health; + + // Only run health check if embeddings is enabled, health check is enabled, and embedding service is available + if (embeddingsOptions is null || !embeddingsOptions.Enabled || healthConfig is null || !healthConfig.Enabled || _embeddingService is null) + { + return; + } + + if (comprehensiveHealthCheckReport.Checks is null) + { + comprehensiveHealthCheckReport.Checks = new List(); + } + + string testText = healthConfig.TestText; + int thresholdMs = healthConfig.ThresholdMs; + int? expectedDimensions = healthConfig.ExpectedDimensions; + + try + { + Stopwatch stopwatch = new(); + stopwatch.Start(); + EmbeddingResult result = await _embeddingService.TryEmbedAsync(testText); + stopwatch.Stop(); + + int responseTimeMs = (int)stopwatch.ElapsedMilliseconds; + bool isResponseTimeWithinThreshold = responseTimeMs <= thresholdMs; + bool isDimensionsValid = true; + string? errorMessage = null; + + if (!result.Success) + { + errorMessage = result.ErrorMessage ?? "Embedding request failed."; + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = HealthCheckConstants.ERROR_RESPONSE_TIME_MS, + ThresholdMs = thresholdMs + }, + Exception = errorMessage, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = HealthStatus.Unhealthy + }); + return; + } + + // Validate dimensions if expected dimensions is specified + if (expectedDimensions.HasValue && result.Embedding is not null) + { + isDimensionsValid = result.Embedding.Length == expectedDimensions.Value; + if (!isDimensionsValid) + { + errorMessage = $"{DIMENSIONS_MISMATCH_ERROR_MESSAGE} Expected: {expectedDimensions.Value}, Actual: {result.Embedding.Length}"; + } + } + + // Check response time threshold + if (!isResponseTimeWithinThreshold) + { + errorMessage = errorMessage is null ? TIME_EXCEEDED_ERROR_MESSAGE : $"{errorMessage} {TIME_EXCEEDED_ERROR_MESSAGE}"; + } + + bool isHealthy = isResponseTimeWithinThreshold && isDimensionsValid; + + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = responseTimeMs, + ThresholdMs = thresholdMs + }, + Exception = errorMessage, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = isHealthy ? HealthStatus.Healthy : HealthStatus.Unhealthy + }); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error executing embeddings health check."); + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = HealthCheckConstants.ERROR_RESPONSE_TIME_MS, + ThresholdMs = thresholdMs + }, + Exception = ex.Message, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = HealthStatus.Unhealthy + }); + } + } } } diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index aeda9346d7..19948da725 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -262,7 +262,13 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); - services.AddSingleton(); + services.AddSingleton(sp => + { + ILogger logger = sp.GetRequiredService>(); + HttpUtilities httpUtility = sp.GetRequiredService(); + IEmbeddingService? embeddingService = sp.GetService(); + return new HealthCheckHelper(logger, httpUtility, embeddingService); + }); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); From 3e02c0ff60984af66a942067fb0e1e7fe72c2289 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Feb 2026 00:42:45 +0000 Subject: [PATCH 15/16] Add EmbeddingController for /embed REST endpoint with role-based authorization Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- .../Controllers/EmbeddingController.cs | 161 ++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 src/Service/Controllers/EmbeddingController.cs diff --git a/src/Service/Controllers/EmbeddingController.cs b/src/Service/Controllers/EmbeddingController.cs new file mode 100644 index 0000000000..7a732128f9 --- /dev/null +++ b/src/Service/Controllers/EmbeddingController.cs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Net; +using System.Net.Mime; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; + +namespace Azure.DataApiBuilder.Service.Controllers; + +/// +/// Controller to serve embedding requests at the configured endpoint path (default: /embed). +/// Accepts plain text input and returns embedding vector as plain text (comma-separated floats). +/// +[ApiController] +public class EmbeddingController : ControllerBase +{ + private readonly IEmbeddingService? _embeddingService; + private readonly RuntimeConfigProvider _runtimeConfigProvider; + private readonly ILogger _logger; + + /// + /// Constructor. + /// + public EmbeddingController( + RuntimeConfigProvider runtimeConfigProvider, + ILogger logger, + IEmbeddingService? embeddingService = null) + { + _runtimeConfigProvider = runtimeConfigProvider; + _logger = logger; + _embeddingService = embeddingService; + } + + /// + /// POST endpoint for generating embeddings. + /// Accepts plain text body and returns embedding vector as comma-separated floats. + /// + /// The route path. + /// Plain text embedding vector or error response. + [HttpPost] + [Route("{*route}")] + [Consumes("text/plain", "application/json")] + [Produces("text/plain")] + public async Task PostAsync(string? route) + { + // Get embeddings configuration + EmbeddingsOptions? embeddingsOptions = _runtimeConfigProvider.GetConfig()?.Runtime?.Embeddings; + EmbeddingsEndpointOptions? endpointOptions = embeddingsOptions?.Endpoint; + + // Check if embeddings and endpoint are enabled + if (embeddingsOptions is null || !embeddingsOptions.Enabled) + { + return NotFound(); + } + + if (endpointOptions is null || !endpointOptions.Enabled) + { + return NotFound(); + } + + // Check if the route matches the configured endpoint path + string expectedPath = endpointOptions.EffectivePath.TrimStart('/'); + if (!string.Equals(route, expectedPath, StringComparison.OrdinalIgnoreCase)) + { + return NotFound(); + } + + // Check if embedding service is available + if (_embeddingService is null || !_embeddingService.IsEnabled) + { + _logger.LogWarning("Embedding endpoint called but embedding service is not available or disabled."); + return StatusCode((int)HttpStatusCode.ServiceUnavailable, "Embedding service is not available."); + } + + // Check authorization + bool isDevelopmentMode = _runtimeConfigProvider.GetConfig()?.Runtime?.Host?.Mode == HostMode.Development; + string clientRole = GetClientRole(); + + if (!endpointOptions.IsRoleAllowed(clientRole, isDevelopmentMode)) + { + _logger.LogWarning("Embedding endpoint access denied for role: {Role}", clientRole); + return StatusCode((int)HttpStatusCode.Forbidden, "Access denied. Role not authorized."); + } + + // Read request body as plain text + string text; + try + { + using StreamReader reader = new(Request.Body); + text = await reader.ReadToEndAsync(); + + // Handle JSON-wrapped string + if (Request.ContentType?.Contains("application/json", StringComparison.OrdinalIgnoreCase) == true) + { + try + { + text = JsonSerializer.Deserialize(text) ?? text; + } + catch + { + // Not valid JSON string, use as-is + } + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to read request body for embedding."); + return BadRequest("Failed to read request body."); + } + + if (string.IsNullOrWhiteSpace(text)) + { + return BadRequest("Request body cannot be empty."); + } + + // Generate embedding + EmbeddingResult result = await _embeddingService.TryEmbedAsync(text); + + if (!result.Success) + { + _logger.LogError("Embedding request failed: {Error}", result.ErrorMessage); + return StatusCode((int)HttpStatusCode.InternalServerError, result.ErrorMessage ?? "Failed to generate embedding."); + } + + if (result.Embedding is null || result.Embedding.Length == 0) + { + _logger.LogError("Embedding request returned empty result."); + return StatusCode((int)HttpStatusCode.InternalServerError, "Failed to generate embedding."); + } + + // Return embedding as comma-separated float values (plain text) + string embeddingText = string.Join(",", result.Embedding); + return Content(embeddingText, MediaTypeNames.Text.Plain); + } + + /// + /// Gets the client role from request headers. + /// + private string GetClientRole() + { + StringValues roleHeader = Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]; + if (roleHeader.Count == 1 && !string.IsNullOrEmpty(roleHeader[0])) + { + return roleHeader[0]!.ToLowerInvariant(); + } + + return EmbeddingsEndpointOptions.ANONYMOUS_ROLE; + } +} From 5c0546426e6b59bfbba65a1fe80c6a219a0daad3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Feb 2026 00:44:28 +0000 Subject: [PATCH 16/16] Address code review feedback for EmbeddingController Co-authored-by: JerryNixon <1749983+JerryNixon@users.noreply.github.com> --- src/Service/Controllers/EmbeddingController.cs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Service/Controllers/EmbeddingController.cs b/src/Service/Controllers/EmbeddingController.cs index 7a732128f9..1c8b1641b8 100644 --- a/src/Service/Controllers/EmbeddingController.cs +++ b/src/Service/Controllers/EmbeddingController.cs @@ -108,9 +108,10 @@ public async Task PostAsync(string? route) { text = JsonSerializer.Deserialize(text) ?? text; } - catch + catch (JsonException) { // Not valid JSON string, use as-is + _logger.LogDebug("Request body is not a valid JSON string, using as plain text."); } } } @@ -151,9 +152,11 @@ public async Task PostAsync(string? route) private string GetClientRole() { StringValues roleHeader = Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]; - if (roleHeader.Count == 1 && !string.IsNullOrEmpty(roleHeader[0])) + string? firstRole = roleHeader.Count == 1 ? roleHeader[0] : null; + + if (!string.IsNullOrEmpty(firstRole)) { - return roleHeader[0]!.ToLowerInvariant(); + return firstRole.ToLowerInvariant(); } return EmbeddingsEndpointOptions.ANONYMOUS_ROLE;