Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.Agents.AI.Workflows/Futures.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static class Futures
/// <c>[Obsolete]</c> in v2.0.0 when the new behavior becomes default, and removed in v3.0.0.
/// </para>
/// <para>
/// <b>Interaction with <see cref="WorkflowHostingExtensions.AsAIAgent"/>.</b> When this flag
/// <b>Interaction with <see cref="WorkflowHostingExtensions.AsAIAgent(Workflow, string?, string?, string?, IWorkflowExecutionEnvironment?, bool, bool)"/>.</b> When this flag
/// is <see langword="true"/>, <see cref="AgentResponseEvent"/> joins
/// <see cref="AgentResponseUpdateEvent"/> in being forwarded out of the agent surface
/// unconditionally — neither honors the host's <c>includeWorkflowOutputsInResponse</c>
Expand Down
8 changes: 5 additions & 3 deletions dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ internal sealed class WorkflowHostAgent : AIAgent
private readonly IWorkflowExecutionEnvironment _executionEnvironment;
private readonly bool _includeExceptionDetails;
private readonly bool _includeWorkflowOutputsInResponse;
private readonly bool _filterToolCallMessages;
private readonly Task<ProtocolDescriptor> _describeTask;

private readonly ConcurrentDictionary<string, string> _assignedSessionIds = [];

