Skip to content

Commit c01eebc

Browse files
committed
Upgrade to Spring AI 1.0.0-M8
1 parent 3da36ad commit c01eebc

File tree

5 files changed

+700
-656
lines changed

5 files changed

+700
-656
lines changed

pom.xml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>com.javaaidev</groupId>
88
<artifactId>springai-openai-client</artifactId>
9-
<version>0.5.2</version>
9+
<version>0.6.0</version>
1010

1111
<name>OpenAI ChatModel</name>
1212
<description>Spring AI ChatModel for OpenAI using official Java SDK</description>
@@ -43,8 +43,8 @@
4343
<kotlin.version>1.9.25</kotlin.version>
4444
<kotlin.code.style>official</kotlin.code.style>
4545
<kotlin.compiler.jvmTarget>${java.version}</kotlin.compiler.jvmTarget>
46-
<spring-ai.version>1.0.0-M5</spring-ai.version>
47-
<openai-java.version>0.44.3</openai-java.version>
46+
<spring-ai.version>1.0.0-M8</spring-ai.version>
47+
<openai-java.version>1.6.0</openai-java.version>
4848
</properties>
4949

5050
<repositories>
@@ -183,7 +183,12 @@
183183
<dependencies>
184184
<dependency>
185185
<groupId>org.springframework.ai</groupId>
186-
<artifactId>spring-ai-core</artifactId>
186+
<artifactId>spring-ai-client-chat</artifactId>
187+
<version>${spring-ai.version}</version>
188+
</dependency>
189+
<dependency>
190+
<groupId>org.springframework.ai</groupId>
191+
<artifactId>spring-ai-openai</artifactId>
187192
<version>${spring-ai.version}</version>
188193
</dependency>
189194
<dependency>

src/main/kotlin/com/javaaidev/openai/OpenAIChatModel.kt

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,76 @@ import org.springframework.ai.chat.messages.SystemMessage
1111
import org.springframework.ai.chat.messages.ToolResponseMessage
1212
import org.springframework.ai.chat.messages.UserMessage
1313
import org.springframework.ai.chat.metadata.ChatGenerationMetadata
14-
import org.springframework.ai.chat.model.AbstractToolCallSupport
1514
import org.springframework.ai.chat.model.ChatModel
1615
import org.springframework.ai.chat.model.ChatResponse
1716
import org.springframework.ai.chat.model.Generation
17+
import org.springframework.ai.chat.prompt.ChatOptions
1818
import org.springframework.ai.chat.prompt.Prompt
1919
import org.springframework.ai.model.ModelOptionsUtils
20-
import org.springframework.ai.model.function.FunctionCallbackResolver
21-
import org.springframework.ai.model.function.FunctionCallingOptions
20+
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate
21+
import org.springframework.ai.model.tool.ToolCallingChatOptions
22+
import org.springframework.ai.model.tool.ToolCallingManager
23+
import org.springframework.ai.model.tool.ToolExecutionResult
2224
import org.springframework.util.MimeType
2325
import org.springframework.util.MimeTypeUtils
2426
import java.util.*
2527

