Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
internal sealed class AIFunctionKernelFunction : KernelFunction
{
private readonly AIFunction _aiFunction;
private readonly JsonElement _jsonSchema;

public AIFunctionKernelFunction(AIFunction aiFunction) :
base(
Expand All @@ -33,14 +34,19 @@ public AIFunctionKernelFunction(AIFunction aiFunction) :
{
// Kernel functions created from AI functions are always fully qualified
this._aiFunction = aiFunction;
this._jsonSchema = aiFunction.JsonSchema.Clone();
}

private AIFunctionKernelFunction(AIFunctionKernelFunction other, string? pluginName) :
base(other.Name, pluginName, other.Description, other.Metadata.Parameters, AbstractionsJsonContext.Default.Options, other.Metadata.ReturnParameter)
{
this._aiFunction = other._aiFunction;
this._jsonSchema = other._jsonSchema;
}

/// <inheritdoc />
public override JsonElement JsonSchema => this._jsonSchema;

public override KernelFunction Clone(string? pluginName = null)
{
// Should allow null but not empty or whitespace
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
Expand Down Expand Up @@ -117,6 +119,41 @@ public void ShouldPreserveDescriptionFromAIFunction()
Assert.Equal("This is a test description", sut.Description);
}

[Fact]
public void ShouldPreserveRootLevelSchemaDefinitionsFromAIFunction()
{
// Arrange
var aiFunction = new TestAIFunctionWithSchema("TestFunction", """
{
"type": "object",
"properties": {
"node": {
"$ref": "#/$defs/Node"
}
},
"$defs": {
"Node": {
"type": "object",
"properties": {
"name": {
"type": "string"
}
}
}
}
}
""");

// Act
AIFunctionKernelFunction sut = new(aiFunction);
var schema = sut.JsonSchema;

// Assert
Assert.True(schema.TryGetProperty("$defs", out var defs));
Assert.True(defs.TryGetProperty("Node", out _));
Assert.Equal("#/$defs/Node", schema.GetProperty("properties").GetProperty("node").GetProperty("$ref").GetString());
}

[Fact]
public async Task ShouldInvokeUnderlyingAIFunctionWhenInvoked()
{
Expand All @@ -133,6 +170,27 @@ public async Task ShouldInvokeUnderlyingAIFunctionWhenInvoked()
Assert.True(testAIFunction.WasInvoked);
}

[Fact]
public async Task ShouldInvokeUnderlyingAIFunctionWhenInvokedAsStreaming()
{
// Arrange
var testAIFunction = new TestAIFunction("TestFunction");
AIFunctionKernelFunction sut = new(testAIFunction);
var kernel = new Kernel();
var streamed = new List<string>();

// Act
await foreach (var chunk in sut.InvokeStreamingAsync<string>(kernel, []))
{
streamed.Add(chunk);
}

// Assert
Assert.True(testAIFunction.WasInvoked);
Assert.Single(streamed);
Assert.Equal("Test result", streamed[0]);
}

[Fact]
public void ShouldCloneCorrectlyWithNewPluginName()
{
Expand Down Expand Up @@ -227,4 +285,24 @@ public TestAIFunction(string name, string description = "")
return ValueTask.FromResult<object?>("Test result");
}
}

private sealed class TestAIFunctionWithSchema : AIFunction
{
private readonly JsonElement _schema;

public TestAIFunctionWithSchema(string name, string jsonSchema)
{
this.Name = name;
this._schema = JsonDocument.Parse(jsonSchema).RootElement.Clone();
}

public override string Name { get; }

public override JsonElement JsonSchema => this._schema;

protected override ValueTask<object?> InvokeCoreAsync(AIFunctionArguments? arguments = null, CancellationToken cancellationToken = default)
{
return ValueTask.FromResult<object?>("Test result");
}
}
}