diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 920c0a4da6..2035615763 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -642,6 +642,145 @@ "default": 4 } } + }, + "embeddings": { + "type": "object", + "description": "Configuration for text embedding/vectorization service. Supports OpenAI and Azure OpenAI providers.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether the embedding service is enabled. Defaults to true.", + "default": true + }, + "provider": { + "type": "string", + "description": "The embedding provider type.", + "enum": ["azure-openai", "openai"] + }, + "base-url": { + "type": "string", + "description": "The provider base URL. For Azure OpenAI, use the Azure resource endpoint. For OpenAI, use https://api.openai.com." + }, + "api-key": { + "type": "string", + "description": "The API key for authentication. Supports environment variable substitution with @env('VAR_NAME')." + }, + "model": { + "type": "string", + "description": "The model or deployment name. Required for Azure OpenAI (deployment name). For OpenAI, defaults to 'text-embedding-3-small' if not specified." + }, + "api-version": { + "type": "string", + "description": "Azure API version. Only used for Azure OpenAI provider.", + "default": "2024-02-01" + }, + "dimensions": { + "type": "integer", + "description": "Output vector dimensions. Optional, uses model default if not specified. Useful for Redis schema alignment.", + "minimum": 1 + }, + "timeout-ms": { + "type": "integer", + "description": "Request timeout in milliseconds.", + "default": 30000, + "minimum": 1, + "maximum": 300000 + }, + "endpoint": { + "type": "object", + "description": "REST endpoint configuration for the embedding service.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether the /embed REST endpoint is enabled. Defaults to false.", + "default": false + }, + "path": { + "type": "string", + "description": "The endpoint path. Defaults to '/embed'.", + "default": "/embed" + }, + "roles": { + "type": "array", + "description": "The roles allowed to access the embedding endpoint. In development mode, defaults to ['anonymous'].", + "items": { + "type": "string" + } + } + } + }, + "health": { + "type": "object", + "description": "Health check configuration for the embedding service.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Whether health checks are enabled for embeddings. Defaults to true.", + "default": true + }, + "threshold-ms": { + "type": "integer", + "description": "The maximum response time in milliseconds to be considered healthy.", + "default": 5000, + "minimum": 1, + "maximum": 300000 + }, + "test-text": { + "type": "string", + "description": "The text to use for health check validation.", + "default": "health check" + }, + "expected-dimensions": { + "type": "integer", + "description": "The expected number of dimensions in the embedding result. If specified, dimension validation is performed.", + "minimum": 1 + } + } + } + }, + "required": ["provider", "base-url", "api-key"], + "allOf": [ + { + "$comment": "Azure OpenAI requires the model (deployment name) to be specified.", + "if": { + "properties": { + "provider": { + "const": "azure-openai" + } + }, + "required": ["provider"] + }, + "then": { + "required": ["model"], + "properties": { + "api-version": { + "type": "string", + "description": "Azure API version. Required for Azure OpenAI provider.", + "default": "2024-02-01" + } + } + } + }, + { + "$comment": "OpenAI does not require model (defaults to text-embedding-3-small) and does not use api-version.", + "if": { + "properties": { + "provider": { + "const": "openai" + } + }, + "required": ["provider"] + }, + "then": { + "properties": { + "api-version": false + } + } + } + ] } } }, diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index c3e0352249..75b5444ec8 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -4,6 +4,7 @@ using System.IO.Abstractions; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Product; using Cli.Constants; using CommandLine; @@ -71,6 +72,21 @@ public ConfigureOptions( RollingInterval? fileSinkRollingInterval = null, int? fileSinkRetainedFileCountLimit = null, long? fileSinkFileSizeLimitBytes = null, + CliBool? runtimeEmbeddingsEnabled = null, + EmbeddingProviderType? runtimeEmbeddingsProvider = null, + string? runtimeEmbeddingsBaseUrl = null, + string? runtimeEmbeddingsApiKey = null, + string? runtimeEmbeddingsModel = null, + string? runtimeEmbeddingsApiVersion = null, + int? runtimeEmbeddingsDimensions = null, + int? runtimeEmbeddingsTimeoutMs = null, + CliBool? runtimeEmbeddingsEndpointEnabled = null, + string? runtimeEmbeddingsEndpointPath = null, + IEnumerable? runtimeEmbeddingsEndpointRoles = null, + CliBool? runtimeEmbeddingsHealthEnabled = null, + int? runtimeEmbeddingsHealthThresholdMs = null, + string? runtimeEmbeddingsHealthTestText = null, + int? runtimeEmbeddingsHealthExpectedDimensions = null, string? config = null) : base(config) { @@ -132,6 +148,24 @@ public ConfigureOptions( FileSinkRollingInterval = fileSinkRollingInterval; FileSinkRetainedFileCountLimit = fileSinkRetainedFileCountLimit; FileSinkFileSizeLimitBytes = fileSinkFileSizeLimitBytes; + // Embeddings + RuntimeEmbeddingsEnabled = runtimeEmbeddingsEnabled; + RuntimeEmbeddingsProvider = runtimeEmbeddingsProvider; + RuntimeEmbeddingsBaseUrl = runtimeEmbeddingsBaseUrl; + RuntimeEmbeddingsApiKey = runtimeEmbeddingsApiKey; + RuntimeEmbeddingsModel = runtimeEmbeddingsModel; + RuntimeEmbeddingsApiVersion = runtimeEmbeddingsApiVersion; + RuntimeEmbeddingsDimensions = runtimeEmbeddingsDimensions; + RuntimeEmbeddingsTimeoutMs = runtimeEmbeddingsTimeoutMs; + // Embeddings Endpoint + RuntimeEmbeddingsEndpointEnabled = runtimeEmbeddingsEndpointEnabled; + RuntimeEmbeddingsEndpointPath = runtimeEmbeddingsEndpointPath; + RuntimeEmbeddingsEndpointRoles = runtimeEmbeddingsEndpointRoles; + // Embeddings Health + RuntimeEmbeddingsHealthEnabled = runtimeEmbeddingsHealthEnabled; + RuntimeEmbeddingsHealthThresholdMs = runtimeEmbeddingsHealthThresholdMs; + RuntimeEmbeddingsHealthTestText = runtimeEmbeddingsHealthTestText; + RuntimeEmbeddingsHealthExpectedDimensions = runtimeEmbeddingsHealthExpectedDimensions; } [Option("data-source.database-type", Required = false, HelpText = "Database type. Allowed values: MSSQL, PostgreSQL, CosmosDB_NoSQL, MySQL.")] @@ -281,6 +315,51 @@ public ConfigureOptions( [Option("runtime.telemetry.file.file-size-limit-bytes", Required = false, HelpText = "Configure maximum file size limit in bytes. Default: 1048576")] public long? FileSinkFileSizeLimitBytes { get; } + [Option("runtime.embeddings.enabled", Required = false, HelpText = "Enable/disable the embedding service. Default: true")] + public CliBool? RuntimeEmbeddingsEnabled { get; } + + [Option("runtime.embeddings.provider", Required = false, HelpText = "Configure embedding provider type. Allowed values: azure-openai, openai.")] + public EmbeddingProviderType? RuntimeEmbeddingsProvider { get; } + + [Option("runtime.embeddings.base-url", Required = false, HelpText = "Configure the embedding provider base URL.")] + public string? RuntimeEmbeddingsBaseUrl { get; } + + [Option("runtime.embeddings.api-key", Required = false, HelpText = "Configure the embedding API key for authentication.")] + public string? RuntimeEmbeddingsApiKey { get; } + + [Option("runtime.embeddings.model", Required = false, HelpText = "Configure the model/deployment name. Required for Azure OpenAI, defaults to text-embedding-3-small for OpenAI.")] + public string? RuntimeEmbeddingsModel { get; } + + [Option("runtime.embeddings.api-version", Required = false, HelpText = "Configure the Azure API version. Only used for Azure OpenAI provider. Default: 2024-02-01")] + public string? RuntimeEmbeddingsApiVersion { get; } + + [Option("runtime.embeddings.dimensions", Required = false, HelpText = "Configure the output vector dimensions. Optional, uses model default if not specified.")] + public int? RuntimeEmbeddingsDimensions { get; } + + [Option("runtime.embeddings.timeout-ms", Required = false, HelpText = "Configure the request timeout in milliseconds. Default: 30000")] + public int? RuntimeEmbeddingsTimeoutMs { get; } + + [Option("runtime.embeddings.endpoint.enabled", Required = false, HelpText = "Enable/disable the endpoint for embeddings. Default: false")] + public CliBool? RuntimeEmbeddingsEndpointEnabled { get; } + + [Option("runtime.embeddings.endpoint.path", Required = false, HelpText = "Configure the endpoint path for embeddings. Default: /embed")] + public string? RuntimeEmbeddingsEndpointPath { get; } + + [Option("runtime.embeddings.endpoint.roles", Required = false, Separator = ',', HelpText = "Configure the roles allowed to access the embedding endpoint. Comma-separated list. In development mode defaults to 'anonymous'.")] + public IEnumerable? RuntimeEmbeddingsEndpointRoles { get; } + + [Option("runtime.embeddings.health.enabled", Required = false, HelpText = "Enable/disable health checks for the embedding service. Default: true")] + public CliBool? RuntimeEmbeddingsHealthEnabled { get; } + + [Option("runtime.embeddings.health.threshold-ms", Required = false, HelpText = "Configure the health check threshold in milliseconds. Default: 5000")] + public int? RuntimeEmbeddingsHealthThresholdMs { get; } + + [Option("runtime.embeddings.health.test-text", Required = false, HelpText = "Configure the test text for health check validation. Default: 'health check'")] + public string? RuntimeEmbeddingsHealthTestText { get; } + + [Option("runtime.embeddings.health.expected-dimensions", Required = false, HelpText = "Configure the expected dimensions for health check validation. Optional.")] + public int? RuntimeEmbeddingsHealthExpectedDimensions { get; } + public int Handler(ILogger logger, FileSystemRuntimeConfigLoader loader, IFileSystem fileSystem) { logger.LogInformation("{productName} {version}", PRODUCT_NAME, ProductInfo.GetProductVersion()); diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 78a5e63a7d..2f1db2a0e1 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -8,6 +8,7 @@ using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.NamingPolicies; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Core; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Service; @@ -908,6 +909,27 @@ options.FileSinkRetainedFileCountLimit is not null || } } + // Embeddings: Provider, Endpoint, ApiKey, Model, ApiVersion, Dimensions, TimeoutMs, Enabled + if (options.RuntimeEmbeddingsProvider is not null || + options.RuntimeEmbeddingsBaseUrl is not null || + options.RuntimeEmbeddingsApiKey is not null || + options.RuntimeEmbeddingsModel is not null || + options.RuntimeEmbeddingsApiVersion is not null || + options.RuntimeEmbeddingsDimensions is not null || + options.RuntimeEmbeddingsTimeoutMs is not null || + options.RuntimeEmbeddingsEnabled is not null) + { + bool status = TryUpdateConfiguredEmbeddingsValues(options, runtimeConfig?.Runtime?.Embeddings, out EmbeddingsOptions? updatedEmbeddingsOptions); + if (status && updatedEmbeddingsOptions is not null) + { + runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { Embeddings = updatedEmbeddingsOptions } }; + } + else + { + return false; + } + } + return runtimeConfig != null; } @@ -1522,6 +1544,97 @@ private static bool TryUpdateConfiguredFileOptions( } } + /// + /// Attempts to update the embeddings configuration based on the provided options. + /// Creates a new EmbeddingsOptions object if the configuration is valid. + /// Provider, endpoint, and API key are required when configuring embeddings. + /// + /// The configuration options provided by the user. + /// The existing embeddings options from the runtime configuration. + /// The resulting embeddings options if successful. + /// True if the embeddings options were successfully configured; otherwise, false. + private static bool TryUpdateConfiguredEmbeddingsValues( + ConfigureOptions options, + EmbeddingsOptions? existingEmbeddingsOptions, + out EmbeddingsOptions? updatedEmbeddingsOptions) + { + updatedEmbeddingsOptions = null; + + try + { + // Get values from options or fall back to existing configuration + EmbeddingProviderType? provider = options.RuntimeEmbeddingsProvider ?? existingEmbeddingsOptions?.Provider; + string? baseUrl = options.RuntimeEmbeddingsBaseUrl ?? existingEmbeddingsOptions?.BaseUrl; + string? apiKey = options.RuntimeEmbeddingsApiKey ?? existingEmbeddingsOptions?.ApiKey; + string? model = options.RuntimeEmbeddingsModel ?? existingEmbeddingsOptions?.Model; + string? apiVersion = options.RuntimeEmbeddingsApiVersion ?? existingEmbeddingsOptions?.ApiVersion; + int? dimensions = options.RuntimeEmbeddingsDimensions ?? existingEmbeddingsOptions?.Dimensions; + int? timeoutMs = options.RuntimeEmbeddingsTimeoutMs ?? existingEmbeddingsOptions?.TimeoutMs; + bool? enabled = options.RuntimeEmbeddingsEnabled.HasValue + ? options.RuntimeEmbeddingsEnabled.Value == CliBool.True + : existingEmbeddingsOptions?.Enabled; + + // Validate required fields + if (provider is null) + { + _logger.LogError("Failed to configure embeddings: provider is required. Use --runtime.embeddings.provider to specify the provider (azure-openai or openai)."); + return false; + } + + if (string.IsNullOrEmpty(baseUrl)) + { + _logger.LogError("Failed to configure embeddings: base-url is required. Use --runtime.embeddings.base-url to specify the provider base URL."); + return false; + } + + if (string.IsNullOrEmpty(apiKey)) + { + _logger.LogError("Failed to configure embeddings: api-key is required. Use --runtime.embeddings.api-key to specify the authentication key."); + return false; + } + + // Validate Azure OpenAI requires model/deployment name + if (provider == EmbeddingProviderType.AzureOpenAI && string.IsNullOrEmpty(model)) + { + _logger.LogError("Failed to configure embeddings: model/deployment name is required for Azure OpenAI provider. Use --runtime.embeddings.model to specify the deployment name."); + return false; + } + + // Validate dimensions if provided + if (dimensions is not null && dimensions <= 0) + { + _logger.LogError("Failed to configure embeddings: dimensions must be a positive integer."); + return false; + } + + // Validate timeout if provided + if (timeoutMs is not null && timeoutMs <= 0) + { + _logger.LogError("Failed to configure embeddings: timeout-ms must be a positive integer."); + return false; + } + + // Create the embeddings options + updatedEmbeddingsOptions = new EmbeddingsOptions( + Provider: (EmbeddingProviderType)provider, + BaseUrl: baseUrl, + ApiKey: apiKey, + Enabled: enabled, + Model: model, + ApiVersion: apiVersion, + Dimensions: dimensions, + TimeoutMs: timeoutMs); + + _logger.LogInformation("Updated RuntimeConfig with Runtime.Embeddings configuration."); + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to update RuntimeConfig.Embeddings with exception message: {exceptionMessage}.", ex.Message); + return false; + } + } + /// /// Parse permission string to create PermissionSetting array. /// diff --git a/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs b/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs new file mode 100644 index 0000000000..b356d3e188 --- /dev/null +++ b/src/Config/Converters/EmbeddingsOptionsConverterFactory.cs @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +namespace Azure.DataApiBuilder.Config.Converters; + +/// +/// Custom JSON converter for EmbeddingsOptions that handles proper deserialization +/// of the configuration properties including environment variable replacement. +/// +internal class EmbeddingsOptionsConverterFactory : JsonConverterFactory +{ + public EmbeddingsOptionsConverterFactory(DeserializationVariableReplacementSettings? replacementSettings = null) + { + // Note: replacementSettings is not used in this converter because the environment variable + // replacement is handled by the string deserializers registered in the JsonSerializerOptions. + } + + /// + public override bool CanConvert(Type typeToConvert) + { + return typeToConvert.IsAssignableTo(typeof(EmbeddingsOptions)); + } + + /// + public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + return new EmbeddingsOptionsConverter(); + } + + private class EmbeddingsOptionsConverter : JsonConverter + { + public override EmbeddingsOptions? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException("Expected start of object."); + } + + bool? enabled = null; + EmbeddingProviderType? provider = null; + string? baseUrl = null; + string? apiKey = null; + string? model = null; + string? apiVersion = null; + int? dimensions = null; + int? timeoutMs = null; + EmbeddingsEndpointOptions? endpoint = null; + EmbeddingsHealthCheckConfig? health = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected property name."); + } + + string? propertyName = reader.GetString()?.ToLowerInvariant(); + reader.Read(); + + switch (propertyName) + { + case "enabled": + enabled = reader.GetBoolean(); + break; + case "provider": + string? providerStr = reader.GetString(); + if (providerStr is not null) + { + provider = providerStr.ToLowerInvariant() switch + { + "azure-openai" => EmbeddingProviderType.AzureOpenAI, + "openai" => EmbeddingProviderType.OpenAI, + _ => throw new JsonException($"Unknown provider: {providerStr}") + }; + } + break; + case "base-url": + baseUrl = JsonSerializer.Deserialize(ref reader, options); + break; + case "api-key": + apiKey = JsonSerializer.Deserialize(ref reader, options); + break; + case "model": + model = JsonSerializer.Deserialize(ref reader, options); + break; + case "api-version": + apiVersion = JsonSerializer.Deserialize(ref reader, options); + break; + case "dimensions": + dimensions = reader.GetInt32(); + break; + case "timeout-ms": + timeoutMs = reader.GetInt32(); + break; + case "endpoint": + endpoint = JsonSerializer.Deserialize(ref reader, options); + break; + case "health": + health = JsonSerializer.Deserialize(ref reader, options); + break; + default: + reader.Skip(); + break; + } + } + + if (provider is null) + { + throw new JsonException("Missing required property: provider"); + } + + if (baseUrl is null) + { + throw new JsonException("Missing required property: base-url"); + } + + if (apiKey is null) + { + throw new JsonException("Missing required property: api-key"); + } + + return new EmbeddingsOptions( + Provider: provider.Value, + BaseUrl: baseUrl, + ApiKey: apiKey, + Enabled: enabled, + Model: model, + ApiVersion: apiVersion, + Dimensions: dimensions, + TimeoutMs: timeoutMs, + Endpoint: endpoint, + Health: health); + } + + public override void Write(Utf8JsonWriter writer, EmbeddingsOptions value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + writer.WriteBoolean("enabled", value.Enabled); + + // Write provider + string providerStr = value.Provider switch + { + EmbeddingProviderType.AzureOpenAI => "azure-openai", + EmbeddingProviderType.OpenAI => "openai", + _ => throw new JsonException($"Unknown provider: {value.Provider}") + }; + writer.WriteString("provider", providerStr); + + writer.WriteString("base-url", value.BaseUrl); + writer.WriteString("api-key", value.ApiKey); + + if (value.Model is not null) + { + writer.WriteString("model", value.Model); + } + + if (value.ApiVersion is not null) + { + writer.WriteString("api-version", value.ApiVersion); + } + + if (value.Dimensions is not null) + { + writer.WriteNumber("dimensions", value.Dimensions.Value); + } + + if (value.TimeoutMs is not null) + { + writer.WriteNumber("timeout-ms", value.TimeoutMs.Value); + } + + if (value.Endpoint is not null) + { + writer.WritePropertyName("endpoint"); + JsonSerializer.Serialize(writer, value.Endpoint, options); + } + + if (value.Health is not null) + { + writer.WritePropertyName("health"); + JsonSerializer.Serialize(writer, value.Health, options); + } + + writer.WriteEndObject(); + } + } +} diff --git a/src/Config/DabConfigEvents.cs b/src/Config/DabConfigEvents.cs index f69193b583..691a71830e 100644 --- a/src/Config/DabConfigEvents.cs +++ b/src/Config/DabConfigEvents.cs @@ -19,4 +19,5 @@ public static class DabConfigEvents public const string GRAPHQL_SCHEMA_EVICTION_ON_CONFIG_CHANGED = "GRAPHQL_SCHEMA_EVICTION_ON_CONFIG_CHANGED"; public const string GRAPHQL_SCHEMA_CREATOR_ON_CONFIG_CHANGED = "GRAPHQL_SCHEMA_CREATOR_ON_CONFIG_CHANGED"; public const string LOG_LEVEL_INITIALIZER_ON_CONFIG_CHANGE = "LOG_LEVEL_INITIALIZER_ON_CONFIG_CHANGE"; + public const string EMBEDDING_SERVICE_ON_CONFIG_CHANGED = "EMBEDDING_SERVICE_ON_CONFIG_CHANGED"; } diff --git a/src/Config/HealthCheck/HealthCheckConstants.cs b/src/Config/HealthCheck/HealthCheckConstants.cs index fd5901575c..b57526fb75 100644 --- a/src/Config/HealthCheck/HealthCheckConstants.cs +++ b/src/Config/HealthCheck/HealthCheckConstants.cs @@ -12,6 +12,7 @@ public static class HealthCheckConstants public const string DATASOURCE = "data-source"; public const string REST = "rest"; public const string GRAPHQL = "graphql"; + public const string EMBEDDING = "embedding"; public const int ERROR_RESPONSE_TIME_MS = -1; public const int DEFAULT_THRESHOLD_RESPONSE_TIME_MS = 1000; public const int DEFAULT_FIRST_VALUE = 100; diff --git a/src/Config/HotReloadEventHandler.cs b/src/Config/HotReloadEventHandler.cs index 666c3c227b..a2ca9eaf98 100644 --- a/src/Config/HotReloadEventHandler.cs +++ b/src/Config/HotReloadEventHandler.cs @@ -34,7 +34,8 @@ public HotReloadEventHandler() { GRAPHQL_SCHEMA_CREATOR_ON_CONFIG_CHANGED, null }, { GRAPHQL_SCHEMA_REFRESH_ON_CONFIG_CHANGED, null }, { GRAPHQL_SCHEMA_EVICTION_ON_CONFIG_CHANGED, null }, - { LOG_LEVEL_INITIALIZER_ON_CONFIG_CHANGE, null } + { LOG_LEVEL_INITIALIZER_ON_CONFIG_CHANGE, null }, + { EMBEDDING_SERVICE_ON_CONFIG_CHANGED, null } }; } diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs b/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs new file mode 100644 index 0000000000..9b2efc994b --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingProviderType.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.Serialization; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.Converters; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Represents the supported embedding provider types. +/// +[JsonConverter(typeof(EnumMemberJsonEnumConverterFactory))] +public enum EmbeddingProviderType +{ + /// + /// Azure OpenAI embedding provider. + /// + [EnumMember(Value = "azure-openai")] + AzureOpenAI, + + /// + /// OpenAI embedding provider. + /// + [EnumMember(Value = "openai")] + OpenAI +} diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs new file mode 100644 index 0000000000..b019aa9aef --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsEndpointOptions.cs @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Endpoint configuration for the embedding service. +/// +public record EmbeddingsEndpointOptions +{ + /// + /// Default path for the embedding endpoint. + /// + public const string DEFAULT_PATH = "/embed"; + + /// + /// Anonymous role constant. + /// + public const string ANONYMOUS_ROLE = "anonymous"; + + /// + /// Whether the endpoint is enabled. Defaults to false. + /// + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } + + /// + /// Flag indicating whether the user provided the enabled setting. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedEnabled { get; init; } + + /// + /// The endpoint path. Defaults to "/embed". + /// + [JsonPropertyName("path")] + public string? Path { get; init; } + + /// + /// Flag indicating whether the user provided a custom path. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedPath { get; init; } + + /// + /// The roles allowed to access the embedding endpoint. + /// In development mode, defaults to ["anonymous"]. + /// In production mode, must be explicitly configured. + /// + [JsonPropertyName("roles")] + public string[]? Roles { get; init; } + + /// + /// Flag indicating whether the user provided roles. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedRoles { get; init; } + + /// + /// Gets the effective path, using default if not specified. + /// + [JsonIgnore] + public string EffectivePath => Path ?? DEFAULT_PATH; + + /// + /// Gets the effective roles based on host mode. + /// In development mode, returns ["anonymous"] if no roles specified. + /// In production mode, returns the configured roles or empty array. + /// + /// Whether the host is in development mode. + /// Array of allowed roles. + public string[] GetEffectiveRoles(bool isDevelopmentMode) + { + if (Roles is not null && Roles.Length > 0) + { + return Roles; + } + + // In development mode, default to anonymous access + if (isDevelopmentMode) + { + return new[] { ANONYMOUS_ROLE }; + } + + // In production mode with no roles specified, return empty (no access) + return Array.Empty(); + } + + /// + /// Checks if the given role is allowed to access the embedding endpoint. + /// + /// The role to check. + /// Whether the host is in development mode. + /// True if the role is allowed; otherwise, false. + public bool IsRoleAllowed(string role, bool isDevelopmentMode) + { + string[] effectiveRoles = GetEffectiveRoles(isDevelopmentMode); + return effectiveRoles.Contains(role, StringComparer.OrdinalIgnoreCase); + } + + /// + /// Default constructor. + /// + public EmbeddingsEndpointOptions() + { + Enabled = false; + } + + /// + /// Constructor with optional parameters. + /// + [JsonConstructor] + public EmbeddingsEndpointOptions( + bool? enabled = null, + string? path = null, + string[]? roles = null) + { + if (enabled.HasValue) + { + Enabled = enabled.Value; + UserProvidedEnabled = true; + } + else + { + Enabled = false; + } + + if (path is not null) + { + Path = path; + UserProvidedPath = true; + } + + if (roles is not null) + { + Roles = roles; + UserProvidedRoles = true; + } + } +} diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs new file mode 100644 index 0000000000..bf2a79764c --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsHealthCheckConfig.cs @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Health check configuration for embeddings. +/// Validates that the embedding service is responding within threshold and returning expected results. +/// +public record EmbeddingsHealthCheckConfig : HealthCheckConfig +{ + /// + /// Default threshold for embedding health check in milliseconds. + /// + public const int DEFAULT_THRESHOLD_MS = 5000; + + /// + /// Default test text used for health check validation. + /// + public const string DEFAULT_TEST_TEXT = "health check"; + + /// + /// The expected milliseconds the embedding request should complete within to be considered healthy. + /// If the request takes longer than this value, the health check will be considered unhealthy. + /// Requests completing at exactly the threshold are considered healthy. + /// Default: 5000ms (5 seconds) + /// + [JsonPropertyName("threshold-ms")] + public int ThresholdMs { get; init; } + + /// + /// Flag indicating whether the user provided a custom threshold. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedThresholdMs { get; init; } + + /// + /// The test text to use for health check validation. + /// This text will be embedded and the result validated. + /// Default: "health check" + /// + [JsonPropertyName("test-text")] + public string TestText { get; init; } + + /// + /// Flag indicating whether the user provided custom test text. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedTestText { get; init; } + + /// + /// The expected number of dimensions in the embedding result. + /// If specified, the health check will verify the embedding has this many dimensions. + /// If not specified, dimension validation is skipped. + /// + [JsonPropertyName("expected-dimensions")] + public int? ExpectedDimensions { get; init; } + + /// + /// Flag indicating whether the user provided expected dimensions. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedExpectedDimensions { get; init; } + + /// + /// Default constructor with default values. + /// + public EmbeddingsHealthCheckConfig() : base() + { + ThresholdMs = DEFAULT_THRESHOLD_MS; + TestText = DEFAULT_TEST_TEXT; + } + + /// + /// Constructor with optional parameters. + /// + [JsonConstructor] + public EmbeddingsHealthCheckConfig( + bool? enabled = null, + int? thresholdMs = null, + string? testText = null, + int? expectedDimensions = null) : base(enabled) + { + if (thresholdMs is not null) + { + ThresholdMs = (int)thresholdMs; + UserProvidedThresholdMs = true; + } + else + { + ThresholdMs = DEFAULT_THRESHOLD_MS; + } + + if (testText is not null) + { + TestText = testText; + UserProvidedTestText = true; + } + else + { + TestText = DEFAULT_TEST_TEXT; + } + + if (expectedDimensions is not null) + { + ExpectedDimensions = expectedDimensions; + UserProvidedExpectedDimensions = true; + } + } +} diff --git a/src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs b/src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs new file mode 100644 index 0000000000..a1afd9abf7 --- /dev/null +++ b/src/Config/ObjectModel/Embeddings/EmbeddingsOptions.cs @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel.Embeddings; + +/// +/// Represents the options for configuring the embedding service. +/// Used for text embedding/vectorization with OpenAI or Azure OpenAI providers. +/// +public record EmbeddingsOptions +{ + /// + /// Default timeout in milliseconds for embedding requests. + /// + public const int DEFAULT_TIMEOUT_MS = 30000; + + /// + /// Default API version for Azure OpenAI. + /// + public const string DEFAULT_AZURE_API_VERSION = "2024-02-01"; + + /// + /// Default model for OpenAI embeddings. + /// + public const string DEFAULT_OPENAI_MODEL = "text-embedding-3-small"; + + /// + /// Whether the embedding service is enabled. Defaults to true. + /// When false, the embedding service will not be used. + /// + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } = true; + + /// + /// Flag indicating whether the user provided the enabled setting. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public bool UserProvidedEnabled { get; init; } + + /// + /// The embedding provider type (azure-openai or openai). + /// Required. + /// + [JsonPropertyName("provider")] + public EmbeddingProviderType Provider { get; init; } + + /// + /// The provider base URL. + /// Required. + /// + [JsonPropertyName("base-url")] + public string BaseUrl { get; init; } + + /// + /// The API key for authentication. + /// Required. + /// + [JsonPropertyName("api-key")] + public string ApiKey { get; init; } + + /// + /// The model or deployment name. + /// For Azure OpenAI, this is the deployment name. + /// For OpenAI, this is the model name (defaults to text-embedding-3-small if not specified). + /// + [JsonPropertyName("model")] + public string? Model { get; init; } + + /// + /// Azure API version. Only used for Azure OpenAI provider. + /// Defaults to 2024-02-01. + /// + [JsonPropertyName("api-version")] + public string? ApiVersion { get; init; } + + /// + /// Output vector dimensions. Optional, uses model default if not specified. + /// + [JsonPropertyName("dimensions")] + public int? Dimensions { get; init; } + + /// + /// Request timeout in milliseconds. Defaults to 30000 (30 seconds). + /// + [JsonPropertyName("timeout-ms")] + public int? TimeoutMs { get; init; } + + /// + /// Endpoint configuration for the embedding service. + /// + [JsonPropertyName("endpoint")] + public EmbeddingsEndpointOptions? Endpoint { get; init; } + + /// + /// Health check configuration for the embedding service. + /// + [JsonPropertyName("health")] + public EmbeddingsHealthCheckConfig? Health { get; init; } + + /// + /// Flag which informs whether the user provided a custom timeout value. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(TimeoutMs))] + public bool UserProvidedTimeoutMs { get; init; } + + /// + /// Flag which informs whether the user provided a custom API version. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(ApiVersion))] + public bool UserProvidedApiVersion { get; init; } + + /// + /// Flag which informs whether the user provided custom dimensions. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(Dimensions))] + public bool UserProvidedDimensions { get; init; } + + /// + /// Flag which informs whether the user provided a custom model. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + [MemberNotNullWhen(true, nameof(Model))] + public bool UserProvidedModel { get; init; } + + /// + /// Gets the effective timeout in milliseconds, using default if not specified. + /// + [JsonIgnore] + public int EffectiveTimeoutMs => TimeoutMs ?? DEFAULT_TIMEOUT_MS; + + /// + /// Gets the effective API version for Azure OpenAI, using default if not specified. + /// + [JsonIgnore] + public string EffectiveApiVersion => ApiVersion ?? DEFAULT_AZURE_API_VERSION; + + /// + /// Gets the effective model name, using default for OpenAI if not specified. + /// For Azure OpenAI, model is required (no default). + /// + [JsonIgnore] + public string? EffectiveModel => Model ?? (Provider == EmbeddingProviderType.OpenAI ? DEFAULT_OPENAI_MODEL : null); + + /// + /// Returns true if embedding health check is enabled. + /// + [JsonIgnore] + public bool IsHealthCheckEnabled => Health?.Enabled ?? false; + + /// + /// Returns true if embedding endpoint is enabled. + /// + [JsonIgnore] + public bool IsEndpointEnabled => Endpoint?.Enabled ?? false; + + /// + /// Gets the effective endpoint path. + /// + [JsonIgnore] + public string EffectiveEndpointPath => Endpoint?.EffectivePath ?? EmbeddingsEndpointOptions.DEFAULT_PATH; + + [JsonConstructor] + public EmbeddingsOptions( + EmbeddingProviderType Provider, + string BaseUrl, + string ApiKey, + bool? Enabled = null, + string? Model = null, + string? ApiVersion = null, + int? Dimensions = null, + int? TimeoutMs = null, + EmbeddingsEndpointOptions? Endpoint = null, + EmbeddingsHealthCheckConfig? Health = null) + { + this.Provider = Provider; + this.BaseUrl = BaseUrl; + this.ApiKey = ApiKey; + this.Endpoint = Endpoint; + this.Health = Health; + + if (Enabled.HasValue) + { + this.Enabled = Enabled.Value; + UserProvidedEnabled = true; + } + else + { + this.Enabled = true; // Default to enabled + } + + if (Model is not null) + { + this.Model = Model; + UserProvidedModel = true; + } + + if (ApiVersion is not null) + { + this.ApiVersion = ApiVersion; + UserProvidedApiVersion = true; + } + + if (Dimensions is not null) + { + this.Dimensions = Dimensions; + UserProvidedDimensions = true; + } + + if (TimeoutMs is not null) + { + this.TimeoutMs = TimeoutMs; + UserProvidedTimeoutMs = true; + } + } +} diff --git a/src/Config/ObjectModel/RuntimeOptions.cs b/src/Config/ObjectModel/RuntimeOptions.cs index 6f6c046651..2a17e89a90 100644 --- a/src/Config/ObjectModel/RuntimeOptions.cs +++ b/src/Config/ObjectModel/RuntimeOptions.cs @@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; namespace Azure.DataApiBuilder.Config.ObjectModel; @@ -17,6 +18,7 @@ public record RuntimeOptions public RuntimeCacheOptions? Cache { get; init; } public PaginationOptions? Pagination { get; init; } public RuntimeHealthCheckConfig? Health { get; init; } + public EmbeddingsOptions? Embeddings { get; init; } [JsonConstructor] public RuntimeOptions( @@ -28,7 +30,8 @@ public RuntimeOptions( TelemetryOptions? Telemetry = null, RuntimeCacheOptions? Cache = null, PaginationOptions? Pagination = null, - RuntimeHealthCheckConfig? Health = null) + RuntimeHealthCheckConfig? Health = null, + EmbeddingsOptions? Embeddings = null) { this.Rest = Rest; this.GraphQL = GraphQL; @@ -39,6 +42,7 @@ public RuntimeOptions( this.Cache = Cache; this.Pagination = Pagination; this.Health = Health; + this.Embeddings = Embeddings; } /// @@ -74,4 +78,12 @@ Mcp is null || Health is null || Health?.Enabled is null || Health?.Enabled is true; + + /// + /// Indicates whether embeddings are configured. + /// Embeddings are considered configured when the Embeddings property is not null. + /// + [JsonIgnore] + [MemberNotNullWhen(true, nameof(Embeddings))] + public bool IsEmbeddingsConfigured => Embeddings is not null; } diff --git a/src/Config/RuntimeConfigLoader.cs b/src/Config/RuntimeConfigLoader.cs index 9a54d09d8e..c43f8c2feb 100644 --- a/src/Config/RuntimeConfigLoader.cs +++ b/src/Config/RuntimeConfigLoader.cs @@ -333,6 +333,9 @@ public static JsonSerializerOptions GetSerializationOptions( // Add AzureKeyVaultOptionsConverterFactory to ensure AKV config is deserialized properly options.Converters.Add(new AzureKeyVaultOptionsConverterFactory(replacementSettings)); + // Add EmbeddingsOptionsConverterFactory to handle embeddings configuration + options.Converters.Add(new EmbeddingsOptionsConverterFactory(replacementSettings)); + // Only add the extensible string converter if we have replacement settings if (replacementSettings is not null) { diff --git a/src/Core/Services/Embeddings/EmbeddingService.cs b/src/Core/Services/Embeddings/EmbeddingService.cs new file mode 100644 index 0000000000..017f6801da --- /dev/null +++ b/src/Core/Services/Embeddings/EmbeddingService.cs @@ -0,0 +1,491 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Net.Http.Headers; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Microsoft.Extensions.Logging; +using ZiggyCreatures.Caching.Fusion; + +namespace Azure.DataApiBuilder.Core.Services.Embeddings; + +/// +/// Service implementation for text embedding/vectorization. +/// Supports both OpenAI and Azure OpenAI providers. +/// Includes L1 memory cache using FusionCache to prevent duplicate embedding API calls. +/// +public class EmbeddingService : IEmbeddingService +{ + private readonly HttpClient _httpClient; + private readonly EmbeddingsOptions _options; + private readonly ILogger _logger; + private readonly IFusionCache _cache; + private readonly string _providerName; + + // Constants + private const char KEY_DELIMITER = ':'; + private const string CACHE_KEY_PREFIX = "embedding"; + + /// + /// Default cache TTL in hours. Set high since embeddings are deterministic and don't get outdated. + /// + private const int DEFAULT_CACHE_TTL_HOURS = 24; + + /// + /// JSON serializer options for request/response handling. + /// + private static readonly JsonSerializerOptions _jsonSerializerOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + /// + /// Initializes a new instance of the EmbeddingService. + /// + /// The HTTP client for making API requests. + /// The embedding configuration options. + /// The logger instance. + /// The FusionCache instance for L1 memory caching. + public EmbeddingService( + HttpClient httpClient, + EmbeddingsOptions options, + ILogger logger, + IFusionCache cache) + { + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _cache = cache ?? throw new ArgumentNullException(nameof(cache)); + + // Cache provider name for telemetry to avoid repeated string allocations + _providerName = _options.Provider.ToString().ToLowerInvariant(); + + // Validate required options + if (string.IsNullOrEmpty(_options.BaseUrl)) + { + throw new ArgumentException("BaseUrl is required in EmbeddingsOptions.", nameof(options)); + } + + if (string.IsNullOrEmpty(_options.ApiKey)) + { + throw new ArgumentException("ApiKey is required in EmbeddingsOptions.", nameof(options)); + } + + // Azure OpenAI requires model/deployment name + if (_options.Provider == EmbeddingProviderType.AzureOpenAI && string.IsNullOrEmpty(_options.EffectiveModel)) + { + throw new InvalidOperationException("Model/deployment name is required for Azure OpenAI provider."); + } + + ConfigureHttpClient(); + } + + /// + /// Configures the HTTP client with timeout and authentication headers. + /// + private void ConfigureHttpClient() + { + _httpClient.Timeout = TimeSpan.FromMilliseconds(_options.EffectiveTimeoutMs); + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + _httpClient.DefaultRequestHeaders.Add("api-key", _options.ApiKey); + } + else + { + _httpClient.DefaultRequestHeaders.Authorization = + new AuthenticationHeaderValue("Bearer", _options.ApiKey); + } + + _httpClient.DefaultRequestHeaders.Accept.Clear(); + _httpClient.DefaultRequestHeaders.Accept.Add( + new MediaTypeWithQualityHeaderValue("application/json")); + } + + /// + public bool IsEnabled => _options.Enabled; + + /// + public async Task TryEmbedAsync(string text, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + _logger.LogDebug("Embedding service is disabled, skipping embed request"); + return new EmbeddingResult(false, null, "Embedding service is disabled."); + } + + if (string.IsNullOrEmpty(text)) + { + _logger.LogWarning("TryEmbedAsync called with null or empty text"); + return new EmbeddingResult(false, null, "Text cannot be null or empty."); + } + + Stopwatch stopwatch = Stopwatch.StartNew(); + using Activity? activity = EmbeddingTelemetryHelper.StartEmbeddingActivity("TryEmbedAsync"); + activity?.SetEmbeddingActivityTags(_providerName, _options.EffectiveModel, textCount: 1); + + try + { + EmbeddingTelemetryHelper.TrackEmbeddingRequest(_providerName, textCount: 1); + + float[] embedding = await EmbedAsync(text, cancellationToken); + + stopwatch.Stop(); + activity?.SetEmbeddingActivitySuccess(stopwatch.Elapsed.TotalMilliseconds, embedding.Length); + EmbeddingTelemetryHelper.TrackTotalDuration(_providerName, stopwatch.Elapsed, fromCache: false); + EmbeddingTelemetryHelper.TrackDimensions(_providerName, embedding.Length); + + return new EmbeddingResult(true, embedding); + } + catch (Exception ex) + { + stopwatch.Stop(); + _logger.LogError(ex, "Failed to generate embedding for text"); + activity?.SetEmbeddingActivityError(ex); + EmbeddingTelemetryHelper.TrackError(_providerName, ex.GetType().Name); + + return new EmbeddingResult(false, null, ex.Message); + } + } + + /// + public async Task TryEmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + _logger.LogDebug("Embedding service is disabled, skipping batch embed request"); + return new EmbeddingBatchResult(false, null, "Embedding service is disabled."); + } + + if (texts is null || texts.Length == 0) + { + _logger.LogWarning("TryEmbedBatchAsync called with null or empty texts array"); + return new EmbeddingBatchResult(false, null, "Texts array cannot be null or empty."); + } + + Stopwatch stopwatch = Stopwatch.StartNew(); + using Activity? activity = EmbeddingTelemetryHelper.StartEmbeddingActivity("TryEmbedBatchAsync"); + activity?.SetEmbeddingActivityTags(_providerName, _options.EffectiveModel, texts.Length); + + try + { + EmbeddingTelemetryHelper.TrackEmbeddingRequest(_providerName, texts.Length); + + float[][] embeddings = await EmbedBatchAsync(texts, cancellationToken); + + stopwatch.Stop(); + int dimensions = embeddings.Length > 0 ? embeddings[0].Length : 0; + activity?.SetEmbeddingActivitySuccess(stopwatch.Elapsed.TotalMilliseconds, dimensions); + EmbeddingTelemetryHelper.TrackTotalDuration(_providerName, stopwatch.Elapsed, fromCache: false); + if (dimensions > 0) + { + EmbeddingTelemetryHelper.TrackDimensions(_providerName, dimensions); + } + + return new EmbeddingBatchResult(true, embeddings); + } + catch (Exception ex) + { + stopwatch.Stop(); + _logger.LogError(ex, "Failed to generate embeddings for batch of {Count} texts", texts.Length); + activity?.SetEmbeddingActivityError(ex); + EmbeddingTelemetryHelper.TrackError(_providerName, ex.GetType().Name); + + return new EmbeddingBatchResult(false, null, ex.Message); + } + } + + /// + public async Task EmbedAsync(string text, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + throw new InvalidOperationException("Embedding service is disabled."); + } + + if (string.IsNullOrEmpty(text)) + { + throw new ArgumentException("Text cannot be null or empty.", nameof(text)); + } + + string cacheKey = CreateCacheKey(text); + + float[]? embedding = await _cache.GetOrSetAsync( + key: cacheKey, + async (FusionCacheFactoryExecutionContext ctx, CancellationToken ct) => + { + _logger.LogDebug("Embedding cache miss, calling API for text hash {TextHash}", cacheKey); + + float[][] results = await EmbedFromApiAsync(new[] { text }, ct); + float[] result = results[0]; + + // Validate the embedding result is not empty + if (result.Length == 0) + { + throw new InvalidOperationException("API returned empty embedding array."); + } + + // L1 only - skip distributed cache + ctx.Options.SetSkipDistributedCache(true, true); + ctx.Options.SetDuration(TimeSpan.FromHours(DEFAULT_CACHE_TTL_HOURS)); + + return result; + }, + token: cancellationToken); + + if (embedding is null || embedding.Length == 0) + { + throw new InvalidOperationException("Failed to get embedding from cache or API."); + } + + return embedding; + } + + /// + public async Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default) + { + if (!_options.Enabled) + { + throw new InvalidOperationException("Embedding service is disabled."); + } + if (texts is null || texts.Length == 0) + { + throw new ArgumentException("Texts cannot be null or empty.", nameof(texts)); + } + + // For batch, check cache for each text individually + string[] cacheKeys = texts.Select(CreateCacheKey).ToArray(); + float[]?[] results = new float[texts.Length][]; + List uncachedIndices = new(); + int cacheHits = 0; + + // Check cache for each text + for (int i = 0; i < texts.Length; i++) + { + MaybeValue cached = _cache.TryGet(key: cacheKeys[i]); + + if (cached.HasValue) + { + _logger.LogDebug("Embedding cache hit for text hash {TextHash}", cacheKeys[i]); + results[i] = cached.Value; + cacheHits++; + EmbeddingTelemetryHelper.TrackCacheHit(_providerName); + } + else + { + uncachedIndices.Add(i); + EmbeddingTelemetryHelper.TrackCacheMiss(_providerName); + } + } + + // If all texts were cached, return immediately + if (uncachedIndices.Count == 0) + { + return results!; + } + + _logger.LogDebug("Embedding cache miss for {Count} text(s), calling API", uncachedIndices.Count); + + // Call API for uncached texts only + string[] uncachedTexts = uncachedIndices.Select(i => texts[i]).ToArray(); + + Stopwatch apiStopwatch = Stopwatch.StartNew(); + float[][] apiResults = await EmbedFromApiAsync(uncachedTexts, cancellationToken); + apiStopwatch.Stop(); + + // Track API call telemetry + EmbeddingTelemetryHelper.TrackApiCall(_providerName, uncachedTexts.Length); + EmbeddingTelemetryHelper.TrackApiDuration(_providerName, apiStopwatch.Elapsed, uncachedTexts.Length); + + // Cache new results and merge with cached results + for (int i = 0; i < uncachedIndices.Count; i++) + { + int originalIndex = uncachedIndices[i]; + results[originalIndex] = apiResults[i]; + + // Store in L1 cache only + _cache.Set( + key: cacheKeys[originalIndex], + value: apiResults[i], + options => + { + options.SetSkipDistributedCache(true, true); + options.SetDuration(TimeSpan.FromHours(DEFAULT_CACHE_TTL_HOURS)); + }); + } + + return results!; + } + + /// + /// Creates a cache key from the text using SHA256 hash. + /// Format: embedding:{provider}:{model}:{SHA256_hash} + /// Includes provider and model to prevent cross-configuration collisions. + /// Uses hash to keep cache keys small and deterministic. + /// + /// The text to create a cache key for. + /// Cache key string. + private string CreateCacheKey(string text) + { + // Include provider and model in hash to avoid cross-provider/model collisions + string keyInput = $"{_options.Provider}:{_options.EffectiveModel}:{text}"; + byte[] textBytes = Encoding.UTF8.GetBytes(keyInput); + byte[] hashBytes = SHA256.HashData(textBytes); + string hashHex = Convert.ToHexString(hashBytes); + + StringBuilder cacheKeyBuilder = new(); + cacheKeyBuilder.Append(CACHE_KEY_PREFIX); + cacheKeyBuilder.Append(KEY_DELIMITER); + cacheKeyBuilder.Append(hashHex); + + return cacheKeyBuilder.ToString(); + } + + /// + /// Calls the embedding API to get embeddings for the provided texts. + /// + private async Task EmbedFromApiAsync(string[] texts, CancellationToken cancellationToken) + { + string requestUrl = BuildRequestUrl(); + object requestBody = BuildRequestBody(texts); + + string requestJson = JsonSerializer.Serialize(requestBody, _jsonSerializerOptions); + using HttpContent content = new StringContent(requestJson, Encoding.UTF8, "application/json"); + + _logger.LogDebug("Sending embedding request to {Url} with {Count} text(s)", requestUrl, texts.Length); + + HttpResponseMessage response = await _httpClient.PostAsync(requestUrl, content, cancellationToken); + + if (!response.IsSuccessStatusCode) + { + string errorContent = await response.Content.ReadAsStringAsync(cancellationToken); + _logger.LogError("Embedding request failed with status {StatusCode}: {ErrorContent}", + response.StatusCode, errorContent); + throw new HttpRequestException( + $"Embedding request failed with status code {response.StatusCode}: {errorContent}"); + } + + string responseJson = await response.Content.ReadAsStringAsync(cancellationToken); + EmbeddingResponse? embeddingResponse = JsonSerializer.Deserialize(responseJson, _jsonSerializerOptions); + + if (embeddingResponse?.Data is null || embeddingResponse.Data.Count == 0) + { + throw new InvalidOperationException("No embedding data received from the provider."); + } + + // Sort by index to ensure correct order and extract embeddings + List sortedData = embeddingResponse.Data.OrderBy(d => d.Index).ToList(); + return sortedData.Select(d => d.Embedding).ToArray(); + } + + /// + /// Builds the request URL based on the provider type. + /// + private string BuildRequestUrl() + { + string baseUrl = _options.BaseUrl.TrimEnd('/'); + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + // Azure OpenAI: {baseUrl}/openai/deployments/{deployment}/embeddings?api-version={version} + string model = _options.EffectiveModel + ?? throw new InvalidOperationException("Model/deployment name is required for Azure OpenAI."); + + return $"{baseUrl}/openai/deployments/{model}/embeddings?api-version={_options.EffectiveApiVersion}"; + } + else + { + // OpenAI: {baseUrl}/v1/embeddings + return $"{baseUrl}/v1/embeddings"; + } + } + + /// + /// Builds the request body based on the provider type. + /// + private object BuildRequestBody(string[] texts) + { + // Use single string for single text, array for batch + object input = texts.Length == 1 ? texts[0] : texts; + + if (_options.Provider == EmbeddingProviderType.AzureOpenAI) + { + // Azure OpenAI request body + if (_options.UserProvidedDimensions) + { + return new + { + input, + dimensions = _options.Dimensions + }; + } + + return new { input }; + } + else + { + // OpenAI request body - includes model in body + string model = _options.EffectiveModel ?? EmbeddingsOptions.DEFAULT_OPENAI_MODEL; + + if (_options.UserProvidedDimensions) + { + return new + { + model, + input, + dimensions = _options.Dimensions + }; + } + + return new + { + model, + input + }; + } + } + + /// + /// Response model for embedding API responses. + /// + private sealed class EmbeddingResponse + { + [JsonPropertyName("data")] + public List? Data { get; set; } + + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("usage")] + public EmbeddingUsage? Usage { get; set; } + } + + /// + /// Individual embedding data in the response. + /// + private sealed class EmbeddingData + { + [JsonPropertyName("index")] + public int Index { get; set; } + + [JsonPropertyName("embedding")] + public float[] Embedding { get; set; } = Array.Empty(); + } + + /// + /// Token usage information in the response. + /// + private sealed class EmbeddingUsage + { + [JsonPropertyName("prompt_tokens")] + public int PromptTokens { get; set; } + + [JsonPropertyName("total_tokens")] + public int TotalTokens { get; set; } + } +} diff --git a/src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs b/src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs new file mode 100644 index 0000000000..d7ed9bfd05 --- /dev/null +++ b/src/Core/Services/Embeddings/EmbeddingTelemetryHelper.cs @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Diagnostics.Metrics; +using Azure.DataApiBuilder.Core.Telemetry; +using OpenTelemetry.Trace; + +namespace Azure.DataApiBuilder.Core.Services.Embeddings; + +/// +/// Helper class for tracking embedding-related telemetry metrics and traces. +/// +public static class EmbeddingTelemetryHelper +{ + /// + /// Meter name for embedding metrics. + /// + public static readonly string MeterName = "DataApiBuilder.Embeddings"; + + // Metrics + private static readonly Meter _meter = new(MeterName); + + // Counters + private static readonly Counter _embeddingRequests = _meter.CreateCounter( + "embedding_requests_total", + description: "Total number of embedding requests"); + + private static readonly Counter _embeddingApiCalls = _meter.CreateCounter( + "embedding_api_calls_total", + description: "Total number of embedding API calls (excludes cache hits)"); + + private static readonly Counter _embeddingCacheHits = _meter.CreateCounter( + "embedding_cache_hits_total", + description: "Total number of embedding cache hits"); + + private static readonly Counter _embeddingCacheMisses = _meter.CreateCounter( + "embedding_cache_misses_total", + description: "Total number of embedding cache misses"); + + private static readonly Counter _embeddingErrors = _meter.CreateCounter( + "embedding_errors_total", + description: "Total number of embedding errors"); + + private static readonly Counter _embeddingTextsProcessed = _meter.CreateCounter( + "embedding_texts_processed_total", + description: "Total number of texts processed for embedding"); + + // Histograms for timing and sizing + private static readonly Histogram _embeddingApiDuration = _meter.CreateHistogram( + "embedding_api_duration_ms", + unit: "ms", + description: "Duration of embedding API calls in milliseconds"); + + private static readonly Histogram _embeddingTotalDuration = _meter.CreateHistogram( + "embedding_total_duration_ms", + unit: "ms", + description: "Total duration of embedding operations including cache lookup"); + + private static readonly Histogram _embeddingTokens = _meter.CreateHistogram( + "embedding_tokens_total", + description: "Total tokens used in embedding requests"); + + private static readonly Histogram _embeddingDimensions = _meter.CreateHistogram( + "embedding_dimensions", + description: "Number of dimensions in embedding vectors"); + + /// + /// Tracks an embedding request (entry point, includes cache hits). + /// + /// The embedding provider (e.g., azure-openai, openai). + /// Number of texts being embedded. + public static void TrackEmbeddingRequest(string provider, int textCount) + { + _embeddingRequests.Add(1, + new KeyValuePair("provider", provider)); + _embeddingTextsProcessed.Add(textCount, + new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding API call (cache miss, actual API call made). + /// + /// The embedding provider. + /// Number of texts sent to API. + public static void TrackApiCall(string provider, int textCount) + { + _embeddingApiCalls.Add(1, + new KeyValuePair("provider", provider), + new KeyValuePair("text_count", textCount)); + } + + /// + /// Tracks an embedding cache hit. + /// + /// The embedding provider. + public static void TrackCacheHit(string provider) + { + _embeddingCacheHits.Add(1, new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding cache miss. + /// + /// The embedding provider. + public static void TrackCacheMiss(string provider) + { + _embeddingCacheMisses.Add(1, new KeyValuePair("provider", provider)); + } + + /// + /// Tracks an embedding error. + /// + /// The embedding provider. + /// The type of error that occurred. + public static void TrackError(string provider, string errorType) + { + _embeddingErrors.Add(1, + new KeyValuePair("provider", provider), + new KeyValuePair("error_type", errorType)); + } + + /// + /// Tracks the duration of an embedding API call. + /// + /// The embedding provider. + /// The duration of the API call. + /// Number of texts embedded. + public static void TrackApiDuration(string provider, TimeSpan duration, int textCount) + { + _embeddingApiDuration.Record(duration.TotalMilliseconds, + new KeyValuePair("provider", provider), + new KeyValuePair("text_count", textCount)); + } + + /// + /// Tracks the total duration of an embedding operation (including cache lookup). + /// + /// The embedding provider. + /// The total duration. + /// Whether result was from cache. + public static void TrackTotalDuration(string provider, TimeSpan duration, bool fromCache) + { + _embeddingTotalDuration.Record(duration.TotalMilliseconds, + new KeyValuePair("provider", provider), + new KeyValuePair("from_cache", fromCache)); + } + + /// + /// Tracks token usage from an embedding request. + /// + /// The embedding provider. + /// Total tokens used. + public static void TrackTokenUsage(string provider, long totalTokens) + { + _embeddingTokens.Record(totalTokens, new KeyValuePair("provider", provider)); + } + + /// + /// Tracks embedding vector dimensions. + /// + /// The embedding provider. + /// Number of dimensions in the vector. + public static void TrackDimensions(string provider, int dimensions) + { + _embeddingDimensions.Record(dimensions, new KeyValuePair("provider", provider)); + } + + /// + /// Starts an activity for embedding operations. + /// + /// Name of the operation (e.g., "EmbedAsync", "EmbedBatchAsync"). + /// The started activity, or null if tracing is not enabled. + public static Activity? StartEmbeddingActivity(string operationName) + { + return TelemetryTracesHelper.DABActivitySource.StartActivity( + name: $"Embedding.{operationName}", + kind: ActivityKind.Client); + } + + /// + /// Sets embedding-specific tags on an activity. + /// + /// The activity to tag. + /// The embedding provider. + /// The model being used. + /// Number of texts being embedded. + public static void SetEmbeddingActivityTags( + this Activity activity, + string provider, + string? model, + int textCount) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.provider", provider); + if (!string.IsNullOrEmpty(model)) + { + activity.SetTag("embedding.model", model); + } + + activity.SetTag("embedding.text_count", textCount); + } + } + + /// + /// Records cache status on an activity. + /// + /// The activity to tag. + /// Number of cache hits. + /// Number of cache misses. + public static void SetCacheActivityTags( + this Activity activity, + int cacheHits, + int cacheMisses) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.cache_hits", cacheHits); + activity.SetTag("embedding.cache_misses", cacheMisses); + } + } + + /// + /// Records successful completion of an embedding activity. + /// + /// The activity to complete. + /// Duration in milliseconds. + /// Number of dimensions in the result. + public static void SetEmbeddingActivitySuccess( + this Activity activity, + double durationMs, + int? dimensions = null) + { + if (activity.IsAllDataRequested) + { + activity.SetTag("embedding.duration_ms", durationMs); + if (dimensions.HasValue) + { + activity.SetTag("embedding.dimensions", dimensions.Value); + } + + activity.SetStatus(ActivityStatusCode.Ok); + } + } + + /// + /// Records an error on an embedding activity. + /// + /// The activity to record error on. + /// The exception that occurred. + public static void SetEmbeddingActivityError( + this Activity activity, + Exception ex) + { + if (activity.IsAllDataRequested) + { + activity.SetStatus(ActivityStatusCode.Error, ex.Message); + activity.RecordException(ex); + activity.SetTag("error.type", ex.GetType().Name); + activity.SetTag("error.message", ex.Message); + } + } +} diff --git a/src/Core/Services/Embeddings/IEmbeddingService.cs b/src/Core/Services/Embeddings/IEmbeddingService.cs new file mode 100644 index 0000000000..ef5a9e490c --- /dev/null +++ b/src/Core/Services/Embeddings/IEmbeddingService.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Services.Embeddings; + +/// +/// Result of a TryEmbed operation. +/// +/// Whether the embedding was generated successfully. +/// The embedding vector, or null if unsuccessful. +/// Error message if unsuccessful, or null if successful. +public record EmbeddingResult(bool Success, float[]? Embedding, string? ErrorMessage = null); + +/// +/// Result of a TryEmbedBatch operation. +/// +/// Whether the embeddings were generated successfully. +/// The embedding vectors, or null if unsuccessful. +/// Error message if unsuccessful, or null if successful. +public record EmbeddingBatchResult(bool Success, float[][]? Embeddings, string? ErrorMessage = null); + +/// +/// Service interface for text embedding/vectorization. +/// Supports both single text and batch embedding operations. +/// +public interface IEmbeddingService +{ + /// + /// Gets whether the embedding service is enabled. + /// + bool IsEnabled { get; } + + /// + /// Attempts to generate an embedding vector for a single text input. + /// Returns a result indicating success or failure without throwing exceptions. + /// + /// The text to embed. + /// Cancellation token for the operation. + /// Result containing the embedding if successful, or error information if not. + Task TryEmbedAsync(string text, CancellationToken cancellationToken = default); + + /// + /// Attempts to generate embedding vectors for multiple text inputs in a batch. + /// Returns a result indicating success or failure without throwing exceptions. + /// + /// The texts to embed. + /// Cancellation token for the operation. + /// Result containing the embeddings if successful, or error information if not. + Task TryEmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default); + + /// + /// Generates an embedding vector for a single text input. + /// Throws if the service is disabled or an error occurs. + /// + /// The text to embed. + /// Cancellation token for the operation. + /// The embedding vector as an array of floats. + /// Thrown when the service is disabled. + Task EmbedAsync(string text, CancellationToken cancellationToken = default); + + /// + /// Generates embedding vectors for multiple text inputs in a batch. + /// Throws if the service is disabled or an error occurs. + /// + /// The texts to embed. + /// Cancellation token for the operation. + /// The embedding vectors as an array of float arrays, matching input order. + /// Thrown when the service is disabled. + Task EmbedBatchAsync(string[] texts, CancellationToken cancellationToken = default); +} diff --git a/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs new file mode 100644 index 0000000000..6c4d9343c4 --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingServiceTests.cs @@ -0,0 +1,208 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Microsoft.Extensions.Logging; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Moq.Protected; +using ZiggyCreatures.Caching.Fusion; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for EmbeddingService. +/// +[TestClass] +public class EmbeddingServiceTests +{ + private Mock> _mockLogger = null!; + private Mock _mockCache = null!; + + [TestInitialize] + public void Setup() + { + _mockLogger = new Mock>(); + _mockCache = new Mock(); + } + + /// + /// Tests that IsEnabled returns true when embeddings are enabled. + /// + [TestMethod] + public void IsEnabled_ReturnsTrue_WhenEnabled() + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Assert + Assert.IsTrue(service.IsEnabled); + } + + /// + /// Tests that IsEnabled returns false when embeddings are disabled. + /// + [TestMethod] + public void IsEnabled_ReturnsFalse_WhenDisabled() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: false, + Model: "text-embedding-ada-002"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Assert + Assert.IsFalse(service.IsEnabled); + } + + /// + /// Tests that TryEmbedAsync returns failure when service is disabled. + /// + [TestMethod] + public async Task TryEmbedAsync_ReturnsFailure_WhenDisabled() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Enabled: false, + Model: "text-embedding-ada-002"); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act + EmbeddingResult result = await service.TryEmbedAsync("test"); + + // Assert + Assert.IsFalse(result.Success); + Assert.IsNull(result.Embedding); + Assert.IsNotNull(result.ErrorMessage); + } + + /// + /// Tests that TryEmbedAsync returns failure for null or empty text. + /// + [DataTestMethod] + [DataRow(null, DisplayName = "Null text returns failure")] + [DataRow("", DisplayName = "Empty text returns failure")] + public async Task TryEmbedAsync_ReturnsFailure_ForNullOrEmptyText(string text) + { + // Arrange + EmbeddingsOptions options = CreateAzureOpenAIOptions(); + HttpClient httpClient = new(); + EmbeddingService service = new(httpClient, options, _mockLogger.Object, _mockCache.Object); + + // Act + EmbeddingResult result = await service.TryEmbedAsync(text!); + + // Assert + Assert.IsFalse(result.Success); + } + + /// + /// Tests that EffectiveModel returns the default model for OpenAI when not specified. + /// + [TestMethod] + public void EmbeddingsOptions_OpenAI_DefaultModel() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.Model); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_OPENAI_MODEL, options.EffectiveModel); + } + + /// + /// Tests that EffectiveModel returns null for Azure OpenAI when model not specified. + /// + [TestMethod] + public void EmbeddingsOptions_AzureOpenAI_NoDefaultModel() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://my.openai.azure.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.Model); + Assert.IsNull(options.EffectiveModel); + } + + /// + /// Tests that EffectiveTimeoutMs returns the default timeout when not specified. + /// + [TestMethod] + public void EmbeddingsOptions_DefaultTimeout() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key"); + + // Assert + Assert.IsNull(options.TimeoutMs); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_TIMEOUT_MS, options.EffectiveTimeoutMs); + } + + /// + /// Tests that custom timeout is used when specified. + /// + [TestMethod] + public void EmbeddingsOptions_CustomTimeout() + { + // Arrange + int customTimeout = 60000; + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-key", + TimeoutMs: customTimeout); + + // Assert + Assert.AreEqual(customTimeout, options.TimeoutMs); + Assert.AreEqual(customTimeout, options.EffectiveTimeoutMs); + Assert.IsTrue(options.UserProvidedTimeoutMs); + } + + #region Helper Methods + + private static EmbeddingsOptions CreateAzureOpenAIOptions() + { + return new EmbeddingsOptions( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://test.openai.azure.com", + ApiKey: "test-api-key", + Model: "text-embedding-ada-002"); + } + + private static EmbeddingsOptions CreateOpenAIOptions() + { + return new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key"); + } + + #endregion +} diff --git a/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs new file mode 100644 index 0000000000..591b6b79ec --- /dev/null +++ b/src/Service.Tests/UnitTests/EmbeddingsOptionsTests.cs @@ -0,0 +1,299 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Text.Json; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.Converters; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Unit tests for EmbeddingsOptions deserialization and EmbeddingProviderType enum. +/// +[TestClass] +public class EmbeddingsOptionsTests +{ + private const string BASIC_CONFIG_WITH_EMBEDDINGS = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""base-url"": ""https://my-openai.openai.azure.com"", + ""api-key"": ""test-api-key"", + ""model"": ""text-embedding-ada-002"", + ""api-version"": ""2024-02-01"", + ""dimensions"": 1536, + ""timeout-ms"": 30000 + } + }, + ""entities"": {} + }"; + + private const string OPENAI_CONFIG = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""openai"", + ""base-url"": ""https://api.openai.com"", + ""api-key"": ""sk-test-key"" + } + }, + ""entities"": {} + }"; + + private const string MINIMAL_AZURE_CONFIG = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""base-url"": ""https://my-openai.openai.azure.com"", + ""api-key"": ""test-api-key"", + ""model"": ""my-deployment"" + } + }, + ""entities"": {} + }"; + + private const string CONFIG_WITHOUT_EMBEDDINGS = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""entities"": {} + }"; + + /// + /// Tests that Azure OpenAI embeddings configuration deserializes correctly. + /// + [TestMethod] + public void TestAzureOpenAIEmbeddingsConfigDeserialization() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(BASIC_CONFIG_WITH_EMBEDDINGS, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig); + Assert.IsNotNull(runtimeConfig.Runtime); + Assert.IsNotNull(runtimeConfig.Runtime.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.AzureOpenAI, embeddings.Provider); + Assert.AreEqual("https://my-openai.openai.azure.com", embeddings.BaseUrl); + Assert.AreEqual("test-api-key", embeddings.ApiKey); + Assert.AreEqual("text-embedding-ada-002", embeddings.Model); + Assert.AreEqual("2024-02-01", embeddings.ApiVersion); + Assert.AreEqual(1536, embeddings.Dimensions); + Assert.AreEqual(30000, embeddings.TimeoutMs); + } + + /// + /// Tests that OpenAI embeddings configuration deserializes correctly with defaults. + /// + [TestMethod] + public void TestOpenAIEmbeddingsConfigWithDefaults() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(OPENAI_CONFIG, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.OpenAI, embeddings.Provider); + Assert.AreEqual("https://api.openai.com", embeddings.BaseUrl); + Assert.AreEqual("sk-test-key", embeddings.ApiKey); + Assert.IsNull(embeddings.Model); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_OPENAI_MODEL, embeddings.EffectiveModel); + Assert.IsNull(embeddings.ApiVersion); + Assert.IsNull(embeddings.Dimensions); + Assert.IsNull(embeddings.TimeoutMs); + Assert.AreEqual(EmbeddingsOptions.DEFAULT_TIMEOUT_MS, embeddings.EffectiveTimeoutMs); + } + + /// + /// Tests that minimal Azure OpenAI config deserializes correctly. + /// + [TestMethod] + public void TestMinimalAzureOpenAIConfig() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(MINIMAL_AZURE_CONFIG, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual(EmbeddingProviderType.AzureOpenAI, embeddings.Provider); + Assert.AreEqual("my-deployment", embeddings.Model); + Assert.AreEqual("my-deployment", embeddings.EffectiveModel); + } + + /// + /// Tests that configuration without embeddings section deserializes correctly. + /// + [TestMethod] + public void TestConfigWithoutEmbeddings() + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(CONFIG_WITHOUT_EMBEDDINGS, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig); + Assert.IsNull(runtimeConfig.Runtime?.Embeddings); + } + + /// + /// Tests that EmbeddingProviderType enum deserializes correctly from JSON. + /// + [DataTestMethod] + [DataRow("azure-openai", EmbeddingProviderType.AzureOpenAI)] + [DataRow("openai", EmbeddingProviderType.OpenAI)] + public void TestEmbeddingProviderTypeDeserialization(string jsonValue, EmbeddingProviderType expected) + { + // Arrange + string config = $@" + {{ + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": {{ + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }}, + ""runtime"": {{ + ""embeddings"": {{ + ""provider"": ""{jsonValue}"", + ""base-url"": ""https://example.com"", + ""api-key"": ""test-key"", + ""model"": ""test-model"" + }} + }}, + ""entities"": {{}} + }}"; + + // Act + bool success = RuntimeConfigLoader.TryParseConfig(config, out RuntimeConfig? runtimeConfig); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + Assert.AreEqual(expected, runtimeConfig.Runtime.Embeddings.Provider); + } + + /// + /// Tests EmbeddingsOptions serialization to JSON. + /// + [TestMethod] + public void TestEmbeddingsOptionsSerialization() + { + // Arrange + EmbeddingsOptions options = new( + Provider: EmbeddingProviderType.AzureOpenAI, + BaseUrl: "https://my-endpoint.openai.azure.com", + ApiKey: "my-api-key", + Model: "my-model", + ApiVersion: "2024-02-01", + Dimensions: 1536, + TimeoutMs: 60000); + + // Act + JsonSerializerOptions serializerOptions = RuntimeConfigLoader.GetSerializationOptions(replacementSettings: null); + string json = JsonSerializer.Serialize(options, serializerOptions); + + // Normalize json for comparison (remove whitespace) + string normalizedJson = json.Replace(" ", "").Replace("\n", "").Replace("\r", ""); + + // Assert + Assert.IsTrue(normalizedJson.Contains("\"provider\":\"azure-openai\""), $"Expected provider in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"base-url\":\"https://my-endpoint.openai.azure.com\""), $"Expected base-url in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"api-key\":\"my-api-key\""), $"Expected api-key in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"model\":\"my-model\""), $"Expected model in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"api-version\":\"2024-02-01\""), $"Expected api-version in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"dimensions\":1536"), $"Expected dimensions in JSON: {json}"); + Assert.IsTrue(normalizedJson.Contains("\"timeout-ms\":60000"), $"Expected timeout-ms in JSON: {json}"); + } + + /// + /// Tests that environment variable replacement works for embeddings configuration. + /// + [TestMethod] + public void TestEmbeddingsConfigWithEnvVarReplacement() + { + // Arrange + string config = @" + { + ""$schema"": ""https://github.com/Azure/data-api-builder/releases/download/vmajor.minor.patch/dab.draft.schema.json"", + ""data-source"": { + ""database-type"": ""mssql"", + ""connection-string"": ""Server=test;Database=test;"" + }, + ""runtime"": { + ""embeddings"": { + ""provider"": ""azure-openai"", + ""base-url"": ""@env('EMBEDDINGS_ENDPOINT')"", + ""api-key"": ""@env('EMBEDDINGS_API_KEY')"", + ""model"": ""@env('EMBEDDINGS_MODEL')"" + } + }, + ""entities"": {} + }"; + + // Set environment variables + Environment.SetEnvironmentVariable("EMBEDDINGS_ENDPOINT", "https://test.openai.azure.com"); + Environment.SetEnvironmentVariable("EMBEDDINGS_API_KEY", "test-key-from-env"); + Environment.SetEnvironmentVariable("EMBEDDINGS_MODEL", "test-model-from-env"); + + // Create replacement settings to enable env var replacement + DeserializationVariableReplacementSettings replacementSettings = new( + doReplaceEnvVar: true, + doReplaceAkvVar: false, + envFailureMode: EnvironmentVariableReplacementFailureMode.Throw); + + try + { + // Act + bool success = RuntimeConfigLoader.TryParseConfig(config, out RuntimeConfig? runtimeConfig, replacementSettings); + + // Assert + Assert.IsTrue(success); + Assert.IsNotNull(runtimeConfig?.Runtime?.Embeddings); + + EmbeddingsOptions embeddings = runtimeConfig.Runtime.Embeddings; + Assert.AreEqual("https://test.openai.azure.com", embeddings.BaseUrl); + Assert.AreEqual("test-key-from-env", embeddings.ApiKey); + Assert.AreEqual("test-model-from-env", embeddings.Model); + } + finally + { + // Cleanup + Environment.SetEnvironmentVariable("EMBEDDINGS_ENDPOINT", null); + Environment.SetEnvironmentVariable("EMBEDDINGS_API_KEY", null); + Environment.SetEnvironmentVariable("EMBEDDINGS_MODEL", null); + } + } +} diff --git a/src/Service/Controllers/EmbeddingController.cs b/src/Service/Controllers/EmbeddingController.cs new file mode 100644 index 0000000000..1c8b1641b8 --- /dev/null +++ b/src/Service/Controllers/EmbeddingController.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Net; +using System.Net.Mime; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; + +namespace Azure.DataApiBuilder.Service.Controllers; + +/// +/// Controller to serve embedding requests at the configured endpoint path (default: /embed). +/// Accepts plain text input and returns embedding vector as plain text (comma-separated floats). +/// +[ApiController] +public class EmbeddingController : ControllerBase +{ + private readonly IEmbeddingService? _embeddingService; + private readonly RuntimeConfigProvider _runtimeConfigProvider; + private readonly ILogger _logger; + + /// + /// Constructor. + /// + public EmbeddingController( + RuntimeConfigProvider runtimeConfigProvider, + ILogger logger, + IEmbeddingService? embeddingService = null) + { + _runtimeConfigProvider = runtimeConfigProvider; + _logger = logger; + _embeddingService = embeddingService; + } + + /// + /// POST endpoint for generating embeddings. + /// Accepts plain text body and returns embedding vector as comma-separated floats. + /// + /// The route path. + /// Plain text embedding vector or error response. + [HttpPost] + [Route("{*route}")] + [Consumes("text/plain", "application/json")] + [Produces("text/plain")] + public async Task PostAsync(string? route) + { + // Get embeddings configuration + EmbeddingsOptions? embeddingsOptions = _runtimeConfigProvider.GetConfig()?.Runtime?.Embeddings; + EmbeddingsEndpointOptions? endpointOptions = embeddingsOptions?.Endpoint; + + // Check if embeddings and endpoint are enabled + if (embeddingsOptions is null || !embeddingsOptions.Enabled) + { + return NotFound(); + } + + if (endpointOptions is null || !endpointOptions.Enabled) + { + return NotFound(); + } + + // Check if the route matches the configured endpoint path + string expectedPath = endpointOptions.EffectivePath.TrimStart('/'); + if (!string.Equals(route, expectedPath, StringComparison.OrdinalIgnoreCase)) + { + return NotFound(); + } + + // Check if embedding service is available + if (_embeddingService is null || !_embeddingService.IsEnabled) + { + _logger.LogWarning("Embedding endpoint called but embedding service is not available or disabled."); + return StatusCode((int)HttpStatusCode.ServiceUnavailable, "Embedding service is not available."); + } + + // Check authorization + bool isDevelopmentMode = _runtimeConfigProvider.GetConfig()?.Runtime?.Host?.Mode == HostMode.Development; + string clientRole = GetClientRole(); + + if (!endpointOptions.IsRoleAllowed(clientRole, isDevelopmentMode)) + { + _logger.LogWarning("Embedding endpoint access denied for role: {Role}", clientRole); + return StatusCode((int)HttpStatusCode.Forbidden, "Access denied. Role not authorized."); + } + + // Read request body as plain text + string text; + try + { + using StreamReader reader = new(Request.Body); + text = await reader.ReadToEndAsync(); + + // Handle JSON-wrapped string + if (Request.ContentType?.Contains("application/json", StringComparison.OrdinalIgnoreCase) == true) + { + try + { + text = JsonSerializer.Deserialize(text) ?? text; + } + catch (JsonException) + { + // Not valid JSON string, use as-is + _logger.LogDebug("Request body is not a valid JSON string, using as plain text."); + } + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to read request body for embedding."); + return BadRequest("Failed to read request body."); + } + + if (string.IsNullOrWhiteSpace(text)) + { + return BadRequest("Request body cannot be empty."); + } + + // Generate embedding + EmbeddingResult result = await _embeddingService.TryEmbedAsync(text); + + if (!result.Success) + { + _logger.LogError("Embedding request failed: {Error}", result.ErrorMessage); + return StatusCode((int)HttpStatusCode.InternalServerError, result.ErrorMessage ?? "Failed to generate embedding."); + } + + if (result.Embedding is null || result.Embedding.Length == 0) + { + _logger.LogError("Embedding request returned empty result."); + return StatusCode((int)HttpStatusCode.InternalServerError, "Failed to generate embedding."); + } + + // Return embedding as comma-separated float values (plain text) + string embeddingText = string.Join(",", result.Embedding); + return Content(embeddingText, MediaTypeNames.Text.Plain); + } + + /// + /// Gets the client role from request headers. + /// + private string GetClientRole() + { + StringValues roleHeader = Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER]; + string? firstRole = roleHeader.Count == 1 ? roleHeader[0] : null; + + if (!string.IsNullOrEmpty(firstRole)) + { + return firstRole.ToLowerInvariant(); + } + + return EmbeddingsEndpointOptions.ANONYMOUS_ROLE; + } +} diff --git a/src/Service/HealthCheck/HealthCheckHelper.cs b/src/Service/HealthCheck/HealthCheckHelper.cs index ab19756195..d44ade1930 100644 --- a/src/Service/HealthCheck/HealthCheckHelper.cs +++ b/src/Service/HealthCheck/HealthCheckHelper.cs @@ -10,7 +10,9 @@ using System.Threading.Tasks; using Azure.DataApiBuilder.Config.HealthCheck; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Product; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -27,20 +29,24 @@ public class HealthCheckHelper // Dependencies private ILogger _logger; private HttpUtilities _httpUtility; + private IEmbeddingService? _embeddingService; private string _incomingRoleHeader = string.Empty; private string _incomingRoleToken = string.Empty; private const string TIME_EXCEEDED_ERROR_MESSAGE = "The threshold for executing the request has exceeded."; + private const string DIMENSIONS_MISMATCH_ERROR_MESSAGE = "The embedding dimensions do not match the expected dimensions."; /// /// Constructor to inject the logger and HttpUtility class. /// /// Logger to track the log statements. /// HttpUtility to call methods from the internal class. - public HealthCheckHelper(ILogger logger, HttpUtilities httpUtility) + /// Optional embedding service for embedding health checks. + public HealthCheckHelper(ILogger logger, HttpUtilities httpUtility, IEmbeddingService? embeddingService = null) { _logger = logger; _httpUtility = httpUtility; + _embeddingService = embeddingService; } /// @@ -137,14 +143,19 @@ private static void UpdateTimestampOfResponse(ref ComprehensiveHealthCheckReport // Updates the DAB configuration details coming from RuntimeConfig for the Health report. private static void UpdateDabConfigurationDetails(ref ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, RuntimeConfig runtimeConfig) { + bool embeddingsEnabled = runtimeConfig?.Runtime?.Embeddings?.Enabled ?? false; + bool embeddingsEndpointEnabled = embeddingsEnabled && (runtimeConfig?.Runtime?.Embeddings?.IsEndpointEnabled ?? false); + comprehensiveHealthCheckReport.ConfigurationDetails = new ConfigurationDetails { - Rest = runtimeConfig.IsRestEnabled, - GraphQL = runtimeConfig.IsGraphQLEnabled, - Mcp = runtimeConfig.IsMcpEnabled, - Caching = runtimeConfig.IsCachingEnabled, + Rest = runtimeConfig?.IsRestEnabled ?? false, + GraphQL = runtimeConfig?.IsGraphQLEnabled ?? false, + Mcp = runtimeConfig?.IsMcpEnabled ?? false, + Caching = runtimeConfig?.IsCachingEnabled ?? false, Telemetry = runtimeConfig?.Runtime?.Telemetry != null, - Mode = runtimeConfig?.Runtime?.Host?.Mode ?? HostMode.Production, // Modify to runtimeConfig.HostMode in Roles PR + Mode = runtimeConfig?.Runtime?.Host?.Mode ?? HostMode.Production, + Embeddings = embeddingsEnabled, + EmbeddingsEndpoint = embeddingsEndpointEnabled }; } @@ -154,6 +165,7 @@ private async Task UpdateHealthCheckDetailsAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport.Checks = new List(); await UpdateDataSourceHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); await UpdateEntityHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); + await UpdateEmbeddingsHealthCheckResultsAsync(comprehensiveHealthCheckReport, runtimeConfig); } // Updates the DataSource Health Check Results in the response. @@ -346,5 +358,108 @@ private async Task PopulateEntityHealthAsync(ComprehensiveHealthCheckReport comp return (HealthCheckConstants.ERROR_RESPONSE_TIME_MS, errorMessage); } + + /// + /// Updates the Embeddings Health Check Results in the response. + /// Executes a test embedding and validates response time and optionally dimensions. + /// + private async Task UpdateEmbeddingsHealthCheckResultsAsync(ComprehensiveHealthCheckReport comprehensiveHealthCheckReport, RuntimeConfig runtimeConfig) + { + EmbeddingsOptions? embeddingsOptions = runtimeConfig?.Runtime?.Embeddings; + EmbeddingsHealthCheckConfig? healthConfig = embeddingsOptions?.Health; + + // Only run health check if embeddings is enabled, health check is enabled, and embedding service is available + if (embeddingsOptions is null || !embeddingsOptions.Enabled || healthConfig is null || !healthConfig.Enabled || _embeddingService is null) + { + return; + } + + if (comprehensiveHealthCheckReport.Checks is null) + { + comprehensiveHealthCheckReport.Checks = new List(); + } + + string testText = healthConfig.TestText; + int thresholdMs = healthConfig.ThresholdMs; + int? expectedDimensions = healthConfig.ExpectedDimensions; + + try + { + Stopwatch stopwatch = new(); + stopwatch.Start(); + EmbeddingResult result = await _embeddingService.TryEmbedAsync(testText); + stopwatch.Stop(); + + int responseTimeMs = (int)stopwatch.ElapsedMilliseconds; + bool isResponseTimeWithinThreshold = responseTimeMs <= thresholdMs; + bool isDimensionsValid = true; + string? errorMessage = null; + + if (!result.Success) + { + errorMessage = result.ErrorMessage ?? "Embedding request failed."; + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = HealthCheckConstants.ERROR_RESPONSE_TIME_MS, + ThresholdMs = thresholdMs + }, + Exception = errorMessage, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = HealthStatus.Unhealthy + }); + return; + } + + // Validate dimensions if expected dimensions is specified + if (expectedDimensions.HasValue && result.Embedding is not null) + { + isDimensionsValid = result.Embedding.Length == expectedDimensions.Value; + if (!isDimensionsValid) + { + errorMessage = $"{DIMENSIONS_MISMATCH_ERROR_MESSAGE} Expected: {expectedDimensions.Value}, Actual: {result.Embedding.Length}"; + } + } + + // Check response time threshold + if (!isResponseTimeWithinThreshold) + { + errorMessage = errorMessage is null ? TIME_EXCEEDED_ERROR_MESSAGE : $"{errorMessage} {TIME_EXCEEDED_ERROR_MESSAGE}"; + } + + bool isHealthy = isResponseTimeWithinThreshold && isDimensionsValid; + + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = responseTimeMs, + ThresholdMs = thresholdMs + }, + Exception = errorMessage, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = isHealthy ? HealthStatus.Healthy : HealthStatus.Unhealthy + }); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error executing embeddings health check."); + comprehensiveHealthCheckReport.Checks.Add(new HealthCheckResultEntry + { + Name = "embeddings", + ResponseTimeData = new ResponseTimeData + { + ResponseTimeMs = HealthCheckConstants.ERROR_RESPONSE_TIME_MS, + ThresholdMs = thresholdMs + }, + Exception = ex.Message, + Tags = new List { HealthCheckConstants.EMBEDDING }, + Status = HealthStatus.Unhealthy + }); + } + } } } diff --git a/src/Service/HealthCheck/Model/ConfigurationDetails.cs b/src/Service/HealthCheck/Model/ConfigurationDetails.cs index 9ff007754e..e73497e3e0 100644 --- a/src/Service/HealthCheck/Model/ConfigurationDetails.cs +++ b/src/Service/HealthCheck/Model/ConfigurationDetails.cs @@ -29,5 +29,11 @@ public record ConfigurationDetails [JsonPropertyName("mode")] public HostMode Mode { get; init; } + + [JsonPropertyName("embeddings")] + public bool Embeddings { get; init; } + + [JsonPropertyName("embeddings-endpoint")] + public bool EmbeddingsEndpoint { get; init; } } } diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index 333bf57234..19948da725 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -10,6 +10,7 @@ using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.Converters; using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Config.ObjectModel.Embeddings; using Azure.DataApiBuilder.Config.Utilities; using Azure.DataApiBuilder.Core.AuthenticationHelpers; using Azure.DataApiBuilder.Core.AuthenticationHelpers.AuthenticationSimulator; @@ -21,6 +22,7 @@ using Azure.DataApiBuilder.Core.Resolvers.Factories; using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; +using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; using Azure.DataApiBuilder.Core.Telemetry; @@ -260,7 +262,13 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); - services.AddSingleton(); + services.AddSingleton(sp => + { + ILogger logger = sp.GetRequiredService>(); + HttpUtilities httpUtility = sp.GetRequiredService(); + IEmbeddingService? embeddingService = sp.GetService(); + return new HealthCheckHelper(logger, httpUtility, embeddingService); + }); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -387,6 +395,45 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); + // Register embedding service if configured + if (runtimeConfigAvailable + && runtimeConfig?.Runtime?.IsEmbeddingsConfigured == true) + { + EmbeddingsOptions embeddingsOptions = runtimeConfig.Runtime.Embeddings; + services.AddHttpClient(); + services.AddSingleton(embeddingsOptions); + + string providerName = embeddingsOptions.Provider.ToString().ToLowerInvariant(); + + if (embeddingsOptions.Enabled) + { + _logger.LogInformation( + "Embeddings service enabled with provider: {Provider}, model: {Model}, base-url: {BaseUrl}", + providerName, + embeddingsOptions.EffectiveModel ?? "(default)", + embeddingsOptions.BaseUrl); + + // Endpoint is only available if both embeddings and endpoint are enabled + if (embeddingsOptions.IsEndpointEnabled) + { + _logger.LogInformation( + "Embeddings endpoint enabled at path: {Path}", + embeddingsOptions.EffectiveEndpointPath); + } + + if (embeddingsOptions.IsHealthCheckEnabled) + { + _logger.LogInformation( + "Embeddings health check enabled with threshold: {ThresholdMs}ms", + embeddingsOptions.Health!.ThresholdMs); + } + } + else + { + _logger.LogInformation("Embeddings service is configured but disabled."); + } + } + AddGraphQLService(services, runtimeConfig?.Runtime?.GraphQL); // Subscribe the GraphQL schema refresh method to the specific hot-reload event