From 101bbc461131cf3d19e92a0e555fda4db483078a Mon Sep 17 00:00:00 2001 From: Sun Yuhan <1085481446@qq.com> Date: Tue, 3 Jun 2025 11:44:13 +0800 Subject: [PATCH] fix: Fixed the issue where tool call information was lost when using DefaultChatOptions. Signed-off-by: Sun Yuhan <1085481446@qq.com> --- .../chat/client/DefaultChatClientUtils.java | 20 +++++++ .../client/DefaultChatClientUtilsTests.java | 60 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java index 10f623e2b70..f68708e7893 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -27,8 +27,10 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.DefaultChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.util.Assert; @@ -39,6 +41,7 @@ * Utilities for supporting the {@link DefaultChatClient} implementation. * * @author Thomas Vitale + * @author Sun Yuhan * @since 1.0.0 */ final class DefaultChatClientUtils { @@ -94,6 +97,23 @@ static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClient */ ChatOptions processedChatOptions = inputRequest.getChatOptions(); + + if (processedChatOptions instanceof DefaultChatOptions defaultChatOptions) { + if (!inputRequest.getToolNames().isEmpty() || !inputRequest.getToolCallbacks().isEmpty() + || !CollectionUtils.isEmpty(inputRequest.getToolContext())) { + processedChatOptions = DefaultToolCallingChatOptions.builder() + .model(defaultChatOptions.getModel()) + .frequencyPenalty(defaultChatOptions.getFrequencyPenalty()) + .maxTokens(defaultChatOptions.getMaxTokens()) + .presencePenalty(defaultChatOptions.getPresencePenalty()) + .stopSequences(defaultChatOptions.getStopSequences()) + .temperature(defaultChatOptions.getTemperature()) + .topK(defaultChatOptions.getTopK()) + .topP(defaultChatOptions.getTopP()) + .build(); + } + } + if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { if (!inputRequest.getToolNames().isEmpty()) { Set toolNames = ToolCallingChatOptions diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java index 9d4d4962069..7b8d5491f90 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -26,6 +26,7 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.DefaultChatOptions; import org.springframework.ai.content.Media; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.template.TemplateRenderer; @@ -43,6 +44,7 @@ * Unit tests for {@link DefaultChatClientUtils}. * * @author Thomas Vitale + * @author Sun Yuhan */ class DefaultChatClientUtilsTests { @@ -322,6 +324,64 @@ void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() { .containsAllEntriesOf(toolContext2); } + @Test + void whenToolNamesAndChatOptionsAreDefaultChatOptions() { + Set toolNames1 = Set.of("toolA", "toolB"); + DefaultChatOptions chatOptions = new DefaultChatOptions(); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolNames(toolNames1.toArray(new String[0])); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames1); + } + + @Test + void whenToolCallbacksAndChatOptionsAreDefaultChatOptions() { + ToolCallback toolCallback1 = new TestToolCallback("tool1"); + DefaultChatOptions chatOptions = new DefaultChatOptions(); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolCallbacks(toolCallback1); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolCallbacks()).containsExactlyInAnyOrder(toolCallback1); + } + + @Test + void whenToolContextAndChatOptionsAreDefaultChatOptions() { + Map toolContext1 = Map.of("key1", "value1"); + DefaultChatOptions chatOptions = new DefaultChatOptions(); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolContext(toolContext1); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext1); + } + @Test void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() { Map advisorParams = Map.of("key1", "value1", "key2", "value2");