Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 116 additions & 40 deletions dotnet/src/Agents/OpenAI/OpenAIResponseAgent.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -47,26 +47,45 @@ public OpenAIResponseAgent(ResponsesClient client, string? modelId = null)
public bool StoreEnabled { get; init; } = false;

/// <inheritdoc/>
public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> InvokeAsync(ICollection<ChatMessageContent> messages, AgentThread? thread = null, AgentInvokeOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> InvokeAsync(
ICollection<ChatMessageContent> messages,
AgentThread? thread = null,
AgentInvokeOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Verify.NotNull(messages);

AgentThread agentThread = await this.EnsureThreadExistsWithMessagesAsync(messages, thread, cancellationToken).ConfigureAwait(false);
AgentThread agentThread;
OpenAIResponseAgentInvokeOptions extensionsContextOptions;
IAsyncEnumerable<ChatMessageContent> invokeResults;

// Get the context contributions from the AIContextProviders.
OpenAIResponseAgentInvokeOptions extensionsContextOptions = await this.FinalizeInvokeOptionsAsync(messages, options, agentThread, cancellationToken).ConfigureAwait(false);
try
{
agentThread = await this.EnsureThreadExistsWithMessagesAsync(messages, thread, cancellationToken).ConfigureAwait(false);
extensionsContextOptions = await this.FinalizeInvokeOptionsAsync(messages, options, agentThread, cancellationToken).ConfigureAwait(false);

ChatHistory chatHistory = [.. messages];
invokeResults = ResponseThreadActions.InvokeAsync(
this,
chatHistory,
agentThread,
extensionsContextOptions,
cancellationToken);
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
throw new KernelException($"OpenAI provider error for agent '{this.Name}': {ex.Message}", ex);
}
Comment on lines +62 to +78

// Invoke responses with the updated chat history.
ChatHistory chatHistory = [.. messages];
var invokeResults = ResponseThreadActions.InvokeAsync(
this,
chatHistory,
agentThread,
extensionsContextOptions,
var errorMessagePrefix = $"OpenAI provider error for agent '{this.Name}': ";
var mappedResults = HandleProviderExceptionsAsync(
invokeResults,
result => result,
errorMessagePrefix,
cancellationToken);

// Notify the thread of new messages and return them to the caller.
await foreach (var result in invokeResults.ConfigureAwait(false))
// Yield results with additional error handling
await foreach (var result in mappedResults.ConfigureAwait(false))
{
if (options?.OnIntermediateMessage is not null)
{
Expand All @@ -79,37 +98,38 @@ public override async IAsyncEnumerable<AgentResponseItem<ChatMessageContent>> In
}

/// <inheritdoc/>
public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageContent>> InvokeStreamingAsync(ICollection<ChatMessageContent> messages, AgentThread? thread = null, AgentInvokeOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public override async IAsyncEnumerable<AgentResponseItem<StreamingChatMessageContent>> InvokeStreamingAsync(
ICollection<ChatMessageContent> messages,
AgentThread? thread = null,
AgentInvokeOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Verify.NotNull(messages);

AgentThread agentThread = await this.EnsureThreadExistsWithMessagesAsync(messages, thread, cancellationToken).ConfigureAwait(false);

// Get the context contributions from the AIContextProviders.
OpenAIResponseAgentInvokeOptions extensionsContextOptions = await this.FinalizeInvokeOptionsAsync(messages, options, agentThread, cancellationToken).ConfigureAwait(false);
AgentThread agentThread;
OpenAIResponseAgentInvokeOptions extensionsContextOptions;
ChatHistory chatHistory;
int messageIndex;
IAsyncEnumerable<StreamingChatMessageContent> invokeResults;

// Invoke responses with the updated chat history.
ChatHistory chatHistory = [.. messages];
int messageCount = chatHistory.Count;
int messageIndex = chatHistory.Count;
var invokeResults = ResponseThreadActions.InvokeStreamingAsync(
this,
chatHistory,
agentThread,
extensionsContextOptions,
cancellationToken);

// Return streaming chat message content to the caller.
await foreach (var result in invokeResults.ConfigureAwait(false))
try
{
// Notify the thread of any messages that were assembled from the streaming response during this iteration.
await NotifyMessagesAsync().ConfigureAwait(false);

yield return new(result, agentThread);
agentThread = await this.EnsureThreadExistsWithMessagesAsync(messages, thread, cancellationToken).ConfigureAwait(false);
extensionsContextOptions = await this.FinalizeInvokeOptionsAsync(messages, options, agentThread, cancellationToken).ConfigureAwait(false);

chatHistory = [.. messages];
messageIndex = chatHistory.Count;
invokeResults = ResponseThreadActions.InvokeStreamingAsync(
this,
chatHistory,
agentThread,
extensionsContextOptions,
cancellationToken);
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
throw new KernelException($"OpenAI provider error for agent '{this.Name}' during streaming: {ex.Message}", ex);
}

// Notify the thread of any remaining messages that were assembled from the streaming response after all iterations are complete.
await NotifyMessagesAsync().ConfigureAwait(false);

async Task NotifyMessagesAsync()
{
Expand All @@ -124,6 +144,19 @@ async Task NotifyMessagesAsync()
}
}
}

var errorMessagePrefix = $"OpenAI provider error for agent '{this.Name}' during streaming: ";
var mappedResults = HandleProviderExceptionsAsync(
invokeResults,
result => result,
errorMessagePrefix,
cancellationToken);

await foreach (var result in mappedResults.ConfigureAwait(false))
{
await NotifyMessagesAsync().ConfigureAwait(false);
yield return new(result, agentThread);
Comment on lines +155 to +158
}
}

/// <inheritdoc/>
Expand Down Expand Up @@ -196,4 +229,47 @@ options is null ?
};
return extensionsContextOptions;
}
}

