Skip to content

Add custom service key support to AddKeyedOllamaApiClient #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 2, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.Extensions.Hosting;
/// <param name="hostBuilder">The <see cref="IHostApplicationBuilder"/> with which services are being registered.</param>
/// <param name="serviceKey">The service key used to register the <see cref="OllamaApiClient"/> service, if any.</param>
/// <param name="disableTracing">A flag to indicate whether tracing should be disabled.</param>
public class AspireOllamaApiClientBuilder(IHostApplicationBuilder hostBuilder, string serviceKey, bool disableTracing)
public class AspireOllamaApiClientBuilder(IHostApplicationBuilder hostBuilder, object serviceKey, bool disableTracing)
{
/// <summary>
/// The host application builder used to configure the application.
Expand All @@ -18,7 +18,7 @@ public class AspireOllamaApiClientBuilder(IHostApplicationBuilder hostBuilder, s
/// <summary>
/// Gets the service key used to register the <see cref="OllamaApiClient"/> service, if any.
/// </summary>
public string ServiceKey { get; } = serviceKey;
public object ServiceKey { get; } = serviceKey;

/// <summary>
/// Gets a flag indicating whether tracing should be disabled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,25 @@ public static ChatClientBuilder AddKeyedChatClient(
this AspireOllamaApiClientBuilder builder)
{
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentException.ThrowIfNullOrEmpty(builder.ServiceKey, nameof(builder.ServiceKey));

return builder.AddKeyedChatClient(builder.ServiceKey);
}

/// <summary>
/// Registers a keyed singleton <see cref="IChatClient"/> in the services provided by the <paramref name="builder"/> using the specified service key.
/// </summary>
/// <param name="builder">An <see cref="AspireOllamaApiClientBuilder" />.</param>
/// <param name="serviceKey">The service key to use for registering the <see cref="IChatClient"/>.</param>
/// <returns>A <see cref="ChatClientBuilder"/> that can be used to build a pipeline around the inner <see cref="IChatClient"/>.</returns>
public static ChatClientBuilder AddKeyedChatClient(
this AspireOllamaApiClientBuilder builder,
object serviceKey)
{
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentNullException.ThrowIfNull(serviceKey, nameof(serviceKey));

return builder.HostBuilder.Services.AddKeyedChatClient(
builder.ServiceKey,
serviceKey,
services => CreateInnerChatClient(services, builder));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,24 @@ public static EmbeddingGeneratorBuilder<string, Embedding<float>> AddKeyedEmbedd
this AspireOllamaApiClientBuilder builder)
{
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentException.ThrowIfNullOrEmpty(builder.ServiceKey, nameof(builder.ServiceKey));
return builder.AddKeyedEmbeddingGenerator(builder.ServiceKey);
}

/// <summary>
/// Registers a keyed singleton <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> in the services provided by the <paramref name="builder"/> using the specified service key.
/// </summary>
/// <param name="builder">An <see cref="AspireOllamaApiClientBuilder" />.</param>
/// <param name="serviceKey">The service key to use for registering the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</param>
/// <returns>A <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> that can be used to build a pipeline around the inner <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</returns>
public static EmbeddingGeneratorBuilder<string, Embedding<float>> AddKeyedEmbeddingGenerator(
this AspireOllamaApiClientBuilder builder,
object serviceKey)
{
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentNullException.ThrowIfNull(serviceKey, nameof(serviceKey));

return builder.HostBuilder.Services.AddKeyedEmbeddingGenerator(
builder.ServiceKey,
serviceKey,
services => CreateInnerEmbeddingGenerator(services, builder));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using OllamaSharp;
using System.Data.Common;

Expand Down Expand Up @@ -43,6 +42,37 @@ public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApp
return AddOllamaClientInternal(builder, $"{DefaultConfigSectionName}:{connectionName}", connectionName, serviceKey: connectionName, configureSettings: configureSettings);
}

/// <summary>
/// Adds <see cref="IOllamaApiClient"/> services to the container using the specified <paramref name="serviceKey"/>.
/// </summary>
/// <param name="builder">The <see cref="IHostApplicationBuilder" /> to read config from and add services to.</param>
/// <param name="serviceKey">A unique key that identifies this instance of the Ollama client service.</param>
/// <param name="connectionName">A name used to retrieve the connection string from the ConnectionStrings configuration section.</param>
/// <param name="configureSettings">An optional delegate that can be used for customizing options. It's invoked after the settings are read from the configuration.</param>
/// <exception cref="UriFormatException">Thrown when no Ollama endpoint is provided.</exception>
public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApplicationBuilder builder, object serviceKey, string connectionName, Action<OllamaSharpSettings>? configureSettings = null)
{
ArgumentNullException.ThrowIfNull(serviceKey, nameof(serviceKey));
ArgumentException.ThrowIfNullOrWhiteSpace(connectionName, nameof(connectionName));
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
return AddOllamaClientInternal(builder, $"{DefaultConfigSectionName}:{connectionName}", connectionName, serviceKey: serviceKey, configureSettings: configureSettings);
}

