diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index a75a274a797..9cd442f99d4 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -249,6 +249,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { generationMetadata = ChatGenerationMetadata.builder() .finishReason(ollamaResponse.doneReason()) + .metadata("thinking", ollamaResponse.message().thinking()) .build(); } @@ -460,7 +461,8 @@ else if (message instanceof ToolResponseMessage toolMessage) { OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel()) .stream(stream) .messages(ollamaMessages) - .options(requestOptions); + .options(requestOptions) + .think(requestOptions.getThink()); if (requestOptions.getFormat() != null) { requestBuilder.format(requestOptions.getFormat()); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index e0ffc06c31d..b08fe066f45 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -51,6 +51,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Jonghoon Park + * @author Sun Yuhan * @since 0.8.0 */ // @formatter:off @@ -251,6 +252,7 @@ public Flux pullModel(PullModelRequest pullModelRequest) { * * @param role The role of the message of type {@link Role}. * @param content The content of the message. + * @param thinking The thinking of the model. * @param images The list of base64-encoded images to send with the message. * Requires multimodal models such as llava or bakllava. * @param toolCalls The relevant tool call. @@ -260,6 +262,7 @@ public Flux pullModel(PullModelRequest pullModelRequest) { public record Message( @JsonProperty("role") Role role, @JsonProperty("content") String content, + @JsonProperty("thinking") String thinking, @JsonProperty("images") List images, @JsonProperty("tool_calls") List toolCalls) { @@ -321,6 +324,7 @@ public static class Builder { private final Role role; private String content; + private String thinking; private List images; private List toolCalls; @@ -333,6 +337,11 @@ public Builder content(String content) { return this; } + public Builder thinking(String thinking) { + this.thinking = thinking; + return this; + } + public Builder images(List images) { this.images = images; return this; @@ -344,7 +353,7 @@ public Builder toolCalls(List toolCalls) { } public Message build() { - return new Message(this.role, this.content, this.images, this.toolCalls); + return new Message(this.role, this.content, this.thinking, this.images, this.toolCalls); } } } @@ -359,6 +368,7 @@ public Message build() { * @param keepAlive Controls how long the model will stay loaded into memory following this request (default: 5m). * @param tools List of tools the model has access to. * @param options Model-specific options. For example, "temperature" can be set through this field, if the model supports it. + * @param think The model should think before responding, if the model supports it. * You can use the {@link OllamaOptions} builder to create the options then {@link OllamaOptions#toMap()} to convert the options into a map. * * @see tools, - @JsonProperty("options") Map options + @JsonProperty("options") Map options, + @JsonProperty("think") Boolean think ) { public static Builder builder(String model) { @@ -448,6 +459,7 @@ public static class Builder { private String keepAlive; private List tools = List.of(); private Map options = Map.of(); + private Boolean think; public Builder(String model) { Assert.notNull(model, "The model can not be null."); @@ -492,8 +504,13 @@ public Builder options(OllamaOptions options) { return this; } + public Builder think(Boolean think) { + this.think = think; + return this; + } + public ChatRequest build() { - return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options); + return new ChatRequest(this.model, this.messages, this.stream, this.format, this.keepAlive, this.tools, this.options, this.think); } } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java index 588b86c5364..b8728874cc9 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApiHelper.java @@ -25,6 +25,7 @@ /** * @author Christian Tzolov + * @author Sun Yuhan * @since 1.0.0 */ public final class OllamaApiHelper { @@ -81,12 +82,18 @@ public static ChatResponse merge(ChatResponse previous, ChatResponse current) { private static OllamaApi.Message merge(OllamaApi.Message previous, OllamaApi.Message current) { String content = mergeContent(previous, current); + String thinking = mergeThinking(previous, current); OllamaApi.Message.Role role = (current.role() != null ? current.role() : previous.role()); role = (role != null ? role : OllamaApi.Message.Role.ASSISTANT); List images = mergeImages(previous, current); List toolCalls = mergeToolCall(previous, current); - return OllamaApi.Message.builder(role).content(content).images(images).toolCalls(toolCalls).build(); + return OllamaApi.Message.builder(role) + .content(content) + .thinking(thinking) + .images(images) + .toolCalls(toolCalls) + .build(); } private static Instant merge(Instant previous, Instant current) { @@ -134,6 +141,17 @@ private static String mergeContent(OllamaApi.Message previous, OllamaApi.Message return previous.content() + current.content(); } + private static String mergeThinking(OllamaApi.Message previous, OllamaApi.Message current) { + if (previous == null || previous.thinking() == null) { + return (current != null ? current.thinking() : null); + } + if (current == null || current.thinking() == null) { + return (previous != null ? previous.thinking() : null); + } + + return previous.thinking() + current.thinking(); + } + private static List mergeToolCall(OllamaApi.Message previous, OllamaApi.Message current) { if (previous == null) { diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java index 7602eca2584..aa91a2ec837 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaModel.java @@ -23,6 +23,7 @@ * * @author Siarhei Blashuk * @author Thomas Vitale + * @author Sun Yuhan * @since 1.0.0 */ public enum OllamaModel implements ChatModelDescription { @@ -32,6 +33,21 @@ public enum OllamaModel implements ChatModelDescription { */ QWEN_2_5_7B("qwen2.5"), + /** + * Qwen3 + */ + QWEN_3_8B("qwen3"), + + /** + * Qwen3 1.7b + */ + QWEN_3_1_7_B("qwen3:1.7b"), + + /** + * Qwen3 0.6b + */ + QWEN_3_06B("qwen3:0.6b"), + /** * QwQ is the reasoning model of the Qwen series. */ diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a71be1ce2b2..ca5aca624a5 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -44,6 +44,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Sun Yuhan * @since 0.8.0 * @see Ollama @@ -318,6 +319,14 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { @JsonProperty("truncate") private Boolean truncate; + /** + * The model should think before responding, if supported. + * If this value is not specified, it defaults to null, and Ollama will return + * the thought process within the `content` field of the response, wrapped in `<thinking>` tags. + */ + @JsonProperty("think") + private Boolean think; + @JsonIgnore private Boolean internalToolExecutionEnabled; @@ -365,6 +374,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .format(fromOptions.getFormat()) .keepAlive(fromOptions.getKeepAlive()) .truncate(fromOptions.getTruncate()) + .think(fromOptions.getThink()) .useNUMA(fromOptions.getUseNUMA()) .numCtx(fromOptions.getNumCtx()) .numBatch(fromOptions.getNumBatch()) @@ -704,6 +714,14 @@ public void setTruncate(Boolean truncate) { this.truncate = truncate; } + public Boolean getThink() { + return this.think; + } + + public void setThink(Boolean think) { + this.think = think; + } + @Override @JsonIgnore public List getToolCallbacks() { @@ -804,7 +822,8 @@ public boolean equals(Object o) { && Objects.equals(this.repeatPenalty, that.repeatPenalty) && Objects.equals(this.presencePenalty, that.presencePenalty) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) - && Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau) + && Objects.equals(this.think, that.think) && Objects.equals(this.mirostat, that.mirostat) + && Objects.equals(this.mirostatTau, that.mirostatTau) && Objects.equals(this.mirostatEta, that.mirostatEta) && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) && Objects.equals(this.toolCallbacks, that.toolCallbacks) @@ -814,13 +833,13 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx, - this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, - this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, - this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, - this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, - this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext); + return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.think, this.useNUMA, + this.numCtx, this.numBatch, this.numGPU, this.mainGPU, this.lowVRAM, this.f16KV, this.logitsAll, + this.vocabOnly, this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, + this.topK, this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, + this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, + this.mirostatEta, this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, + this.internalToolExecutionEnabled, this.toolContext); } public static class Builder { @@ -852,6 +871,11 @@ public Builder truncate(Boolean truncate) { return this; } + public Builder think(Boolean think) { + this.options.think = think; + return this; + } + public Builder useNUMA(Boolean useNUMA) { this.options.useNUMA = useNUMA; return this; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java new file mode 100644 index 00000000000..9f2d1c44dc6 --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMetadataTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama; + +import io.micrometer.observation.tck.TestObservationRegistry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link OllamaChatModel} asserting AI metadata. + * + * @author Sun Yuhan + */ +@SpringBootTest(classes = OllamaChatModelMetadataTests.Config.class) +class OllamaChatModelMetadataTests extends BaseOllamaIT { + + private static final String MODEL = OllamaModel.QWEN_3_06B.getName(); + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + OllamaChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void ollamaThinkingMetadataCaptured() { + var options = OllamaOptions.builder().model(MODEL).think(true).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNotNull(); + }); + } + + @Test + void ollamaThinkingMetadataNotCapturedWhenNotSetThinkFlag() { + var options = OllamaOptions.builder().model(MODEL).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNull(); + }); + } + + @Test + void ollamaThinkingMetadataNotCapturedWhenSetThinkFlagToFalse() { + var options = OllamaOptions.builder().model(MODEL).think(false).build(); + + Prompt prompt = new Prompt("Why is the sky blue?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + chatResponse.getResults().forEach(generation -> { + ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + assertThat(chatGenerationMetadata).isNotNull(); + var thinking = chatGenerationMetadata.get("thinking"); + assertThat(thinking).isNull(); + }); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public OllamaApi ollamaApi() { + return initializeOllama(MODEL); + } + + @Bean + public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) { + return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build(); + } + + } + +} diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java index 2220bf22695..146c2c042d5 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaImage.java @@ -23,7 +23,7 @@ */ public final class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.2"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.9.0"); private OllamaImage() { diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java index 98af032efbd..cf2ea4f8afd 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java @@ -33,14 +33,16 @@ import org.springframework.ai.ollama.api.OllamaApi.Message.Role; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNull; /** * @author Christian Tzolov * @author Thomas Vitale + * @author Sun Yuhan */ public class OllamaApiIT extends BaseOllamaIT { - private static final String MODEL = OllamaModel.LLAMA3_2.getName(); + private static final String MODEL = OllamaModel.QWEN_3_1_7_B.getName(); @BeforeAll public static void beforeAll() throws IOException, InterruptedException { @@ -107,11 +109,67 @@ public void embedText() { assertThat(response).isNotNull(); assertThat(response.embeddings()).hasSize(1); - assertThat(response.embeddings().get(0)).hasSize(3072); + assertThat(response.embeddings().get(0)).hasSize(2048); assertThat(response.model()).isEqualTo(MODEL); assertThat(response.promptEvalCount()).isEqualTo(5); assertThat(response.loadDuration()).isGreaterThan(1); assertThat(response.totalDuration()).isGreaterThan(1); } + @Test + public void streamChatWithThinking() { + var request = ChatRequest.builder(MODEL) + .stream(true) + .think(true) + .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().thinking()) + .collect(Collectors.joining(System.lineSeparator()))).contains("solar"); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + + @Test + public void streamChatWithoutThinking() { + var request = ChatRequest.builder(MODEL) + .stream(true) + .think(false) + .messages(List.of(Message.builder(Role.USER).content("What are the planets in the solar system?").build())) + .options(OllamaOptions.builder().temperature(0.9).build().toMap()) + .build(); + + Flux response = getOllamaApi().streamingChat(request); + + List responses = response.collectList().block(); + System.out.println(responses); + + assertThat(responses).isNotNull(); + + assertThat(responses.stream() + .filter(r -> r.message() != null) + .map(r -> r.message().content()) + .collect(Collectors.joining(System.lineSeparator()))).contains("Earth"); + + assertThat(responses.stream().filter(r -> r.message() != null).allMatch(r -> r.message().thinking() == null)) + .isTrue(); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.message().content()).isEmpty(); + assertNull(lastResponse.message().thinking()); + assertThat(lastResponse.done()).isTrue(); + } + }