@@ -2,27 +2,65 @@ package com.javaaidev.openai
2
2
3
3
import com.openai.client.OpenAIClient
4
4
import com.openai.models.EmbeddingCreateParams
5
+ import org.springframework.ai.chat.metadata.EmptyUsage
5
6
import org.springframework.ai.document.Document
6
- import org.springframework.ai.embedding.AbstractEmbeddingModel
7
- import org.springframework.ai.embedding.Embedding
8
- import org.springframework.ai.embedding.EmbeddingRequest
9
- import org.springframework.ai.embedding.EmbeddingResponse
7
+ import org.springframework.ai.embedding.*
8
+ import org.springframework.ai.model.ModelOptionsUtils
10
9
11
- class OpenAIEmbeddingModel (private val openAIClient : OpenAIClient ) : AbstractEmbeddingModel() {
10
+ class OpenAIEmbeddingModel (
11
+ private val openAIClient : OpenAIClient ,
12
+ private val defaultOptions : OpenAIEmbeddingOptions ? = null ,
13
+ ) :
14
+ AbstractEmbeddingModel () {
12
15
override fun call (request : EmbeddingRequest ): EmbeddingResponse {
13
16
val paramsBuilder = EmbeddingCreateParams .builder()
14
17
.inputOfArrayOfStrings(request.instructions)
15
- request.options.model?.let {
18
+
19
+ val options = mergeOptions(request.options)
20
+
21
+ options.model?.let {
16
22
paramsBuilder.model(it)
17
23
}
18
- request. options.dimensions?.let {
24
+ options.dimensions?.let {
19
25
paramsBuilder.dimensions(it.toLong())
20
26
}
27
+ options.encodingFormat?.let {
28
+ paramsBuilder.encodingFormat(EmbeddingCreateParams .EncodingFormat .of(it))
29
+ }
30
+ options.user?.let {
31
+ paramsBuilder.user(it)
32
+ }
33
+
21
34
val response = openAIClient.embeddings().create(paramsBuilder.build())
22
35
val embeddings = response.data().map { e ->
23
36
Embedding (e.embedding().map { v -> v.toFloat() }.toFloatArray(), e.index().toInt())
24
37
}
25
- return EmbeddingResponse (embeddings)
38
+ return EmbeddingResponse (embeddings, EmbeddingResponseMetadata (response.model(), EmptyUsage ()))
39
+ }
40
+
41
+ private fun mergeOptions (runtimeOptions : EmbeddingOptions ? ): OpenAIEmbeddingOptions {
42
+ val defaultOptions = this .defaultOptions ? : OpenAIEmbeddingOptions .builder().build()
43
+ return ModelOptionsUtils .copyToTarget(
44
+ runtimeOptions, EmbeddingOptions ::class .java,
45
+ OpenAIEmbeddingOptions ::class .java
46
+ )?.let { options ->
47
+ OpenAIEmbeddingOptions .builder()
48
+ .model(ModelOptionsUtils .mergeOption(options.model, defaultOptions.model))
49
+ .dimensions(
50
+ ModelOptionsUtils .mergeOption(
51
+ options.dimensions,
52
+ defaultOptions.dimensions
53
+ )
54
+ )
55
+ .encodingFormat(
56
+ ModelOptionsUtils .mergeOption(
57
+ options.encodingFormat,
58
+ defaultOptions.encodingFormat
59
+ )
60
+ )
61
+ .user(ModelOptionsUtils .mergeOption(options.user, defaultOptions.user))
62
+ .build()
63
+ } ? : defaultOptions
26
64
}
27
65
28
66
override fun embed (document : Document ): FloatArray {
0 commit comments