/// <summary>
/// Adds <see cref="IOllamaApiClient"/> services to the container using the specified <paramref name="serviceKey"/>.
/// </summary>
/// <param name="builder">The <see cref="IHostApplicationBuilder" /> to read config from and add services to.</param>
/// <param name="serviceKey">A unique key that identifies this instance of the Ollama client service.</param>
/// <param name="settings">The settings required to configure the <see cref="IOllamaApiClient"/>.</param>
/// <exception cref="UriFormatException">Thrown when no Ollama endpoint is provided.</exception>
public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApplicationBuilder builder, object serviceKey, OllamaSharpSettings settings)
{
ArgumentNullException.ThrowIfNull(serviceKey, nameof(serviceKey));
ArgumentNullException.ThrowIfNull(builder, nameof(builder));
ArgumentNullException.ThrowIfNull(settings, nameof(settings));
return AddOllamaClientInternal(builder, DefaultConfigSectionName, serviceKey.ToString() ?? "default", serviceKey: serviceKey, configureSettings: null, settings: settings);
}

/// <summary>
/// Adds <see cref="IOllamaApiClient"/> and <see cref="IChatClient"/> services to the container.
/// </summary>
Expand Down Expand Up @@ -105,11 +135,15 @@ private static AspireOllamaApiClientBuilder AddOllamaClientInternal(
IHostApplicationBuilder builder,
string configurationSectionName,
string connectionName,
string? serviceKey = null,
Action<OllamaSharpSettings>? configureSettings = null)
object? serviceKey = null,
Action<OllamaSharpSettings>? configureSettings = null,
OllamaSharpSettings? settings = null)
{
OllamaSharpSettings settings = new();
builder.Configuration.GetSection(configurationSectionName).Bind(settings);
settings ??= new();
if (string.IsNullOrEmpty(settings.Endpoint?.ToString()))
{
builder.Configuration.GetSection(configurationSectionName).Bind(settings);
}

if (builder.Configuration.GetConnectionString(connectionName) is string connectionString)
{
Expand Down
Copy link
Member

@Alirexaa Alirexaa Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be change at this PR. We have a seperate workflow for changing this file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, sorry I keep forgetting that in the review

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot undo the changes in this file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted the changes to the API surface file as requested. This auto-generated file should not be manually modified. Fixed in commit 63729e3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted the changes to the API surface file as requested. This auto-generated file should not be manually modified. Fixed in commit 63729e3.

Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,41 @@ namespace Microsoft.Extensions.Hosting
{
public partial class AspireOllamaApiClientBuilder
{
public AspireOllamaApiClientBuilder(IHostApplicationBuilder hostBuilder, string serviceKey, bool disableTracing) { }
public AspireOllamaApiClientBuilder(IHostApplicationBuilder hostBuilder, object? serviceKey, bool disableTracing) { }

public bool DisableTracing { get { throw null; } }

public IHostApplicationBuilder HostBuilder { get { throw null; } }

public string ServiceKey { get { throw null; } }
public object? ServiceKey { get { throw null; } }
}

public static partial class AspireOllamaChatClientExtensions
{
public static AI.ChatClientBuilder AddChatClient(this AspireOllamaApiClientBuilder builder) { throw null; }

public static AI.ChatClientBuilder AddKeyedChatClient(this AspireOllamaApiClientBuilder builder) { throw null; }

public static AI.ChatClientBuilder AddKeyedChatClient(this AspireOllamaApiClientBuilder builder, object serviceKey) { throw null; }
}

public static partial class AspireOllamaEmbeddingGeneratorExtensions
{
public static AI.EmbeddingGeneratorBuilder<string, AI.Embedding<float>> AddEmbeddingGenerator(this AspireOllamaApiClientBuilder builder) { throw null; }

public static AI.EmbeddingGeneratorBuilder<string, AI.Embedding<float>> AddKeyedEmbeddingGenerator(this AspireOllamaApiClientBuilder builder) { throw null; }

public static AI.EmbeddingGeneratorBuilder<string, AI.Embedding<float>> AddKeyedEmbeddingGenerator(this AspireOllamaApiClientBuilder builder, object serviceKey) { throw null; }
}

public static partial class AspireOllamaSharpExtensions
{
public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApplicationBuilder builder, string connectionName, System.Action<CommunityToolkit.Aspire.OllamaSharp.OllamaSharpSettings>? configureSettings = null) { throw null; }

public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApplicationBuilder builder, object serviceKey, string connectionName, System.Action<CommunityToolkit.Aspire.OllamaSharp.OllamaSharpSettings>? configureSettings = null) { throw null; }

public static AspireOllamaApiClientBuilder AddKeyedOllamaApiClient(this IHostApplicationBuilder builder, object serviceKey, CommunityToolkit.Aspire.OllamaSharp.OllamaSharpSettings settings) { throw null; }

[System.Obsolete("This approach to registering IChatClient is deprecated, use AddKeyedOllamaApiClient().AddChatClient() instead.")]
public static void AddKeyedOllamaSharpChatClient(this IHostApplicationBuilder builder, string connectionName, System.Action<CommunityToolkit.Aspire.OllamaSharp.OllamaSharpSettings>? configureSettings = null) { }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,92 @@ public void CanSetMultipleKeyedClients()
Assert.NotEqual(client, client3);
}

[Fact]
public void CanSetMultipleKeyedClientsWithCustomServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}"),
new KeyValuePair<string, string?>("ConnectionStrings:Ollama2", "Endpoint=https://localhost:5002/"),
new KeyValuePair<string, string?>("ConnectionStrings:Ollama3", "Endpoint=https://localhost:5003/")
]);