private static async IAsyncEnumerable<TResult> HandleProviderExceptionsAsync<TSource, TResult>(
IAsyncEnumerable<TSource> source,
Func<TSource, TResult> selector,
string errorMessagePrefix,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IAsyncEnumerator<TSource> enumerator;
try
{
enumerator = source.GetAsyncEnumerator(cancellationToken);
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
throw new KernelException($"{errorMessagePrefix}{ex.Message}", ex);
}

try
{
while (true)
{
TSource item;
try
{
if (!await enumerator.MoveNextAsync().ConfigureAwait(false))
{
break;
}
item = enumerator.Current;
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
throw new KernelException($"{errorMessagePrefix}{ex.Message}", ex);
}

yield return selector(item);
}
}
finally
{
await enumerator.DisposeAsync().ConfigureAwait(false);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using OpenAI;
using OpenAI.Responses;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents.OpenAI;

namespace SemanticKernel.Agents.UnitTests.OpenAI;

/// <summary>
/// Tests for the exception handling logic in OpenAIResponseAgent.
/// Verifies that provider exceptions are wrapped in KernelException.
/// </summary>
public class OpenAIResponseAgentExceptionTests : BaseOpenAIResponseClientTest
{
private const string AgentName = "TestAgent";

/// <summary>
/// Verifies that InvokeAsync wraps provider exceptions in a KernelException.
/// </summary>
[Fact]
public async Task InvokeAsync_WhenProviderFails_ShouldThrowKernelExceptionAsync()
{
// Arrange
var ex = new HttpRequestException("HTTP 429 Rate limit exceeded");
var agent = CreateAgentWithThrowingHandler(ex);

// Act & Assert
var exception = await Assert.ThrowsAsync<KernelException>(async () =>
{
await foreach (var item in agent.InvokeAsync("Hello"))
{
// Enumerate to trigger lazy execution
}
});

Assert.Contains($"OpenAI provider error for agent '{AgentName}':", exception.Message);
Assert.NotNull(exception.InnerException);
Assert.Contains("HTTP 429 Rate limit exceeded", exception.ToString());
}

/// <summary>
/// Verifies that InvokeStreamingAsync wraps provider exceptions in a KernelException.
/// </summary>
[Fact]
public async Task InvokeStreamingAsync_WhenProviderFails_ShouldThrowKernelExceptionAsync()
{
// Arrange
var ex = new HttpRequestException("HTTP 500 Internal Server Error");
var agent = CreateAgentWithThrowingHandler(ex);

// Act & Assert
var exception = await Assert.ThrowsAsync<KernelException>(async () =>
{
await foreach (var item in agent.InvokeStreamingAsync("Hello"))
{
// Enumerate to trigger lazy execution
}
});

Assert.Contains($"OpenAI provider error for agent '{AgentName}' during streaming:", exception.Message);
Assert.NotNull(exception.InnerException);
Assert.Contains("HTTP 500 Internal Server Error", exception.ToString());
}

private OpenAIResponseAgent CreateAgentWithThrowingHandler(Exception exceptionToThrow)
{
#pragma warning disable CA2000 // Dispose objects before losing scope
var handler = new ThrowingHttpMessageHandler(exceptionToThrow);
var httpClient = new HttpClient(handler);
var clientOptions = new OpenAIClientOptions()
{
Transport = new HttpClientPipelineTransport(httpClient)
};
var client = new ResponsesClient(new ApiKeyCredential("apiKey"), clientOptions);
return new OpenAIResponseAgent(client) { Name = AgentName };
#pragma warning restore CA2000 // Dispose objects before losing scope
}

private sealed class ThrowingHttpMessageHandler(Exception exceptionToThrow) : HttpMessageHandler
{
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
return Task.FromException<HttpResponseMessage>(exceptionToThrow);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -218,6 +218,29 @@ public async Task ItSendsBinaryContentCorrectlyAsync(bool useUriData)
/// </summary>
private const string PdfBase64Data = "JVBERi0xLjQKMSAwIG9iago8PC9UeXBlIC9DYXRhbG9nCi9QYWdlcyAyIDAgUgo+PgplbmRvYmoKMiAwIG9iago8PC9UeXBlIC9QYWdlcwovS2lkcyBbMyAwIFJdCi9Db3VudCAxCj4+CmVuZG9iagozIDAgb2JqCjw8L1R5cGUgL1BhZ2UKL1BhcmVudCAyIDAgUgovTWVkaWFCb3ggWzAgMCA1OTUgODQyXQovQ29udGVudHMgNSAwIFIKL1Jlc291cmNlcyA8PC9Qcm9jU2V0IFsvUERGIC9UZXh0XQovRm9udCA8PC9GMSA0IDAgUj4+Cj4+Cj4+CmVuZG9iago0IDAgb2JqCjw8L1R5cGUgL0ZvbnQKL1N1YnR5cGUgL1R5cGUxCi9OYW1lIC9GMQovQmFzZUZvbnQgL0hlbHZldGljYQovRW5jb2RpbmcgL01hY1JvbWFuRW5jb2RpbmcKPj4KZW5kb2JqCjUgMCBvYmoKPDwvTGVuZ3RoIDUzCj4+CnN0cmVhbQpCVAovRjEgMjAgVGYKMjIwIDQwMCBUZAooRHVtbXkgUERGKSBUagpFVAplbmRzdHJlYW0KZW5kb2JqCnhyZWYKMCA2CjAwMDAwMDAwMDAgNjU1MzUgZgowMDAwMDAwMDA5IDAwMDAwIG4KMDAwMDAwMDA2MyAwMDAwMCBuCjAwMDAwMDAxMjQgMDAwMDAgbgowMDAwMDAwMjc3IDAwMDAwIG4KMDAwMDAwMDM5MiAwMDAwMCBuCnRyYWlsZXIKPDwvU2l6ZSA2Ci9Sb290IDEgMCBSCj4+CnN0YXJ0eHJlZgo0OTUKJSVFT0YK";

[Fact]
public async Task GetChatMessageContentsAsyncUsesModelIdFromExecutionSettingsAsync()
{
// Arrange
string constructorModel = "fake-model-constructor";
string overriddenModel = "fake-model-overridden";
var sut = new GoogleAIGeminiChatCompletionService(constructorModel, "key", httpClient: this._httpClient);

var executionSettings = new GeminiPromptExecutionSettings { ModelId = overriddenModel };

// Act
var result = await sut.GetChatMessageContentsAsync(new ChatHistory { new ChatMessageContent(AuthorRole.User, "hello") }, executionSettings);

// Assert
Assert.NotNull(result);
Assert.NotNull(this._messageHandlerStub.RequestContent);
Assert.Equal(overriddenModel, result[0].ModelId);

// Verify the request URI uses the overridden model
Assert.NotNull(this._messageHandlerStub.RequestUri);
Assert.Contains($"/models/{overriddenModel}:generateContent", this._messageHandlerStub.RequestUri.AbsoluteUri);
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Loading
Loading