Skip to content

Commit ffeb099

Browse files
feiyun0112markwallace-microsoftstephentoubrogerbarreto
authored
1 parent f703783 commit ffeb099

8 files changed

+517
-0
lines changed

dotnet/Directory.Packages.props

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,9 @@
130130
<PrivateAssets>all</PrivateAssets>
131131
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
132132
</PackageReference>
133+
<!-- OnnxRuntimeGenAI -->
134+
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.3.0"/>
135+
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Version="0.3.0"/>
136+
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Version="0.3.0"/>
133137
</ItemGroup>
134138
</Project>
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using Microsoft.Extensions.DependencyInjection;
4+
using Microsoft.SemanticKernel;
5+
using Microsoft.SemanticKernel.ChatCompletion;
6+
using Microsoft.SemanticKernel.Connectors.Onnx;
7+
using Xunit;
8+
9+
namespace SemanticKernel.Connectors.Onnx.UnitTests;
10+
11+
/// <summary>
12+
/// Unit tests for <see cref="OnnxKernelBuilderExtensions"/>.
13+
/// </summary>
14+
public class OnnxExtensionsTests
15+
{
16+
[Fact]
17+
public void AddOnnxRuntimeGenAIChatCompletionToServiceCollection()
18+
{
19+
// Arrange
20+
var collection = new ServiceCollection();
21+
collection.AddOnnxRuntimeGenAIChatCompletion("modelId", "modelPath");
22+
23+
// Act
24+
var kernelBuilder = collection.AddKernel();
25+
var kernel = collection.BuildServiceProvider().GetRequiredService<Kernel>();
26+
var service = kernel.GetRequiredService<IChatCompletionService>();
27+
28+
// Assert
29+
Assert.NotNull(service);
30+
Assert.IsType<OnnxRuntimeGenAIChatCompletionService>(service);
31+
}
32+
33+
[Fact]
34+
public void AddOnnxRuntimeGenAIChatCompletionToKernelBuilder()
35+
{
36+
// Arrange
37+
var collection = new ServiceCollection();
38+
var kernelBuilder = collection.AddKernel();
39+
kernelBuilder.AddOnnxRuntimeGenAIChatCompletion("modelId", "modelPath");
40+
41+
// Act
42+
var kernel = collection.BuildServiceProvider().GetRequiredService<Kernel>();
43+
var service = kernel.GetRequiredService<IChatCompletionService>();
44+
45+
// Assert
46+
Assert.NotNull(service);
47+
Assert.IsType<OnnxRuntimeGenAIChatCompletionService>(service);
48+
}
49+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System.Text.Json;
4+
using Microsoft.SemanticKernel;
5+
using Microsoft.SemanticKernel.Connectors.Onnx;
6+
using Xunit;
7+
8+
namespace SemanticKernel.Connectors.Onnx.UnitTests;
9+
10+
/// <summary>
11+
/// Unit tests for <see cref="OnnxRuntimeGenAIPromptExecutionSettings"/>.
12+
/// </summary>
13+
public class OnnxRuntimeGenAIPromptExecutionSettingsTests
14+
{
15+
[Fact]
16+
public void FromExecutionSettingsWhenAlreadyMistralShouldReturnSame()
17+
{
18+
// Arrange
19+
var executionSettings = new OnnxRuntimeGenAIPromptExecutionSettings();
20+
21+
// Act
22+
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
23+
24+
// Assert
25+
Assert.Same(executionSettings, onnxExecutionSettings);
26+
}
27+
28+
[Fact]
29+
public void FromExecutionSettingsWhenNullShouldReturnDefaultSettings()
30+
{
31+
// Arrange
32+
PromptExecutionSettings? executionSettings = null;
33+
34+
// Act
35+
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
36+
37+
// Assert
38+
Assert.Null(onnxExecutionSettings.TopK);
39+
Assert.Null(onnxExecutionSettings.TopP);
40+
Assert.Null(onnxExecutionSettings.Temperature);
41+
Assert.Null(onnxExecutionSettings.RepetitionPenalty);
42+
Assert.Null(onnxExecutionSettings.PastPresentShareBuffer);
43+
Assert.Null(onnxExecutionSettings.NumReturnSequences);
44+
Assert.Null(onnxExecutionSettings.NumBeams);
45+
Assert.Null(onnxExecutionSettings.NoRepeatNgramSize);
46+
Assert.Null(onnxExecutionSettings.MinTokens);
47+
Assert.Null(onnxExecutionSettings.MaxTokens);
48+
Assert.Null(onnxExecutionSettings.LengthPenalty);
49+
Assert.Null(onnxExecutionSettings.DiversityPenalty);
50+
Assert.Null(onnxExecutionSettings.EarlyStopping);
51+
Assert.Null(onnxExecutionSettings.DoSample);
52+
}
53+
54+
[Fact]
55+
public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecialized()
56+
{
57+
// Arrange
58+
string jsonSettings = """
59+
{
60+
"top_k": 2,
61+
"top_p": 0.9,
62+
"temperature": 0.5,
63+
"repetition_penalty": 0.1,
64+
"past_present_share_buffer": true,
65+
"num_return_sequences": 200,
66+
"num_beams": 20,
67+
"no_repeat_ngram_size": 15,
68+
"min_tokens": 10,
69+
"max_tokens": 100,
70+
"length_penalty": 0.2,
71+
"diversity_penalty": 0.3,
72+
"early_stopping": false,
73+
"do_sample": true
74+
}
75+
""";
76+
77+
// Act
78+
var executionSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(jsonSettings);
79+
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
80+
81+
// Assert
82+
Assert.Equal(2, onnxExecutionSettings.TopK);
83+
Assert.Equal(0.9f, onnxExecutionSettings.TopP);
84+
Assert.Equal(0.5f, onnxExecutionSettings.Temperature);
85+
Assert.Equal(0.1f, onnxExecutionSettings.RepetitionPenalty);
86+
Assert.True(onnxExecutionSettings.PastPresentShareBuffer);
87+
Assert.Equal(200, onnxExecutionSettings.NumReturnSequences);
88+
Assert.Equal(20, onnxExecutionSettings.NumBeams);
89+
Assert.Equal(15, onnxExecutionSettings.NoRepeatNgramSize);
90+
Assert.Equal(10, onnxExecutionSettings.MinTokens);
91+
Assert.Equal(100, onnxExecutionSettings.MaxTokens);
92+
Assert.Equal(0.2f, onnxExecutionSettings.LengthPenalty);
93+
Assert.Equal(0.3f, onnxExecutionSettings.DiversityPenalty);
94+
Assert.False(onnxExecutionSettings.EarlyStopping);
95+
Assert.True(onnxExecutionSettings.DoSample);
96+
}
97+
}

dotnet/src/Connectors/Connectors.Onnx/Connectors.Onnx.csproj

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,10 @@
2525
<PackageReference Include="System.Numerics.Tensors" />
2626
</ItemGroup>
2727

28+
<ItemGroup>
29+
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI" Condition=" '$(Configuration)' == 'Debug' OR '$(Configuration)' == 'Release' " />
30+
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Condition=" '$(Configuration)' == 'Debug_Cuda' OR '$(Configuration)' == 'Release_Cuda' " />
31+
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Condition=" '$(Configuration)' == 'Debug_DirectML' OR '$(Configuration)' == 'Release_DirectML' " />
32+
</ItemGroup>
33+
2834
</Project>

dotnet/src/Connectors/Connectors.Onnx/OnnxKernelBuilderExtensions.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
using System.IO;
44
using Microsoft.Extensions.DependencyInjection;
5+
using Microsoft.Extensions.Logging;
6+
using Microsoft.SemanticKernel.ChatCompletion;
57
using Microsoft.SemanticKernel.Connectors.Onnx;
68
using Microsoft.SemanticKernel.Embeddings;
79

@@ -14,6 +16,29 @@ namespace Microsoft.SemanticKernel;
1416
/// </summary>
1517
public static class OnnxKernelBuilderExtensions
1618
{
19+
/// <summary>
20+
/// Add OnnxRuntimeGenAI Chat Completion services to the kernel builder.
21+
/// </summary>
22+
/// <param name="builder">The kernel builder.</param>
23+
/// <param name="modelId">Model Id.</param>
24+
/// <param name="modelPath">The generative AI ONNX model path.</param>
25+
/// <param name="serviceId">The optional service ID.</param>
26+
/// <returns>The updated kernel builder.</returns>
27+
public static IKernelBuilder AddOnnxRuntimeGenAIChatCompletion(
28+
this IKernelBuilder builder,
29+
string modelId,
30+
string modelPath,
31+
string? serviceId = null)
32+
{
33+
builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
34+
new OnnxRuntimeGenAIChatCompletionService(
35+
modelId,
36+
modelPath: modelPath,
37+
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
38+
39+
return builder;
40+
}
41+
1742
/// <summary>Adds a text embedding generation service using a BERT ONNX model.</summary>
1843
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
1944
/// <param name="onnxModelPath">The path to the ONNX model file.</param>

0 commit comments

Comments
 (0)