// Use custom service keys instead of connection names
builder.AddKeyedOllamaApiClient("ChatModel", "Ollama");
builder.AddKeyedOllamaApiClient("VisionModel", "Ollama2");
builder.AddKeyedOllamaApiClient("EmbeddingModel", "Ollama3");

using var host = builder.Build();
var chatClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("ChatModel");
var visionClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("VisionModel");
var embeddingClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("EmbeddingModel");

Assert.Equal(Endpoint, chatClient.Uri);
Assert.Equal("https://localhost:5002/", visionClient.Uri?.ToString());
Assert.Equal("https://localhost:5003/", embeddingClient.Uri?.ToString());

Assert.NotEqual(chatClient, visionClient);
Assert.NotEqual(chatClient, embeddingClient);
Assert.NotEqual(visionClient, embeddingClient);
}

[Fact]
public void CanSetKeyedClientWithSettingsOverload()
{
var builder = Host.CreateEmptyApplicationBuilder(null);

var settings = new OllamaSharpSettings
{
Endpoint = Endpoint,
SelectedModel = "testmodel"
};

builder.AddKeyedOllamaApiClient("TestService", settings);

using var host = builder.Build();
var client = host.Services.GetRequiredKeyedService<IOllamaApiClient>("TestService");

Assert.Equal(Endpoint, client.Uri);
Assert.Equal("testmodel", client.SelectedModel);
}

[Fact]
public void CanUseSameConnectionWithDifferentServiceKeys()
{
// This test demonstrates the main use case from the issue:
// Using the same connection but different service keys for different models
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:LocalAI", $"Endpoint={Endpoint}")
]);

// Same connection, different service keys and models
builder.AddKeyedOllamaApiClient("ChatModel", "LocalAI", settings =>
{
settings.SelectedModel = "llama3.2";
});

builder.AddKeyedOllamaApiClient("VisionModel", "LocalAI", settings =>
{
settings.SelectedModel = "llava";
});

