@@ -11,24 +11,76 @@ import org.springframework.ai.chat.messages.SystemMessage
11
11
import org.springframework.ai.chat.messages.ToolResponseMessage
12
12
import org.springframework.ai.chat.messages.UserMessage
13
13
import org.springframework.ai.chat.metadata.ChatGenerationMetadata
14
- import org.springframework.ai.chat.model.AbstractToolCallSupport
15
14
import org.springframework.ai.chat.model.ChatModel
16
15
import org.springframework.ai.chat.model.ChatResponse
17
16
import org.springframework.ai.chat.model.Generation
17
+ import org.springframework.ai.chat.prompt.ChatOptions
18
18
import org.springframework.ai.chat.prompt.Prompt
19
19
import org.springframework.ai.model.ModelOptionsUtils
20
- import org.springframework.ai.model.function.FunctionCallbackResolver
21
- import org.springframework.ai.model.function.FunctionCallingOptions
20
+ import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate
21
+ import org.springframework.ai.model.tool.ToolCallingChatOptions
22
+ import org.springframework.ai.model.tool.ToolCallingManager
23
+ import org.springframework.ai.model.tool.ToolExecutionResult
22
24
import org.springframework.util.MimeType
23
25
import org.springframework.util.MimeTypeUtils
24
26
import java.util.*
25
27
26
28
class OpenAIChatModel (
27
29
private val openAIClient : OpenAIClient ,
28
- functionCallbackResolver : FunctionCallbackResolver ? = null
29
- ) : AbstractToolCallSupport(functionCallbackResolver), ChatModel {
30
+ manager : ToolCallingManager ? = null ,
31
+ options : OpenAiChatOptions ? = null ,
32
+ ) : ChatModel {
33
+ private val defaultOptions = options ? : OpenAiChatOptions .builder().build()
34
+ private val toolCallingManager = manager ? : ToolCallingManager .builder().build()
35
+ private val toolExecutionEligibilityPredicate = DefaultToolExecutionEligibilityPredicate ()
30
36
31
37
override fun call (prompt : Prompt ): ChatResponse {
38
+ var runtimeOptions: OpenAiChatOptions ? = null
39
+ if (prompt.options != null ) {
40
+ runtimeOptions = if (prompt.options is ToolCallingChatOptions ) {
41
+ ModelOptionsUtils .copyToTarget(
42
+ prompt.options as ToolCallingChatOptions ,
43
+ ToolCallingChatOptions ::class .java,
44
+ OpenAiChatOptions ::class .java
45
+ )
46
+ } else {
47
+ ModelOptionsUtils .copyToTarget(
48
+ prompt.options, ChatOptions ::class .java,
49
+ OpenAiChatOptions ::class .java
50
+ )
51
+ }
52
+ }
53
+
54
+ val requestOptions = ModelOptionsUtils .merge(
55
+ runtimeOptions, this .defaultOptions,
56
+ OpenAiChatOptions ::class .java
57
+ )
58
+
59
+ if (runtimeOptions != null ) {
60
+ requestOptions.httpHeaders = mergeHttpHeaders(runtimeOptions.httpHeaders, this .defaultOptions.httpHeaders)
61
+ requestOptions.isInternalToolExecutionEnabled = ModelOptionsUtils .mergeOption<Boolean >(
62
+ runtimeOptions.internalToolExecutionEnabled,
63
+ this .defaultOptions.internalToolExecutionEnabled
64
+ )
65
+ requestOptions.toolNames = ToolCallingChatOptions .mergeToolNames(
66
+ runtimeOptions.toolNames,
67
+ this .defaultOptions.toolNames
68
+ )
69
+ requestOptions.toolCallbacks = ToolCallingChatOptions .mergeToolCallbacks(
70
+ runtimeOptions.toolCallbacks,
71
+ this .defaultOptions.toolCallbacks
72
+ )
73
+ requestOptions.toolContext = ToolCallingChatOptions .mergeToolContext(
74
+ runtimeOptions.toolContext,
75
+ this .defaultOptions.toolContext
76
+ )
77
+ } else {
78
+ requestOptions.httpHeaders = this .defaultOptions.httpHeaders
79
+ requestOptions.isInternalToolExecutionEnabled = this .defaultOptions.internalToolExecutionEnabled
80
+ requestOptions.toolNames = this .defaultOptions.toolNames
81
+ requestOptions.toolCallbacks = this .defaultOptions.toolCallbacks
82
+ requestOptions.toolContext = this .defaultOptions.toolContext
83
+ }
32
84
return internalCall(prompt, null )
33
85
}
34
86
@@ -44,7 +96,7 @@ class OpenAIChatModel(
44
96
ChatCompletionContentPartText .builder().text(message.text).build()
45
97
)
46
98
)
47
- message.media? .map { media ->
99
+ message.media.map { media ->
48
100
when (media.mimeType) {
49
101
MimeTypeUtils .parseMimeType(" audio/mp3" ) -> ChatCompletionContentPart .ofInputAudio(
50
102
ChatCompletionContentPartInputAudio .builder()
@@ -78,7 +130,7 @@ class OpenAIChatModel(
78
130
.build()
79
131
)
80
132
}
81
- }? .let {
133
+ }.let {
82
134
contentParts.addAll(it)
83
135
}
84
136
paramsBuilder.addMessage(
@@ -96,7 +148,7 @@ class OpenAIChatModel(
96
148
ChatCompletionContentPartText .builder().text(message.text).build()
97
149
)
98
150
)
99
- message.toolCalls? .map { toolCall ->
151
+ message.toolCalls.map { toolCall ->
100
152
ChatCompletionMessageToolCall .builder()
101
153
.id(toolCall.id)
102
154
.function(
@@ -106,7 +158,7 @@ class OpenAIChatModel(
106
158
.build()
107
159
)
108
160
.build()
109
- }? .let {
161
+ }.let {
110
162
if (it.isNotEmpty()) {
111
163
messageParamBuilder.toolCalls(it)
112
164
}
@@ -134,17 +186,17 @@ class OpenAIChatModel(
134
186
paramsBuilder.temperature(it)
135
187
}
136
188
137
- if (prompt.options is FunctionCallingOptions ) {
138
- val tools = (prompt.options as FunctionCallingOptions ).functions? .let {
139
- resolveFunctionCallbacks (it).map { functionCallback ->
189
+ if (prompt.options is ToolCallingChatOptions ) {
190
+ val tools = (prompt.options as ToolCallingChatOptions ) .let {
191
+ toolCallingManager.resolveToolDefinitions (it).map { toolDefinition ->
140
192
val parametersMap =
141
- ModelOptionsUtils .jsonToMap(functionCallback.inputTypeSchema )
193
+ ModelOptionsUtils .jsonToMap(toolDefinition.inputSchema() )
142
194
val jsonValue = JsonValue .from(parametersMap)
143
195
ChatCompletionTool .builder()
144
196
.function(
145
197
FunctionDefinition .builder()
146
- .name(functionCallback .name)
147
- .description(functionCallback .description)
198
+ .name(toolDefinition .name() )
199
+ .description(toolDefinition .description() )
148
200
.parameters(
149
201
FunctionParameters .builder()
150
202
.putAllAdditionalProperties((jsonValue as JsonObject ).values)
@@ -155,7 +207,7 @@ class OpenAIChatModel(
155
207
.build()
156
208
}
157
209
}
158
- if (tools? .isNotEmpty() == true ) {
210
+ if (tools.isNotEmpty()) {
159
211
paramsBuilder.tools(tools)
160
212
}
161
213
}
@@ -171,9 +223,19 @@ class OpenAIChatModel(
171
223
)
172
224
}
173
225
val response = ChatResponse .builder().generations(generations).build()
174
- if (isToolCall(response, setOf (" TOOL_CALLS" , " STOP" ))) {
175
- val toolCallConversation = handleToolCalls(prompt, response)
176
- return this .internalCall(Prompt (toolCallConversation, prompt.options), response)
226
+ if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.options, response)) {
227
+ val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response)
228
+ if (toolExecutionResult.returnDirect()) {
229
+ return ChatResponse .builder()
230
+ .from(response)
231
+ .generations(ToolExecutionResult .buildGenerations(toolExecutionResult))
232
+ .build()
233
+ } else {
234
+ return this .internalCall(
235
+ Prompt (toolExecutionResult.conversationHistory(), prompt.options),
236
+ response
237
+ )
238
+ }
177
239
}
178
240
return response
179
241
}
@@ -226,4 +288,13 @@ class OpenAIChatModel(
226
288
}
227
289
}
228
290
}
291
+
292
+ private fun mergeHttpHeaders (
293
+ runtimeHttpHeaders : Map <String , String >,
294
+ defaultHttpHeaders : Map <String , String >
295
+ ): Map <String , String > {
296
+ val mergedHttpHeaders = HashMap (defaultHttpHeaders)
297
+ mergedHttpHeaders.putAll(runtimeHttpHeaders)
298
+ return mergedHttpHeaders
299
+ }
229
300
}
0 commit comments