From ace3d1d079f7020603402a2100d8e8d660b61605 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 4 Sep 2025 16:01:04 -0700 Subject: [PATCH 1/9] [WIP] Basic prototype --- eng/packages/General.props | 1 + eng/packages/TestOnly.props | 2 +- .../ToolReduction/IToolReductionStrategy.cs | 41 +++ .../Microsoft.Extensions.AI.csproj | 2 + ...hatClientBuilderToolReductionExtensions.cs | 32 ++ .../EmbeddingToolReductionStrategy.cs | 197 ++++++++++++ .../ToolReduction/ToolReducingChatClient.cs | 97 ++++++ .../ChatClientIntegrationTests.cs | 158 ++++++++++ ...oft.Extensions.AI.Integration.Tests.csproj | 1 + .../ToolReductionTests.cs | 291 ++++++++++++++++++ .../OpenAIChatClientIntegrationTests.cs | 4 + 11 files changed, 825 insertions(+), 1 deletion(-) create mode 100644 src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs create mode 100644 src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs create mode 100644 src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs create mode 100644 src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs create mode 100644 test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs diff --git a/eng/packages/General.props b/eng/packages/General.props index 253fb51ce1b..314b351f590 100644 --- a/eng/packages/General.props +++ b/eng/packages/General.props @@ -27,6 +27,7 @@ + diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props index dcfc7c03525..96668e0c00d 100644 --- a/eng/packages/TestOnly.props +++ b/eng/packages/TestOnly.props @@ -22,7 +22,7 @@ - + diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs new file mode 100644 index 00000000000..029eeae47a1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ToolReduction/IToolReductionStrategy.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a strategy capable of selecting a reduced set of tools for a chat request. +/// +/// +/// A tool reduction strategy is invoked prior to sending a request to an underlying , +/// enabling scenarios where a large tool catalog must be trimmed to fit provider limits or to improve model +/// tool selection quality. +/// +/// The implementation should return a non- enumerable. Returning the original +/// instance indicates no change. Returning a different enumerable indicates +/// the caller may replace the existing tool list. +/// +/// +[Experimental("MEAI001")] +public interface IToolReductionStrategy +{ + /// + /// Selects the tools that should be included for a specific request. + /// + /// The chat messages for the request. This is an to avoid premature materialization. + /// The chat options for the request (may be ). + /// A token to observe cancellation. + /// + /// A (possibly reduced) enumerable of instances. Must never be . + /// Returning the same instance referenced by . signals no change. + /// + Task> SelectToolsForRequestAsync( + IEnumerable messages, + ChatOptions? options, + CancellationToken cancellationToken = default); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index fe431ca21e5..a5b192f94b7 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -33,6 +33,7 @@ true true true + true false @@ -44,6 +45,7 @@ + diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs new file mode 100644 index 00000000000..5a644267328 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ChatClientBuilderToolReductionExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Extension methods for adding tool reduction middleware to a chat client pipeline. +[Experimental("MEAI001")] +public static class ChatClientBuilderToolReductionExtensions +{ + /// + /// Adds tool reduction to the chat client pipeline using the specified . + /// + /// The chat client builder. + /// The reduction strategy. + /// The original builder for chaining. + /// If or is . + /// + /// This should typically appear in the pipeline before function invocation middleware so that only the reduced tools + /// are exposed to the underlying provider. + /// + public static ChatClientBuilder UseToolReduction(this ChatClientBuilder builder, IToolReductionStrategy strategy) + { + _ = Throw.IfNull(builder); + _ = Throw.IfNull(strategy); + + return builder.Use(inner => new ToolReducingChatClient(inner, strategy)); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs new file mode 100644 index 00000000000..3738dcad0de --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs @@ -0,0 +1,197 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Numerics.Tensors; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A tool reduction strategy that ranks tools by embedding similarity to the current conversation context. +/// +/// +/// The strategy embeds each tool (name + description by default) once (cached) and embeds the current +/// conversation content each request. It then selects the top toolLimit tools by similarity. +/// +[Experimental("MEAI001")] +public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy +{ + private readonly ConditionalWeakTable> _toolEmbeddingsCache = new(); + private readonly IEmbeddingGenerator> _embeddingGenerator; + private readonly int _toolLimit; + + /// + /// Initializes a new instance of the class. + /// + /// Embedding generator used to produce embeddings. + /// Maximum number of tools to return. Must be greater than zero. + public EmbeddingToolReductionStrategy( + IEmbeddingGenerator> embeddingGenerator, + int toolLimit) + { + _embeddingGenerator = Throw.IfNull(embeddingGenerator); + _toolLimit = Throw.IfLessThanOrEqual(toolLimit, min: 0); + } + + /// + /// Gets or sets a delegate used to produce the text to embed for a tool. + /// Defaults to: Name + "\n" + Description (omitting empty parts). + /// + public Func EmbeddingTextFactory + { + get => field ??= static t => + { + if (string.IsNullOrWhiteSpace(t.Name)) + { + return t.Description; + } + + if (string.IsNullOrWhiteSpace(t.Description)) + { + return t.Name; + } + + return t.Name + "\n" + t.Description; + }; + set => field = Throw.IfNull(value); + } + + /// + /// Gets or sets a similarity function applied to (query, tool) embedding vectors. Defaults to cosine similarity. + /// + public Func, ReadOnlyMemory, float> Similarity + { + get => field ??= static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span); + set => field = Throw.IfNull(value); + } + + /// + /// Gets or sets a value indicating whether tool embeddings are cached. Defaults to . + /// + public bool EnableEmbeddingCaching { get; set; } = true; + + /// + /// Gets or sets a value indicating whether to preserve original ordering of selected tools. + /// If (default), tools are ordered by descending similarity. + /// If , the top-N tools by similarity are re-emitted in their original order. + /// + public bool PreserveOriginalOrdering { get; set; } + + /// + /// Gets or sets the maximum number of most recent messages to include when forming the query embedding. + /// Defaults to (all messages). + /// + public int MaxMessagesForQueryEmbedding { get; set; } = int.MaxValue; + + /// + public async Task> SelectToolsForRequestAsync( + IEnumerable messages, + ChatOptions? options, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + if (options?.Tools is not { Count: > 0 } tools) + { + return options?.Tools ?? []; + } + + Debug.Assert(_toolLimit > 0, "Expected the tool count limit to be greater than zero."); + + if (tools.Count <= _toolLimit) + { + // No reduction necessary. + return tools; + } + + // Build query text from recent messages. + var messageTexts = messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s)); + var queryText = string.Join("\n", messageTexts); + if (string.IsNullOrWhiteSpace(queryText)) + { + // We couldn't build a meaningful query, likely because the message list was empty. + // We'll just return a truncated list of tools. + return tools.Take(_toolLimit); + } + + // Ensure embeddings for any uncached tools are generated in a batch. + var toolEmbeddings = await GetToolEmbeddingsAsync(tools, cancellationToken).ConfigureAwait(false); + + // Generate the query embedding. + var queryEmbedding = await _embeddingGenerator.GenerateAsync(queryText, cancellationToken: cancellationToken).ConfigureAwait(false); + var queryVector = queryEmbedding.Vector; + + // Compute rankings. + var ranked = tools + .Zip(toolEmbeddings, static (tool, embedding) => (Tool: tool, Embedding: embedding)) + .Select((t, i) => (t.Tool, Index: i, Score: Similarity(queryVector, t.Embedding.Vector))) + .OrderByDescending(t => t.Score) + .Take(_toolLimit); + + if (PreserveOriginalOrdering) + { + ranked = ranked.OrderBy(t => t.Index); + } + + return ranked.Select(t => t.Tool); + } + + private async Task>> GetToolEmbeddingsAsync(IList tools, CancellationToken cancellationToken) + { + if (!EnableEmbeddingCaching) + { + // Embed all tools in one batch; do not store in cache. + return await ComputeEmbeddingsAsync(tools.Select(t => EmbeddingTextFactory(t)), expectedCount: tools.Count); + } + + var result = new Embedding[tools.Count]; + var cacheMisses = new List<(AITool Tool, int Index)>(tools.Count); + + for (var i = 0; i < tools.Count; i++) + { + if (_toolEmbeddingsCache.TryGetValue(tools[i], out var embedding)) + { + result[i] = embedding; + } + else + { + cacheMisses.Add((tools[i], i)); + } + } + + if (cacheMisses.Count == 0) + { + return result; + } + + var uncachedEmbeddings = await ComputeEmbeddingsAsync(cacheMisses.Select(t => EmbeddingTextFactory(t.Tool)), expectedCount: cacheMisses.Count); + + for (var i = 0; i < cacheMisses.Count; i++) + { + var embedding = uncachedEmbeddings[i]; + result[cacheMisses[i].Index] = embedding; + _toolEmbeddingsCache.Add(cacheMisses[i].Tool, embedding); + } + + return result; + + async ValueTask>> ComputeEmbeddingsAsync(IEnumerable texts, int expectedCount) + { + var embeddings = await _embeddingGenerator.GenerateAsync(texts, cancellationToken: cancellationToken).ConfigureAwait(false); + if (embeddings.Count != expectedCount) + { + Throw.InvalidOperationException($"Expected {expectedCount} embeddings, got {embeddings.Count}."); + } + + return embeddings; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs new file mode 100644 index 00000000000..01fec30e8d9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs @@ -0,0 +1,97 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that applies a tool reduction strategy before invoking the inner client. +/// +/// +/// Insert this into a pipeline (typically before function invocation middleware) to automatically +/// reduce the tool list carried on for each request. +/// +[Experimental("MEAI001")] +public sealed class ToolReducingChatClient : DelegatingChatClient +{ + private readonly IToolReductionStrategy _strategy; + + /// + /// Initializes a new instance of the class. + /// + /// The inner client. + /// The tool reduction strategy to apply. + /// Thrown if any argument is . + public ToolReducingChatClient(IChatClient innerClient, IToolReductionStrategy strategy) + : base(innerClient) + { + _strategy = Throw.IfNull(strategy); + } + + /// + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + options = await ApplyReductionAsync(messages, options, cancellationToken).ConfigureAwait(false); + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + } + + /// + public override async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + options = await ApplyReductionAsync(messages, options, cancellationToken).ConfigureAwait(false); + + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } + + private async Task ApplyReductionAsync( + IEnumerable messages, + ChatOptions? options, + CancellationToken cancellationToken) + { + // If there are no options or no tools, skip. + if (options?.Tools is not { Count: > 1 }) + { + return options; + } + + IEnumerable reduced; + try + { + reduced = await _strategy.SelectToolsForRequestAsync(messages, options, cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + return options; + } + + // If strategy returned the same list instance (or reference equality), assume no change. + if (ReferenceEquals(reduced, options.Tools)) + { + return options; + } + + // Materialize and compare counts; if unchanged and tools have identical ordering and references, keep original. + if (reduced is not IList reducedList) + { + reducedList = reduced.ToList(); + } + + // Clone options to avoid mutating a possibly shared instance. + var cloned = options.Clone(); + cloned.Tools = reducedList; + return cloned; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index c87625cf143..8a59904fc46 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -41,6 +41,8 @@ protected ChatClientIntegrationTests() protected IChatClient? ChatClient { get; } + protected IEmbeddingGenerator>? EmbeddingGenerator { get; private set; } + public void Dispose() { ChatClient?.Dispose(); @@ -49,6 +51,13 @@ public void Dispose() protected abstract IChatClient? CreateChatClient(); + /// + /// Optionally supplies an embedding generator for integration tests that exercise + /// embedding-based components (e.g., tool reduction). Default returns null and + /// tests depending on embeddings will skip if not overridden. + /// + protected virtual IEmbeddingGenerator>? CreateEmbeddingGenerator() => null; + [ConditionalFact] public virtual async Task GetResponseAsync_SingleRequestMessage() { @@ -1395,6 +1404,144 @@ public void Dispose() } } + [ConditionalFact] + public virtual async Task ToolReduction_SingleRelevantToolSelected() + { + SkipIfNotEnabled(); + EnsureEmbeddingGenerator(); + + // Strategy: pick top 1 tool + var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 1); + + // Define several tools with clearly distinct domains + var weatherTool = AIFunctionFactory.Create( + () => "Weather data", + new AIFunctionFactoryOptions + { + Name = "GetWeatherForecast", + Description = "Returns weather forecast and temperature for a given city." + }); + + var stockTool = AIFunctionFactory.Create( + () => "Stock data", + new AIFunctionFactoryOptions + { + Name = "GetStockQuote", + Description = "Retrieves live stock market price for a company ticker symbol." + }); + + var translateTool = AIFunctionFactory.Create( + () => "Translated text", + new AIFunctionFactoryOptions + { + Name = "TranslateText", + Description = "Translates text between human languages." + }); + + var mathTool = AIFunctionFactory.Create( + () => 42, + new AIFunctionFactoryOptions + { + Name = "SolveMath", + Description = "Solves arithmetic or algebraic math problems." + }); + + var allTools = new List { weatherTool, stockTool, translateTool, mathTool }; + + IList? capturedTools = null; + + using var client = ChatClient! + .AsBuilder() + .UseToolReduction(strategy) + // Capture the tools after reduction, before invoking the underlying model. + .Use((messages, options, next, ct) => + { + capturedTools = options?.Tools; + return next(messages, options, ct); + }) + .Build(); + + var question = "What will the weather be in Paris tomorrow?"; + _ = await client.GetResponseAsync([new(ChatRole.User, question)], new ChatOptions + { + Tools = allTools + }); + + Assert.NotNull(capturedTools); + Assert.Single(capturedTools!); + Assert.Equal("GetWeatherForecast", capturedTools![0].Name); + } + + [ConditionalFact] + public virtual async Task ToolReduction_MultiConceptQuery_SelectsTwoRelevantTools() + { + SkipIfNotEnabled(); + EnsureEmbeddingGenerator(); + + var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2); + + var weatherTool = AIFunctionFactory.Create( + () => "Weather data", + new AIFunctionFactoryOptions + { + Name = "GetWeatherForecast", + Description = "Returns weather forecast and temperature for a given city." + }); + + var translateTool = AIFunctionFactory.Create( + () => "Translated text", + new AIFunctionFactoryOptions + { + Name = "TranslateText", + Description = "Translates text between human languages." + }); + + var stockTool = AIFunctionFactory.Create( + () => "Stock data", + new AIFunctionFactoryOptions + { + Name = "GetStockQuote", + Description = "Retrieves live stock market price for a company ticker symbol." + }); + + var mathTool = AIFunctionFactory.Create( + () => 42, + new AIFunctionFactoryOptions + { + Name = "SolveMath", + Description = "Solves arithmetic or algebraic math problems." + }); + + var allTools = new List { weatherTool, translateTool, stockTool, mathTool }; + + IList? capturedTools = null; + + using var client = ChatClient! + .AsBuilder() + .UseToolReduction(strategy) + .Use((messages, options, next, ct) => + { + capturedTools = options?.Tools; + return next(messages, options, ct); + }) + .Build(); + + // Query intentionally references two distinct semantic domains: weather + translation. + var question = "Please translate 'good morning' into Spanish and also tell me the weather forecast for Barcelona."; + _ = await client.GetResponseAsync([new(ChatRole.User, question)], new ChatOptions + { + Tools = allTools + }); + + Assert.NotNull(capturedTools); + Assert.Equal(2, capturedTools!.Count); + + // Order is not guaranteed; assert membership. + var names = capturedTools.Select(t => t.Name).ToList(); + Assert.Contains("GetWeatherForecast", names); + Assert.Contains("TranslateText", names); + } + [MemberNotNull(nameof(ChatClient))] protected void SkipIfNotEnabled() { @@ -1405,4 +1552,15 @@ protected void SkipIfNotEnabled() throw new SkipTestException("Client is not enabled."); } } + + [MemberNotNull(nameof(EmbeddingGenerator))] + protected void EnsureEmbeddingGenerator() + { + EmbeddingGenerator ??= CreateEmbeddingGenerator(); + + if (EmbeddingGenerator is null) + { + throw new SkipTestException("Embedding generator is not enabled."); + } + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj index 0fc4698c4e4..06b0e82ca75 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj @@ -47,6 +47,7 @@ + diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs new file mode 100644 index 00000000000..ea95705dfb5 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs @@ -0,0 +1,291 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ToolReductionTests +{ + [Fact] + public async Task Strategy_NoReduction_WhenToolsBelowLimit() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 5); + + var tools = CreateTools("Weather", "Math"); + var options = new ChatOptions { Tools = tools }; + + var result = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "Tell me about weather") }, + options); + + Assert.Same(tools, result); + } + + [Fact] + public async Task Strategy_Reduces_ToLimit_BySimilarity() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + + var tools = CreateTools( + "Weather", + "Translate", + "Math", + "Jokes"); + + var options = new ChatOptions { Tools = tools }; + + var messages = new[] + { + new ChatMessage(ChatRole.User, "Can you do some weather math for forecasting?") + }; + + var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList(); + + Assert.Equal(2, reduced.Count); + + // Only assert membership; ordering is an implementation detail when scores tie. + Assert.Contains(reduced, t => t.Name == "Weather"); + Assert.Contains(reduced, t => t.Name == "Math"); + } + + [Fact] + public async Task Strategy_PreserveOriginalOrdering_ReordersAfterSelection() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2) + { + PreserveOriginalOrdering = true + }; + + var tools = CreateTools("Math", "Translate", "Weather"); + var options = new ChatOptions { Tools = tools }; + + var messages = new[] { new ChatMessage(ChatRole.User, "Explain weather math please") }; + + var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList(); + + Assert.Equal(2, reduced.Count); + Assert.Contains(reduced, t => t.Name == "Math"); + Assert.Contains(reduced, t => t.Name == "Weather"); + + // With PreserveOriginalOrdering the original relative order (Math before Weather) is maintained. + Assert.Equal("Math", reduced[0].Name); + Assert.Equal("Weather", reduced[1].Name); + } + + [Fact] + public async Task Strategy_EmptyQuery_FallsBackToFirstN() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + + var tools = CreateTools("A", "B", "C"); + var options = new ChatOptions { Tools = tools }; + + var messages = new[] { new ChatMessage(ChatRole.User, " ") }; + + var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList(); + + Assert.Equal(2, reduced.Count); + Assert.Equal("A", reduced[0].Name); + Assert.Equal("B", reduced[1].Name); + } + + [Fact] + public async Task Strategy_Caching_AvoidsReEmbeddingTools() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + + var tools = CreateTools("Weather", "Math", "Jokes"); + var options = new ChatOptions { Tools = tools }; + var messages = new[] { new ChatMessage(ChatRole.User, "weather") }; + + _ = await strategy.SelectToolsForRequestAsync(messages, options); + int afterFirst = gen.TotalValueInputs; + + _ = await strategy.SelectToolsForRequestAsync(messages, options); + int afterSecond = gen.TotalValueInputs; + + Assert.Equal(afterFirst + 1, afterSecond); + } + + [Fact] + public async Task Strategy_CachingDisabled_ReEmbedsToolsEachCall() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + EnableEmbeddingCaching = false + }; + + var tools = CreateTools("Weather", "Math"); + var options = new ChatOptions { Tools = tools }; + var messages = new[] { new ChatMessage(ChatRole.User, "weather") }; + + _ = await strategy.SelectToolsForRequestAsync(messages, options); + int afterFirst = gen.TotalValueInputs; + + _ = await strategy.SelectToolsForRequestAsync(messages, options); + int afterSecond = gen.TotalValueInputs; + + Assert.Equal(afterFirst + tools.Count + 1, afterSecond); + } + + [Fact] + public void Strategy_Constructor_ThrowsWhenToolLimitIsLessThanOrEqualToZero() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + Assert.Throws(() => new EmbeddingToolReductionStrategy(gen, toolLimit: 0)); + } + + [Fact] + public async Task ToolReducingChatClient_ReducesTools_ForGetResponseAsync() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + var tools = CreateTools("Weather", "Math", "Translate", "Jokes"); + + IList? observedTools = null; + + using var inner = new TestChatClient + { + GetResponseAsyncCallback = (messages, options, ct) => + { + observedTools = options?.Tools; + return Task.FromResult(new ChatResponse()); + } + }; + + using var client = inner + .AsBuilder() + .UseToolReduction(strategy) + .Build(); + + await client.GetResponseAsync( + new[] { new ChatMessage(ChatRole.User, "weather math please") }, + new ChatOptions { Tools = tools }); + + Assert.NotNull(observedTools); + Assert.Equal(2, observedTools!.Count); + Assert.Contains(observedTools, t => t.Name == "Weather"); + Assert.Contains(observedTools, t => t.Name == "Math"); + } + + [Fact] + public async Task ToolReducingChatClient_ReducesTools_ForStreaming() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + var tools = CreateTools("Weather", "Math"); + + IList? observedTools = null; + + using var inner = new TestChatClient + { + GetStreamingResponseAsyncCallback = (messages, options, ct) => + { + observedTools = options?.Tools; + return AsyncEnumerable.Empty(); + } + }; + + using var client = inner + .AsBuilder() + .UseToolReduction(strategy) + .Build(); + + await foreach (var _ in client.GetStreamingResponseAsync( + new[] { new ChatMessage(ChatRole.User, "math") }, + new ChatOptions { Tools = tools })) + { + // Consume + } + + Assert.NotNull(observedTools); + Assert.Single(observedTools!); + Assert.Equal("Math", observedTools![0].Name); + } + + private static List CreateTools(params string[] names) => + names.Select(n => (AITool)new SimpleTool(n, $"Description about {n}")).ToList(); + + private sealed class SimpleTool : AITool + { + private readonly string _name; + private readonly string _description; + + public SimpleTool(string name, string description) + { + _name = name; + _description = description; + } + + public override string Name => _name; + public override string Description => _description; + } + + /// + /// Deterministic embedding generator producing sparse keyword indicator vectors. + /// Each dimension corresponds to a known keyword. Cosine similarity then reflects + /// pure keyword overlap (non-overlapping keywords contribute nothing), avoiding + /// false ties for tools unrelated to the query. + /// + private sealed class DeterministicTestEmbeddingGenerator : IEmbeddingGenerator> + { + private static readonly string[] _keywords = + [ + "weather","forecast","temperature","math","calculate","sum","translate","language","joke" + ]; + + // +1 bias dimension (last) to avoid zero magnitude vectors when no keywords present. + private static int VectorLength => _keywords.Length + 1; + + public int TotalValueInputs { get; private set; } + + public Task>> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + var list = new List>(); + + foreach (var v in values) + { + TotalValueInputs++; + var vec = new float[VectorLength]; + if (!string.IsNullOrWhiteSpace(v)) + { + var lower = v.ToLowerInvariant(); + for (int i = 0; i < _keywords.Length; i++) + { + if (lower.Contains(_keywords[i])) + { + vec[i] = 1f; + } + } + } + + vec[VectorLength - 1] = 1f; // bias + list.Add(new Embedding(vec)); + } + + return Task.FromResult(new GeneratedEmbeddings>(list)); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + + public void Dispose() + { + // No-op + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs index 6322e3d6b64..a9e08a58e52 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs @@ -8,4 +8,8 @@ public class OpenAIChatClientIntegrationTests : ChatClientIntegrationTests protected override IChatClient? CreateChatClient() => IntegrationTestHelpers.GetOpenAIClient() ?.GetChatClient(TestRunnerConfiguration.Instance["OpenAI:ChatModel"] ?? "gpt-4o-mini").AsIChatClient(); + + protected override IEmbeddingGenerator>? CreateEmbeddingGenerator() => + IntegrationTestHelpers.GetOpenAIClient() + ?.GetEmbeddingClient(TestRunnerConfiguration.Instance["OpenAI:EmbeddingModel"] ?? "text-embedding-3-small").AsIEmbeddingGenerator(); } From c7608c135b51e917d64c042e1bf2da50f4897caa Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Mon, 8 Sep 2025 13:53:02 -0700 Subject: [PATCH 2/9] Some refinements + more tests --- .../EmbeddingToolReductionStrategy.cs | 29 ++- .../ToolReductionTests.cs | 235 +++++++++++++++++- 2 files changed, 244 insertions(+), 20 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs index 3738dcad0de..152073b172d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs @@ -45,7 +45,7 @@ public EmbeddingToolReductionStrategy( /// Gets or sets a delegate used to produce the text to embed for a tool. /// Defaults to: Name + "\n" + Description (omitting empty parts). /// - public Func EmbeddingTextFactory + public Func ToolEmbeddingTextFactory { get => field ??= static t => { @@ -64,6 +64,20 @@ public Func EmbeddingTextFactory set => field = Throw.IfNull(value); } + /// + /// Gets or sets the factory function used to generate a single text string from a collection of chat messages for + /// embedding purposes. + /// + public Func, string> MessagesEmbeddingTextFactory + { + get => field ??= static messages => + { + var messageTexts = messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s)); + return string.Join("\n", messageTexts); + }; + set => field = Throw.IfNull(value); + } + /// /// Gets or sets a similarity function applied to (query, tool) embedding vectors. Defaults to cosine similarity. /// @@ -85,12 +99,6 @@ public Func, ReadOnlyMemory, float> Similarity /// public bool PreserveOriginalOrdering { get; set; } - /// - /// Gets or sets the maximum number of most recent messages to include when forming the query embedding. - /// Defaults to (all messages). - /// - public int MaxMessagesForQueryEmbedding { get; set; } = int.MaxValue; - /// public async Task> SelectToolsForRequestAsync( IEnumerable messages, @@ -113,8 +121,7 @@ public async Task> SelectToolsForRequestAsync( } // Build query text from recent messages. - var messageTexts = messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s)); - var queryText = string.Join("\n", messageTexts); + var queryText = MessagesEmbeddingTextFactory(messages); if (string.IsNullOrWhiteSpace(queryText)) { // We couldn't build a meaningful query, likely because the message list was empty. @@ -149,7 +156,7 @@ private async Task>> GetToolEmbeddingsAsync(IList if (!EnableEmbeddingCaching) { // Embed all tools in one batch; do not store in cache. - return await ComputeEmbeddingsAsync(tools.Select(t => EmbeddingTextFactory(t)), expectedCount: tools.Count); + return await ComputeEmbeddingsAsync(tools.Select(t => ToolEmbeddingTextFactory(t)), expectedCount: tools.Count); } var result = new Embedding[tools.Count]; @@ -172,7 +179,7 @@ private async Task>> GetToolEmbeddingsAsync(IList return result; } - var uncachedEmbeddings = await ComputeEmbeddingsAsync(cacheMisses.Select(t => EmbeddingTextFactory(t.Tool)), expectedCount: cacheMisses.Count); + var uncachedEmbeddings = await ComputeEmbeddingsAsync(cacheMisses.Select(t => ToolEmbeddingTextFactory(t.Tool)), expectedCount: cacheMisses.Count); for (var i = 0; i < cacheMisses.Count; i++) { diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs index ea95705dfb5..45e83a5f757 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs @@ -13,7 +13,14 @@ namespace Microsoft.Extensions.AI; public class ToolReductionTests { [Fact] - public async Task Strategy_NoReduction_WhenToolsBelowLimit() + public void EmbeddingToolReductionStrategy_Constructor_ThrowsWhenToolLimitIsLessThanOrEqualToZero() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + Assert.Throws(() => new EmbeddingToolReductionStrategy(gen, toolLimit: 0)); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_NoReduction_WhenToolsBelowLimit() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 5); @@ -29,7 +36,7 @@ public async Task Strategy_NoReduction_WhenToolsBelowLimit() } [Fact] - public async Task Strategy_Reduces_ToLimit_BySimilarity() + public async Task EmbeddingToolReductionStrategy_Reduces_ToLimit_BySimilarity() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); @@ -57,7 +64,7 @@ public async Task Strategy_Reduces_ToLimit_BySimilarity() } [Fact] - public async Task Strategy_PreserveOriginalOrdering_ReordersAfterSelection() + public async Task EmbeddingToolReductionStrategy_PreserveOriginalOrdering_ReordersAfterSelection() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2) @@ -82,7 +89,7 @@ public async Task Strategy_PreserveOriginalOrdering_ReordersAfterSelection() } [Fact] - public async Task Strategy_EmptyQuery_FallsBackToFirstN() + public async Task EmbeddingToolReductionStrategy_EmptyQuery_FallsBackToFirstN() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); @@ -100,7 +107,7 @@ public async Task Strategy_EmptyQuery_FallsBackToFirstN() } [Fact] - public async Task Strategy_Caching_AvoidsReEmbeddingTools() + public async Task EmbeddingToolReductionStrategy_Caching_AvoidsReEmbeddingTools() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); @@ -115,11 +122,12 @@ public async Task Strategy_Caching_AvoidsReEmbeddingTools() _ = await strategy.SelectToolsForRequestAsync(messages, options); int afterSecond = gen.TotalValueInputs; + // The additional embedding generator call is for the message list. Assert.Equal(afterFirst + 1, afterSecond); } [Fact] - public async Task Strategy_CachingDisabled_ReEmbedsToolsEachCall() + public async Task EmbeddingToolReductionStrategy_CachingDisabled_ReEmbedsToolsEachCall() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) @@ -137,14 +145,168 @@ public async Task Strategy_CachingDisabled_ReEmbedsToolsEachCall() _ = await strategy.SelectToolsForRequestAsync(messages, options); int afterSecond = gen.TotalValueInputs; + // The additional embedding generator call is for the message list. Assert.Equal(afterFirst + tools.Count + 1, afterSecond); } [Fact] - public void Strategy_Constructor_ThrowsWhenToolLimitIsLessThanOrEqualToZero() + public async Task EmbeddingToolReductionStrategy_OptionsNullOrNoTools_ReturnsEmptyOrOriginal() { using var gen = new DeterministicTestEmbeddingGenerator(); - Assert.Throws(() => new EmbeddingToolReductionStrategy(gen, toolLimit: 0)); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + + // Null options => empty result + var empty = await strategy.SelectToolsForRequestAsync(new[] { new ChatMessage(ChatRole.User, "anything") }, null); + Assert.Empty(empty); + + // Empty tools list => returns that same list + var options = new ChatOptions { Tools = [] }; + var result = await strategy.SelectToolsForRequestAsync(new[] { new ChatMessage(ChatRole.User, "weather") }, options); + Assert.Same(options.Tools, result); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_CustomSimilarity_InvertsOrdering() + { + using var gen = new VectorBasedTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + // Custom similarity chooses the *smallest* first component instead of largest cosine. + Similarity = (q, t) => -t.Span[0] + }; + + var highTool = new SimpleTool("HighScore", "alpha"); // vector[0] = 10 + var lowTool = new SimpleTool("LowScore", "beta"); // vector[0] = 1 + gen.VectorSelector = text => text.Contains("alpha") ? 10f : 1f; + + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "Pick something") }, + new ChatOptions { Tools = [highTool, lowTool] })).ToList(); + + Assert.Single(reduced); + + // Because we negated similarity, lowest underlying value wins + Assert.Equal("LowScore", reduced[0].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextFactory_EmptyDescription_UsesNameOnly() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); + + var target = new SimpleTool("ComputeSum", description: ""); + var filler = new SimpleTool("Other", "Unrelated"); + await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "math") }, + new ChatOptions { Tools = new List { target, filler } }); + + Assert.Contains("ComputeSum", recorder.Inputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextFactory_EmptyName_UsesDescriptionOnly() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); + + var target = new SimpleTool("", description: "Translates between languages."); + var filler = new SimpleTool("Other", "Unrelated"); + await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "translate") }, + new ChatOptions { Tools = [target, filler] }); + + Assert.Contains("Translates between languages.", recorder.Inputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_CustomEmbeddingTextFactory_Applied() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1) + { + ToolEmbeddingTextFactory = t => $"NAME:{t.Name}|DESC:{t.Description}" + }; + + var target = new SimpleTool("WeatherTool", "Gets forecast."); + var filler = new SimpleTool("Other", "Irrelevant"); + await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "weather") }, + new ChatOptions { Tools = [target, filler] }); + + Assert.Contains("NAME:WeatherTool|DESC:Gets forecast.", recorder.Inputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_CustomFiltersMessages() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + + // Tools we want to discriminate between + var tools = CreateTools("Weather", "Math", "Translate"); + + // Earlier message mentions weather, last message mentions math + var messages = new[] + { + new ChatMessage(ChatRole.User, "Please tell me the weather tomorrow."), + new ChatMessage(ChatRole.Assistant, "Sure, I can help."), + new ChatMessage(ChatRole.User, "Now instead solve a math problem.") + }; + + // Only consider the last message (so "Math" should clearly win) + strategy.MessagesEmbeddingTextFactory = msgs => msgs.LastOrDefault()?.Text ?? string.Empty; + + var reduced = (await strategy.SelectToolsForRequestAsync( + messages, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Single(reduced); + Assert.Equal("Math", reduced[0].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_EmptyResult_FallbacksToFirstN() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + + var tools = CreateTools("Alpha", "Beta", "Gamma"); + + // Return only whitespace so strategy should truncate deterministically + strategy.MessagesEmbeddingTextFactory = _ => " "; + + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "Irrelevant content") }, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Equal(2, reduced.Count); + Assert.Equal("Alpha", reduced[0].Name); + Assert.Equal("Beta", reduced[1].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_InvokedOnce() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + + var tools = CreateTools("Weather", "Math"); + int invocationCount = 0; + + strategy.MessagesEmbeddingTextFactory = msgs => + { + invocationCount++; + + // Default-like behavior + return string.Join("\n", msgs.Select(m => m.Text)); + }; + + _ = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "weather and math") }, + new ChatOptions { Tools = tools }); + + Assert.Equal(1, invocationCount); } [Fact] @@ -274,7 +436,7 @@ public Task>> GenerateAsync( } } - vec[VectorLength - 1] = 1f; // bias + vec[^1] = 1f; // bias list.Add(new Embedding(vec)); } @@ -288,4 +450,59 @@ public void Dispose() // No-op } } + + private sealed class RecordingEmbeddingGenerator : IEmbeddingGenerator> + { + public List Inputs { get; } = new(); + + public Task>> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + var list = new List>(); + foreach (var v in values) + { + Inputs.Add(v); + + // Basic 2-dim vector (length encodes a bit of variability) + list.Add(new Embedding(new float[] { v.Length, 1f })); + } + + return Task.FromResult(new GeneratedEmbeddings>(list)); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + public void Dispose() + { + // No-op + } + } + + private sealed class VectorBasedTestEmbeddingGenerator : IEmbeddingGenerator> + { + // External control for choosing first component of embedding. + public Func VectorSelector { get; set; } = _ => 1f; + + public Task>> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + var list = new List>(); + foreach (var v in values) + { + float val = VectorSelector(v); + list.Add(new Embedding(new float[] { val, 1f })); + } + + return Task.FromResult(new GeneratedEmbeddings>(list)); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + public void Dispose() + { + // No-op + } + } } From a23834996e82d20045e115258b61e15db897cdb3 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Mon, 8 Sep 2025 14:20:11 -0700 Subject: [PATCH 3/9] Remove usage of 'field' --- .../Microsoft.Extensions.AI.csproj | 1 - .../EmbeddingToolReductionStrategy.cs | 54 +++++++++++-------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index a5b192f94b7..014738b1ee4 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -33,7 +33,6 @@ true true true - true false diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs index 152073b172d..448e6634e6b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs @@ -14,6 +14,8 @@ namespace Microsoft.Extensions.AI; +#pragma warning disable IDE0032 // Use auto property, suppressed until repo updates to C# 14 + /// /// A tool reduction strategy that ranks tools by embedding similarity to the current conversation context. /// @@ -28,6 +30,29 @@ public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy private readonly IEmbeddingGenerator> _embeddingGenerator; private readonly int _toolLimit; + private Func _toolEmbeddingTextFactory = static t => + { + if (string.IsNullOrWhiteSpace(t.Name)) + { + return t.Description; + } + + if (string.IsNullOrWhiteSpace(t.Description)) + { + return t.Name; + } + + return t.Name + "\n" + t.Description; + }; + + private Func, string> _messagesEmbeddingTextFactory = static messages => + { + var messageTexts = messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s)); + return string.Join("\n", messageTexts); + }; + + private Func, ReadOnlyMemory, float> _similarity = static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span); + /// /// Initializes a new instance of the class. /// @@ -47,21 +72,8 @@ public EmbeddingToolReductionStrategy( /// public Func ToolEmbeddingTextFactory { - get => field ??= static t => - { - if (string.IsNullOrWhiteSpace(t.Name)) - { - return t.Description; - } - - if (string.IsNullOrWhiteSpace(t.Description)) - { - return t.Name; - } - - return t.Name + "\n" + t.Description; - }; - set => field = Throw.IfNull(value); + get => _toolEmbeddingTextFactory; + set => _toolEmbeddingTextFactory = Throw.IfNull(value); } /// @@ -70,12 +82,8 @@ public Func ToolEmbeddingTextFactory /// public Func, string> MessagesEmbeddingTextFactory { - get => field ??= static messages => - { - var messageTexts = messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s)); - return string.Join("\n", messageTexts); - }; - set => field = Throw.IfNull(value); + get => _messagesEmbeddingTextFactory; + set => _messagesEmbeddingTextFactory = Throw.IfNull(value); } /// @@ -83,8 +91,8 @@ public Func, string> MessagesEmbeddingTextFactory /// public Func, ReadOnlyMemory, float> Similarity { - get => field ??= static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span); - set => field = Throw.IfNull(value); + get => _similarity; + set => _similarity = Throw.IfNull(value); } /// From 06cedcc83ca33562cdb24bc3b1e64cd1574e08ec Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Mon, 8 Sep 2025 14:42:08 -0700 Subject: [PATCH 4/9] Update System.Linq.Async version --- eng/packages/TestOnly.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props index 96668e0c00d..100a2301e5d 100644 --- a/eng/packages/TestOnly.props +++ b/eng/packages/TestOnly.props @@ -22,7 +22,7 @@ - + From 56f4bdd16d7968cc149bddee3ea20c529dfe91c0 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 11 Sep 2025 12:03:06 -0700 Subject: [PATCH 5/9] Don't handle `OperationCanceledException` --- .../ToolReduction/ToolReducingChatClient.cs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs index 01fec30e8d9..062a06715c0 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs @@ -67,15 +67,7 @@ public override async IAsyncEnumerable GetStreamingResponseA return options; } - IEnumerable reduced; - try - { - reduced = await _strategy.SelectToolsForRequestAsync(messages, options, cancellationToken).ConfigureAwait(false); - } - catch (OperationCanceledException) - { - return options; - } + var reduced = await _strategy.SelectToolsForRequestAsync(messages, options, cancellationToken).ConfigureAwait(false); // If strategy returned the same list instance (or reference equality), assume no change. if (ReferenceEquals(reduced, options.Tools)) From b2583ecd0751a60b2276885822d87ec26b531954 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 11 Sep 2025 13:12:42 -0700 Subject: [PATCH 6/9] More PR feedback --- eng/packages/TestOnly.props | 1 - .../EmbeddingToolReductionStrategy.cs | 28 +++++++++++++++---- ...oft.Extensions.AI.Integration.Tests.csproj | 1 - .../ToolReductionTests.cs | 9 +++++- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props index 100a2301e5d..e4f0d146ba9 100644 --- a/eng/packages/TestOnly.props +++ b/eng/packages/TestOnly.props @@ -22,7 +22,6 @@ - diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs index 448e6634e6b..215cc6ddd7b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs @@ -68,8 +68,10 @@ public EmbeddingToolReductionStrategy( /// /// Gets or sets a delegate used to produce the text to embed for a tool. - /// Defaults to: Name + "\n" + Description (omitting empty parts). /// + /// + /// Defaults to: Name + "\n" + Description (omitting empty parts). + /// public Func ToolEmbeddingTextFactory { get => _toolEmbeddingTextFactory; @@ -87,8 +89,11 @@ public Func, string> MessagesEmbeddingTextFactory } /// - /// Gets or sets a similarity function applied to (query, tool) embedding vectors. Defaults to cosine similarity. + /// Gets or sets a similarity function applied to (query, tool) embedding vectors. /// + /// + /// Defaults to cosine similarity. + /// public Func, ReadOnlyMemory, float> Similarity { get => _similarity; @@ -117,6 +122,9 @@ public async Task> SelectToolsForRequestAsync( if (options?.Tools is not { Count: > 0 } tools) { + // Prefer the original tools list reference if possible. + // This allows ToolReducingChatClient to avoid unnecessarily copying ChatOptions. + // When no reduction is performed. return options?.Tools ?? []; } @@ -156,7 +164,9 @@ public async Task> SelectToolsForRequestAsync( ranked = ranked.OrderBy(t => t.Index); } - return ranked.Select(t => t.Tool); + return ranked + .Select(t => t.Tool) + .ToList(); } private async Task>> GetToolEmbeddingsAsync(IList tools, CancellationToken cancellationToken) @@ -164,7 +174,10 @@ private async Task>> GetToolEmbeddingsAsync(IList if (!EnableEmbeddingCaching) { // Embed all tools in one batch; do not store in cache. - return await ComputeEmbeddingsAsync(tools.Select(t => ToolEmbeddingTextFactory(t)), expectedCount: tools.Count); + return await ComputeEmbeddingsAsync( + texts: tools.Select(t => ToolEmbeddingTextFactory(t)), + expectedCount: tools.Count, + cancellationToken); } var result = new Embedding[tools.Count]; @@ -187,7 +200,10 @@ private async Task>> GetToolEmbeddingsAsync(IList return result; } - var uncachedEmbeddings = await ComputeEmbeddingsAsync(cacheMisses.Select(t => ToolEmbeddingTextFactory(t.Tool)), expectedCount: cacheMisses.Count); + var uncachedEmbeddings = await ComputeEmbeddingsAsync( + texts: cacheMisses.Select(t => ToolEmbeddingTextFactory(t.Tool)), + expectedCount: cacheMisses.Count, + cancellationToken); for (var i = 0; i < cacheMisses.Count; i++) { @@ -198,7 +214,7 @@ private async Task>> GetToolEmbeddingsAsync(IList return result; - async ValueTask>> ComputeEmbeddingsAsync(IEnumerable texts, int expectedCount) + async ValueTask>> ComputeEmbeddingsAsync(IEnumerable texts, int expectedCount, CancellationToken cancellationToken) { var embeddings = await _embeddingGenerator.GenerateAsync(texts, cancellationToken: cancellationToken).ConfigureAwait(false); if (embeddings.Count != expectedCount) diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj index 06b0e82ca75..0fc4698c4e4 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj @@ -47,7 +47,6 @@ - diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs index 45e83a5f757..03dfd4f6060 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs @@ -356,7 +356,7 @@ public async Task ToolReducingChatClient_ReducesTools_ForStreaming() GetStreamingResponseAsyncCallback = (messages, options, ct) => { observedTools = options?.Tools; - return AsyncEnumerable.Empty(); + return EmptyAsyncEnumerable(); } }; @@ -380,6 +380,13 @@ public async Task ToolReducingChatClient_ReducesTools_ForStreaming() private static List CreateTools(params string[] names) => names.Select(n => (AITool)new SimpleTool(n, $"Description about {n}")).ToList(); +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + private static async IAsyncEnumerable EmptyAsyncEnumerable() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + { + yield break; + } + private sealed class SimpleTool : AITool { private readonly string _name; From 491ff7284ecfdb65a00c9c1c0a07408e4ea7029c Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 18 Sep 2025 15:57:29 -0700 Subject: [PATCH 7/9] PR feedback + other improvements --- .../EmbeddingToolReductionStrategy.cs | 254 +++++++---- .../ToolReduction/ToolReducingChatClient.cs | 2 +- .../ChatClientIntegrationTests.cs | 224 +++++++--- .../ToolReductionTests.cs | 400 ++++++++++++------ 4 files changed, 617 insertions(+), 263 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs index 215cc6ddd7b..82d66c680c6 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs @@ -2,12 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Numerics.Tensors; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -30,7 +31,7 @@ public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy private readonly IEmbeddingGenerator> _embeddingGenerator; private readonly int _toolLimit; - private Func _toolEmbeddingTextFactory = static t => + private Func _toolEmbeddingTextSelector = static t => { if (string.IsNullOrWhiteSpace(t.Name)) { @@ -42,22 +43,46 @@ public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy return t.Name; } - return t.Name + "\n" + t.Description; + return t.Name + Environment.NewLine + t.Description; }; - private Func, string> _messagesEmbeddingTextFactory = static messages => + private Func, string> _messagesEmbeddingTextSelector = static messages => { - var messageTexts = messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s)); - return string.Join("\n", messageTexts); + var sb = new StringBuilder(); + foreach (var message in messages) + { + var contents = message.Contents; + for (var i = 0; i < contents.Count; i++) + { + string text; + switch (contents[i]) + { + case TextContent content: + text = content.Text; + break; + case TextReasoningContent content: + text = content.Text; + break; + default: + continue; + } + + _ = sb.AppendLine(text); + } + } + + return sb.ToString(); }; private Func, ReadOnlyMemory, float> _similarity = static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span); + private Func _isRequiredTool = static _ => false; + /// /// Initializes a new instance of the class. /// /// Embedding generator used to produce embeddings. - /// Maximum number of tools to return. Must be greater than zero. + /// Maximum number of tools to return, excluding required tools. Must be greater than zero. public EmbeddingToolReductionStrategy( IEmbeddingGenerator> embeddingGenerator, int toolLimit) @@ -67,25 +92,25 @@ public EmbeddingToolReductionStrategy( } /// - /// Gets or sets a delegate used to produce the text to embed for a tool. + /// Gets or sets the selector used to generate a single text string from a tool. /// /// /// Defaults to: Name + "\n" + Description (omitting empty parts). /// - public Func ToolEmbeddingTextFactory + public Func ToolEmbeddingTextSelector { - get => _toolEmbeddingTextFactory; - set => _toolEmbeddingTextFactory = Throw.IfNull(value); + get => _toolEmbeddingTextSelector; + set => _toolEmbeddingTextSelector = Throw.IfNull(value); } /// - /// Gets or sets the factory function used to generate a single text string from a collection of chat messages for + /// Gets or sets the selector used to generate a single text string from a collection of chat messages for /// embedding purposes. /// - public Func, string> MessagesEmbeddingTextFactory + public Func, string> MessagesEmbeddingTextSelector { - get => _messagesEmbeddingTextFactory; - set => _messagesEmbeddingTextFactory = Throw.IfNull(value); + get => _messagesEmbeddingTextSelector; + set => _messagesEmbeddingTextSelector = Throw.IfNull(value); } /// @@ -101,9 +126,19 @@ public Func, ReadOnlyMemory, float> Similarity } /// - /// Gets or sets a value indicating whether tool embeddings are cached. Defaults to . + /// Gets or sets a function that determines whether a tool is required (always included). /// - public bool EnableEmbeddingCaching { get; set; } = true; + /// + /// If this returns , the tool is included regardless of ranking and does not count against + /// the configured non-required tool limit. A tool explicitly named by (when + /// is non-null) is also treated as required, independent + /// of this delegate's result. + /// + public Func IsRequiredTool + { + get => _isRequiredTool; + set => _isRequiredTool = Throw.IfNull(value); + } /// /// Gets or sets a value indicating whether to preserve original ordering of selected tools. @@ -132,97 +167,164 @@ public async Task> SelectToolsForRequestAsync( if (tools.Count <= _toolLimit) { - // No reduction necessary. + // Since the total number of tools doesn't exceed the configured tool limit, + // there's no need to determine which tools are optional, i.e., subject to reduction. + // We can return the original tools list early. return tools; } - // Build query text from recent messages. - var queryText = MessagesEmbeddingTextFactory(messages); - if (string.IsNullOrWhiteSpace(queryText)) + var toolRankingInfoArray = ArrayPool.Shared.Rent(tools.Count); + try { - // We couldn't build a meaningful query, likely because the message list was empty. - // We'll just return a truncated list of tools. - return tools.Take(_toolLimit); - } + var toolRankingInfoMemory = toolRankingInfoArray.AsMemory(start: 0, length: tools.Count); + + // We allocate tool rankings in a contiguous chunk of memory, but partition them such that + // required tools come first and are immediately followed by optional tools. + // This allows us to separately rank optional tools by similarity score, but then later re-order + // the top N tools (including required tools) to preserve their original relative order. + var (requiredTools, optionalTools) = PartitionToolRankings(toolRankingInfoMemory, tools, options.ToolMode); + + if (optionalTools.Length <= _toolLimit) + { + // There aren't enough optional tools to require reduction, so we'll return the original + // tools list. + return tools; + } + + // Build query text from recent messages. + var queryText = MessagesEmbeddingTextSelector(messages); + if (string.IsNullOrWhiteSpace(queryText)) + { + // We couldn't build a meaningful query, likely because the message list was empty. + // We'll just return the original tools list. + return tools; + } + + var queryEmbedding = await _embeddingGenerator.GenerateAsync(queryText, cancellationToken: cancellationToken).ConfigureAwait(false); - // Ensure embeddings for any uncached tools are generated in a batch. - var toolEmbeddings = await GetToolEmbeddingsAsync(tools, cancellationToken).ConfigureAwait(false); + // Compute and populate similarity scores in the tool ranking info. + await ComputeSimilarityScoresAsync(optionalTools, queryEmbedding, cancellationToken); - // Generate the query embedding. - var queryEmbedding = await _embeddingGenerator.GenerateAsync(queryText, cancellationToken: cancellationToken).ConfigureAwait(false); - var queryVector = queryEmbedding.Vector; + var topTools = toolRankingInfoMemory.Slice(start: 0, length: requiredTools.Length + _toolLimit); +#if NET + optionalTools.Span.Sort(AIToolRankingInfo.CompareByDescendingSimilarityScore); + if (PreserveOriginalOrdering) + { + topTools.Span.Sort(AIToolRankingInfo.CompareByOriginalIndex); + } +#else + Array.Sort(toolRankingInfoArray, index: requiredTools.Length, length: optionalTools.Length, AIToolRankingInfo.CompareByDescendingSimilarityScore); + if (PreserveOriginalOrdering) + { + Array.Sort(toolRankingInfoArray, index: 0, length: topTools.Length, AIToolRankingInfo.CompareByOriginalIndex); + } +#endif + return ToToolList(topTools.Span); - // Compute rankings. - var ranked = tools - .Zip(toolEmbeddings, static (tool, embedding) => (Tool: tool, Embedding: embedding)) - .Select((t, i) => (t.Tool, Index: i, Score: Similarity(queryVector, t.Embedding.Vector))) - .OrderByDescending(t => t.Score) - .Take(_toolLimit); + static List ToToolList(ReadOnlySpan toolInfo) + { + var result = new List(capacity: toolInfo.Length); + foreach (var info in toolInfo) + { + result.Add(info.Tool); + } - if (PreserveOriginalOrdering) + return result; + } + } + finally { - ranked = ranked.OrderBy(t => t.Index); + ArrayPool.Shared.Return(toolRankingInfoArray); } - - return ranked - .Select(t => t.Tool) - .ToList(); } - private async Task>> GetToolEmbeddingsAsync(IList tools, CancellationToken cancellationToken) + private (Memory RequiredTools, Memory OptionalTools) PartitionToolRankings( + Memory toolRankingInfo, IList tools, ChatToolMode? toolMode) { - if (!EnableEmbeddingCaching) + // Always include a tool if its name matches the required function name. + var requiredFunctionName = (toolMode as RequiredChatToolMode)?.RequiredFunctionName; + var nextRequiredToolIndex = 0; + var nextOptionalToolIndex = tools.Count - 1; + for (var i = 0; i < toolRankingInfo.Length; i++) { - // Embed all tools in one batch; do not store in cache. - return await ComputeEmbeddingsAsync( - texts: tools.Select(t => ToolEmbeddingTextFactory(t)), - expectedCount: tools.Count, - cancellationToken); + var tool = tools[i]; + var isRequiredByToolMode = requiredFunctionName is not null && string.Equals(requiredFunctionName, tool.Name, StringComparison.Ordinal); + var toolIndex = isRequiredByToolMode || IsRequiredTool(tool) + ? nextRequiredToolIndex++ + : nextOptionalToolIndex--; + toolRankingInfo.Span[toolIndex] = new AIToolRankingInfo(tool, originalIndex: i); } - var result = new Embedding[tools.Count]; - var cacheMisses = new List<(AITool Tool, int Index)>(tools.Count); + return ( + RequiredTools: toolRankingInfo.Slice(0, nextRequiredToolIndex), + OptionalTools: toolRankingInfo.Slice(nextRequiredToolIndex)); + } - for (var i = 0; i < tools.Count; i++) + private async Task ComputeSimilarityScoresAsync(Memory toolInfo, Embedding queryEmbedding, CancellationToken cancellationToken) + { + var anyCacheMisses = false; + List cacheMissToolEmbeddingTexts = null!; + List cacheMissToolInfoIndexes = null!; + for (var i = 0; i < toolInfo.Length; i++) { - if (_toolEmbeddingsCache.TryGetValue(tools[i], out var embedding)) + ref var info = ref toolInfo.Span[i]; + if (_toolEmbeddingsCache.TryGetValue(info.Tool, out var toolEmbedding)) { - result[i] = embedding; + info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector); } else { - cacheMisses.Add((tools[i], i)); + if (!anyCacheMisses) + { + anyCacheMisses = true; + cacheMissToolEmbeddingTexts = []; + cacheMissToolInfoIndexes = []; + } + + var text = ToolEmbeddingTextSelector(info.Tool); + cacheMissToolEmbeddingTexts.Add(text); + cacheMissToolInfoIndexes.Add(i); } } - if (cacheMisses.Count == 0) + if (!anyCacheMisses) { - return result; + // There were no cache misses; no more work to do. + return; } - var uncachedEmbeddings = await ComputeEmbeddingsAsync( - texts: cacheMisses.Select(t => ToolEmbeddingTextFactory(t.Tool)), - expectedCount: cacheMisses.Count, - cancellationToken); - - for (var i = 0; i < cacheMisses.Count; i++) + var uncachedEmbeddings = await _embeddingGenerator.GenerateAsync(cacheMissToolEmbeddingTexts, cancellationToken: cancellationToken).ConfigureAwait(false); + if (uncachedEmbeddings.Count != cacheMissToolEmbeddingTexts.Count) { - var embedding = uncachedEmbeddings[i]; - result[cacheMisses[i].Index] = embedding; - _toolEmbeddingsCache.Add(cacheMisses[i].Tool, embedding); + throw new InvalidOperationException($"Expected {cacheMissToolEmbeddingTexts.Count} embeddings, got {uncachedEmbeddings.Count}."); } - return result; - - async ValueTask>> ComputeEmbeddingsAsync(IEnumerable texts, int expectedCount, CancellationToken cancellationToken) + for (var i = 0; i < uncachedEmbeddings.Count; i++) { - var embeddings = await _embeddingGenerator.GenerateAsync(texts, cancellationToken: cancellationToken).ConfigureAwait(false); - if (embeddings.Count != expectedCount) - { - Throw.InvalidOperationException($"Expected {expectedCount} embeddings, got {embeddings.Count}."); - } - - return embeddings; + var toolInfoIndex = cacheMissToolInfoIndexes[i]; + var toolEmbedding = uncachedEmbeddings[i]; + ref var info = ref toolInfo.Span[toolInfoIndex]; + info.SimilarityScore = Similarity(queryEmbedding.Vector, toolEmbedding.Vector); + _toolEmbeddingsCache.Add(info.Tool, toolEmbedding); } } + + private struct AIToolRankingInfo(AITool tool, int originalIndex) + { + public static readonly Comparer CompareByDescendingSimilarityScore + = Comparer.Create(static (a, b) => + { + var result = b.SimilarityScore.CompareTo(a.SimilarityScore); + return result != 0 + ? result + : a.OriginalIndex.CompareTo(b.OriginalIndex); // Stabilize ties. + }); + + public static readonly Comparer CompareByOriginalIndex + = Comparer.Create(static (a, b) => a.OriginalIndex.CompareTo(b.OriginalIndex)); + + public AITool Tool { get; } = tool; + public int OriginalIndex { get; } = originalIndex; + public float SimilarityScore { get; set; } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs index 062a06715c0..6a5d6d925fc 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/ToolReducingChatClient.cs @@ -62,7 +62,7 @@ public override async IAsyncEnumerable GetStreamingResponseA CancellationToken cancellationToken) { // If there are no options or no tools, skip. - if (options?.Tools is not { Count: > 1 }) + if (options?.Tools is not { Count: > 0 }) { return options; } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index 8a59904fc46..e90edc26f4d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -1405,15 +1405,15 @@ public void Dispose() } [ConditionalFact] - public virtual async Task ToolReduction_SingleRelevantToolSelected() + public virtual async Task ToolReduction_DynamicSelection_RespectsConversationHistory() { SkipIfNotEnabled(); EnsureEmbeddingGenerator(); - // Strategy: pick top 1 tool - var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 1); + // Limit to 2 so that, once the conversation references both weather and translation, + // both tools can be included even if the latest user turn only mentions one of them. + var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2); - // Define several tools with clearly distinct domains var weatherTool = AIFunctionFactory.Create( () => "Weather data", new AIFunctionFactoryOptions @@ -1422,14 +1422,6 @@ public virtual async Task ToolReduction_SingleRelevantToolSelected() Description = "Returns weather forecast and temperature for a given city." }); - var stockTool = AIFunctionFactory.Create( - () => "Stock data", - new AIFunctionFactoryOptions - { - Name = "GetStockQuote", - Description = "Retrieves live stock market price for a company ticker symbol." - }); - var translateTool = AIFunctionFactory.Create( () => "Translated text", new AIFunctionFactoryOptions @@ -1443,103 +1435,215 @@ public virtual async Task ToolReduction_SingleRelevantToolSelected() new AIFunctionFactoryOptions { Name = "SolveMath", - Description = "Solves arithmetic or algebraic math problems." + Description = "Solves basic math problems." }); - var allTools = new List { weatherTool, stockTool, translateTool, mathTool }; + var allTools = new List { weatherTool, translateTool, mathTool }; - IList? capturedTools = null; + IList? firstTurnTools = null; + IList? secondTurnTools = null; using var client = ChatClient! .AsBuilder() .UseToolReduction(strategy) - // Capture the tools after reduction, before invoking the underlying model. - .Use((messages, options, next, ct) => + .Use(async (messages, options, next, ct) => { - capturedTools = options?.Tools; - return next(messages, options, ct); + // Capture the (possibly reduced) tool list for each turn. + if (firstTurnTools is null) + { + firstTurnTools = options?.Tools; + } + else + { + secondTurnTools ??= options?.Tools; + } + + await next(messages, options, ct); }) + .UseFunctionInvocation() .Build(); - var question = "What will the weather be in Paris tomorrow?"; - _ = await client.GetResponseAsync([new(ChatRole.User, question)], new ChatOptions - { - Tools = allTools - }); + // Maintain chat history across turns. + List history = []; + + // Turn 1: Ask a weather question. + history.Add(new ChatMessage(ChatRole.User, "What will the weather be in Seattle tomorrow?")); + var firstResponse = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools }); + history.AddMessages(firstResponse); // Append assistant reply. + + Assert.NotNull(firstTurnTools); + Assert.Contains(firstTurnTools, t => t.Name == "GetWeatherForecast"); - Assert.NotNull(capturedTools); - Assert.Single(capturedTools!); - Assert.Equal("GetWeatherForecast", capturedTools![0].Name); + // Turn 2: Ask a translation question. Even though only translation is mentioned now, + // conversation history still contains a weather request. Expect BOTH weather + translation tools. + history.Add(new ChatMessage(ChatRole.User, "Please translate 'good evening' into French.")); + var secondResponse = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools }); + history.AddMessages(secondResponse); + + Assert.NotNull(secondTurnTools); + Assert.Equal(2, secondTurnTools.Count); // Should have filled both slots with the two relevant domains. + Assert.Contains(secondTurnTools, t => t.Name == "GetWeatherForecast"); + Assert.Contains(secondTurnTools, t => t.Name == "TranslateText"); + + // Ensure unrelated tool was excluded. + Assert.DoesNotContain(secondTurnTools, t => t.Name == "SolveMath"); } [ConditionalFact] - public virtual async Task ToolReduction_MultiConceptQuery_SelectsTwoRelevantTools() + public virtual async Task ToolReduction_RequireSpecificToolPreservedAndOrdered() { SkipIfNotEnabled(); EnsureEmbeddingGenerator(); - var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2); - - var weatherTool = AIFunctionFactory.Create( - () => "Weather data", - new AIFunctionFactoryOptions - { - Name = "GetWeatherForecast", - Description = "Returns weather forecast and temperature for a given city." - }); + // Limit would normally reduce to 1, but required tool plus another should remain. + var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 1); var translateTool = AIFunctionFactory.Create( () => "Translated text", new AIFunctionFactoryOptions { Name = "TranslateText", - Description = "Translates text between human languages." + Description = "Translates phrases between languages." }); - var stockTool = AIFunctionFactory.Create( - () => "Stock data", + var weatherTool = AIFunctionFactory.Create( + () => "Weather data", new AIFunctionFactoryOptions { - Name = "GetStockQuote", - Description = "Retrieves live stock market price for a company ticker symbol." + Name = "GetWeatherForecast", + Description = "Returns forecast data for a city." }); - var mathTool = AIFunctionFactory.Create( - () => 42, + var tools = new List { translateTool, weatherTool }; + + IList? captured = null; + + using var client = ChatClient! + .AsBuilder() + .UseToolReduction(strategy) + .UseFunctionInvocation() + .Use((messages, options, next, ct) => + { + captured = options?.Tools; + return next(messages, options, ct); + }) + .Build(); + + var history = new List + { + new(ChatRole.User, "What will the weather be like in Redmond next week?.") + }; + + var response = await client.GetResponseAsync(history, new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.RequireSpecific(translateTool.Name) + }); + history.AddMessages(response); + + Assert.NotNull(captured); + Assert.Equal(2, captured!.Count); + Assert.Equal("TranslateText", captured[0].Name); // Required should appear first. + Assert.Equal("GetWeatherForecast", captured[1].Name); + } + + [ConditionalFact] + public virtual async Task ToolReduction_ToolRemovedAfterFirstUse_NotInvokedAgain() + { + SkipIfNotEnabled(); + EnsureEmbeddingGenerator(); + + int weatherInvocationCount = 0; + + var weatherTool = AIFunctionFactory.Create( + () => + { + weatherInvocationCount++; + return "Sunny and dry."; + }, new AIFunctionFactoryOptions { - Name = "SolveMath", - Description = "Solves arithmetic or algebraic math problems." + Name = "GetWeather", + Description = "Gets the weather forecast for a given location." }); - var allTools = new List { weatherTool, translateTool, stockTool, mathTool }; + // Strategy exposes tools only on the first request, then removes them. + var removalStrategy = new RemoveToolAfterFirstUseStrategy(); - IList? capturedTools = null; + IList? firstTurnTools = null; + IList? secondTurnTools = null; using var client = ChatClient! .AsBuilder() - .UseToolReduction(strategy) + // Place capture immediately after reduction so it's invoked exactly once per user request. + .UseToolReduction(removalStrategy) .Use((messages, options, next, ct) => { - capturedTools = options?.Tools; + if (firstTurnTools is null) + { + firstTurnTools = options?.Tools; + } + else + { + secondTurnTools ??= options?.Tools; + } + return next(messages, options, ct); }) + .UseFunctionInvocation() .Build(); - // Query intentionally references two distinct semantic domains: weather + translation. - var question = "Please translate 'good morning' into Spanish and also tell me the weather forecast for Barcelona."; - _ = await client.GetResponseAsync([new(ChatRole.User, question)], new ChatOptions + List history = []; + + // Turn 1 + history.Add(new ChatMessage(ChatRole.User, "What's the weather like tomorrow in Seattle?")); + var firstResponse = await client.GetResponseAsync(history, new ChatOptions + { + Tools = [weatherTool], + ToolMode = ChatToolMode.RequireAny + }); + history.AddMessages(firstResponse); + + Assert.Equal(1, weatherInvocationCount); + Assert.NotNull(firstTurnTools); + Assert.Contains(firstTurnTools!, t => t.Name == "GetWeather"); + + // Turn 2 (tool removed by strategy even though caller supplies it again) + history.Add(new ChatMessage(ChatRole.User, "And what about next week?")); + var secondResponse = await client.GetResponseAsync(history, new ChatOptions { - Tools = allTools + Tools = [weatherTool] }); + history.AddMessages(secondResponse); + + Assert.Equal(1, weatherInvocationCount); // Not invoked again. + Assert.NotNull(secondTurnTools); + Assert.Empty(secondTurnTools!); // Strategy removed the tool set. + + // Response text shouldn't just echo the tool's stub output. + Assert.DoesNotContain("Sunny and dry.", secondResponse.Text, StringComparison.OrdinalIgnoreCase); + } + + // Test-only custom strategy: include tools on first request, then remove them afterward. + private sealed class RemoveToolAfterFirstUseStrategy : IToolReductionStrategy + { + private bool _used; - Assert.NotNull(capturedTools); - Assert.Equal(2, capturedTools!.Count); + public Task> SelectToolsForRequestAsync( + IEnumerable messages, + ChatOptions? options, + CancellationToken cancellationToken = default) + { + if (!_used && options?.Tools is { Count: > 0 }) + { + _used = true; + // Returning the same instance signals no change. + return Task.FromResult>(options.Tools); + } - // Order is not guaranteed; assert membership. - var names = capturedTools.Select(t => t.Name).ToList(); - Assert.Contains("GetWeatherForecast", names); - Assert.Contains("TranslateText", names); + // After first use, remove all tools. + return Task.FromResult>(Array.Empty()); + } } [MemberNotNull(nameof(ChatClient))] diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs index 03dfd4f6060..bd894c760d5 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs @@ -35,18 +35,31 @@ public async Task EmbeddingToolReductionStrategy_NoReduction_WhenToolsBelowLimit Assert.Same(tools, result); } + [Fact] + public async Task EmbeddingToolReductionStrategy_NoReduction_WhenOptionalToolsBelowLimit() + { + // 1 required + 2 optional, limit = 2 (optional count == limit) => original list returned + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2) + { + IsRequiredTool = t => t.Name == "Req" + }; + + var tools = CreateTools("Req", "Opt1", "Opt2"); + var result = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "anything") }, + new ChatOptions { Tools = tools }); + + Assert.Same(tools, result); + } + [Fact] public async Task EmbeddingToolReductionStrategy_Reduces_ToLimit_BySimilarity() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); - var tools = CreateTools( - "Weather", - "Translate", - "Math", - "Jokes"); - + var tools = CreateTools("Weather", "Translate", "Math", "Jokes"); var options = new ChatOptions { Tools = tools }; var messages = new[] @@ -57,8 +70,6 @@ public async Task EmbeddingToolReductionStrategy_Reduces_ToLimit_BySimilarity() var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList(); Assert.Equal(2, reduced.Count); - - // Only assert membership; ordering is an implementation detail when scores tie. Assert.Contains(reduced, t => t.Name == "Weather"); Assert.Contains(reduced, t => t.Name == "Math"); } @@ -73,39 +84,15 @@ public async Task EmbeddingToolReductionStrategy_PreserveOriginalOrdering_Reorde }; var tools = CreateTools("Math", "Translate", "Weather"); - var options = new ChatOptions { Tools = tools }; - - var messages = new[] { new ChatMessage(ChatRole.User, "Explain weather math please") }; - - var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList(); + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "Explain weather math please") }, + new ChatOptions { Tools = tools })).ToList(); Assert.Equal(2, reduced.Count); - Assert.Contains(reduced, t => t.Name == "Math"); - Assert.Contains(reduced, t => t.Name == "Weather"); - - // With PreserveOriginalOrdering the original relative order (Math before Weather) is maintained. Assert.Equal("Math", reduced[0].Name); Assert.Equal("Weather", reduced[1].Name); } - [Fact] - public async Task EmbeddingToolReductionStrategy_EmptyQuery_FallsBackToFirstN() - { - using var gen = new DeterministicTestEmbeddingGenerator(); - var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); - - var tools = CreateTools("A", "B", "C"); - var options = new ChatOptions { Tools = tools }; - - var messages = new[] { new ChatMessage(ChatRole.User, " ") }; - - var reduced = (await strategy.SelectToolsForRequestAsync(messages, options)).ToList(); - - Assert.Equal(2, reduced.Count); - Assert.Equal("A", reduced[0].Name); - Assert.Equal("B", reduced[1].Name); - } - [Fact] public async Task EmbeddingToolReductionStrategy_Caching_AvoidsReEmbeddingTools() { @@ -113,55 +100,31 @@ public async Task EmbeddingToolReductionStrategy_Caching_AvoidsReEmbeddingTools( var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); var tools = CreateTools("Weather", "Math", "Jokes"); - var options = new ChatOptions { Tools = tools }; var messages = new[] { new ChatMessage(ChatRole.User, "weather") }; - _ = await strategy.SelectToolsForRequestAsync(messages, options); + _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); int afterFirst = gen.TotalValueInputs; - _ = await strategy.SelectToolsForRequestAsync(messages, options); + _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); int afterSecond = gen.TotalValueInputs; - // The additional embedding generator call is for the message list. + // +1 for second query embedding only Assert.Equal(afterFirst + 1, afterSecond); } - [Fact] - public async Task EmbeddingToolReductionStrategy_CachingDisabled_ReEmbedsToolsEachCall() - { - using var gen = new DeterministicTestEmbeddingGenerator(); - var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) - { - EnableEmbeddingCaching = false - }; - - var tools = CreateTools("Weather", "Math"); - var options = new ChatOptions { Tools = tools }; - var messages = new[] { new ChatMessage(ChatRole.User, "weather") }; - - _ = await strategy.SelectToolsForRequestAsync(messages, options); - int afterFirst = gen.TotalValueInputs; - - _ = await strategy.SelectToolsForRequestAsync(messages, options); - int afterSecond = gen.TotalValueInputs; - - // The additional embedding generator call is for the message list. - Assert.Equal(afterFirst + tools.Count + 1, afterSecond); - } - [Fact] public async Task EmbeddingToolReductionStrategy_OptionsNullOrNoTools_ReturnsEmptyOrOriginal() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); - // Null options => empty result - var empty = await strategy.SelectToolsForRequestAsync(new[] { new ChatMessage(ChatRole.User, "anything") }, null); + var empty = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "anything") }, null); Assert.Empty(empty); - // Empty tools list => returns that same list var options = new ChatOptions { Tools = [] }; - var result = await strategy.SelectToolsForRequestAsync(new[] { new ChatMessage(ChatRole.User, "weather") }, options); + var result = await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "weather") }, options); Assert.Same(options.Tools, result); } @@ -171,12 +134,11 @@ public async Task EmbeddingToolReductionStrategy_CustomSimilarity_InvertsOrderin using var gen = new VectorBasedTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) { - // Custom similarity chooses the *smallest* first component instead of largest cosine. Similarity = (q, t) => -t.Span[0] }; - var highTool = new SimpleTool("HighScore", "alpha"); // vector[0] = 10 - var lowTool = new SimpleTool("LowScore", "beta"); // vector[0] = 1 + var highTool = new SimpleTool("HighScore", "alpha"); + var lowTool = new SimpleTool("LowScore", "beta"); gen.VectorSelector = text => text.Contains("alpha") ? 10f : 1f; var reduced = (await strategy.SelectToolsForRequestAsync( @@ -184,35 +146,50 @@ public async Task EmbeddingToolReductionStrategy_CustomSimilarity_InvertsOrderin new ChatOptions { Tools = [highTool, lowTool] })).ToList(); Assert.Single(reduced); - - // Because we negated similarity, lowest underlying value wins Assert.Equal("LowScore", reduced[0].Name); } [Fact] - public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextFactory_EmptyDescription_UsesNameOnly() + public async Task EmbeddingToolReductionStrategy_TieDeterminism_PrefersLowerOriginalIndex() + { + // Generator returns identical vectors so similarity ties; we expect original order preserved + using var gen = new ConstantEmbeddingGenerator(3); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); + + var tools = CreateTools("T1", "T2", "T3", "T4"); + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "any") }, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Equal(2, reduced.Count); + Assert.Equal("T1", reduced[0].Name); + Assert.Equal("T2", reduced[1].Name); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextSelector_EmptyDescription_UsesNameOnly() { using var recorder = new RecordingEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); var target = new SimpleTool("ComputeSum", description: ""); var filler = new SimpleTool("Other", "Unrelated"); - await strategy.SelectToolsForRequestAsync( + _ = await strategy.SelectToolsForRequestAsync( new[] { new ChatMessage(ChatRole.User, "math") }, - new ChatOptions { Tools = new List { target, filler } }); + new ChatOptions { Tools = [target, filler] }); Assert.Contains("ComputeSum", recorder.Inputs); } [Fact] - public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextFactory_EmptyName_UsesDescriptionOnly() + public async Task EmbeddingToolReductionStrategy_DefaultEmbeddingTextSelector_EmptyName_UsesDescriptionOnly() { using var recorder = new RecordingEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); var target = new SimpleTool("", description: "Translates between languages."); var filler = new SimpleTool("Other", "Unrelated"); - await strategy.SelectToolsForRequestAsync( + _ = await strategy.SelectToolsForRequestAsync( new[] { new ChatMessage(ChatRole.User, "translate") }, new ChatOptions { Tools = [target, filler] }); @@ -220,17 +197,17 @@ await strategy.SelectToolsForRequestAsync( } [Fact] - public async Task EmbeddingToolReductionStrategy_CustomEmbeddingTextFactory_Applied() + public async Task EmbeddingToolReductionStrategy_CustomEmbeddingTextSelector_Applied() { using var recorder = new RecordingEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1) { - ToolEmbeddingTextFactory = t => $"NAME:{t.Name}|DESC:{t.Description}" + ToolEmbeddingTextSelector = t => $"NAME:{t.Name}|DESC:{t.Description}" }; var target = new SimpleTool("WeatherTool", "Gets forecast."); var filler = new SimpleTool("Other", "Irrelevant"); - await strategy.SelectToolsForRequestAsync( + _ = await strategy.SelectToolsForRequestAsync( new[] { new ChatMessage(ChatRole.User, "weather") }, new ChatOptions { Tools = [target, filler] }); @@ -238,15 +215,13 @@ await strategy.SelectToolsForRequestAsync( } [Fact] - public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_CustomFiltersMessages() + public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_CustomFiltersMessages() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); - // Tools we want to discriminate between var tools = CreateTools("Weather", "Math", "Translate"); - // Earlier message mentions weather, last message mentions math var messages = new[] { new ChatMessage(ChatRole.User, "Please tell me the weather tomorrow."), @@ -254,8 +229,7 @@ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_Cu new ChatMessage(ChatRole.User, "Now instead solve a math problem.") }; - // Only consider the last message (so "Math" should clearly win) - strategy.MessagesEmbeddingTextFactory = msgs => msgs.LastOrDefault()?.Text ?? string.Empty; + strategy.MessagesEmbeddingTextSelector = msgs => msgs.LastOrDefault()?.Text ?? string.Empty; var reduced = (await strategy.SelectToolsForRequestAsync( messages, @@ -266,27 +240,7 @@ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_Cu } [Fact] - public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_EmptyResult_FallbacksToFirstN() - { - using var gen = new DeterministicTestEmbeddingGenerator(); - var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 2); - - var tools = CreateTools("Alpha", "Beta", "Gamma"); - - // Return only whitespace so strategy should truncate deterministically - strategy.MessagesEmbeddingTextFactory = _ => " "; - - var reduced = (await strategy.SelectToolsForRequestAsync( - new[] { new ChatMessage(ChatRole.User, "Irrelevant content") }, - new ChatOptions { Tools = tools })).ToList(); - - Assert.Equal(2, reduced.Count); - Assert.Equal("Alpha", reduced[0].Name); - Assert.Equal("Beta", reduced[1].Name); - } - - [Fact] - public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_InvokedOnce() + public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_InvokedOnce() { using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); @@ -294,11 +248,9 @@ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_In var tools = CreateTools("Weather", "Math"); int invocationCount = 0; - strategy.MessagesEmbeddingTextFactory = msgs => + strategy.MessagesEmbeddingTextSelector = msgs => { invocationCount++; - - // Default-like behavior return string.Join("\n", msgs.Select(m => m.Text)); }; @@ -309,6 +261,102 @@ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextFactory_In Assert.Equal(1, invocationCount); } + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultMessagesEmbeddingTextSelector_IncludesReasoningContent() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); + var tools = CreateTools("Weather", "Math"); + + var reasoningLine = "Thinking about the best way to get tomorrow's forecast..."; + var answerLine = "Tomorrow will be sunny."; + var userLine = "What's the weather tomorrow?"; + + var messages = new[] + { + new ChatMessage(ChatRole.User, userLine), + new ChatMessage(ChatRole.Assistant, + [ + new TextReasoningContent(reasoningLine), + new TextContent(answerLine) + ]) + }; + + _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); + + string queryInput = recorder.Inputs[0]; + + Assert.Contains(userLine, queryInput); + Assert.Contains(reasoningLine, queryInput); + Assert.Contains(answerLine, queryInput); + + var userIndex = queryInput.IndexOf(userLine, StringComparison.Ordinal); + var reasoningIndex = queryInput.IndexOf(reasoningLine, StringComparison.Ordinal); + var answerIndex = queryInput.IndexOf(answerLine, StringComparison.Ordinal); + Assert.True(userIndex >= 0 && reasoningIndex > userIndex && answerIndex > reasoningIndex); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_DefaultMessagesEmbeddingTextSelector_SkipsNonTextContent() + { + using var recorder = new RecordingEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(recorder, toolLimit: 1); + var tools = CreateTools("Alpha", "Beta"); + + var textOnly = "Provide translation."; + var messages = new[] + { + new ChatMessage(ChatRole.User, + [ + new DataContent(new byte[] { 1, 2, 3 }, "application/octet-stream"), + new TextContent(textOnly) + ]) + }; + + _ = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); + + var queryInput = recorder.Inputs[0]; + Assert.Contains(textOnly, queryInput); + Assert.DoesNotContain("application/octet-stream", queryInput, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_RequiredToolAlwaysIncluded() + { + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + IsRequiredTool = t => t.Name == "Core" + }; + + var tools = CreateTools("Core", "Weather", "Math"); + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "math") }, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Equal(2, reduced.Count); // required + one optional (limit=1) + Assert.Contains(reduced, t => t.Name == "Core"); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_MultipleRequiredTools_ExceedLimit_AllRequiredIncluded() + { + // 3 required, limit=1 => expect 3 required + 1 ranked optional = 4 total + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + IsRequiredTool = t => t.Name.StartsWith("R", StringComparison.Ordinal) + }; + + var tools = CreateTools("R1", "R2", "R3", "Weather", "Math"); + var reduced = (await strategy.SelectToolsForRequestAsync( + new[] { new ChatMessage(ChatRole.User, "weather math") }, + new ChatOptions { Tools = tools })).ToList(); + + Assert.Equal(4, reduced.Count); + Assert.Equal(3, reduced.Count(t => t.Name.StartsWith("R"))); + } + [Fact] public async Task ToolReducingChatClient_ReducesTools_ForGetResponseAsync() { @@ -327,10 +375,7 @@ public async Task ToolReducingChatClient_ReducesTools_ForGetResponseAsync() } }; - using var client = inner - .AsBuilder() - .UseToolReduction(strategy) - .Build(); + using var client = inner.AsBuilder().UseToolReduction(strategy).Build(); await client.GetResponseAsync( new[] { new ChatMessage(ChatRole.User, "weather math please") }, @@ -360,10 +405,7 @@ public async Task ToolReducingChatClient_ReducesTools_ForStreaming() } }; - using var client = inner - .AsBuilder() - .UseToolReduction(strategy) - .Build(); + using var client = inner.AsBuilder().UseToolReduction(strategy).Build(); await foreach (var _ in client.GetStreamingResponseAsync( new[] { new ChatMessage(ChatRole.User, "math") }, @@ -377,15 +419,83 @@ public async Task ToolReducingChatClient_ReducesTools_ForStreaming() Assert.Equal("Math", observedTools![0].Name); } + [Fact] + public async Task EmbeddingToolReductionStrategy_EmptyQuery_NoReduction() + { + // Arrange: more tools than limit so we'd normally reduce, but query is empty -> return full list unchanged. + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1); + + var tools = CreateTools("ToolA", "ToolB", "ToolC"); + var options = new ChatOptions { Tools = tools }; + + // Empty / whitespace message text produces empty query. + var messages = new[] { new ChatMessage(ChatRole.User, " ") }; + + // Act + var result = await strategy.SelectToolsForRequestAsync(messages, options); + + // Assert: same reference (no reduction), and generator not invoked at all. + Assert.Same(tools, result); + Assert.Equal(0, gen.TotalValueInputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_EmptyQuery_NoReduction_WithRequiredTool() + { + // Arrange: required tool + optional tools; still should return original set when query is empty. + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + IsRequiredTool = t => t.Name == "Req" + }; + + var tools = CreateTools("Req", "Optional1", "Optional2"); + var options = new ChatOptions { Tools = tools }; + + var messages = new[] { new ChatMessage(ChatRole.User, " ") }; + + // Act + var result = await strategy.SelectToolsForRequestAsync(messages, options); + + // Assert + Assert.Same(tools, result); + Assert.Equal(0, gen.TotalValueInputs); + } + + [Fact] + public async Task EmbeddingToolReductionStrategy_EmptyQuery_ViaCustomMessagesSelector_NoReduction() + { + // Arrange: force empty query through custom selector returning whitespace. + using var gen = new DeterministicTestEmbeddingGenerator(); + var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) + { + MessagesEmbeddingTextSelector = _ => " " + }; + + var tools = CreateTools("One", "Two"); + var messages = new[] + { + new ChatMessage(ChatRole.User, "This content will be ignored by custom selector.") + }; + + // Act + var result = await strategy.SelectToolsForRequestAsync(messages, new ChatOptions { Tools = tools }); + + // Assert: no reduction and no embeddings generated. + Assert.Same(tools, result); + Assert.Equal(0, gen.TotalValueInputs); + } + private static List CreateTools(params string[] names) => names.Select(n => (AITool)new SimpleTool(n, $"Description about {n}")).ToList(); -#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously +#pragma warning disable CS1998 private static async IAsyncEnumerable EmptyAsyncEnumerable() -#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { yield break; } +#pragma warning restore CS1998 private sealed class SimpleTool : AITool { @@ -488,19 +598,39 @@ public void Dispose() private sealed class VectorBasedTestEmbeddingGenerator : IEmbeddingGenerator> { - // External control for choosing first component of embedding. public Func VectorSelector { get; set; } = _ => 1f; - - public Task>> GenerateAsync( - IEnumerable values, - EmbeddingGenerationOptions? options = null, - CancellationToken cancellationToken = default) + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) { var list = new List>(); foreach (var v in values) { - float val = VectorSelector(v); - list.Add(new Embedding(new float[] { val, 1f })); + list.Add(new Embedding(new float[] { VectorSelector(v), 1f })); + } + + return Task.FromResult(new GeneratedEmbeddings>(list)); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + public void Dispose() + { + // No-op + } + } + + private sealed class ConstantEmbeddingGenerator : IEmbeddingGenerator> + { + private readonly float[] _vector; + public ConstantEmbeddingGenerator(int dims) + { + _vector = Enumerable.Repeat(1f, dims).ToArray(); + } + + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + var list = new List>(); + foreach (var _ in values) + { + list.Add(new Embedding(_vector)); } return Task.FromResult(new GeneratedEmbeddings>(list)); @@ -512,4 +642,22 @@ public void Dispose() // No-op } } + + private sealed class TestChatClient : IChatClient + { + public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; } + public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; } + + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + (GetResponseAsyncCallback ?? throw new InvalidOperationException())(messages, options, cancellationToken); + + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + (GetStreamingResponseAsyncCallback ?? throw new InvalidOperationException())(messages, options, cancellationToken); + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + public void Dispose() + { + // No-op + } + } } From b3f7f85a3a0c9d995e3e87bc763ab34d3f07d960 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 14 Oct 2025 16:53:35 -0700 Subject: [PATCH 8/9] Enable asynchronous chat history embedding text generation --- .../EmbeddingToolReductionStrategy.cs | 8 +- .../ChatClientIntegrationTests.cs | 95 +++++++++++++++++++ .../ToolReductionTests.cs | 6 +- 3 files changed, 102 insertions(+), 7 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs index 82d66c680c6..f9e4c60995a 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ToolReduction/EmbeddingToolReductionStrategy.cs @@ -46,7 +46,7 @@ public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy return t.Name + Environment.NewLine + t.Description; }; - private Func, string> _messagesEmbeddingTextSelector = static messages => + private Func, ValueTask> _messagesEmbeddingTextSelector = static messages => { var sb = new StringBuilder(); foreach (var message in messages) @@ -71,7 +71,7 @@ public sealed class EmbeddingToolReductionStrategy : IToolReductionStrategy } } - return sb.ToString(); + return new ValueTask(sb.ToString()); }; private Func, ReadOnlyMemory, float> _similarity = static (a, b) => TensorPrimitives.CosineSimilarity(a.Span, b.Span); @@ -107,7 +107,7 @@ public Func ToolEmbeddingTextSelector /// Gets or sets the selector used to generate a single text string from a collection of chat messages for /// embedding purposes. /// - public Func, string> MessagesEmbeddingTextSelector + public Func, ValueTask> MessagesEmbeddingTextSelector { get => _messagesEmbeddingTextSelector; set => _messagesEmbeddingTextSelector = Throw.IfNull(value); @@ -192,7 +192,7 @@ public async Task> SelectToolsForRequestAsync( } // Build query text from recent messages. - var queryText = MessagesEmbeddingTextSelector(messages); + var queryText = await MessagesEmbeddingTextSelector(messages).ConfigureAwait(false); if (string.IsNullOrWhiteSpace(queryText)) { // We couldn't build a meaningful query, likely because the message list was empty. diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index 4a99a6efa7b..e016448f064 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -1624,6 +1624,101 @@ public virtual async Task ToolReduction_ToolRemovedAfterFirstUse_NotInvokedAgain Assert.DoesNotContain("Sunny and dry.", secondResponse.Text, StringComparison.OrdinalIgnoreCase); } + [ConditionalFact] + public virtual async Task ToolReduction_MessagesEmbeddingTextSelector_UsesChatClientToAnalyzeConversation() + { + SkipIfNotEnabled(); + EnsureEmbeddingGenerator(); + + // Create tools for different domains. + var weatherTool = AIFunctionFactory.Create( + () => "Weather data", + new AIFunctionFactoryOptions + { + Name = "GetWeatherForecast", + Description = "Returns weather forecast and temperature for a given city." + }); + + var translateTool = AIFunctionFactory.Create( + () => "Translated text", + new AIFunctionFactoryOptions + { + Name = "TranslateText", + Description = "Translates text between human languages." + }); + + var mathTool = AIFunctionFactory.Create( + () => 42, + new AIFunctionFactoryOptions + { + Name = "SolveMath", + Description = "Solves basic math problems." + }); + + var allTools = new List { weatherTool, translateTool, mathTool }; + + // Track the analysis result from the chat client used in the selector. + string? capturedAnalysis = null; + + var strategy = new EmbeddingToolReductionStrategy(EmbeddingGenerator, toolLimit: 2) + { + // Use a chat client to analyze the conversation and extract relevant tool categories. + MessagesEmbeddingTextSelector = async messages => + { + var conversationText = string.Join("\n", messages.Select(m => $"{m.Role}: {m.Text}")); + + var analysisPrompt = $""" + Analyze the following conversation and identify what kinds of tools would be most helpful. + Focus on the key topics and tasks being discussed. + Respond with a brief summary of the relevant tool categories (e.g., "weather", "translation", "math"). + + Conversation: + {conversationText} + + Relevant tool categories: + """; + + var response = await ChatClient!.GetResponseAsync(analysisPrompt); + capturedAnalysis = response.Text; + + // Return the analysis as the query text for embedding-based tool selection. + return capturedAnalysis; + } + }; + + IList? selectedTools = null; + + using var client = ChatClient! + .AsBuilder() + .UseToolReduction(strategy) + .Use(async (messages, options, next, ct) => + { + selectedTools = options?.Tools; + await next(messages, options, ct); + }) + .UseFunctionInvocation() + .Build(); + + // Conversation that clearly indicates weather-related needs. + List history = []; + history.Add(new ChatMessage(ChatRole.User, "What will the weather be like in London tomorrow?")); + + var response = await client.GetResponseAsync(history, new ChatOptions { Tools = allTools }); + history.AddMessages(response); + + // Verify that the chat client was used to analyze the conversation. + Assert.NotNull(capturedAnalysis); + Assert.True( + capturedAnalysis.IndexOf("weather", StringComparison.OrdinalIgnoreCase) >= 0 || + capturedAnalysis.IndexOf("forecast", StringComparison.OrdinalIgnoreCase) >= 0, + $"Expected analysis to mention weather or forecast: {capturedAnalysis}"); + + // Verify that the tool selection was influenced by the analysis. + Assert.NotNull(selectedTools); + Assert.True(selectedTools.Count <= 2, $"Expected at most 2 tools, got {selectedTools.Count}"); + Assert.Contains(selectedTools, t => t.Name == "GetWeatherForecast"); + } + // Test-only custom strategy: include tools on first request, then remove them afterward. private sealed class RemoveToolAfterFirstUseStrategy : IToolReductionStrategy { diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs index bd894c760d5..96c9adc6311 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ToolReductionTests.cs @@ -229,7 +229,7 @@ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_C new ChatMessage(ChatRole.User, "Now instead solve a math problem.") }; - strategy.MessagesEmbeddingTextSelector = msgs => msgs.LastOrDefault()?.Text ?? string.Empty; + strategy.MessagesEmbeddingTextSelector = msgs => new ValueTask(msgs.LastOrDefault()?.Text ?? string.Empty); var reduced = (await strategy.SelectToolsForRequestAsync( messages, @@ -251,7 +251,7 @@ public async Task EmbeddingToolReductionStrategy_MessagesEmbeddingTextSelector_I strategy.MessagesEmbeddingTextSelector = msgs => { invocationCount++; - return string.Join("\n", msgs.Select(m => m.Text)); + return new ValueTask(string.Join("\n", msgs.Select(m => m.Text))); }; _ = await strategy.SelectToolsForRequestAsync( @@ -470,7 +470,7 @@ public async Task EmbeddingToolReductionStrategy_EmptyQuery_ViaCustomMessagesSel using var gen = new DeterministicTestEmbeddingGenerator(); var strategy = new EmbeddingToolReductionStrategy(gen, toolLimit: 1) { - MessagesEmbeddingTextSelector = _ => " " + MessagesEmbeddingTextSelector = _ => new ValueTask(" ") }; var tools = CreateTools("One", "Two"); From 1b4661d028b7ce3717191cc4221a5029e7b0e75b Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Wed, 15 Oct 2025 08:55:52 -0700 Subject: [PATCH 9/9] Update test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../ChatClientIntegrationTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index e016448f064..992e86a1184 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -1531,7 +1531,7 @@ public virtual async Task ToolReduction_RequireSpecificToolPreservedAndOrdered() var history = new List { - new(ChatRole.User, "What will the weather be like in Redmond next week?.") + new(ChatRole.User, "What will the weather be like in Redmond next week?") }; var response = await client.GetResponseAsync(history, new ChatOptions