using var host = builder.Build();
var chatClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("ChatModel");
var visionClient = host.Services.GetRequiredKeyedService<IOllamaApiClient>("VisionModel");

// Both use the same endpoint
Assert.Equal(Endpoint, chatClient.Uri);
Assert.Equal(Endpoint, visionClient.Uri);

// But have different models
Assert.Equal("llama3.2", chatClient.SelectedModel);
Assert.Equal("llava", visionClient.SelectedModel);

// And are different instances
Assert.NotEqual(chatClient, visionClient);
}

[Fact]
public void RegisteringChatClientAndEmbeddingGeneratorReturnsCorrectModelForServices()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,83 @@ public void CanChainUseMethodsCorrectly()

using var host = builder.Build();
var client = host.Services.GetRequiredService<IChatClient>();

var distributedCacheClient = Assert.IsType<DistributedCachingChatClient>(client);
var functionInvocationClient = Assert.IsType<FunctionInvokingChatClient>(GetInnerClient(distributedCacheClient));
var otelClient = Assert.IsType<OpenTelemetryChatClient>(GetInnerClient(functionInvocationClient));

Assert.IsType<IOllamaApiClient>(GetInnerClient(otelClient), exactMatch: false);
}

[Fact]
public void CanSetMultipleKeyedChatClientsWithCustomServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}"),
new KeyValuePair<string, string?>("ConnectionStrings:Ollama2", "Endpoint=https://localhost:5002/")
]);

// Use custom service keys for different chat clients
builder.AddKeyedOllamaApiClient("ChatModel", "Ollama").AddKeyedChatClient();
builder.AddKeyedOllamaApiClient("VisionModel", "Ollama2").AddKeyedChatClient();

using var host = builder.Build();
var chatClient = host.Services.GetRequiredKeyedService<IChatClient>("ChatModel");
var visionClient = host.Services.GetRequiredKeyedService<IChatClient>("VisionModel");

Assert.Equal(Endpoint, chatClient.GetService<ChatClientMetadata>()?.ProviderUri);
Assert.Equal("https://localhost:5002/", visionClient.GetService<ChatClientMetadata>()?.ProviderUri?.ToString());

Assert.NotEqual(chatClient, visionClient);
}

[Fact]
public void CanSetMultipleChatClientsWithDifferentServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}")
]);

// Use one Ollama API client with multiple chat clients using different service keys
var cb = builder.AddKeyedOllamaApiClient("OllamaKey", "Ollama");
cb.AddKeyedChatClient("ChatKey1");
cb.AddKeyedChatClient("ChatKey2");

using var host = builder.Build();
var chatClient1 = host.Services.GetRequiredKeyedService<IChatClient>("ChatKey1");
var chatClient2 = host.Services.GetRequiredKeyedService<IChatClient>("ChatKey2");

Assert.Equal(Endpoint, chatClient1.GetService<ChatClientMetadata>()?.ProviderUri);
Assert.Equal(Endpoint, chatClient2.GetService<ChatClientMetadata>()?.ProviderUri);

Assert.NotEqual(chatClient1, chatClient2);
}

[Fact]
public void CanMixChatClientsAndEmbeddingGeneratorsWithCustomServiceKeys()
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new KeyValuePair<string, string?>("ConnectionStrings:Ollama", $"Endpoint={Endpoint}")
]);

// Use one Ollama API client with both chat clients and embedding generators using different service keys
var cb = builder.AddKeyedOllamaApiClient("OllamaKey", "Ollama");
cb.AddKeyedChatClient("ChatKey1");
cb.AddKeyedEmbeddingGenerator("EmbeddingKey1");

using var host = builder.Build();
var chatClient1 = host.Services.GetRequiredKeyedService<IChatClient>("ChatKey1");
var embeddingGenerator = host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("EmbeddingKey1");

Assert.Equal(Endpoint, chatClient1.GetService<ChatClientMetadata>()?.ProviderUri);
Assert.Equal(Endpoint, embeddingGenerator.GetService<EmbeddingGeneratorMetadata>()?.ProviderUri);

Assert.Equal(chatClient1 as IOllamaApiClient, embeddingGenerator as IOllamaApiClient);
}

[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "get_InnerClient")]
private static extern IChatClient GetInnerClient(DelegatingChatClient client);
}
Loading
Loading