From b81053b5b4ca5c050ce0fc171823d8dba79b9479 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Mon, 28 Oct 2024 12:40:17 -0700 Subject: [PATCH 01/13] Add vector search support using BSON Subtype 9 Vector. JAVA-5650 Correct assume in vector tests. JAVA-5650 Add tests. JAVA-5650 --- .../com/mongodb/client/model/Aggregates.java | 91 +++-- .../CreateSearchIndexesOperation.java | 4 +- .../operation/ListSearchIndexesOperation.java | 9 +- .../operation/SearchIndexRequest.java | 5 +- ...AggregatesVectorSearchIntegrationTest.java | 324 ++++++++++++++++++ .../mongodb/client/test/CollectionHelper.java | 24 ++ .../model/AggregatesSpecification.groovy | 75 ++-- .../org/mongodb/scala/model/Aggregates.scala | 27 +- .../mongodb/scala/model/AggregatesSpec.scala | 92 ++++- 9 files changed, 560 insertions(+), 91 deletions(-) create mode 100644 driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java diff --git a/driver-core/src/main/com/mongodb/client/model/Aggregates.java b/driver-core/src/main/com/mongodb/client/model/Aggregates.java index 4bb3a03771c..af0384ba0dd 100644 --- a/driver-core/src/main/com/mongodb/client/model/Aggregates.java +++ b/driver-core/src/main/com/mongodb/client/model/Aggregates.java @@ -37,6 +37,7 @@ import org.bson.BsonType; import org.bson.BsonValue; import org.bson.Document; +import org.bson.Vector; import org.bson.codecs.configuration.CodecRegistry; import org.bson.conversions.Bson; @@ -963,28 +964,37 @@ public static Bson vectorSearch( notNull("queryVector", queryVector); notNull("index", index); notNull("options", options); - return new Bson() { - @Override - public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { - Document specificationDoc = new Document("path", path.toValue()) - .append("queryVector", queryVector) - .append("index", index) - .append("limit", limit); - specificationDoc.putAll(options.toBsonDocument(documentClass, codecRegistry)); - return new Document("$vectorSearch", specificationDoc).toBsonDocument(documentClass, codecRegistry); - } + return new VectorSearchBson<>(path, queryVector, index, limit, options); + } - @Override - public String toString() { - return "Stage{name=$vectorSearch" - + ", path=" + path - + ", queryVector=" + queryVector - + ", index=" + index - + ", limit=" + limit - + ", options=" + options - + '}'; - } - }; + /** + * Creates a {@code $vectorSearch} pipeline stage supported by MongoDB Atlas. + * You may use the {@code $meta: "vectorSearchScore"} expression, e.g., via {@link Projections#metaVectorSearchScore(String)}, + * to extract the relevance score assigned to each found document. + * + * @param queryVector The {@linkplain Vector query vector}. The number of dimensions must match that of the {@code index}. + * @param path The field to be searched. + * @param index The name of the index to use. + * @param limit The limit on the number of documents produced by the pipeline stage. + * @param options Optional {@code $vectorSearch} pipeline stage fields. + * @return The {@code $vectorSearch} pipeline stage. + * @mongodb.atlas.manual atlas-vector-search/vector-search-stage/ $vectorSearch + * @mongodb.atlas.manual atlas-search/scoring/ Scoring + * @mongodb.server.release 6.0.11 + * @see Vector + * @since 5.3 + */ + public static Bson vectorSearch( + final FieldSearchPath path, + final Vector queryVector, + final String index, + final long limit, + final VectorSearchOptions options) { + notNull("path", path); + notNull("queryVector", queryVector); + notNull("index", index); + notNull("options", options); + return new VectorSearchBson<>(path, queryVector, index, limit, options); } /** @@ -2145,6 +2155,45 @@ public String toString() { } } + private static class VectorSearchBson implements Bson { + private final FieldSearchPath path; + private final T queryVector; + private final String index; + private final long limit; + private final VectorSearchOptions options; + + VectorSearchBson(final FieldSearchPath path, final T queryVector, + final String index, final long limit, + final VectorSearchOptions options) { + this.path = path; + this.queryVector = queryVector; + this.index = index; + this.limit = limit; + this.options = options; + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + Document specificationDoc = new Document("path", path.toValue()) + .append("queryVector", queryVector) + .append("index", index) + .append("limit", limit); + specificationDoc.putAll(options.toBsonDocument(documentClass, codecRegistry)); + return new Document("$vectorSearch", specificationDoc).toBsonDocument(documentClass, codecRegistry); + } + + @Override + public String toString() { + return "Stage{name=$vectorSearch" + + ", path=" + path + + ", queryVector=" + queryVector + + ", index=" + index + + ", limit=" + limit + + ", options=" + options + + '}'; + } + } + private Aggregates() { } } diff --git a/driver-core/src/main/com/mongodb/internal/operation/CreateSearchIndexesOperation.java b/driver-core/src/main/com/mongodb/internal/operation/CreateSearchIndexesOperation.java index 2e52e3fa0ae..a57087e9217 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/CreateSearchIndexesOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/CreateSearchIndexesOperation.java @@ -32,11 +32,11 @@ * *

This class is not part of the public API and may be removed or changed at any time

*/ -final class CreateSearchIndexesOperation extends AbstractWriteSearchIndexOperation { +public final class CreateSearchIndexesOperation extends AbstractWriteSearchIndexOperation { private static final String COMMAND_NAME = "createSearchIndexes"; private final List indexRequests; - CreateSearchIndexesOperation(final MongoNamespace namespace, final List indexRequests) { + public CreateSearchIndexesOperation(final MongoNamespace namespace, final List indexRequests) { super(namespace); this.indexRequests = assertNotNull(indexRequests); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java b/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java index 0f9a81dbf19..3dfde30511d 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ListSearchIndexesOperation.java @@ -42,7 +42,7 @@ * *

This class is not part of the public API and may be removed or changed at any time

*/ -final class ListSearchIndexesOperation +public final class ListSearchIndexesOperation implements AsyncExplainableReadOperation>, ExplainableReadOperation> { private static final String STAGE_LIST_SEARCH_INDEXES = "$listSearchIndexes"; private final MongoNamespace namespace; @@ -59,9 +59,10 @@ final class ListSearchIndexesOperation private final String indexName; private final boolean retryReads; - ListSearchIndexesOperation(final MongoNamespace namespace, final Decoder decoder, @Nullable final String indexName, - @Nullable final Integer batchSize, @Nullable final Collation collation, @Nullable final BsonValue comment, - @Nullable final Boolean allowDiskUse, final boolean retryReads) { + public ListSearchIndexesOperation(final MongoNamespace namespace, final Decoder decoder, @Nullable final String indexName, + @Nullable final Integer batchSize, @Nullable final Collation collation, + @Nullable final BsonValue comment, + @Nullable final Boolean allowDiskUse, final boolean retryReads) { this.namespace = namespace; this.decoder = decoder; this.allowDiskUse = allowDiskUse; diff --git a/driver-core/src/main/com/mongodb/internal/operation/SearchIndexRequest.java b/driver-core/src/main/com/mongodb/internal/operation/SearchIndexRequest.java index 0d37d2c2178..29b9b1ef34d 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/SearchIndexRequest.java +++ b/driver-core/src/main/com/mongodb/internal/operation/SearchIndexRequest.java @@ -31,14 +31,15 @@ * *

This class is not part of the public API and may be removed or changed at any time

*/ -final class SearchIndexRequest { +public final class SearchIndexRequest { private final BsonDocument definition; @Nullable private final String indexName; @Nullable private final SearchIndexType searchIndexType; - SearchIndexRequest(final BsonDocument definition, @Nullable final String indexName, @Nullable final SearchIndexType searchIndexType) { + public SearchIndexRequest(final BsonDocument definition, @Nullable final String indexName, + @Nullable final SearchIndexType searchIndexType) { assertNotNull(definition); this.definition = definition; this.indexName = indexName; diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java new file mode 100644 index 00000000000..9ade413073c --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -0,0 +1,324 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model.search; + +import com.mongodb.MongoNamespace; +import com.mongodb.client.model.Aggregates; +import com.mongodb.client.model.SearchIndexType; +import com.mongodb.client.test.CollectionHelper; +import com.mongodb.internal.operation.SearchIndexRequest; +import org.bson.BsonDocument; +import org.bson.Document; +import org.bson.Vector; +import org.bson.codecs.DocumentCodec; +import org.bson.conversions.Bson; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import static com.mongodb.ClusterFixture.isAtlasSearchTest; +import static com.mongodb.ClusterFixture.serverVersionAtLeast; +import static com.mongodb.client.model.Filters.and; +import static com.mongodb.client.model.Filters.eq; +import static com.mongodb.client.model.Filters.gt; +import static com.mongodb.client.model.Filters.gte; +import static com.mongodb.client.model.Filters.in; +import static com.mongodb.client.model.Filters.lt; +import static com.mongodb.client.model.Filters.lte; +import static com.mongodb.client.model.Filters.ne; +import static com.mongodb.client.model.Filters.nin; +import static com.mongodb.client.model.Filters.or; +import static com.mongodb.client.model.Projections.fields; +import static com.mongodb.client.model.Projections.metaVectorSearchScore; +import static com.mongodb.client.model.search.SearchPath.fieldPath; +import static com.mongodb.client.model.search.VectorSearchOptions.approximateVectorSearchOptions; +import static com.mongodb.client.model.search.VectorSearchOptions.exactVectorSearchOptions; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +class AggregatesVectorSearchIntegrationTest { + private static final String VECTOR_INDEX = "vector_search_index"; + private static final String VECTOR_FIELD_INT_8 = "int8Vector"; + private static final String VECTOR_FIELD_FLOAT_32 = "float32Vector"; + private static final String VECTOR_FIELD_LEGACY_DOUBLE_LIST = "legacyDoubleVector"; + private static final int LIMIT = 5; + private static final String FIELD_YEAR = "year"; + private static CollectionHelper collectionHelper; + private static final BsonDocument VECTOR_SEARCH_DEFINITION = BsonDocument.parse( + "{" + + " fields: [" + + " {" + + " path: '" + VECTOR_FIELD_INT_8 + "'," + + " numDimensions: 5," + + " similarity: 'cosine'," + + " type: 'vector'," + + " }," + + " {" + + " path: '" + VECTOR_FIELD_FLOAT_32 + "'," + + " numDimensions: 5," + + " similarity: 'cosine'," + + " type: 'vector'," + + " }," + + " {" + + " path: '" + VECTOR_FIELD_LEGACY_DOUBLE_LIST + "'," + + " numDimensions: 5," + + " similarity: 'cosine'," + + " type: 'vector'," + + " }," + + " {" + + " path: '" + FIELD_YEAR + "'," + + " type: 'filter'," + + " }," + + " ]" + + "}"); + + @BeforeAll + static void beforeAll() throws InterruptedException { + collectionHelper = + new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("test", "test")); + collectionHelper.drop(); + collectionHelper.insertDocuments( + new Document() + .append("_id", 0) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{0, 1, 2, 3, 4})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{0.0001, 1.12345, 2.23456, 3.34567, 4.45678}) + .append(FIELD_YEAR, 2016), + new Document() + .append("_id", 1) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{1, 2, 3, 4, 5})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{1.0001f, 2.12345f, 3.23456f, 4.34567f, 5.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{1.0001, 2.12345, 3.23456, 4.34567, 5.45678}) + .append(FIELD_YEAR, 2017), + new Document() + .append("_id", 2) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{2, 3, 4, 5, 6})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{2.0002f, 3.12345f, 4.23456f, 5.34567f, 6.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{2.0002, 3.12345, 4.23456, 5.34567, 6.45678}) + .append(FIELD_YEAR, 2018), + new Document() + .append("_id", 3) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{3, 4, 5, 6, 7})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{3.0003f, 4.12345f, 5.23456f, 6.34567f, 7.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{3.0003, 4.12345, 5.23456, 6.34567, 7.45678}) + .append(FIELD_YEAR, 2019), + new Document() + .append("_id", 4) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{4, 5, 6, 7, 8})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{4.0004f, 5.12345f, 6.23456f, 7.34567f, 8.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{4.0004, 5.12345, 6.23456, 7.34567, 8.45678}) + .append(FIELD_YEAR, 2020), + new Document() + .append("_id", 5) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{5, 6, 7, 8, 9})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{5.0005f, 6.12345f, 7.23456f, 8.34567f, 9.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{5.0005, 6.12345, 7.23456, 8.34567, 9.45678}) + .append(FIELD_YEAR, 2021), + new Document() + .append("_id", 6) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{6, 7, 8, 9, 10})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{6.0006f, 7.12345f, 8.23456f, 9.34567f, 10.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{6.0006, 7.12345, 8.23456, 9.34567, 10.45678}) + .append(FIELD_YEAR, 2022), + new Document() + .append("_id", 7) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{7, 8, 9, 10, 11})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{7.0007f, 8.12345f, 9.23456f, 10.34567f, 11.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{7.0007, 8.12345, 9.23456, 10.34567, 11.45678}) + .append(FIELD_YEAR, 2023), + new Document() + .append("_id", 8) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{8, 9, 10, 11, 12})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{8.0008f, 9.12345f, 10.23456f, 11.34567f, 12.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{8.0008, 9.12345, 10.23456, 11.34567, 12.45678}) + .append(FIELD_YEAR, 2024), + new Document() + .append("_id", 9) + .append(VECTOR_FIELD_INT_8, Vector.int8Vector(new byte[]{9, 10, 11, 12, 13})) + .append(VECTOR_FIELD_FLOAT_32, Vector.floatVector(new float[]{9.0009f, 10.12345f, 11.23456f, 12.34567f, 13.45678f})) + .append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{9.0009, 10.12345, 11.23456, 12.34567, 13.45678}) + .append(FIELD_YEAR, 2025) + ); + + collectionHelper.createSearchIndex( + new SearchIndexRequest(VECTOR_SEARCH_DEFINITION, VECTOR_INDEX, + SearchIndexType.vectorSearch())); + awaitIndexCreation(); + } + + @AfterAll + static void afterAll() { + collectionHelper.drop(); + } + + @BeforeEach + void beforeEach() { + assumeTrue(isAtlasSearchTest()); + assumeTrue(serverVersionAtLeast(6, 0)); + } + + private static Stream provideSupportedVectors() { + return Stream.of( + arguments(Vector.int8Vector(new byte[]{0, 1, 2, 3, 4}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_INT_8).multi("ignored"), + approximateVectorSearchOptions(LIMIT * 2)), + + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_FLOAT_32), + approximateVectorSearchOptions(LIMIT * 2)), + + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_FLOAT_32).multi("ignored"), + exactVectorSearchOptions()), + + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST).multi("ignored"), + exactVectorSearchOptions()), + + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST).multi("ignored"), + approximateVectorSearchOptions(LIMIT * 2)) + ); + } + + @ParameterizedTest + @MethodSource("provideSupportedVectors") + void shouldSearchBySupportedVectorWithSearchScore(final Vector vector, + final FieldSearchPath fieldSearchPath, + final VectorSearchOptions vectorSearchOptions) { + //given + List pipeline = asList( + Aggregates.vectorSearch( + fieldSearchPath, + vector, + VECTOR_INDEX, LIMIT, + vectorSearchOptions), + Aggregates.project( + fields( + metaVectorSearchScore("vectorSearchScore") + )) + ); + + //when + List aggregate = collectionHelper.aggregate(pipeline); + + //then + Assertions.assertEquals(LIMIT, aggregate.size()); + assertScoreIsDecreasing(aggregate); + Document highestScoreDocument = aggregate.get(0); + assertEquals(1, highestScoreDocument.getDouble("vectorSearchScore")); + } + + @ParameterizedTest + @MethodSource("provideSupportedVectors") + void shouldSearchBySupportedVector(final Vector vector, + final FieldSearchPath fieldSearchPath, + final VectorSearchOptions vectorSearchOptions) { + //given + List pipeline = asList( + Aggregates.vectorSearch( + fieldSearchPath, + vector, + VECTOR_INDEX, LIMIT, + vectorSearchOptions) + ); + + //when + List aggregate = collectionHelper.aggregate(pipeline); + + //then + Assertions.assertEquals(LIMIT, aggregate.size()); + assertFalse( + aggregate.stream() + .anyMatch(document -> document.containsKey("vectorSearchScore")) + ); + } + + @ParameterizedTest + @MethodSource("provideSupportedVectors") + void provideSupportedVectors(final Vector vector, + final FieldSearchPath fieldSearchPath, + final VectorSearchOptions vectorSearchOptions) { + Consumer asserter = filter -> { + List pipeline = singletonList( + Aggregates.vectorSearch( + fieldSearchPath, vector, VECTOR_INDEX, 1, + vectorSearchOptions.filter(filter)) + ); + + List aggregate = collectionHelper.aggregate(pipeline); + Assertions.assertFalse(aggregate.isEmpty()); + }; + + assertAll( + () -> asserter.accept(lt("year", 2020)), + () -> asserter.accept(lte("year", 2020)), + () -> asserter.accept(eq("year", 2020)), + () -> asserter.accept(gte("year", 2016)), + () -> asserter.accept(gt("year", 2015)), + () -> asserter.accept(ne("year", 2016)), + () -> asserter.accept(in("year", 2000, 2024)), + () -> asserter.accept(nin("year", 2000, 2024)), + () -> asserter.accept(and(gte("year", 2015), lte("year", 2017))), + () -> asserter.accept(or(eq("year", 2015), eq("year", 2017))) + ); + } + + private static void assertScoreIsDecreasing(final List aggregate) { + double previousScore = Integer.MAX_VALUE; + for (Document document : aggregate) { + Double vectorSearchScore = document.getDouble("vectorSearchScore"); + assertTrue(vectorSearchScore > 0, "Expected positive score"); + assertTrue(vectorSearchScore < previousScore, "Expected decreasing score"); + previousScore = vectorSearchScore; + } + } + + private static void awaitIndexCreation() throws InterruptedException { + int attempts = 5; + while (attempts-- > 0) { + if (collectionHelper.listSearchIndex(VECTOR_INDEX) + .filter(document -> document.getBoolean("queryable")) + .isPresent()) { + return; + } + + TimeUnit.SECONDS.sleep(1); + } + Assertions.fail("Exceeded maximum attempts waiting for Search Index creation in Atlas cluster"); + } +} diff --git a/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java b/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java index e297726d325..adce165ee51 100644 --- a/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java +++ b/driver-core/src/test/functional/com/mongodb/client/test/CollectionHelper.java @@ -43,11 +43,14 @@ import com.mongodb.internal.operation.CountDocumentsOperation; import com.mongodb.internal.operation.CreateCollectionOperation; import com.mongodb.internal.operation.CreateIndexesOperation; +import com.mongodb.internal.operation.CreateSearchIndexesOperation; import com.mongodb.internal.operation.DropCollectionOperation; import com.mongodb.internal.operation.DropDatabaseOperation; import com.mongodb.internal.operation.FindOperation; import com.mongodb.internal.operation.ListIndexesOperation; +import com.mongodb.internal.operation.ListSearchIndexesOperation; import com.mongodb.internal.operation.MixedBulkWriteOperation; +import com.mongodb.internal.operation.SearchIndexRequest; import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonDocumentWrapper; @@ -56,6 +59,7 @@ import org.bson.BsonString; import org.bson.BsonValue; import org.bson.Document; +import org.bson.assertions.Assertions; import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.Codec; import org.bson.codecs.Decoder; @@ -65,6 +69,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static com.mongodb.ClusterFixture.executeAsync; @@ -297,6 +302,25 @@ public List find() { return find(codec); } + public Optional listSearchIndex(final String indexName) { + ListSearchIndexesOperation listSearchIndexesOperation = + new ListSearchIndexesOperation<>(namespace, codec, indexName, null, null, null, null, true); + BatchCursor cursor = listSearchIndexesOperation.execute(getBinding()); + + List results = new ArrayList<>(); + while (cursor.hasNext()) { + results.addAll(cursor.next()); + } + Assertions.assertTrue("Expected at most one result, but found " + results.size(), results.size() <= 1); + return results.isEmpty() ? Optional.empty() : Optional.of(results.get(0)); + } + + public void createSearchIndex(final SearchIndexRequest searchIndexModel) { + CreateSearchIndexesOperation searchIndexesOperation = + new CreateSearchIndexesOperation(namespace, singletonList(searchIndexModel)); + searchIndexesOperation.execute(getBinding()); + } + public List find(final Codec codec) { BatchCursor cursor = new FindOperation<>(namespace, codec) .sort(new BsonDocument("_id", new BsonInt32(1))) diff --git a/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy b/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy index 21df76e401e..d9819d691d5 100644 --- a/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy @@ -23,6 +23,7 @@ import com.mongodb.client.model.search.SearchOperator import org.bson.BsonDocument import org.bson.BsonInt32 import org.bson.Document +import org.bson.Vector import org.bson.conversions.Bson import spock.lang.IgnoreIf import spock.lang.Specification @@ -30,59 +31,12 @@ import spock.lang.Specification import static BucketGranularity.R5 import static MongoTimeUnit.DAY import static com.mongodb.ClusterFixture.serverVersionLessThan -import static com.mongodb.client.model.Accumulators.accumulator -import static com.mongodb.client.model.Accumulators.addToSet -import static com.mongodb.client.model.Accumulators.avg -import static com.mongodb.client.model.Accumulators.bottom -import static com.mongodb.client.model.Accumulators.bottomN -import static com.mongodb.client.model.Accumulators.first -import static com.mongodb.client.model.Accumulators.firstN -import static com.mongodb.client.model.Accumulators.last -import static com.mongodb.client.model.Accumulators.lastN -import static com.mongodb.client.model.Accumulators.max -import static com.mongodb.client.model.Accumulators.maxN -import static com.mongodb.client.model.Accumulators.mergeObjects -import static com.mongodb.client.model.Accumulators.min -import static com.mongodb.client.model.Accumulators.minN -import static com.mongodb.client.model.Accumulators.push -import static com.mongodb.client.model.Accumulators.stdDevPop -import static com.mongodb.client.model.Accumulators.stdDevSamp -import static com.mongodb.client.model.Accumulators.sum -import static com.mongodb.client.model.Accumulators.top -import static com.mongodb.client.model.Accumulators.topN -import static com.mongodb.client.model.Aggregates.addFields -import static com.mongodb.client.model.Aggregates.bucket -import static com.mongodb.client.model.Aggregates.bucketAuto -import static com.mongodb.client.model.Aggregates.count -import static com.mongodb.client.model.Aggregates.densify -import static com.mongodb.client.model.Aggregates.fill -import static com.mongodb.client.model.Aggregates.graphLookup -import static com.mongodb.client.model.Aggregates.group -import static com.mongodb.client.model.Aggregates.limit -import static com.mongodb.client.model.Aggregates.lookup -import static com.mongodb.client.model.Aggregates.match -import static com.mongodb.client.model.Aggregates.merge -import static com.mongodb.client.model.Aggregates.out -import static com.mongodb.client.model.Aggregates.project -import static com.mongodb.client.model.Aggregates.replaceRoot -import static com.mongodb.client.model.Aggregates.replaceWith -import static com.mongodb.client.model.Aggregates.sample -import static com.mongodb.client.model.Aggregates.search -import static com.mongodb.client.model.Aggregates.searchMeta -import static com.mongodb.client.model.Aggregates.set -import static com.mongodb.client.model.Aggregates.setWindowFields -import static com.mongodb.client.model.Aggregates.skip -import static com.mongodb.client.model.Aggregates.sort -import static com.mongodb.client.model.Aggregates.sortByCount -import static com.mongodb.client.model.Aggregates.unionWith -import static com.mongodb.client.model.Aggregates.unwind -import static com.mongodb.client.model.Aggregates.vectorSearch +import static com.mongodb.client.model.Accumulators.* +import static com.mongodb.client.model.Aggregates.* import static com.mongodb.client.model.BsonHelper.toBson import static com.mongodb.client.model.Filters.eq import static com.mongodb.client.model.Filters.expr -import static com.mongodb.client.model.Projections.computed -import static com.mongodb.client.model.Projections.fields -import static com.mongodb.client.model.Projections.include +import static com.mongodb.client.model.Projections.* import static com.mongodb.client.model.Sorts.ascending import static com.mongodb.client.model.Sorts.descending import static com.mongodb.client.model.Windows.Bound.CURRENT @@ -855,7 +809,7 @@ class AggregatesSpecification extends Specification { BsonDocument vectorSearchDoc = toBson( vectorSearch( fieldPath('fieldName').multi('ignored'), - [1.0d, 2.0d], + vector, 'indexName', 1, approximateVectorSearchOptions(2) @@ -868,13 +822,20 @@ class AggregatesSpecification extends Specification { vectorSearchDoc == parse('''{ "$vectorSearch": { "path": "fieldName", - "queryVector": [1.0, 2.0], + "queryVector": ''' + queryVector + ''', "index": "indexName", "numCandidates": {"$numberLong": "2"}, "limit": {"$numberLong": "1"}, "filter": {"fieldName": {"$ne": "fieldValue"}} } }''') + + where: + vectorType | vector | queryVector + "int8" | Vector.int8Vector(new byte[]{127, 7}) | '{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}' + "float32" | Vector.floatVector(new float[]{127.0f, 7.0f}) | '{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}' + "packedBit" | Vector.packedBitVector(new byte[]{127, 7}, (byte) 0) | '{"$binary": {"base64": "EAB/Bw==", "subType": "09"}}' + "double" | [1.0d, 2.0d] | "[1.0, 2.0]" } def 'should render exact $vectorSearch'() { @@ -882,7 +843,7 @@ class AggregatesSpecification extends Specification { BsonDocument vectorSearchDoc = toBson( vectorSearch( fieldPath('fieldName').multi('ignored'), - [1.0d, 2.0d], + vector, 'indexName', 1, exactVectorSearchOptions() @@ -895,13 +856,19 @@ class AggregatesSpecification extends Specification { vectorSearchDoc == parse('''{ "$vectorSearch": { "path": "fieldName", - "queryVector": [1.0, 2.0], + "queryVector": ''' + queryVector + ''', "index": "indexName", "exact": true, "limit": {"$numberLong": "1"}, "filter": {"fieldName": {"$ne": "fieldValue"}} } }''') + + where: + vectorType | vector | queryVector + "int8" | Vector.int8Vector(new byte[]{127, 7}) | '{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}' + "float32" | Vector.floatVector(new float[]{127.0f, 7.0f}) | '{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}' + "double" | [1.0d, 2.0d] | "[1.0, 2.0]" } def 'should create string representation for simple stages'() { diff --git a/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala b/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala index c7b8d120cf7..0f5f5636360 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala @@ -16,12 +16,12 @@ package org.mongodb.scala.model -import com.mongodb.annotations.{ Beta, Reason } import com.mongodb.client.model.fill.FillOutputField import com.mongodb.client.model.search.FieldSearchPath import scala.collection.JavaConverters._ import com.mongodb.client.model.{ Aggregates => JAggregates } +import org.bson.Vector import org.mongodb.scala.MongoNamespace import org.mongodb.scala.bson.conversions.Bson import org.mongodb.scala.model.densify.{ DensifyOptions, DensifyRange } @@ -746,6 +746,31 @@ object Aggregates { ): Bson = JAggregates.vectorSearch(path, queryVector.asJava, index, limit, options) + /** + * Creates a `\$vectorSearch` pipeline stage supported by MongoDB Atlas. + * You may use the `\$meta: "vectorSearchScore"` expression, e.g., via [[Projections.metaVectorSearchScore]], + * to extract the relevance score assigned to each found document. + * + * @param queryVector The query vector. The number of dimensions must match that of the `index`. + * @param path The field to be searched. + * @param index The name of the index to use. + * @param limit The limit on the number of documents produced by the pipeline stage. + * @param options Optional `\$vectorSearch` pipeline stage fields. + * @return The `\$vectorSearch` pipeline stage. + * @see [[https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ \$vectorSearch]] + * @note Requires MongoDB 6.0.10 or greater + * @see [[org.bson.Vector]] + * @since 5.3 + */ + def vectorSearch( + path: FieldSearchPath, + queryVector: org.bson.Vector, + index: String, + limit: Long, + options: VectorSearchOptions + ): Bson = + JAggregates.vectorSearch(path, queryVector, index, limit, options) + /** * Creates an `\$unset` pipeline stage that removes/excludes fields from documents * diff --git a/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala b/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala index 25152a22d97..e38fb8189b4 100644 --- a/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala +++ b/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala @@ -18,8 +18,6 @@ package org.mongodb.scala.model import com.mongodb.client.model.GeoNearOptions.geoNearOptions import com.mongodb.client.model.fill.FillOutputField - -import java.lang.reflect.Modifier._ import org.bson.BsonDocument import org.mongodb.scala.bson.BsonArray import org.mongodb.scala.bson.collection.immutable.Document @@ -34,15 +32,19 @@ import org.mongodb.scala.model.Windows.{ documents, range } import org.mongodb.scala.model.densify.DensifyRange.fullRangeWithStep import org.mongodb.scala.model.fill.FillOptions.fillOptions import org.mongodb.scala.model.geojson.{ Point, Position } +import org.mongodb.scala.model.search.SearchCollector import org.mongodb.scala.model.search.SearchCount.total import org.mongodb.scala.model.search.SearchFacet.stringFacet import org.mongodb.scala.model.search.SearchHighlight.paths -import org.mongodb.scala.model.search.SearchCollector import org.mongodb.scala.model.search.SearchOperator.exists import org.mongodb.scala.model.search.SearchOptions.searchOptions import org.mongodb.scala.model.search.SearchPath.{ fieldPath, wildcardPath } import org.mongodb.scala.model.search.VectorSearchOptions.{ approximateVectorSearchOptions, exactVectorSearchOptions } import org.mongodb.scala.{ BaseSpec, MongoClient, MongoNamespace } +import org.scalatest.prop.TableDrivenPropertyChecks.forAll +import org.scalatest.prop.Tables.Table + +import java.lang.reflect.Modifier._ class AggregatesSpec extends BaseSpec { val registry = MongoClient.DEFAULT_CODEC_REGISTRY @@ -763,11 +765,85 @@ class AggregatesSpec extends BaseSpec { ) } - it should "render approximate $vectorSearch" in { + val vectorTestCases = Table( + ("vector", "queryVector"), + ( + org.bson.Vector.int8Vector(Array(127.toByte, 7.toByte)), + """{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}""" + ), + ( + org.bson.Vector.floatVector(Array(127.0f, 7.0f)), + """{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}""" + ), + ( + org.bson.Vector.packedBitVector(Array(127.toByte, 7.toByte), 0.toByte), + """{"$binary": {"base64": "EAB/Bw==", "subType": "09"}}""" + ) + ) + + it should "render approximate $vectorSearch with Vector" in { + forAll(vectorTestCases) { (vector: org.bson.Vector, expectedSerializedVector: String) => + toBson( + Aggregates.vectorSearch( + fieldPath("fieldName").multi("ignored"), + vector, + "indexName", + 1, + approximateVectorSearchOptions(2) + .filter(Filters.ne("fieldName", "fieldValue")) + ) + ) should equal( + Document( + s"""{ + | "$$vectorSearch": { + | "path": "fieldName", + | "queryVector": $expectedSerializedVector, + | "index": "indexName", + | "limit": {"$$numberLong": "1"}, + | "numCandidates": {"$$numberLong": "2"}, + | "filter": {"fieldName": {"$$ne": "fieldValue"}} + | } + |}""".stripMargin + ) + ) + } + } + + it should "render exact $vectorSearch with Vector" in { + forAll(vectorTestCases) { (vector: org.bson.Vector, expectedSerializedVector: String) => + toBson( + Aggregates.vectorSearch( + fieldPath("fieldName").multi("ignored"), + vector, + "indexName", + 1, + exactVectorSearchOptions() + .filter(Filters.ne("fieldName", "fieldValue")) + ) + ) should equal( + Document( + s"""{ + | "$$vectorSearch": { + | "path": "fieldName", + | "queryVector": $expectedSerializedVector, + | "index": "indexName", + | "exact": true, + | "limit": {"$$numberLong": "1"}, + | "filter": {"fieldName": {"$$ne": "fieldValue"}} + | } + |}""".stripMargin + ) + ) + } + } + + it should "render approximate $vectorSearch with List" in { + val queryVectorJava: List[java.lang.Double] = List(1.0d, 2.0d) + toBson( Aggregates.vectorSearch( fieldPath("fieldName").multi("ignored"), - List(1.0d, 2.0d), + queryVectorJava, "indexName", 1, approximateVectorSearchOptions(2) @@ -789,11 +865,13 @@ class AggregatesSpec extends BaseSpec { ) } - it should "render exact $vectorSearch" in { + it should "render exact $vectorSearch with List" in { + val queryVectorJava: List[java.lang.Double] = List(1.0d, 2.0d) + toBson( Aggregates.vectorSearch( fieldPath("fieldName").multi("ignored"), - List(1.0d, 2.0d), + queryVectorJava, "indexName", 1, exactVectorSearchOptions() From b176ec20c87e19e7a426478e4c61b485b0b8429a Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 13:17:35 -0700 Subject: [PATCH 02/13] Revert Scala API due to backwards compatibility issues. JAVA-5650 --- .../org/mongodb/scala/model/Aggregates.scala | 27 +----- .../mongodb/scala/model/AggregatesSpec.scala | 92 ++----------------- 2 files changed, 8 insertions(+), 111 deletions(-) diff --git a/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala b/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala index 0f5f5636360..c7b8d120cf7 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala @@ -16,12 +16,12 @@ package org.mongodb.scala.model +import com.mongodb.annotations.{ Beta, Reason } import com.mongodb.client.model.fill.FillOutputField import com.mongodb.client.model.search.FieldSearchPath import scala.collection.JavaConverters._ import com.mongodb.client.model.{ Aggregates => JAggregates } -import org.bson.Vector import org.mongodb.scala.MongoNamespace import org.mongodb.scala.bson.conversions.Bson import org.mongodb.scala.model.densify.{ DensifyOptions, DensifyRange } @@ -746,31 +746,6 @@ object Aggregates { ): Bson = JAggregates.vectorSearch(path, queryVector.asJava, index, limit, options) - /** - * Creates a `\$vectorSearch` pipeline stage supported by MongoDB Atlas. - * You may use the `\$meta: "vectorSearchScore"` expression, e.g., via [[Projections.metaVectorSearchScore]], - * to extract the relevance score assigned to each found document. - * - * @param queryVector The query vector. The number of dimensions must match that of the `index`. - * @param path The field to be searched. - * @param index The name of the index to use. - * @param limit The limit on the number of documents produced by the pipeline stage. - * @param options Optional `\$vectorSearch` pipeline stage fields. - * @return The `\$vectorSearch` pipeline stage. - * @see [[https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ \$vectorSearch]] - * @note Requires MongoDB 6.0.10 or greater - * @see [[org.bson.Vector]] - * @since 5.3 - */ - def vectorSearch( - path: FieldSearchPath, - queryVector: org.bson.Vector, - index: String, - limit: Long, - options: VectorSearchOptions - ): Bson = - JAggregates.vectorSearch(path, queryVector, index, limit, options) - /** * Creates an `\$unset` pipeline stage that removes/excludes fields from documents * diff --git a/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala b/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala index e38fb8189b4..25152a22d97 100644 --- a/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala +++ b/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala @@ -18,6 +18,8 @@ package org.mongodb.scala.model import com.mongodb.client.model.GeoNearOptions.geoNearOptions import com.mongodb.client.model.fill.FillOutputField + +import java.lang.reflect.Modifier._ import org.bson.BsonDocument import org.mongodb.scala.bson.BsonArray import org.mongodb.scala.bson.collection.immutable.Document @@ -32,19 +34,15 @@ import org.mongodb.scala.model.Windows.{ documents, range } import org.mongodb.scala.model.densify.DensifyRange.fullRangeWithStep import org.mongodb.scala.model.fill.FillOptions.fillOptions import org.mongodb.scala.model.geojson.{ Point, Position } -import org.mongodb.scala.model.search.SearchCollector import org.mongodb.scala.model.search.SearchCount.total import org.mongodb.scala.model.search.SearchFacet.stringFacet import org.mongodb.scala.model.search.SearchHighlight.paths +import org.mongodb.scala.model.search.SearchCollector import org.mongodb.scala.model.search.SearchOperator.exists import org.mongodb.scala.model.search.SearchOptions.searchOptions import org.mongodb.scala.model.search.SearchPath.{ fieldPath, wildcardPath } import org.mongodb.scala.model.search.VectorSearchOptions.{ approximateVectorSearchOptions, exactVectorSearchOptions } import org.mongodb.scala.{ BaseSpec, MongoClient, MongoNamespace } -import org.scalatest.prop.TableDrivenPropertyChecks.forAll -import org.scalatest.prop.Tables.Table - -import java.lang.reflect.Modifier._ class AggregatesSpec extends BaseSpec { val registry = MongoClient.DEFAULT_CODEC_REGISTRY @@ -765,85 +763,11 @@ class AggregatesSpec extends BaseSpec { ) } - val vectorTestCases = Table( - ("vector", "queryVector"), - ( - org.bson.Vector.int8Vector(Array(127.toByte, 7.toByte)), - """{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}""" - ), - ( - org.bson.Vector.floatVector(Array(127.0f, 7.0f)), - """{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}""" - ), - ( - org.bson.Vector.packedBitVector(Array(127.toByte, 7.toByte), 0.toByte), - """{"$binary": {"base64": "EAB/Bw==", "subType": "09"}}""" - ) - ) - - it should "render approximate $vectorSearch with Vector" in { - forAll(vectorTestCases) { (vector: org.bson.Vector, expectedSerializedVector: String) => - toBson( - Aggregates.vectorSearch( - fieldPath("fieldName").multi("ignored"), - vector, - "indexName", - 1, - approximateVectorSearchOptions(2) - .filter(Filters.ne("fieldName", "fieldValue")) - ) - ) should equal( - Document( - s"""{ - | "$$vectorSearch": { - | "path": "fieldName", - | "queryVector": $expectedSerializedVector, - | "index": "indexName", - | "limit": {"$$numberLong": "1"}, - | "numCandidates": {"$$numberLong": "2"}, - | "filter": {"fieldName": {"$$ne": "fieldValue"}} - | } - |}""".stripMargin - ) - ) - } - } - - it should "render exact $vectorSearch with Vector" in { - forAll(vectorTestCases) { (vector: org.bson.Vector, expectedSerializedVector: String) => - toBson( - Aggregates.vectorSearch( - fieldPath("fieldName").multi("ignored"), - vector, - "indexName", - 1, - exactVectorSearchOptions() - .filter(Filters.ne("fieldName", "fieldValue")) - ) - ) should equal( - Document( - s"""{ - | "$$vectorSearch": { - | "path": "fieldName", - | "queryVector": $expectedSerializedVector, - | "index": "indexName", - | "exact": true, - | "limit": {"$$numberLong": "1"}, - | "filter": {"fieldName": {"$$ne": "fieldValue"}} - | } - |}""".stripMargin - ) - ) - } - } - - it should "render approximate $vectorSearch with List" in { - val queryVectorJava: List[java.lang.Double] = List(1.0d, 2.0d) - + it should "render approximate $vectorSearch" in { toBson( Aggregates.vectorSearch( fieldPath("fieldName").multi("ignored"), - queryVectorJava, + List(1.0d, 2.0d), "indexName", 1, approximateVectorSearchOptions(2) @@ -865,13 +789,11 @@ class AggregatesSpec extends BaseSpec { ) } - it should "render exact $vectorSearch with List" in { - val queryVectorJava: List[java.lang.Double] = List(1.0d, 2.0d) - + it should "render exact $vectorSearch" in { toBson( Aggregates.vectorSearch( fieldPath("fieldName").multi("ignored"), - queryVectorJava, + List(1.0d, 2.0d), "indexName", 1, exactVectorSearchOptions() From 138e2b6d60ed787a83ab1d70b25c50fd3435bfbd Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 13:31:57 -0700 Subject: [PATCH 03/13] Change test names. JAVA-5650 --- ...AggregatesVectorSearchIntegrationTest.java | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java index 9ade413073c..d2de0f47cd0 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -102,7 +102,7 @@ class AggregatesVectorSearchIntegrationTest { + "}"); @BeforeAll - static void beforeAll() throws InterruptedException { + static void beforeAll() { collectionHelper = new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("test", "test")); collectionHelper.drop(); @@ -192,9 +192,15 @@ private static Stream provideSupportedVectors() { // `multi` is used here only to verify that it is tolerated fieldPath(VECTOR_FIELD_INT_8).multi("ignored"), approximateVectorSearchOptions(LIMIT * 2)), + arguments(Vector.int8Vector(new byte[]{0, 1, 2, 3, 4}), + fieldPath(VECTOR_FIELD_INT_8), + approximateVectorSearchOptions(LIMIT * 2)), arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), // `multi` is used here only to verify that it is tolerated + fieldPath(VECTOR_FIELD_FLOAT_32).multi("ignored"), + approximateVectorSearchOptions(LIMIT * 2)), + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), fieldPath(VECTOR_FIELD_FLOAT_32), approximateVectorSearchOptions(LIMIT * 2)), @@ -202,22 +208,31 @@ private static Stream provideSupportedVectors() { // `multi` is used here only to verify that it is tolerated fieldPath(VECTOR_FIELD_FLOAT_32).multi("ignored"), exactVectorSearchOptions()), + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + fieldPath(VECTOR_FIELD_FLOAT_32), + exactVectorSearchOptions()), arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), // `multi` is used here only to verify that it is tolerated fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST).multi("ignored"), exactVectorSearchOptions()), + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST), + exactVectorSearchOptions()), arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), // `multi` is used here only to verify that it is tolerated fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST).multi("ignored"), + approximateVectorSearchOptions(LIMIT * 2)), + arguments(Vector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}), + fieldPath(VECTOR_FIELD_LEGACY_DOUBLE_LIST), approximateVectorSearchOptions(LIMIT * 2)) ); } @ParameterizedTest @MethodSource("provideSupportedVectors") - void shouldSearchBySupportedVectorWithSearchScore(final Vector vector, + void shouldSearchByVectorWithSearchScore(final Vector vector, final FieldSearchPath fieldSearchPath, final VectorSearchOptions vectorSearchOptions) { //given @@ -245,7 +260,7 @@ void shouldSearchBySupportedVectorWithSearchScore(final Vector vector, @ParameterizedTest @MethodSource("provideSupportedVectors") - void shouldSearchBySupportedVector(final Vector vector, + void shouldSearchByVector(final Vector vector, final FieldSearchPath fieldSearchPath, final VectorSearchOptions vectorSearchOptions) { //given @@ -270,7 +285,7 @@ void shouldSearchBySupportedVector(final Vector vector, @ParameterizedTest @MethodSource("provideSupportedVectors") - void provideSupportedVectors(final Vector vector, + void shouldSearchByVectorWithFilter(final Vector vector, final FieldSearchPath fieldSearchPath, final VectorSearchOptions vectorSearchOptions) { Consumer asserter = filter -> { @@ -308,8 +323,8 @@ private static void assertScoreIsDecreasing(final List aggregate) { } } - private static void awaitIndexCreation() throws InterruptedException { - int attempts = 5; + private static void awaitIndexCreation() { + int attempts = 10; while (attempts-- > 0) { if (collectionHelper.listSearchIndex(VECTOR_INDEX) .filter(document -> document.getBoolean("queryable")) @@ -317,7 +332,11 @@ private static void awaitIndexCreation() throws InterruptedException { return; } - TimeUnit.SECONDS.sleep(1); + try { + TimeUnit.SECONDS.sleep(1); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } } Assertions.fail("Exceeded maximum attempts waiting for Search Index creation in Atlas cluster"); } From 6a9e8ae87cdc63629589d5c7baaca401f18ecb33 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 13:39:58 -0700 Subject: [PATCH 04/13] Remove redundant test parameters. JAVA-5650 --- .../model/AggregatesSpecification.groovy | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy b/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy index d9819d691d5..273fe588f83 100644 --- a/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy @@ -831,11 +831,11 @@ class AggregatesSpecification extends Specification { }''') where: - vectorType | vector | queryVector - "int8" | Vector.int8Vector(new byte[]{127, 7}) | '{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}' - "float32" | Vector.floatVector(new float[]{127.0f, 7.0f}) | '{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}' - "packedBit" | Vector.packedBitVector(new byte[]{127, 7}, (byte) 0) | '{"$binary": {"base64": "EAB/Bw==", "subType": "09"}}' - "double" | [1.0d, 2.0d] | "[1.0, 2.0]" + vector | queryVector + Vector.int8Vector(new byte[]{127, 7}) | '{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}' + Vector.floatVector(new float[]{127.0f, 7.0f}) | '{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}' + Vector.packedBitVector(new byte[]{127, 7}, (byte) 0) | '{"$binary": {"base64": "EAB/Bw==", "subType": "09"}}' + [1.0d, 2.0d] | "[1.0, 2.0]" } def 'should render exact $vectorSearch'() { @@ -865,10 +865,10 @@ class AggregatesSpecification extends Specification { }''') where: - vectorType | vector | queryVector - "int8" | Vector.int8Vector(new byte[]{127, 7}) | '{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}' - "float32" | Vector.floatVector(new float[]{127.0f, 7.0f}) | '{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}' - "double" | [1.0d, 2.0d] | "[1.0, 2.0]" + vector | queryVector + Vector.int8Vector(new byte[]{127, 7}) | '{"$binary": {"base64": "AwB/Bw==", "subType": "09"}}' + Vector.floatVector(new float[]{127.0f, 7.0f}) | '{"$binary": {"base64": "JwAAAP5CAADgQA==", "subType": "09"}}' + [1.0d, 2.0d] | "[1.0, 2.0]" } def 'should create string representation for simple stages'() { From 40b4f7eee7eeacec0d6c40e6c2428aa0c6733f92 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 13:45:05 -0700 Subject: [PATCH 05/13] Revert imports. JAVA-5650 --- .../model/AggregatesSpecification.groovy | 53 +++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy b/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy index 273fe588f83..3af81fc992c 100644 --- a/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/client/model/AggregatesSpecification.groovy @@ -31,12 +31,59 @@ import spock.lang.Specification import static BucketGranularity.R5 import static MongoTimeUnit.DAY import static com.mongodb.ClusterFixture.serverVersionLessThan -import static com.mongodb.client.model.Accumulators.* -import static com.mongodb.client.model.Aggregates.* +import static com.mongodb.client.model.Accumulators.accumulator +import static com.mongodb.client.model.Accumulators.addToSet +import static com.mongodb.client.model.Accumulators.avg +import static com.mongodb.client.model.Accumulators.bottom +import static com.mongodb.client.model.Accumulators.bottomN +import static com.mongodb.client.model.Accumulators.first +import static com.mongodb.client.model.Accumulators.firstN +import static com.mongodb.client.model.Accumulators.last +import static com.mongodb.client.model.Accumulators.lastN +import static com.mongodb.client.model.Accumulators.max +import static com.mongodb.client.model.Accumulators.maxN +import static com.mongodb.client.model.Accumulators.mergeObjects +import static com.mongodb.client.model.Accumulators.min +import static com.mongodb.client.model.Accumulators.minN +import static com.mongodb.client.model.Accumulators.push +import static com.mongodb.client.model.Accumulators.stdDevPop +import static com.mongodb.client.model.Accumulators.stdDevSamp +import static com.mongodb.client.model.Accumulators.sum +import static com.mongodb.client.model.Accumulators.top +import static com.mongodb.client.model.Accumulators.topN +import static com.mongodb.client.model.Aggregates.addFields +import static com.mongodb.client.model.Aggregates.bucket +import static com.mongodb.client.model.Aggregates.bucketAuto +import static com.mongodb.client.model.Aggregates.count +import static com.mongodb.client.model.Aggregates.densify +import static com.mongodb.client.model.Aggregates.fill +import static com.mongodb.client.model.Aggregates.graphLookup +import static com.mongodb.client.model.Aggregates.group +import static com.mongodb.client.model.Aggregates.limit +import static com.mongodb.client.model.Aggregates.lookup +import static com.mongodb.client.model.Aggregates.match +import static com.mongodb.client.model.Aggregates.merge +import static com.mongodb.client.model.Aggregates.out +import static com.mongodb.client.model.Aggregates.project +import static com.mongodb.client.model.Aggregates.replaceRoot +import static com.mongodb.client.model.Aggregates.replaceWith +import static com.mongodb.client.model.Aggregates.sample +import static com.mongodb.client.model.Aggregates.search +import static com.mongodb.client.model.Aggregates.searchMeta +import static com.mongodb.client.model.Aggregates.set +import static com.mongodb.client.model.Aggregates.setWindowFields +import static com.mongodb.client.model.Aggregates.skip +import static com.mongodb.client.model.Aggregates.sort +import static com.mongodb.client.model.Aggregates.sortByCount +import static com.mongodb.client.model.Aggregates.unionWith +import static com.mongodb.client.model.Aggregates.unwind +import static com.mongodb.client.model.Aggregates.vectorSearch import static com.mongodb.client.model.BsonHelper.toBson import static com.mongodb.client.model.Filters.eq import static com.mongodb.client.model.Filters.expr -import static com.mongodb.client.model.Projections.* +import static com.mongodb.client.model.Projections.computed +import static com.mongodb.client.model.Projections.fields +import static com.mongodb.client.model.Projections.include import static com.mongodb.client.model.Sorts.ascending import static com.mongodb.client.model.Sorts.descending import static com.mongodb.client.model.Windows.Bound.CURRENT From 2fe167b657a7936eb38ac0d78bd329e6fa5c2ab6 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 13:56:37 -0700 Subject: [PATCH 06/13] Add assume for Atlas Search tests. JAVA-5650 --- .../model/search/AggregatesVectorSearchIntegrationTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java index d2de0f47cd0..c9b4e7854f3 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -103,6 +103,8 @@ class AggregatesVectorSearchIntegrationTest { @BeforeAll static void beforeAll() { + assumeTrue(isAtlasSearchTest()); + collectionHelper = new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("test", "test")); collectionHelper.drop(); @@ -182,7 +184,6 @@ static void afterAll() { @BeforeEach void beforeEach() { - assumeTrue(isAtlasSearchTest()); assumeTrue(serverVersionAtLeast(6, 0)); } From ba3407fa3c0abd7c96c6005dc053be673b45cc36 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 13:57:31 -0700 Subject: [PATCH 07/13] Move all assumes to static initialization for Atlas Search tests. JAVA-5650 --- .../search/AggregatesVectorSearchIntegrationTest.java | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java index c9b4e7854f3..518d739efec 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -29,7 +29,6 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -104,6 +103,7 @@ class AggregatesVectorSearchIntegrationTest { @BeforeAll static void beforeAll() { assumeTrue(isAtlasSearchTest()); + assumeTrue(serverVersionAtLeast(6, 0)); collectionHelper = new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("test", "test")); @@ -182,11 +182,6 @@ static void afterAll() { collectionHelper.drop(); } - @BeforeEach - void beforeEach() { - assumeTrue(serverVersionAtLeast(6, 0)); - } - private static Stream provideSupportedVectors() { return Stream.of( arguments(Vector.int8Vector(new byte[]{0, 1, 2, 3, 4}), From 29494ee26821c70e3bf1e29898db966c882fee15 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 14:17:14 -0700 Subject: [PATCH 08/13] Fix test. JAVA-5650 --- .../model/search/AggregatesVectorSearchIntegrationTest.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java index 518d739efec..d1a68b66d19 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -179,7 +179,9 @@ static void beforeAll() { @AfterAll static void afterAll() { - collectionHelper.drop(); + if (collectionHelper != null) { + collectionHelper.drop(); + } } private static Stream provideSupportedVectors() { From a3c46907685b0b4a5c7e65a098106c0d4e1ecc6d Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 14:48:00 -0700 Subject: [PATCH 09/13] Add Vector search test to sh script. JAVA-5650 --- .evergreen/run-atlas-search-tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/run-atlas-search-tests.sh b/.evergreen/run-atlas-search-tests.sh index 7669c87ae5d..36cc981b3f4 100755 --- a/.evergreen/run-atlas-search-tests.sh +++ b/.evergreen/run-atlas-search-tests.sh @@ -16,4 +16,4 @@ echo "Running Atlas Search tests" ./gradlew --stacktrace --info \ -Dorg.mongodb.test.atlas.search=true \ -Dorg.mongodb.test.uri=${MONGODB_URI} \ - driver-core:test --tests AggregatesSearchIntegrationTest + driver-core:test --tests AggregatesSearchIntegrationTest --tests AggregatesVectorSearchIntegrationTest From 85b9c2babd082f631d5f6f8ffcb4ef6b3e21c478 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 16:34:43 -0700 Subject: [PATCH 10/13] Change database name. JAVA-5650 --- .../search/AggregatesVectorSearchIntegrationTest.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java index d1a68b66d19..cde126f035f 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -106,7 +106,7 @@ static void beforeAll() { assumeTrue(serverVersionAtLeast(6, 0)); collectionHelper = - new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("test", "test")); + new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("javaVectorSearchTest", AggregatesVectorSearchIntegrationTest.class.getSimpleName())); collectionHelper.drop(); collectionHelper.insertDocuments( new Document() @@ -331,11 +331,15 @@ private static void awaitIndexCreation() { } try { - TimeUnit.SECONDS.sleep(1); + TimeUnit.SECONDS.sleep(5); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } + collectionHelper.listSearchIndex(VECTOR_INDEX).ifPresent(document -> { + Assertions.fail("Exceeded maximum attempts waiting for Search Index creation in Atlas cluster. Index document: " + document.toJson()); + }); + Assertions.fail("Exceeded maximum attempts waiting for Search Index creation in Atlas cluster"); } } From 3b2074f642d036de7bdfb1907c95c5b4b8f7f3e9 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 30 Oct 2024 17:37:20 -0700 Subject: [PATCH 11/13] Change assertion. JAVA-5650 --- .../AggregatesVectorSearchIntegrationTest.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java index cde126f035f..710026634c7 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -34,6 +34,7 @@ import org.junit.jupiter.params.provider.MethodSource; import java.util.List; +import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.stream.Stream; @@ -323,9 +324,11 @@ private static void assertScoreIsDecreasing(final List aggregate) { private static void awaitIndexCreation() { int attempts = 10; + Optional searchIndex = Optional.empty(); + while (attempts-- > 0) { - if (collectionHelper.listSearchIndex(VECTOR_INDEX) - .filter(document -> document.getBoolean("queryable")) + searchIndex = collectionHelper.listSearchIndex(VECTOR_INDEX); + if (searchIndex.filter(document -> document.getBoolean("queryable")) .isPresent()) { return; } @@ -336,10 +339,9 @@ private static void awaitIndexCreation() { Thread.currentThread().interrupt(); } } - collectionHelper.listSearchIndex(VECTOR_INDEX).ifPresent(document -> { - Assertions.fail("Exceeded maximum attempts waiting for Search Index creation in Atlas cluster. Index document: " + document.toJson()); - }); - Assertions.fail("Exceeded maximum attempts waiting for Search Index creation in Atlas cluster"); + String message = "Exceeded maximum attempts waiting for Search Index creation in Atlas cluster."; + searchIndex.ifPresent(document -> Assertions.fail(message + " Index document: " + document.toJson())); + Assertions.fail(message); } } From bc94e7993aa994696b4fdc8b19f7008664a5c541 Mon Sep 17 00:00:00 2001 From: Viacheslav Babanin Date: Wed, 30 Oct 2024 21:49:42 -0700 Subject: [PATCH 12/13] Update driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java Co-authored-by: Valentin Kovalenko --- .../model/search/AggregatesVectorSearchIntegrationTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java index 710026634c7..fd2c54750a2 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -337,6 +337,7 @@ private static void awaitIndexCreation() { TimeUnit.SECONDS.sleep(5); } catch (InterruptedException e) { Thread.currentThread().interrupt(); + throw new MongoInterruptedException(null, e); } } From 92030fece6fed7a1eb6f20ee75a06499f7c4e7a3 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Thu, 31 Oct 2024 10:29:03 -0700 Subject: [PATCH 13/13] Remove generics. JAVA-5650 --- .../main/com/mongodb/client/model/Aggregates.java | 12 ++++++------ .../AggregatesVectorSearchIntegrationTest.java | 15 ++++++++++----- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/driver-core/src/main/com/mongodb/client/model/Aggregates.java b/driver-core/src/main/com/mongodb/client/model/Aggregates.java index af0384ba0dd..7d6306cdd23 100644 --- a/driver-core/src/main/com/mongodb/client/model/Aggregates.java +++ b/driver-core/src/main/com/mongodb/client/model/Aggregates.java @@ -964,7 +964,7 @@ public static Bson vectorSearch( notNull("queryVector", queryVector); notNull("index", index); notNull("options", options); - return new VectorSearchBson<>(path, queryVector, index, limit, options); + return new VectorSearchBson(path, queryVector, index, limit, options); } /** @@ -980,7 +980,7 @@ public static Bson vectorSearch( * @return The {@code $vectorSearch} pipeline stage. * @mongodb.atlas.manual atlas-vector-search/vector-search-stage/ $vectorSearch * @mongodb.atlas.manual atlas-search/scoring/ Scoring - * @mongodb.server.release 6.0.11 + * @mongodb.server.release 6.0 * @see Vector * @since 5.3 */ @@ -994,7 +994,7 @@ public static Bson vectorSearch( notNull("queryVector", queryVector); notNull("index", index); notNull("options", options); - return new VectorSearchBson<>(path, queryVector, index, limit, options); + return new VectorSearchBson(path, queryVector, index, limit, options); } /** @@ -2155,14 +2155,14 @@ public String toString() { } } - private static class VectorSearchBson implements Bson { + private static class VectorSearchBson implements Bson { private final FieldSearchPath path; - private final T queryVector; + private final Object queryVector; private final String index; private final long limit; private final VectorSearchOptions options; - VectorSearchBson(final FieldSearchPath path, final T queryVector, + VectorSearchBson(final FieldSearchPath path, final Object queryVector, final String index, final long limit, final VectorSearchOptions options) { this.path = path; diff --git a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java index fd2c54750a2..15def0f5d71 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/search/AggregatesVectorSearchIntegrationTest.java @@ -16,6 +16,7 @@ package com.mongodb.client.model.search; +import com.mongodb.MongoInterruptedException; import com.mongodb.MongoNamespace; import com.mongodb.client.model.Aggregates; import com.mongodb.client.model.SearchIndexType; @@ -56,6 +57,7 @@ import static com.mongodb.client.model.search.SearchPath.fieldPath; import static com.mongodb.client.model.search.VectorSearchOptions.approximateVectorSearchOptions; import static com.mongodb.client.model.search.VectorSearchOptions.exactVectorSearchOptions; +import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static org.junit.jupiter.api.Assertions.assertAll; @@ -66,6 +68,9 @@ import static org.junit.jupiter.params.provider.Arguments.arguments; class AggregatesVectorSearchIntegrationTest { + private static final String EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE = + "Exceeded maximum attempts waiting for Search Index creation in Atlas cluster. Index document: %s"; + private static final String VECTOR_INDEX = "vector_search_index"; private static final String VECTOR_FIELD_INT_8 = "int8Vector"; private static final String VECTOR_FIELD_FLOAT_32 = "float32Vector"; @@ -73,7 +78,7 @@ class AggregatesVectorSearchIntegrationTest { private static final int LIMIT = 5; private static final String FIELD_YEAR = "year"; private static CollectionHelper collectionHelper; - private static final BsonDocument VECTOR_SEARCH_DEFINITION = BsonDocument.parse( + private static final BsonDocument VECTOR_SEARCH_INDEX_DEFINITION = BsonDocument.parse( "{" + " fields: [" + " {" @@ -173,7 +178,7 @@ static void beforeAll() { ); collectionHelper.createSearchIndex( - new SearchIndexRequest(VECTOR_SEARCH_DEFINITION, VECTOR_INDEX, + new SearchIndexRequest(VECTOR_SEARCH_INDEX_DEFINITION, VECTOR_INDEX, SearchIndexType.vectorSearch())); awaitIndexCreation(); } @@ -341,8 +346,8 @@ private static void awaitIndexCreation() { } } - String message = "Exceeded maximum attempts waiting for Search Index creation in Atlas cluster."; - searchIndex.ifPresent(document -> Assertions.fail(message + " Index document: " + document.toJson())); - Assertions.fail(message); + searchIndex.ifPresent(document -> + Assertions.fail(format(EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE, document.toJson()))); + Assertions.fail(format(EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE, "null")); } }