Skip to content

Commit cb97d9c

Browse files
meistermeierilayaperumalg
authored andcommitted
Neo4j module: Determine default embedding dimension from model.
In cases where no custom size is set, derive the size by the given embedding model. Had to migrate the embeddingDimension Spring Boot property from int to Integer to introduce the null check for the fluent config. Everything else would have been just noisy. Auto-cherry-pick to 1.0.x Fixes #977 Signed-off-by: Gerrit Meier <meistermeier@gmail.com>
1 parent 5f23dca commit cb97d9c

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-neo4j/src/main/java/org/springframework/ai/vectorstore/neo4j/autoconfigure/Neo4jVectorStoreAutoConfiguration.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ public Neo4jVectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel
6868
.customObservationConvention(customObservationConvention.getIfAvailable(() -> null))
6969
.batchingStrategy(batchingStrategy)
7070
.databaseName(properties.getDatabaseName())
71-
.embeddingDimension(properties.getEmbeddingDimension())
71+
.embeddingDimension(properties.getEmbeddingDimension() != null ? properties.getEmbeddingDimension()
72+
: embeddingModel.dimensions())
7273
.distanceType(properties.getDistanceType())
7374
.label(properties.getLabel())
7475
.embeddingProperty(properties.getEmbeddingProperty())

auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-neo4j/src/main/java/org/springframework/ai/vectorstore/neo4j/autoconfigure/Neo4jVectorStoreProperties.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class Neo4jVectorStoreProperties extends CommonVectorStoreProperties {
3333

3434
private String databaseName;
3535

36-
private int embeddingDimension = Neo4jVectorStore.DEFAULT_EMBEDDING_DIMENSION;
36+
private Integer embeddingDimension;
3737

3838
private Neo4jVectorStore.Neo4jDistanceType distanceType = Neo4jVectorStore.Neo4jDistanceType.COSINE;
3939

@@ -57,7 +57,7 @@ public void setDatabaseName(String databaseName) {
5757
this.databaseName = databaseName;
5858
}
5959

60-
public int getEmbeddingDimension() {
60+
public Integer getEmbeddingDimension() {
6161
return this.embeddingDimension;
6262
}
6363

vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ public class Neo4jVectorStore extends AbstractObservationVectorStore implements
136136

137137
private static final Logger logger = LoggerFactory.getLogger(Neo4jVectorStore.class);
138138

139+
@Deprecated(forRemoval = true)
139140
public static final int DEFAULT_EMBEDDING_DIMENSION = 1536;
140141

141142
public static final int DEFAULT_TRANSACTION_SIZE = 10_000;
@@ -189,7 +190,7 @@ protected Neo4jVectorStore(Builder builder) {
189190

190191
this.driver = builder.driver;
191192
this.sessionConfig = builder.sessionConfig;
192-
this.embeddingDimension = builder.embeddingDimension;
193+
this.embeddingDimension = builder.embeddingDimension.orElseGet(() -> builder.getEmbeddingModel().dimensions());
193194
this.distanceType = builder.distanceType;
194195
this.embeddingProperty = SchemaNames.sanitize(builder.embeddingProperty).orElseThrow();
195196
this.label = SchemaNames.sanitize(builder.label).orElseThrow();
@@ -404,7 +405,7 @@ public static class Builder extends AbstractVectorStoreBuilder<Builder> {
404405

405406
private SessionConfig sessionConfig = SessionConfig.defaultConfig();
406407

407-
private int embeddingDimension = DEFAULT_EMBEDDING_DIMENSION;
408+
private Optional<Integer> embeddingDimension = Optional.empty();
408409

409410
private Neo4jDistanceType distanceType = Neo4jDistanceType.COSINE;
410411

@@ -459,7 +460,7 @@ public Builder sessionConfig(SessionConfig sessionConfig) {
459460
*/
460461
public Builder embeddingDimension(int dimension) {
461462
Assert.isTrue(dimension >= 1, "Dimension has to be positive");
462-
this.embeddingDimension = dimension;
463+
this.embeddingDimension = Optional.of(dimension);
463464
return this;
464465
}
465466

vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreIT.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.neo4j.driver.AuthTokens;
3232
import org.neo4j.driver.Driver;
3333
import org.neo4j.driver.GraphDatabase;
34+
import org.springframework.context.annotation.Primary;
3435
import org.testcontainers.containers.Neo4jContainer;
3536
import org.testcontainers.junit.jupiter.Container;
3637
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -356,16 +357,43 @@ void getNativeClientTest() {
356357
});
357358
}
358359

360+
@Test
361+
void vectorIndexDimensionsDefaultAndOverwriteWorks() {
362+
this.contextRunner.run(context -> {
363+
var result = context.getBean(Driver.class)
364+
.executableQuery(
365+
"SHOW VECTOR INDEXES yield name, options return name, options['indexConfig']['vector.dimensions'] as dimensions")
366+
.execute()
367+
.records()
368+
.stream()
369+
.map(r -> r.get("name").asString() + r.get("dimensions").asInt())
370+
.toList();
371+
assertThat(result).containsExactlyInAnyOrder("secondIndex123", "spring-ai-document-index1536");
372+
});
373+
}
374+
359375
@SpringBootConfiguration
360376
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
361377
public static class TestApplication {
362378

363379
@Bean
380+
@Primary
364381
public VectorStore vectorStore(Driver driver, EmbeddingModel embeddingModel) {
365382

366383
return Neo4jVectorStore.builder(driver, embeddingModel).initializeSchema(true).build();
367384
}
368385

386+
@Bean
387+
public VectorStore vectorStoreWithCustomDimension(Driver driver, EmbeddingModel embeddingModel) {
388+
389+
return Neo4jVectorStore.builder(driver, embeddingModel)
390+
.initializeSchema(true)
391+
.indexName("secondIndex")
392+
.embeddingProperty("somethingElse")
393+
.embeddingDimension(123)
394+
.build();
395+
}
396+
369397
@Bean
370398
public Driver driver() {
371399
return GraphDatabase.driver(neo4jContainer.getBoltUrl(),

0 commit comments

Comments
 (0)