public WorkflowHostAgent(Workflow workflow, string? id = null, string? name = null, string? description = null, IWorkflowExecutionEnvironment? executionEnvironment = null, bool includeExceptionDetails = false, bool includeWorkflowOutputsInResponse = false)
public WorkflowHostAgent(Workflow workflow, string? id = null, string? name = null, string? description = null, IWorkflowExecutionEnvironment? executionEnvironment = null, bool includeExceptionDetails = false, bool includeWorkflowOutputsInResponse = false, bool filterToolCallMessages = false)
{
this._workflow = Throw.IfNull(workflow);

Expand All @@ -42,6 +43,7 @@ public WorkflowHostAgent(Workflow workflow, string? id = null, string? name = nu

this._includeExceptionDetails = includeExceptionDetails;
this._includeWorkflowOutputsInResponse = includeWorkflowOutputsInResponse;
this._filterToolCallMessages = filterToolCallMessages;

this._id = id;
this.Name = name;
Expand Down Expand Up @@ -74,7 +76,7 @@ private async ValueTask ValidateWorkflowAsync()
}

protected override ValueTask<AgentSession> CreateSessionCoreAsync(CancellationToken cancellationToken = default)
=> new(new WorkflowSession(this._workflow, this.GenerateNewId(), this._executionEnvironment, this._includeExceptionDetails, this._includeWorkflowOutputsInResponse));
=> new(new WorkflowSession(this._workflow, this.GenerateNewId(), this._executionEnvironment, this._includeExceptionDetails, this._includeWorkflowOutputsInResponse, this._filterToolCallMessages));

protected override ValueTask<JsonElement> SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
{
Expand All @@ -89,7 +91,7 @@ protected override ValueTask<JsonElement> SerializeSessionCoreAsync(AgentSession
}

protected override ValueTask<AgentSession> DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default)
=> new(new WorkflowSession(this._workflow, serializedState, this._executionEnvironment, this._includeExceptionDetails, this._includeWorkflowOutputsInResponse, jsonSerializerOptions));
=> new(new WorkflowSession(this._workflow, serializedState, this._executionEnvironment, this._includeExceptionDetails, this._includeWorkflowOutputsInResponse, this._filterToolCallMessages, jsonSerializerOptions));

private async ValueTask<WorkflowSession> UpdateSessionAsync(IEnumerable<ChatMessage> messages, AgentSession? session = null, CancellationToken cancellationToken = default)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,36 @@ public static AIAgent AsAIAgent(
return new WorkflowHostAgent(workflow, id, name, description, executionEnvironment, includeExceptionDetails, includeWorkflowOutputsInResponse);
}

/// <summary>
/// Convert a workflow with the appropriate primary input type to an <see cref="AIAgent"/>.
/// </summary>
/// <param name="workflow">The workflow to be hosted by the resulting <see cref="AIAgent"/></param>
/// <param name="filterToolCallMessages">If <see langword="true"/>, will remove <see cref="FunctionCallContent"/> and
/// <see cref="FunctionResultContent"/> from messages surfaced by the hosted workflow agent.</param>
/// <param name="id">A unique id for the hosting <see cref="AIAgent"/>.</param>
/// <param name="name">A name for the hosting <see cref="AIAgent"/>.</param>
/// <param name="description">A description for the hosting <see cref="AIAgent"/>.</param>
/// <param name="executionEnvironment">Specify the execution environment to use when running the workflows. See
/// <see cref="InProcessExecution.OffThread"/>, <see cref="InProcessExecution.Concurrent"/> and
/// <see cref="InProcessExecution.Lockstep"/> for the in-process environments.</param>
/// <param name="includeExceptionDetails">If <see langword="true"/>, will include <see cref="System.Exception.Message"/>
/// in the <see cref="ErrorContent"/> representing the workflow error.</param>
/// <param name="includeWorkflowOutputsInResponse">If <see langword="true"/>, will transform outgoing workflow outputs
/// into content in <see cref="AgentResponseUpdate"/>s or the <see cref="AgentResponse"/> as appropriate.</param>
/// <returns></returns>
public static AIAgent AsAIAgent(
this Workflow workflow,
bool filterToolCallMessages,
string? id = null,
string? name = null,
string? description = null,
IWorkflowExecutionEnvironment? executionEnvironment = null,
bool includeExceptionDetails = false,
bool includeWorkflowOutputsInResponse = false)
{
return new WorkflowHostAgent(workflow, id, name, description, executionEnvironment, includeExceptionDetails, includeWorkflowOutputsInResponse, filterToolCallMessages);
}

internal static FunctionCallContent ToFunctionCall(this ExternalRequest request)
{
Dictionary<string, object?> parameters = new()
Expand Down
90 changes: 85 additions & 5 deletions dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ internal sealed class WorkflowSession : AgentSession

private readonly bool _includeExceptionDetails;
private readonly bool _includeWorkflowOutputsInResponse;
private readonly bool _filterToolCallMessages;

private InMemoryCheckpointManager? _inMemoryCheckpointManager;

Expand Down Expand Up @@ -68,11 +69,12 @@ internal static bool VerifyCheckpointingConfiguration(IWorkflowExecutionEnvironm
return true;
}

public WorkflowSession(Workflow workflow, string sessionId, IWorkflowExecutionEnvironment executionEnvironment, bool includeExceptionDetails = false, bool includeWorkflowOutputsInResponse = false)
public WorkflowSession(Workflow workflow, string sessionId, IWorkflowExecutionEnvironment executionEnvironment, bool includeExceptionDetails = false, bool includeWorkflowOutputsInResponse = false, bool filterToolCallMessages = false)
{
this._workflow = Throw.IfNull(workflow);
this._includeExceptionDetails = includeExceptionDetails;
this._includeWorkflowOutputsInResponse = includeWorkflowOutputsInResponse;
this._filterToolCallMessages = filterToolCallMessages;

IWorkflowExecutionEnvironment env = Throw.IfNull(executionEnvironment);
if (VerifyCheckpointingConfiguration(env, out InProcessExecutionEnvironment? inProcEnv))
Expand All @@ -96,11 +98,12 @@ private CheckpointManager EnsureExternalizedInMemoryCheckpointing()
return new(this._inMemoryCheckpointManager ??= new());
}

public WorkflowSession(Workflow workflow, JsonElement serializedSession, IWorkflowExecutionEnvironment executionEnvironment, bool includeExceptionDetails = false, bool includeWorkflowOutputsInResponse = false, JsonSerializerOptions? jsonSerializerOptions = null)
public WorkflowSession(Workflow workflow, JsonElement serializedSession, IWorkflowExecutionEnvironment executionEnvironment, bool includeExceptionDetails = false, bool includeWorkflowOutputsInResponse = false, bool filterToolCallMessages = false, JsonSerializerOptions? jsonSerializerOptions = null)
{
this._workflow = Throw.IfNull(workflow);
this._includeExceptionDetails = includeExceptionDetails;
this._includeWorkflowOutputsInResponse = includeWorkflowOutputsInResponse;
this._filterToolCallMessages = filterToolCallMessages;

IWorkflowExecutionEnvironment env = Throw.IfNull(executionEnvironment);

Expand Down Expand Up @@ -173,6 +176,71 @@ public AgentResponseUpdate CreateUpdate(string responseId, object raw, ChatMessa
};
}

private AgentResponseUpdate? FilterToolCallContents(AgentResponseUpdate update)
{
if (!this._filterToolCallMessages || !ContainsToolCallContent(update.Contents))
{
return update;
}

List<AIContent> retainedContents = update.Contents
.Where(static content => !IsToolCallContent(content))
.ToList();

if (retainedContents.Count == 0)
{
return null;
}

return new AgentResponseUpdate
{
AdditionalProperties = update.AdditionalProperties,
AgentId = update.AgentId,
AuthorName = update.AuthorName,
Contents = retainedContents,
ContinuationToken = update.ContinuationToken,
CreatedAt = update.CreatedAt,
FinishReason = update.FinishReason,
MessageId = update.MessageId,
RawRepresentation = update.RawRepresentation,
ResponseId = update.ResponseId,
Role = update.Role,
};
}

private ChatMessage? FilterToolCallContents(ChatMessage message)
{
if (!this._filterToolCallMessages || !ContainsToolCallContent(message.Contents))
{
return message;
}

List<AIContent> retainedContents = message.Contents
.Where(static content => !IsToolCallContent(content))
.ToList();

if (retainedContents.Count == 0)
{
return null;
}

ChatMessage filteredMessage = message.Clone();
filteredMessage.Contents = retainedContents;
return filteredMessage;
}

private AgentResponseUpdate? CreateFilteredUpdate(string responseId, object raw, ChatMessage message)
{
ChatMessage? filteredMessage = this.FilterToolCallContents(message);
return filteredMessage == null ? null : this.CreateUpdate(responseId, raw, filteredMessage);
}

private static bool ContainsToolCallContent(IEnumerable<AIContent> contents)
=> contents.Any(static content => IsToolCallContent(content));

private static bool IsToolCallContent(AIContent content)
=> content is FunctionCallContent or FunctionResultContent;

private async ValueTask<ResumeRunResult> CreateOrResumeRunAsync(List<ChatMessage> messages, CancellationToken cancellationToken = default)
{
// The workflow is validated to be a ChatProtocol workflow by the WorkflowHostAgent before creating the session,
Expand Down Expand Up @@ -480,7 +548,11 @@ IAsyncEnumerable<AgentResponseUpdate> InvokeStageAsync(
switch (evt)
{
case AgentResponseUpdateEvent agentUpdate:
yield return agentUpdate.Update;
AgentResponseUpdate? filteredUpdate = this.FilterToolCallContents(agentUpdate.Update);
if (filteredUpdate != null)
{
yield return filteredUpdate;
}
break;

case RequestInfoEvent requestInfo:
Expand Down Expand Up @@ -553,7 +625,11 @@ IAsyncEnumerable<AgentResponseUpdate> InvokeStageAsync(
// as an output executor.
foreach (ChatMessage message in agentResponse.Response.Messages)
{
yield return this.CreateUpdate(this.LastResponseId, evt, message);
AgentResponseUpdate? filteredMessageUpdate = this.CreateFilteredUpdate(this.LastResponseId, evt, message);
if (filteredMessageUpdate != null)
{
yield return filteredMessageUpdate;
}
}
break;

Expand All @@ -576,7 +652,11 @@ IAsyncEnumerable<AgentResponseUpdate> InvokeStageAsync(

foreach (ChatMessage message in updateMessages)
{
yield return this.CreateUpdate(this.LastResponseId, evt, message);
AgentResponseUpdate? filteredMessageUpdate = this.CreateFilteredUpdate(this.LastResponseId, evt, message);
if (filteredMessageUpdate != null)
{
yield return filteredMessageUpdate;
}
}
break;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ public override ValueTask HandleAsync(string message, IWorkflowContext context,

public class WorkflowHostSmokeTests : AIAgentHostingExecutorTestsBase
{
private const string ToolCallText = "Before tool call.";
private const string FinalText = "Final answer.";

private sealed class AlwaysFailsAIAgent(bool failByThrowing) : AIAgent
{
private sealed class Session : AgentSession
Expand Down Expand Up @@ -263,6 +266,38 @@ private static Workflow CreateWorkflow(bool failByThrowing)
return new WorkflowBuilder(agent).Build();
}

private static Workflow CreateToolCallWorkflow()
{
List<ChatMessage> messages =
[
new(ChatRole.Assistant,
[
new TextContent(ToolCallText),
new FunctionCallContent("call_1", "get_data"),
])
{
MessageId = "tool-call-message",
},
new(ChatRole.Tool, [new FunctionResultContent("call_1", "tool result")])
{
MessageId = "tool-result-message",
},
new(ChatRole.Assistant, [new TextContent(FinalText)])
{
MessageId = "final-message",
},
];

TestReplayAgent agent = new(messages, TestAgentId, TestAgentName);
ExecutorBinding binding = agent.BindAsExecutor(new AIAgentHostOptions
{
EmitAgentUpdateEvents = true,
EmitAgentResponseEvents = true,
});

return new WorkflowBuilder(binding).Build();
}

[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
Expand Down Expand Up @@ -299,6 +334,53 @@ public async Task Test_AsAgent_ErrorContentStreamedOutAsync(bool includeExceptio
hadErrorContent.Should().BeTrue();
}

[Fact]
public async Task Test_AsAgent_DefaultPreservesToolCallMessagesAsync()
{
// Arrange
Workflow workflow = CreateToolCallWorkflow();
AIAgent workflowAgent = workflow.AsAIAgent("WorkflowAgent");

// Act
AgentResponse response = await workflowAgent.RunAsync(new ChatMessage(ChatRole.User, "Hello"));

// Assert
List<AIContent> contents = [.. response.Messages.SelectMany(message => message.Contents)];
contents.Should().Contain(content => content is FunctionCallContent);
contents.Should().Contain(content => content is FunctionResultContent);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task Test_AsAgent_FilterToolCallMessagesRemovesToolCallContentsAsync(bool runStreaming)
{
// Arrange
Workflow workflow = CreateToolCallWorkflow();
AIAgent workflowAgent = workflow.AsAIAgent(filterToolCallMessages: true, id: "WorkflowAgent");

// Act
List<AIContent> contents;
if (runStreaming)
{
List<AgentResponseUpdate> updates = await workflowAgent.RunStreamingAsync(new ChatMessage(ChatRole.User, "Hello")).ToListAsync();
contents = [.. updates.SelectMany(update => update.Contents)];
}
else
{
AgentResponse response = await workflowAgent.RunAsync(new ChatMessage(ChatRole.User, "Hello"));
contents = [.. response.Messages.SelectMany(message => message.Contents)];
}

// Assert
contents.Should().NotContain(content => content is FunctionCallContent);
contents.Should().NotContain(content => content is FunctionResultContent);

List<string> textContents = [.. contents.OfType<TextContent>().Select(content => content.Text)];
textContents.Should().Contain(ToolCallText);
textContents.Should().Contain(FinalText);
}

/// <summary>
/// Tests that when a workflow emits a RequestInfoEvent with FunctionCallContent data,
/// the AgentResponseUpdate preserves the original FunctionCallContent type.
Expand Down