2628
class OpenAIChatModel(
2729
private val openAIClient: OpenAIClient,
28-
functionCallbackResolver: FunctionCallbackResolver? = null
29-
) : AbstractToolCallSupport(functionCallbackResolver), ChatModel {
30+
manager: ToolCallingManager? = null,
31+
options: OpenAiChatOptions? = null,
32+
) : ChatModel {
33+
private val defaultOptions = options ?: OpenAiChatOptions.builder().build()
34+
private val toolCallingManager = manager ?: ToolCallingManager.builder().build()
35+
private val toolExecutionEligibilityPredicate = DefaultToolExecutionEligibilityPredicate()
3036

3137
override fun call(prompt: Prompt): ChatResponse {
38+
var runtimeOptions: OpenAiChatOptions? = null
39+
if (prompt.options != null) {
40+
runtimeOptions = if (prompt.options is ToolCallingChatOptions) {
41+
ModelOptionsUtils.copyToTarget(
42+
prompt.options as ToolCallingChatOptions,
43+
ToolCallingChatOptions::class.java,
44+
OpenAiChatOptions::class.java
45+
)
46+
} else {
47+
ModelOptionsUtils.copyToTarget(
48+
prompt.options, ChatOptions::class.java,
49+
OpenAiChatOptions::class.java
50+
)
51+
}
52+
}
53+
54+
val requestOptions = ModelOptionsUtils.merge(
55+
runtimeOptions, this.defaultOptions,
56+
OpenAiChatOptions::class.java
57+
)
58+
59+
if (runtimeOptions != null) {
60+
requestOptions.httpHeaders = mergeHttpHeaders(runtimeOptions.httpHeaders, this.defaultOptions.httpHeaders)
61+
requestOptions.isInternalToolExecutionEnabled = ModelOptionsUtils.mergeOption<Boolean>(
62+
runtimeOptions.internalToolExecutionEnabled,
63+
this.defaultOptions.internalToolExecutionEnabled
64+
)
65+
requestOptions.toolNames = ToolCallingChatOptions.mergeToolNames(
66+
runtimeOptions.toolNames,
67+
this.defaultOptions.toolNames
68+
)
69+
requestOptions.toolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(
70+
runtimeOptions.toolCallbacks,
71+
this.defaultOptions.toolCallbacks
72+
)
73+
requestOptions.toolContext = ToolCallingChatOptions.mergeToolContext(
74+
runtimeOptions.toolContext,
75+
this.defaultOptions.toolContext
76+
)
77+
} else {
78+
requestOptions.httpHeaders = this.defaultOptions.httpHeaders
79+
requestOptions.isInternalToolExecutionEnabled = this.defaultOptions.internalToolExecutionEnabled
80+
requestOptions.toolNames = this.defaultOptions.toolNames
81+
requestOptions.toolCallbacks = this.defaultOptions.toolCallbacks
82+
requestOptions.toolContext = this.defaultOptions.toolContext
83+
}
3284
return internalCall(prompt, null)
3385
}
3486

@@ -44,7 +96,7 @@ class OpenAIChatModel(
4496
ChatCompletionContentPartText.builder().text(message.text).build()
4597
)
4698
)
47-
message.media?.map { media ->
99+
message.media.map { media ->
48100
when (media.mimeType) {
49101
MimeTypeUtils.parseMimeType("audio/mp3") -> ChatCompletionContentPart.ofInputAudio(
50102
ChatCompletionContentPartInputAudio.builder()
@@ -78,7 +130,7 @@ class OpenAIChatModel(
78130
.build()
79131
)
80132
}
81-
}?.let {
133+
}.let {
82134
contentParts.addAll(it)
83135
}
84136
paramsBuilder.addMessage(
@@ -96,7 +148,7 @@ class OpenAIChatModel(
96148
ChatCompletionContentPartText.builder().text(message.text).build()
97149
)
98150
)
99-
message.toolCalls?.map { toolCall ->
151+
message.toolCalls.map { toolCall ->
100152
ChatCompletionMessageToolCall.builder()
101153
.id(toolCall.id)
102154
.function(
@@ -106,7 +158,7 @@ class OpenAIChatModel(
106158
.build()
107159
)
108160
.build()
109-
}?.let {
161+
}.let {
110162
if (it.isNotEmpty()) {
111163
messageParamBuilder.toolCalls(it)
112164
}
@@ -134,17 +186,17 @@ class OpenAIChatModel(
134186
paramsBuilder.temperature(it)
135187
}
136188

137-
if (prompt.options is FunctionCallingOptions) {
138-
val tools = (prompt.options as FunctionCallingOptions).functions?.let {
139-
resolveFunctionCallbacks(it).map { functionCallback ->
189+
if (prompt.options is ToolCallingChatOptions) {
190+
val tools = (prompt.options as ToolCallingChatOptions).let {
191+
toolCallingManager.resolveToolDefinitions(it).map { toolDefinition ->
140192
val parametersMap =
141-
ModelOptionsUtils.jsonToMap(functionCallback.inputTypeSchema)
193+
ModelOptionsUtils.jsonToMap(toolDefinition.inputSchema())
142194
val jsonValue = JsonValue.from(parametersMap)
143195
ChatCompletionTool.builder()
144196
.function(
145197
FunctionDefinition.builder()
146-
.name(functionCallback.name)
147-
.description(functionCallback.description)
198+
.name(toolDefinition.name())
199+
.description(toolDefinition.description())
148200
.parameters(
149201
FunctionParameters.builder()
150202
.putAllAdditionalProperties((jsonValue as JsonObject).values)
@@ -155,7 +207,7 @@ class OpenAIChatModel(
155207
.build()
156208
}
157209
}
158-
if (tools?.isNotEmpty() == true) {
210+
if (tools.isNotEmpty()) {
159211
paramsBuilder.tools(tools)
160212
}
161213
}
@@ -171,9 +223,19 @@ class OpenAIChatModel(
171223
)
172224
}
173225
val response = ChatResponse.builder().generations(generations).build()
174-
if (isToolCall(response, setOf("TOOL_CALLS", "STOP"))) {
175-
val toolCallConversation = handleToolCalls(prompt, response)
176-
return this.internalCall(Prompt(toolCallConversation, prompt.options), response)
226+
if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
227+
val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
228+
if (toolExecutionResult.returnDirect()) {
229+
return ChatResponse.builder()
230+
.from(response)
231+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
232+
.build()
233+
} else {
234+
return this.internalCall(
235+
Prompt(toolExecutionResult.conversationHistory(), prompt.options),
236+
response
237+
)
238+
}
177239
}
178240
return response
179241
}
@@ -226,4 +288,13 @@ class OpenAIChatModel(
226288
}
227289
}
228290
}
291+
292+
private fun mergeHttpHeaders(
293+
runtimeHttpHeaders: Map<String, String>,
294+
defaultHttpHeaders: Map<String, String>
295+
): Map<String, String> {
296+
val mergedHttpHeaders = HashMap(defaultHttpHeaders)
297+
mergedHttpHeaders.putAll(runtimeHttpHeaders)
298+
return mergedHttpHeaders
299+
}
229300
}

0 commit comments

Comments
 (0)