11// Copyright (c) Microsoft. All rights reserved.
2+ using System . Net . Http . Headers ;
23using System . Net . Http . Json ;
34using System . Text . Json . Serialization ;
5+ using Azure . Core ;
6+ using Azure . Identity ;
47using KernelMemory . Core . Config . Enums ;
8+ using KernelMemory . Core . Http ;
59using Microsoft . Extensions . Logging ;
610
711namespace KernelMemory . Core . Embeddings . Providers ;
812
913/// <summary>
1014/// Azure OpenAI embedding generator implementation.
1115/// Communicates with Azure OpenAI Service.
12- /// Supports API key authentication ( managed identity would require Azure.Identity package) .
16+ /// Supports API key authentication or managed identity via <see cref="DefaultAzureCredential"/> .
1317/// </summary>
1418public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator
1519{
1620 private readonly HttpClient _httpClient ;
1721 private readonly string _endpoint ;
1822 private readonly string _deployment ;
19- private readonly string _apiKey ;
23+ private readonly string ? _apiKey ;
24+ private readonly bool _useManagedIdentity ;
25+ private readonly TokenCredential ? _credential ;
26+ private readonly int _batchSize ;
2027 private readonly ILogger < AzureOpenAIEmbeddingGenerator > _logger ;
28+ private readonly Func < TimeSpan , CancellationToken , Task > _delayAsync ;
2129
2230 /// <inheritdoc />
2331 public EmbeddingsTypes ProviderType => EmbeddingsTypes . AzureOpenAI ;
@@ -38,35 +46,52 @@ public sealed class AzureOpenAIEmbeddingGenerator : IEmbeddingGenerator
3846 /// <param name="endpoint">Azure OpenAI endpoint (e.g., https://myservice.openai.azure.com).</param>
3947 /// <param name="deployment">Deployment name in Azure.</param>
4048 /// <param name="model">Model name for identification.</param>
41- /// <param name="apiKey">Azure OpenAI API key.</param>
49+ /// <param name="apiKey">Azure OpenAI API key (required unless <paramref name="useManagedIdentity"/> is true) .</param>
4250 /// <param name="vectorDimensions">Vector dimensions produced by the model.</param>
4351 /// <param name="isNormalized">Whether vectors are normalized.</param>
4452 /// <param name="logger">Logger instance.</param>
53+ /// <param name="batchSize">Maximum number of texts per API request.</param>
54+ /// <param name="useManagedIdentity">Whether to authenticate using managed identity.</param>
55+ /// <param name="credential">Optional token credential (used for testing); defaults to <see cref="DefaultAzureCredential"/>.</param>
56+ /// <param name="delayAsync">Optional delay function for retries (used for fast unit tests).</param>
4557 public AzureOpenAIEmbeddingGenerator (
4658 HttpClient httpClient ,
4759 string endpoint ,
4860 string deployment ,
4961 string model ,
50- string apiKey ,
62+ string ? apiKey ,
5163 int vectorDimensions ,
5264 bool isNormalized ,
53- ILogger < AzureOpenAIEmbeddingGenerator > logger )
65+ ILogger < AzureOpenAIEmbeddingGenerator > logger ,
66+ int batchSize ,
67+ bool useManagedIdentity ,
68+ TokenCredential ? credential = null ,
69+ Func < TimeSpan , CancellationToken , Task > ? delayAsync = null )
5470 {
5571 ArgumentNullException . ThrowIfNull ( httpClient , nameof ( httpClient ) ) ;
5672 ArgumentNullException . ThrowIfNull ( endpoint , nameof ( endpoint ) ) ;
5773 ArgumentNullException . ThrowIfNull ( deployment , nameof ( deployment ) ) ;
5874 ArgumentNullException . ThrowIfNull ( model , nameof ( model ) ) ;
59- ArgumentNullException . ThrowIfNull ( apiKey , nameof ( apiKey ) ) ;
6075 ArgumentNullException . ThrowIfNull ( logger , nameof ( logger ) ) ;
76+ ArgumentOutOfRangeException . ThrowIfLessThan ( batchSize , 1 , nameof ( batchSize ) ) ;
6177
6278 this . _httpClient = httpClient ;
6379 this . _endpoint = endpoint . TrimEnd ( '/' ) ;
6480 this . _deployment = deployment ;
6581 this . _apiKey = apiKey ;
82+ this . _useManagedIdentity = useManagedIdentity ;
83+ this . _credential = credential ;
84+ this . _batchSize = batchSize ;
6685 this . ModelName = model ;
6786 this . VectorDimensions = vectorDimensions ;
6887 this . IsNormalized = isNormalized ;
6988 this . _logger = logger ;
89+ this . _delayAsync = delayAsync ?? Task . Delay ;
90+
91+ if ( ! this . _useManagedIdentity && string . IsNullOrWhiteSpace ( this . _apiKey ) )
92+ {
93+ throw new ArgumentException ( "Azure OpenAI API key is required when not using managed identity" , nameof ( apiKey ) ) ;
94+ }
7095
7196 this . _logger . LogDebug ( "AzureOpenAIEmbeddingGenerator initialized: {Endpoint}, deployment: {Deployment}, model: {Model}" ,
7297 this . _endpoint , this . _deployment , this . ModelName ) ;
@@ -88,21 +113,53 @@ public async Task<EmbeddingResult[]> GenerateAsync(IEnumerable<string> texts, Ca
88113 return [ ] ;
89114 }
90115
116+ var allResults = new List < EmbeddingResult > ( textArray . Length ) ;
117+ foreach ( var chunk in Chunk ( textArray , this . _batchSize ) )
118+ {
119+ var chunkResults = await this . GenerateBatchAsync ( chunk , ct ) . ConfigureAwait ( false ) ;
120+ allResults . AddRange ( chunkResults ) ;
121+ }
122+
123+ return allResults . ToArray ( ) ;
124+ }
125+
126+ private async Task < EmbeddingResult [ ] > GenerateBatchAsync ( string [ ] textArray , CancellationToken ct )
127+ {
91128 var url = $ "{ this . _endpoint } /openai/deployments/{ this . _deployment } /embeddings?api-version={ Constants . EmbeddingDefaults . AzureOpenAIApiVersion } ";
92129
93130 var request = new AzureEmbeddingRequest
94131 {
95132 Input = textArray
96133 } ;
97134
98- using var httpRequest = new HttpRequestMessage ( HttpMethod . Post , url ) ;
99- httpRequest . Headers . Add ( "api-key" , this . _apiKey ) ;
100- httpRequest . Content = JsonContent . Create ( request ) ;
135+ var bearerToken = this . _useManagedIdentity
136+ ? await this . GetManagedIdentityTokenAsync ( ct ) . ConfigureAwait ( false )
137+ : null ;
101138
102139 this . _logger . LogTrace ( "Calling Azure OpenAI embeddings API: deployment={Deployment}, batch size: {BatchSize}" ,
103140 this . _deployment , textArray . Length ) ;
104141
105- var response = await this . _httpClient . SendAsync ( httpRequest , ct ) . ConfigureAwait ( false ) ;
142+ using var response = await HttpRetryPolicy . SendAsync (
143+ this . _httpClient ,
144+ requestFactory : ( ) =>
145+ {
146+ var httpRequest = new HttpRequestMessage ( HttpMethod . Post , url ) ;
147+ if ( bearerToken != null )
148+ {
149+ httpRequest . Headers . Authorization = new AuthenticationHeaderValue ( "Bearer" , bearerToken ) ;
150+ }
151+ else
152+ {
153+ httpRequest . Headers . Add ( "api-key" , this . _apiKey ) ;
154+ }
155+
156+ httpRequest . Content = JsonContent . Create ( request ) ;
157+ return httpRequest ;
158+ } ,
159+ this . _logger ,
160+ ct ,
161+ delayAsync : this . _delayAsync ) . ConfigureAwait ( false ) ;
162+
106163 response . EnsureSuccessStatusCode ( ) ;
107164
108165 var result = await response . Content . ReadFromJsonAsync < AzureEmbeddingResponse > ( ct ) . ConfigureAwait ( false ) ;
@@ -141,6 +198,26 @@ public async Task<EmbeddingResult[]> GenerateAsync(IEnumerable<string> texts, Ca
141198 return results ;
142199 }
143200
201+ private async Task < string > GetManagedIdentityTokenAsync ( CancellationToken ct )
202+ {
203+ var credential = this . _credential ?? new DefaultAzureCredential ( ) ;
204+ var token = await credential . GetTokenAsync (
205+ new TokenRequestContext ( [ "https://cognitiveservices.azure.com/.default" ] ) ,
206+ ct ) . ConfigureAwait ( false ) ;
207+ return token . Token ;
208+ }
209+
210+ private static IEnumerable < string [ ] > Chunk ( string [ ] items , int chunkSize )
211+ {
212+ for ( int i = 0 ; i < items . Length ; i += chunkSize )
213+ {
214+ var length = Math . Min ( chunkSize , items . Length - i ) ;
215+ var chunk = new string [ length ] ;
216+ Array . Copy ( items , i , chunk , 0 , length ) ;
217+ yield return chunk ;
218+ }
219+ }
220+
144221 /// <summary>
145222 /// Request body for Azure OpenAI embeddings API.
146223 /// </summary>
0 commit comments