diff --git a/src/main/antora/modules/ROOT/pages/migration-guides/migration-guide-5.3-5.4.adoc b/src/main/antora/modules/ROOT/pages/migration-guides/migration-guide-5.3-5.4.adoc index d9eddf030..c5178ff75 100644 --- a/src/main/antora/modules/ROOT/pages/migration-guides/migration-guide-5.3-5.4.adoc +++ b/src/main/antora/modules/ROOT/pages/migration-guides/migration-guide-5.3-5.4.adoc @@ -6,6 +6,17 @@ This section describes breaking changes from version 5.3.x to 5.4.x and how remo [[elasticsearch-migration-guide-5.3-5.4.breaking-changes]] == Breaking Changes +[[elasticsearch-migration-guide-5.3-5.4.breaking-changes.knn-search]] +=== knn search +The `withKnnQuery` method in `NativeQueryBuilder` has been replaced with `withKnnSearches` to build a `NativeQuery` with knn search. + +`KnnQuery` and `KnnSearch` are two different classes in elasticsearch java client and are used for different queries, with different parameters supported: + +- `KnnSearch`: is https://www.elastic.co/guide/en/elasticsearch/reference/8.13/search-search.html#search-api-knn[the top level `knn` query] in the elasticsearch request; +- `KnnQuery`: is https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html[the `knn` query inside `query` clause]; + +If `KnnQuery` is still preferable, please be sure to construct it inside `query` clause manually, by means of `withQuery(co.elastic.clients.elasticsearch._types.query_dsl.Query query)` clause in `NativeQueryBuilder`. + [[elasticsearch-migration-guide-5.3-5.4.deprecations]] == Deprecations diff --git a/src/main/java/org/springframework/data/elasticsearch/annotations/Field.java b/src/main/java/org/springframework/data/elasticsearch/annotations/Field.java index dd299a493..63f74716b 100644 --- a/src/main/java/org/springframework/data/elasticsearch/annotations/Field.java +++ b/src/main/java/org/springframework/data/elasticsearch/annotations/Field.java @@ -37,6 +37,7 @@ * @author Brian Kimmig * @author Morgan Lutz * @author Sascha Woo + * @author Haibo Liu */ @Retention(RetentionPolicy.RUNTIME) @Target({ ElementType.FIELD, ElementType.ANNOTATION_TYPE, ElementType.METHOD }) @@ -195,6 +196,27 @@ */ int dims() default -1; + /** + * to be used in combination with {@link FieldType#Dense_Vector} + * + * @since 5.4 + */ + String elementType() default FieldElementType.DEFAULT; + + /** + * to be used in combination with {@link FieldType#Dense_Vector} + * + * @since 5.4 + */ + KnnSimilarity knnSimilarity() default KnnSimilarity.DEFAULT; + + /** + * to be used in combination with {@link FieldType#Dense_Vector} + * + * @since 5.4 + */ + KnnIndexOptions[] knnIndexOptions() default {}; + /** * Controls how Elasticsearch dynamically adds fields to the inner object within the document.
* To be used in combination with {@link FieldType#Object} or {@link FieldType#Nested} diff --git a/src/main/java/org/springframework/data/elasticsearch/annotations/FieldElementType.java b/src/main/java/org/springframework/data/elasticsearch/annotations/FieldElementType.java new file mode 100644 index 000000000..93247b735 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/annotations/FieldElementType.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.annotations; + +/** + * @author Haibo Liu + * @since 5.4 + */ +public final class FieldElementType { + public final static String DEFAULT = ""; + public final static String FLOAT = "float"; + public final static String BYTE = "byte"; +} diff --git a/src/main/java/org/springframework/data/elasticsearch/annotations/InnerField.java b/src/main/java/org/springframework/data/elasticsearch/annotations/InnerField.java index ceb605411..35bccd968 100644 --- a/src/main/java/org/springframework/data/elasticsearch/annotations/InnerField.java +++ b/src/main/java/org/springframework/data/elasticsearch/annotations/InnerField.java @@ -29,6 +29,7 @@ * @author Aleksei Arsenev * @author Brian Kimmig * @author Morgan Lutz + * @author Haibo Liu */ @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.ANNOTATION_TYPE) @@ -149,4 +150,25 @@ * @since 4.2 */ int dims() default -1; + + /** + * to be used in combination with {@link FieldType#Dense_Vector} + * + * @since 5.4 + */ + String elementType() default FieldElementType.DEFAULT; + + /** + * to be used in combination with {@link FieldType#Dense_Vector} + * + * @since 5.4 + */ + KnnSimilarity knnSimilarity() default KnnSimilarity.DEFAULT; + + /** + * to be used in combination with {@link FieldType#Dense_Vector} + * + * @since 5.4 + */ + KnnIndexOptions[] knnIndexOptions() default {}; } diff --git a/src/main/java/org/springframework/data/elasticsearch/annotations/KnnAlgorithmType.java b/src/main/java/org/springframework/data/elasticsearch/annotations/KnnAlgorithmType.java new file mode 100644 index 000000000..1eae1188d --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/annotations/KnnAlgorithmType.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.annotations; + +/** + * @author Haibo Liu + * @since 5.4 + */ +public enum KnnAlgorithmType { + HNSW("hnsw"), + INT8_HNSW("int8_hnsw"), + FLAT("flat"), + INT8_FLAT("int8_flat"), + DEFAULT(""); + + private final String type; + + KnnAlgorithmType(String type) { + this.type = type; + } + + public String getType() { + return type; + } +} diff --git a/src/main/java/org/springframework/data/elasticsearch/annotations/KnnIndexOptions.java b/src/main/java/org/springframework/data/elasticsearch/annotations/KnnIndexOptions.java new file mode 100644 index 000000000..48e27a45f --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/annotations/KnnIndexOptions.java @@ -0,0 +1,40 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.annotations; + +/** + * @author Haibo Liu + * @since 5.4 + */ +public @interface KnnIndexOptions { + + KnnAlgorithmType type() default KnnAlgorithmType.DEFAULT; + + /** + * Only applicable to {@link KnnAlgorithmType#HNSW} and {@link KnnAlgorithmType#INT8_HNSW} index types. + */ + int m() default -1; + + /** + * Only applicable to {@link KnnAlgorithmType#HNSW} and {@link KnnAlgorithmType#INT8_HNSW} index types. + */ + int efConstruction() default -1; + + /** + * Only applicable to {@link KnnAlgorithmType#INT8_HNSW} and {@link KnnAlgorithmType#INT8_FLAT} index types. + */ + float confidenceInterval() default -1F; +} diff --git a/src/main/java/org/springframework/data/elasticsearch/annotations/KnnSimilarity.java b/src/main/java/org/springframework/data/elasticsearch/annotations/KnnSimilarity.java new file mode 100644 index 000000000..97a23aa35 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/annotations/KnnSimilarity.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.annotations; + +/** + * @author Haibo Liu + * @since 5.4 + */ +public enum KnnSimilarity { + L2_NORM("l2_norm"), + DOT_PRODUCT("dot_product"), + COSINE("cosine"), + MAX_INNER_PRODUCT("max_inner_product"), + DEFAULT(""); + + private final String similarity; + + KnnSimilarity(String similarity) { + this.similarity = similarity; + } + + public String getSimilarity() { + return similarity; + } +} diff --git a/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQuery.java b/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQuery.java index f1b5fdc24..ee1978688 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQuery.java @@ -15,7 +15,6 @@ */ package org.springframework.data.elasticsearch.client.elc; -import co.elastic.clients.elasticsearch._types.KnnQuery; import co.elastic.clients.elasticsearch._types.KnnSearch; import co.elastic.clients.elasticsearch._types.SortOptions; import co.elastic.clients.elasticsearch._types.aggregations.Aggregation; @@ -30,7 +29,6 @@ import java.util.Map; import org.springframework.data.elasticsearch.core.query.BaseQuery; -import org.springframework.data.elasticsearch.core.query.ScriptedField; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -40,6 +38,7 @@ * * @author Peter-Josef Meisch * @author Sascha Woo + * @author Haibo Liu * @since 4.4 */ public class NativeQuery extends BaseQuery { @@ -54,7 +53,6 @@ public class NativeQuery extends BaseQuery { private List sortOptions = Collections.emptyList(); private Map searchExtensions = Collections.emptyMap(); - @Nullable private KnnQuery knnQuery; @Nullable private List knnSearches = Collections.emptyList(); public NativeQuery(NativeQueryBuilder builder) { @@ -72,7 +70,6 @@ public NativeQuery(NativeQueryBuilder builder) { "Cannot add an NativeQuery in a NativeQuery"); } this.springDataQuery = builder.getSpringDataQuery(); - this.knnQuery = builder.getKnnQuery(); this.knnSearches = builder.getKnnSearches(); } @@ -124,14 +121,6 @@ public void setSpringDataQuery(@Nullable org.springframework.data.elasticsearch. this.springDataQuery = springDataQuery; } - /** - * @since 5.1 - */ - @Nullable - public KnnQuery getKnnQuery() { - return knnQuery; - } - /** * @since 5.3.1 */ diff --git a/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQueryBuilder.java b/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQueryBuilder.java index 1956a75ea..6887963da 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQueryBuilder.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQueryBuilder.java @@ -40,6 +40,7 @@ /** * @author Peter-Josef Meisch * @author Sascha Woo + * @author Haibo Liu * @since 4.4 */ public class NativeQueryBuilder extends BaseQueryBuilder { @@ -213,13 +214,30 @@ public NativeQueryBuilder withQuery(org.springframework.data.elasticsearch.core. } /** - * @since 5.1 + * @since 5.4 */ - public NativeQueryBuilder withKnnQuery(KnnQuery knnQuery) { - this.knnQuery = knnQuery; + public NativeQueryBuilder withKnnSearches(List knnSearches) { + this.knnSearches = knnSearches; return this; } + /** + * @since 5.4 + */ + public NativeQueryBuilder withKnnSearches(Function> fn) { + + Assert.notNull(fn, "fn must not be null"); + + return withKnnSearches(fn.apply(new KnnSearch.Builder()).build()); + } + + /** + * @since 5.4 + */ + public NativeQueryBuilder withKnnSearches(KnnSearch knnSearch) { + return withKnnSearches(List.of(knnSearch)); + } + public NativeQuery build() { Assert.isTrue(query == null || springDataQuery == null, "Cannot have both a native query and a Spring Data query"); return new NativeQuery(this); diff --git a/src/main/java/org/springframework/data/elasticsearch/client/elc/RequestConverter.java b/src/main/java/org/springframework/data/elasticsearch/client/elc/RequestConverter.java index 80efd4487..09ad82c6f 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/elc/RequestConverter.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/elc/RequestConverter.java @@ -1377,7 +1377,7 @@ public MsearchRequest searchMsearchRequest( private Function> msearchHeaderBuilder(Query query, IndexCoordinates index, @Nullable String routing) { return h -> { - var searchType = (query instanceof NativeQuery nativeQuery && nativeQuery.getKnnQuery() != null) ? null + var searchType = (query instanceof NativeQuery nativeQuery && !isEmpty(nativeQuery.getKnnSearches())) ? null : searchType(query.getSearchType()); h // @@ -1409,7 +1409,7 @@ private void prepareSearchRequest(Query query, @Nullable String routing, @Nu ElasticsearchPersistentEntity persistentEntity = getPersistentEntity(clazz); - var searchType = (query instanceof NativeQuery nativeQuery && nativeQuery.getKnnQuery() != null) ? null + var searchType = (query instanceof NativeQuery nativeQuery && !isEmpty(nativeQuery.getKnnSearches())) ? null : searchType(query.getSearchType()); builder // @@ -1728,17 +1728,6 @@ private void prepareNativeSearch(NativeQuery query, SearchRequest.Builder builde .sort(query.getSortOptions()) // ; - if (query.getKnnQuery() != null) { - var kq = query.getKnnQuery(); - builder.knn(ksb -> ksb - .field(kq.field()) - .queryVector(kq.queryVector()) - .numCandidates(kq.numCandidates()) - .filter(kq.filter()) - .similarity(kq.similarity())); - - } - if (!isEmpty(query.getKnnSearches())) { builder.knn(query.getKnnSearches()); } @@ -1760,17 +1749,6 @@ private void prepareNativeSearch(NativeQuery query, MultisearchBody.Builder buil .collapse(query.getFieldCollapse()) // .sort(query.getSortOptions()); - if (query.getKnnQuery() != null) { - var kq = query.getKnnQuery(); - builder.knn(ksb -> ksb - .field(kq.field()) - .queryVector(kq.queryVector()) - .numCandidates(kq.numCandidates()) - .filter(kq.filter()) - .similarity(kq.similarity())); - - } - if (!isEmpty(query.getKnnSearches())) { builder.knn(query.getKnnSearches()); } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/index/MappingParameters.java b/src/main/java/org/springframework/data/elasticsearch/core/index/MappingParameters.java index 595cc6bd6..cd3903376 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/index/MappingParameters.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/index/MappingParameters.java @@ -23,15 +23,7 @@ import java.util.List; import java.util.stream.Collectors; -import org.springframework.data.elasticsearch.annotations.DateFormat; -import org.springframework.data.elasticsearch.annotations.Field; -import org.springframework.data.elasticsearch.annotations.FieldType; -import org.springframework.data.elasticsearch.annotations.IndexOptions; -import org.springframework.data.elasticsearch.annotations.IndexPrefixes; -import org.springframework.data.elasticsearch.annotations.InnerField; -import org.springframework.data.elasticsearch.annotations.NullValueType; -import org.springframework.data.elasticsearch.annotations.Similarity; -import org.springframework.data.elasticsearch.annotations.TermVector; +import org.springframework.data.elasticsearch.annotations.*; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -49,6 +41,7 @@ * @author Brian Kimmig * @author Morgan Lutz * @author Sascha Woo + * @author Haibo Liu * @since 4.0 */ public final class MappingParameters { @@ -78,6 +71,10 @@ public final class MappingParameters { static final String FIELD_PARAM_ORIENTATION = "orientation"; static final String FIELD_PARAM_POSITIVE_SCORE_IMPACT = "positive_score_impact"; static final String FIELD_PARAM_DIMS = "dims"; + static final String FIELD_PARAM_ELEMENT_TYPE = "element_type"; + static final String FIELD_PARAM_M = "m"; + static final String FIELD_PARAM_EF_CONSTRUCTION = "ef_construction"; + static final String FIELD_PARAM_CONFIDENCE_INTERVAL = "confidence_interval"; static final String FIELD_PARAM_SCALING_FACTOR = "scaling_factor"; static final String FIELD_PARAM_SEARCH_ANALYZER = "search_analyzer"; static final String FIELD_PARAM_STORE = "store"; @@ -110,6 +107,9 @@ public final class MappingParameters { private final Integer positionIncrementGap; private final boolean positiveScoreImpact; private final Integer dims; + private final String elementType; + private final KnnSimilarity knnSimilarity; + @Nullable private final KnnIndexOptions knnIndexOptions; private final String searchAnalyzer; private final double scalingFactor; private final String similarity; @@ -174,6 +174,9 @@ private MappingParameters(Field field) { Assert.isTrue(dims >= 1 && dims <= 4096, "Invalid required parameter! Dense_Vector value \"dims\" must be between 1 and 4096."); } + elementType = field.elementType(); + knnSimilarity = field.knnSimilarity(); + knnIndexOptions = field.knnIndexOptions().length > 0 ? field.knnIndexOptions()[0] : null; Assert.isTrue(field.enabled() || type == FieldType.Object, "enabled false is only allowed for field type object"); enabled = field.enabled(); eagerGlobalOrdinals = field.eagerGlobalOrdinals(); @@ -217,6 +220,9 @@ private MappingParameters(InnerField field) { Assert.isTrue(dims >= 1 && dims <= 4096, "Invalid required parameter! Dense_Vector value \"dims\" must be between 1 and 4096."); } + elementType = field.elementType(); + knnSimilarity = field.knnSimilarity(); + knnIndexOptions = field.knnIndexOptions().length > 0 ? field.knnIndexOptions()[0] : null; enabled = true; eagerGlobalOrdinals = field.eagerGlobalOrdinals(); } @@ -356,6 +362,48 @@ public void writeTypeAndParametersTo(ObjectNode objectNode) throws IOException { if (type == FieldType.Dense_Vector) { objectNode.put(FIELD_PARAM_DIMS, dims); + + if (!FieldElementType.DEFAULT.equals(elementType)) { + objectNode.put(FIELD_PARAM_ELEMENT_TYPE, elementType); + } + + if (knnSimilarity != KnnSimilarity.DEFAULT) { + objectNode.put(FIELD_PARAM_SIMILARITY, knnSimilarity.getSimilarity()); + } + + if (knnSimilarity != KnnSimilarity.DEFAULT) { + Assert.isTrue(index, "knn similarity can only be specified when 'index' is true."); + objectNode.put(FIELD_PARAM_SIMILARITY, knnSimilarity.getSimilarity()); + } + + if (knnIndexOptions != null) { + Assert.isTrue(index, "knn index options can only be specified when 'index' is true."); + ObjectNode indexOptionsNode = objectNode.putObject(FIELD_PARAM_INDEX_OPTIONS); + KnnAlgorithmType algoType = knnIndexOptions.type(); + if (algoType != KnnAlgorithmType.DEFAULT) { + if (algoType == KnnAlgorithmType.INT8_HNSW || algoType == KnnAlgorithmType.INT8_FLAT) { + Assert.isTrue(!FieldElementType.BYTE.equals(elementType), + "'element_type' can only be float when using vector quantization."); + } + indexOptionsNode.put(FIELD_PARAM_TYPE, algoType.getType()); + } + if (knnIndexOptions.m() >= 0) { + Assert.isTrue(algoType == KnnAlgorithmType.HNSW || algoType == KnnAlgorithmType.INT8_HNSW, + "knn 'm' parameter can only be applicable to hnsw and int8_hnsw index types."); + indexOptionsNode.put(FIELD_PARAM_M, knnIndexOptions.m()); + } + if (knnIndexOptions.efConstruction() >= 0) { + Assert.isTrue(algoType == KnnAlgorithmType.HNSW || algoType == KnnAlgorithmType.INT8_HNSW, + "knn 'ef_construction' can only be applicable to hnsw and int8_hnsw index types."); + indexOptionsNode.put(FIELD_PARAM_EF_CONSTRUCTION, knnIndexOptions.efConstruction()); + } + if (knnIndexOptions.confidenceInterval() >= 0) { + Assert.isTrue(algoType == KnnAlgorithmType.INT8_HNSW + || algoType == KnnAlgorithmType.INT8_FLAT, + "knn 'confidence_interval' can only be applicable to int8_hnsw and int8_flat index types."); + indexOptionsNode.put(FIELD_PARAM_CONFIDENCE_INTERVAL, knnIndexOptions.confidenceInterval()); + } + } } if (!enabled) { diff --git a/src/test/java/org/springframework/data/elasticsearch/core/index/MappingBuilderIntegrationTests.java b/src/test/java/org/springframework/data/elasticsearch/core/index/MappingBuilderIntegrationTests.java index f769ffaaa..1d37f007a 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/index/MappingBuilderIntegrationTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/index/MappingBuilderIntegrationTests.java @@ -58,6 +58,7 @@ * @author Roman Puchkovskiy * @author Brian Kimmig * @author Morgan Lutz + * @author Haibo Liu */ @SpringIntegrationTest public abstract class MappingBuilderIntegrationTests extends MappingContextBaseTests { @@ -908,7 +909,8 @@ static class SimilarityEntity { @Nullable @Id private String id; - @Field(type = FieldType.Dense_Vector, dims = 42, similarity = "cosine") private double[] denseVector; + @Field(type = FieldType.Dense_Vector, dims = 42, knnSimilarity = KnnSimilarity.COSINE) + private double[] denseVector; } @Mapping(aliases = { diff --git a/src/test/java/org/springframework/data/elasticsearch/core/index/MappingBuilderUnitTests.java b/src/test/java/org/springframework/data/elasticsearch/core/index/MappingBuilderUnitTests.java index 8ce062c28..cf5ac6f3b 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/index/MappingBuilderUnitTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/index/MappingBuilderUnitTests.java @@ -62,6 +62,7 @@ * @author Roman Puchkovskiy * @author Brian Kimmig * @author Morgan Lutz + * @author Haibo Liu */ public class MappingBuilderUnitTests extends MappingContextBaseTests { @@ -695,6 +696,32 @@ void shouldWriteDenseVectorProperties() throws JSONException { assertEquals(expected, mapping, false); } + @Test + @DisplayName("should write dense_vector properties for knn search") + void shouldWriteDenseVectorPropertiesWithKnnSearch() throws JSONException { + String expected = """ + { + "properties":{ + "my_vector":{ + "type":"dense_vector", + "dims":16, + "element_type":"float", + "similarity":"dot_product", + "index_options":{ + "type":"hnsw", + "m":16, + "ef_construction":100 + } + } + } + } + """; + + String mapping = getMappingBuilder().buildPropertyMapping(DenseVectorEntityWithKnnSearch.class); + + assertEquals(expected, mapping, false); + } + @Test // #1370 @DisplayName("should not write mapping when enabled is false on entity") void shouldNotWriteMappingWhenEnabledIsFalseOnEntity() throws JSONException { @@ -741,6 +768,14 @@ void shouldOnlyAllowDisabledPropertiesOnTypeObject() { .isInstanceOf(MappingException.class); } + @Test + @DisplayName("should match confidence interval parameter for dense_vector type") + void shouldMatchConfidenceIntervalParameterForDenseVectorType() { + + assertThatThrownBy(() -> getMappingBuilder().buildPropertyMapping(DenseVectorMisMatchConfidenceIntervalClass.class)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test // #1711 @DisplayName("should write typeHint entries") void shouldWriteTypeHintEntries() throws JSONException { @@ -2063,6 +2098,36 @@ public void setMy_vector(@Nullable float[] my_vector) { } } + @SuppressWarnings("unused") + static class DenseVectorEntityWithKnnSearch { + @Nullable + @Id private String id; + + @Nullable + @Field(type = FieldType.Dense_Vector, dims = 16, elementType = FieldElementType.FLOAT, + knnIndexOptions = @KnnIndexOptions(type = KnnAlgorithmType.HNSW, m = 16, efConstruction = 100), + knnSimilarity = KnnSimilarity.DOT_PRODUCT) + private float[] my_vector; + + @Nullable + public String getId() { + return id; + } + + public void setId(@Nullable String id) { + this.id = id; + } + + @Nullable + public float[] getMy_vector() { + return my_vector; + } + + public void setMy_vector(@Nullable float[] my_vector) { + this.my_vector = my_vector; + } + } + @Mapping(enabled = false) static class DisabledMappingEntity { @Nullable @@ -2115,6 +2180,13 @@ public void setText(@Nullable String text) { } } + static class DenseVectorMisMatchConfidenceIntervalClass { + @Field(type = Dense_Vector, dims = 16, elementType = FieldElementType.FLOAT, + knnIndexOptions = @KnnIndexOptions(type = KnnAlgorithmType.HNSW, m = 16, confidenceInterval = 0.95F), + knnSimilarity = KnnSimilarity.DOT_PRODUCT) + private float[] dense_vector; + } + static class DisabledMappingProperty { @Nullable @Id private String id; diff --git a/src/test/java/org/springframework/data/elasticsearch/repositories/knn/KnnSearchELCIntegrationTests.java b/src/test/java/org/springframework/data/elasticsearch/repositories/knn/KnnSearchELCIntegrationTests.java new file mode 100644 index 000000000..6edc3044d --- /dev/null +++ b/src/test/java/org/springframework/data/elasticsearch/repositories/knn/KnnSearchELCIntegrationTests.java @@ -0,0 +1,44 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.repositories.knn; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.data.elasticsearch.junit.jupiter.ElasticsearchTemplateConfiguration; +import org.springframework.data.elasticsearch.repository.config.EnableElasticsearchRepositories; +import org.springframework.data.elasticsearch.utils.IndexNameProvider; +import org.springframework.test.context.ContextConfiguration; + +/** + * @author Haibo Liu + * @since 5.4 + */ +@ContextConfiguration(classes = { KnnSearchELCIntegrationTests.Config.class }) +public class KnnSearchELCIntegrationTests extends KnnSearchIntegrationTests { + + @Configuration + @Import({ ElasticsearchTemplateConfiguration.class }) + @EnableElasticsearchRepositories( + basePackages = { "org.springframework.data.elasticsearch.repositories.knn" }, + considerNestedRepositories = true) + static class Config { + @Bean + IndexNameProvider indexNameProvider() { + return new IndexNameProvider("knn-repository"); + } + } +} diff --git a/src/test/java/org/springframework/data/elasticsearch/repositories/knn/KnnSearchIntegrationTests.java b/src/test/java/org/springframework/data/elasticsearch/repositories/knn/KnnSearchIntegrationTests.java new file mode 100644 index 000000000..1def46cd8 --- /dev/null +++ b/src/test/java/org/springframework/data/elasticsearch/repositories/knn/KnnSearchIntegrationTests.java @@ -0,0 +1,179 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.repositories.knn; + +import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.elasticsearch.annotations.FieldType.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Pageable; +import org.springframework.data.elasticsearch.annotations.*; +import org.springframework.data.elasticsearch.client.elc.NativeQuery; +import org.springframework.data.elasticsearch.client.elc.NativeQueryBuilder; +import org.springframework.data.elasticsearch.core.ElasticsearchOperations; +import org.springframework.data.elasticsearch.core.SearchHit; +import org.springframework.data.elasticsearch.core.SearchHits; +import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; +import org.springframework.data.elasticsearch.junit.jupiter.SpringIntegrationTest; +import org.springframework.data.elasticsearch.repository.ElasticsearchRepository; +import org.springframework.data.elasticsearch.utils.IndexNameProvider; +import org.springframework.lang.Nullable; + +/** + * @author Haibo Liu + * @since 5.4 + */ +@SpringIntegrationTest +public abstract class KnnSearchIntegrationTests { + + @Autowired ElasticsearchOperations operations; + @Autowired private IndexNameProvider indexNameProvider; + @Autowired private VectorEntityRepository vectorEntityRepository; + + @BeforeEach + public void before() { + indexNameProvider.increment(); + operations.indexOps(VectorEntity.class).createWithMapping(); + } + + @Test + @org.junit.jupiter.api.Order(java.lang.Integer.MAX_VALUE) + void cleanup() { + operations.indexOps(IndexCoordinates.of(indexNameProvider.getPrefix() + "*")).delete(); + } + + private List createVectorEntities(int n) { + List entities = new ArrayList<>(); + float increment = 1.0f / n; + for (int i = 0; i < n; i++) { + VectorEntity entity = new VectorEntity(); + entity.setId(UUID.randomUUID().toString()); + entity.setMessage("top" + (i + 1)); + + // The generated vector is always in the first quadrant, from the x-axis direction to the y-axis direction + float[] vector = new float[] {1.0f - i * increment, increment}; + entity.setVector(vector); + entities.add(entity); + } + + return entities; + } + + @Test + public void shouldReturnXAxisVector() { + + // given + List entities = createVectorEntities(5); + vectorEntityRepository.saveAll(entities); + List xAxisVector = List.of(100f, 0f); + + // when + NativeQuery query = new NativeQueryBuilder() + .withKnnSearches(ksb -> ksb.queryVector(xAxisVector).k(3L).field("vector")) + .withPageable(Pageable.ofSize(2)) + .build(); + SearchHits result = operations.search(query, VectorEntity.class); + + List vectorEntities = result.getSearchHits().stream().map(SearchHit::getContent).toList(); + + // then + assertThat(result).isNotNull(); + assertThat(result.getTotalHits()).isEqualTo(3L); + // should return the first vector, because it's near x-axis + assertThat(vectorEntities.get(0).getMessage()).isEqualTo("top1"); + } + + @Test + public void shouldReturnYAxisVector() { + + // given + List entities = createVectorEntities(10); + vectorEntityRepository.saveAll(entities); + List yAxisVector = List.of(0f, 100f); + + // when + NativeQuery query = new NativeQueryBuilder() + .withKnnSearches(ksb -> ksb.queryVector(yAxisVector).k(3L).field("vector")) + .withPageable(Pageable.ofSize(2)) + .build(); + SearchHits result = operations.search(query, VectorEntity.class); + + List vectorEntities = result.getSearchHits().stream().map(SearchHit::getContent).toList(); + + // then + assertThat(result).isNotNull(); + assertThat(result.getTotalHits()).isEqualTo(3L); + // should return the last vector, because it's near y-axis + assertThat(vectorEntities.get(0).getMessage()).isEqualTo("top10"); + } + + public interface VectorEntityRepository extends ElasticsearchRepository { + } + + @Document(indexName = "#{@indexNameProvider.indexName()}") + static class VectorEntity { + @Nullable + @Id + private String id; + + @Nullable + @Field(type = Keyword) + private String message; + + // TODO: `elementType = FieldElementType.FLOAT,` is to be added here later + // TODO: element_type can not be set here, because it's left out in elasticsearch-specification + // TODO: the issue is fixed in https://github.com/elastic/elasticsearch-java/pull/800, but still not released in 8.13.x + // TODO: will be fixed later by either upgrading to 8.14.0 or a newer 8.13.x + @Field(type = FieldType.Dense_Vector, dims = 2, + knnIndexOptions = @KnnIndexOptions(type = KnnAlgorithmType.HNSW, m = 16, efConstruction = 100), + knnSimilarity = KnnSimilarity.COSINE) + private float[] vector; + + @Nullable + public String getId() { + return id; + } + + public void setId(@Nullable String id) { + this.id = id; + } + + @Nullable + public String getMessage() { + return message; + } + + public void setMessage(@Nullable String message) { + this.message = message; + } + + @Nullable + public float[] getVector() { + return vector; + } + + public void setVector(@Nullable float[] vector) { + this.vector = vector; + } + } +}