Skip to content

Commit 8d43af9

Browse files
committed
add default embedding options
1 parent e81ffab commit 8d43af9

File tree

4 files changed

+66
-14
lines changed

4 files changed

+66
-14
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Add Maven dependency.
1515
<dependency>
1616
<groupId>com.javaaidev</groupId>
1717
<artifactId>springai-openai-client</artifactId>
18-
<version>0.4.0</version>
18+
<version>0.4.1</version>
1919
</dependency>
2020
```
2121

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>com.javaaidev</groupId>
88
<artifactId>springai-openai-client</artifactId>
9-
<version>0.4.0</version>
9+
<version>0.4.1</version>
1010

1111
<name>OpenAI ChatModel</name>
1212
<description>Spring AI ChatModel for OpenAI using official Java SDK</description>

src/main/kotlin/com/javaaidev/openai/OpenAIEmbeddingModel.kt

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,65 @@ package com.javaaidev.openai
22

33
import com.openai.client.OpenAIClient
44
import com.openai.models.EmbeddingCreateParams
5+
import org.springframework.ai.chat.metadata.EmptyUsage
56
import org.springframework.ai.document.Document
6-
import org.springframework.ai.embedding.AbstractEmbeddingModel
7-
import org.springframework.ai.embedding.Embedding
8-
import org.springframework.ai.embedding.EmbeddingRequest
9-
import org.springframework.ai.embedding.EmbeddingResponse
7+
import org.springframework.ai.embedding.*
8+
import org.springframework.ai.model.ModelOptionsUtils
109

11-
class OpenAIEmbeddingModel(private val openAIClient: OpenAIClient) : AbstractEmbeddingModel() {
10+
class OpenAIEmbeddingModel(
11+
private val openAIClient: OpenAIClient,
12+
private val defaultOptions: OpenAIEmbeddingOptions? = null,
13+
) :
14+
AbstractEmbeddingModel() {
1215
override fun call(request: EmbeddingRequest): EmbeddingResponse {
1316
val paramsBuilder = EmbeddingCreateParams.builder()
1417
.inputOfArrayOfStrings(request.instructions)
15-
request.options.model?.let {
18+
19+
val options = mergeOptions(request.options)
20+
21+
options.model?.let {
1622
paramsBuilder.model(it)
1723
}
18-
request.options.dimensions?.let {
24+
options.dimensions?.let {
1925
paramsBuilder.dimensions(it.toLong())
2026
}
27+
options.encodingFormat?.let {
28+
paramsBuilder.encodingFormat(EmbeddingCreateParams.EncodingFormat.of(it))
29+
}
30+
options.user?.let {
31+
paramsBuilder.user(it)
32+
}
33+
2134
val response = openAIClient.embeddings().create(paramsBuilder.build())
2235
val embeddings = response.data().map { e ->
2336
Embedding(e.embedding().map { v -> v.toFloat() }.toFloatArray(), e.index().toInt())
2437
}
25-
return EmbeddingResponse(embeddings)
38+
return EmbeddingResponse(embeddings, EmbeddingResponseMetadata(response.model(), EmptyUsage()))
39+
}
40+
41+
private fun mergeOptions(runtimeOptions: EmbeddingOptions?): OpenAIEmbeddingOptions {
42+
val defaultOptions = this.defaultOptions ?: OpenAIEmbeddingOptions.builder().build()
43+
return ModelOptionsUtils.copyToTarget(
44+
runtimeOptions, EmbeddingOptions::class.java,
45+
OpenAIEmbeddingOptions::class.java
46+
)?.let { options ->
47+
OpenAIEmbeddingOptions.builder()
48+
.model(ModelOptionsUtils.mergeOption(options.model, defaultOptions.model))
49+
.dimensions(
50+
ModelOptionsUtils.mergeOption(
51+
options.dimensions,
52+
defaultOptions.dimensions
53+
)
54+
)
55+
.encodingFormat(
56+
ModelOptionsUtils.mergeOption(
57+
options.encodingFormat,
58+
defaultOptions.encodingFormat
59+
)
60+
)
61+
.user(ModelOptionsUtils.mergeOption(options.user, defaultOptions.user))
62+
.build()
63+
} ?: defaultOptions
2664
}
2765

2866
override fun embed(document: Document): FloatArray {
Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package com.javaaidev.openai
22

33
import com.openai.client.okhttp.OpenAIOkHttpClient
4-
import org.junit.jupiter.api.Assertions.assertNotNull
5-
import org.junit.jupiter.api.Assertions.assertTrue
4+
import org.junit.jupiter.api.Assertions.*
65
import org.junit.jupiter.api.DisplayName
76
import org.junit.jupiter.api.Test
87
import org.springframework.ai.embedding.EmbeddingModel
@@ -13,21 +12,36 @@ class OpenAIEmbeddingModelTest {
1312

1413
init {
1514
val client = OpenAIOkHttpClient.fromEnv()
16-
embeddingModel = OpenAIEmbeddingModel(client)
15+
embeddingModel = OpenAIEmbeddingModel(
16+
client,
17+
OpenAIEmbeddingOptions.builder().model("text-embedding-3-small").build()
18+
)
1719
}
1820

1921
@Test
2022
@DisplayName("simple embedding")
2123
fun testEmbedding() {
24+
val response = embeddingModel.embed(
25+
listOf("hello", "world")
26+
)
27+
assertNotNull(response)
28+
assertTrue(response.isNotEmpty())
29+
}
30+
31+
@Test
32+
@DisplayName("embedding with options")
33+
fun testEmbeddingWithOptions() {
34+
val model = "text-embedding-3-large"
2235
val response = embeddingModel.call(
2336
EmbeddingRequest(
2437
listOf("hello", "world"),
2538
OpenAIEmbeddingOptions.builder()
26-
.model("text-embedding-3-small")
39+
.model(model)
2740
.build()
2841
)
2942
)
3043
assertNotNull(response)
44+
assertEquals(model, response.metadata.model)
3145
assertTrue(response.results.isNotEmpty())
3246
}
3347
}

0 commit comments

Comments
 (0)