diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index f5a1e8cd11a..6fe9d2e22cb 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -22,10 +22,14 @@ import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import io.micrometer.observation.Observation; @@ -570,7 +574,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final List media = new ArrayList<>(); - private final List toolNames = new ArrayList<>(); + private final Set toolNames = new LinkedHashSet<>(); private final List toolCallbacks = new ArrayList<>(); @@ -606,9 +610,9 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, @Nullable String systemText, Map systemParams, - List toolCallbacks, List messages, List toolNames, List media, - @Nullable ChatOptions chatOptions, List advisors, Map advisorParams, - ObservationRegistry observationRegistry, + List toolCallbacks, List messages, Collection toolNames, + List media, @Nullable ChatOptions chatOptions, List advisors, + Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext, @Nullable TemplateRenderer templateRenderer) { @@ -685,7 +689,7 @@ public List getMedia() { return this.media; } - public List getToolNames() { + public Set getToolNames() { return this.toolNames; } @@ -701,6 +705,10 @@ public TemplateRenderer getTemplateRenderer() { return this.templateRenderer; } + public boolean hasToolConfiguration() { + return !this.toolNames.isEmpty() || !this.toolCallbacks.isEmpty() || !this.toolContext.isEmpty(); + } + /** * Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose * settings are replicated from this {@code ChatClientRequest}. @@ -778,7 +786,7 @@ public ChatClientRequestSpec options(T options) { public ChatClientRequestSpec toolNames(String... toolNames) { Assert.notNull(toolNames, "toolNames cannot be null"); Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); - this.toolNames.addAll(List.of(toolNames)); + Collections.addAll(this.toolNames, toolNames); return this; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index 10f623e2b70..2b1894712e2 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -94,22 +94,40 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient */ ChatOptions processedChatOptions = inputRequest.getChatOptions(); - if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { - if (!inputRequest.getToolNames().isEmpty()) { - Set toolNames = ToolCallingChatOptions - .mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); - toolCallingChatOptions.setToolNames(toolNames); + if (inputRequest.hasToolConfiguration()) { + if (processedChatOptions == null) { + ToolCallingChatOptions.Builder builder = ToolCallingChatOptions.builder(); + if (!inputRequest.getToolNames().isEmpty()) { + builder.toolNames(inputRequest.getToolNames()); + } + if (!inputRequest.getToolCallbacks().isEmpty()) { + List toolCallbacks = inputRequest.getToolCallbacks(); + ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); + builder.toolCallbacks(inputRequest.getToolCallbacks()); + } + if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { + builder.toolContext(inputRequest.getToolContext()); + } + + processedChatOptions = builder.build(); } - if (!inputRequest.getToolCallbacks().isEmpty()) { - List toolCallbacks = ToolCallingChatOptions - .mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); - ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); - toolCallingChatOptions.setToolCallbacks(toolCallbacks); - } - if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { - Map toolContext = ToolCallingChatOptions.mergeToolContext(inputRequest.getToolContext(), - toolCallingChatOptions.getToolContext()); - toolCallingChatOptions.setToolContext(toolContext); + else if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { + if (!inputRequest.getToolNames().isEmpty()) { + Set toolNames = ToolCallingChatOptions.mergeToolNames( + new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); + toolCallingChatOptions.setToolNames(toolNames); + } + if (!inputRequest.getToolCallbacks().isEmpty()) { + List toolCallbacks = ToolCallingChatOptions + .mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); + ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); + toolCallingChatOptions.setToolCallbacks(toolCallbacks); + } + if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { + Map toolContext = ToolCallingChatOptions + .mergeToolContext(inputRequest.getToolContext(), toolCallingChatOptions.getToolContext()); + toolCallingChatOptions.setToolContext(toolContext); + } } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java index 9d4d4962069..0f8788f326b 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -322,6 +322,60 @@ void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() { .containsAllEntriesOf(toolContext2); } + @Test + void whenToolNamesWithoutChatOptionsAreProvidedThenToolCallingChatOptionsAreSet() { + List toolNames = List.of("tool1", "tool2"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolNames(toolNames.toArray(new String[0])); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames); + } + + @Test + void whenToolCallbacksWithoutChatOptionsAreProvidedThenToolCallingChatOptionsAreSet() { + ToolCallback toolCallback = new TestToolCallback("tool1"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolCallbacks(toolCallback); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolCallbacks()).contains(toolCallback); + } + + @Test + void whenToolContextWithoutChatOptionsIsProvidedThenToolCallingChatOptionsAreSet() { + Map toolContext = Map.of("key", "value"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .toolContext(toolContext); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext); + } + @Test void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() { Map advisorParams = Map.of("key1", "value1", "key2", "value2");