From aae00ff7aa3f018313e6b1268bc8281ac18c5d85 Mon Sep 17 00:00:00 2001 From: rishabh208gupta Date: Tue, 6 Aug 2024 15:23:25 +0530 Subject: [PATCH 1/2] allow embeddings vector to be used for mmr searching (#2620) Signed-off-by: rishabh208gupta --- .../helpers/vectorstore/_async/vectorstore.py | 20 +++++++++++++------ .../helpers/vectorstore/_sync/vectorstore.py | 20 +++++++++++++------ .../test_vectorstore/test_vectorstore.py | 15 ++++++++++---- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index 81356cf92..16d0aece3 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -232,7 +232,7 @@ async def delete( # type: ignore[no-untyped-def] async def search( self, *, - query: Optional[str], + query: Optional[str] = None, query_vector: Optional[List[float]] = None, k: int = 4, num_candidates: int = 50, @@ -344,8 +344,9 @@ async def _create_index_if_not_exists(self) -> None: async def max_marginal_relevance_search( self, *, - embedding_service: AsyncEmbeddingService, - query: str, + query: Optional[str] = None, + query_embedding: Optional[List[float]] = None, + embedding_service: Optional[AsyncEmbeddingService] = None, vector_field: str, k: int = 4, num_candidates: int = 20, @@ -361,6 +362,8 @@ async def max_marginal_relevance_search( among selected documents. :param query (str): Text to look up documents similar to. + :param query_embedding: Input embedding vector. If given, input query string is + ignored. :param k (int): Number of Documents to return. Defaults to 4. :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. :param lambda_mult (float): Number between 0 and 1 that determines the degree @@ -381,12 +384,17 @@ async def max_marginal_relevance_search( remove_vector_query_field_from_metadata = False # Embed the query - query_embedding = await embedding_service.embed_query(query) + if query_embedding: + query_vector = query_embedding + elif self.embedding_service: + if not query: + raise ValueError("specify a query or a query_embedding to search") + query_vector = await self.embedding_service.embed_query(query) # Fetch the initial documents got_hits = await self.search( query=None, - query_vector=query_embedding, + query_vector=query_vector, k=num_candidates, fields=fields, custom_query=custom_query, @@ -397,7 +405,7 @@ async def max_marginal_relevance_search( # Select documents using maximal marginal relevance selected_indices = maximal_marginal_relevance( - query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k + query_vector, got_embeddings, lambda_mult=lambda_mult, k=k ) selected_hits = [got_hits[i] for i in selected_indices] diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index 9aaa966f3..80f53fc12 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -229,7 +229,7 @@ def delete( # type: ignore[no-untyped-def] def search( self, *, - query: Optional[str], + query: Optional[str] = None, query_vector: Optional[List[float]] = None, k: int = 4, num_candidates: int = 50, @@ -341,8 +341,9 @@ def _create_index_if_not_exists(self) -> None: def max_marginal_relevance_search( self, *, - embedding_service: EmbeddingService, - query: str, + query: Optional[str] = None, + query_embedding: Optional[List[float]] = None, + embedding_service: Optional[EmbeddingService] = None, vector_field: str, k: int = 4, num_candidates: int = 20, @@ -358,6 +359,8 @@ def max_marginal_relevance_search( among selected documents. :param query (str): Text to look up documents similar to. + :param query_embedding: Input embedding vector. If given, input query string is + ignored. :param k (int): Number of Documents to return. Defaults to 4. :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. :param lambda_mult (float): Number between 0 and 1 that determines the degree @@ -378,12 +381,17 @@ def max_marginal_relevance_search( remove_vector_query_field_from_metadata = False # Embed the query - query_embedding = embedding_service.embed_query(query) + if query_embedding: + query_vector = query_embedding + elif self.embedding_service: + if not query: + raise ValueError("specify a query or a query_embedding to search") + query_vector = self.embedding_service.embed_query(query) # Fetch the initial documents got_hits = self.search( query=None, - query_vector=query_embedding, + query_vector=query_vector, k=num_candidates, fields=fields, custom_query=custom_query, @@ -394,7 +402,7 @@ def max_marginal_relevance_search( # Select documents using maximal marginal relevance selected_indices = maximal_marginal_relevance( - query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k + query_vector, got_embeddings, lambda_mult=lambda_mult, k=k ) selected_hits = [got_hits[i] for i in selected_indices] diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index a8cae670f..af6562867 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -822,6 +822,7 @@ def test_max_marginal_relevance_search( texts = ["foo", "bar", "baz"] vector_field = "vector_field" text_field = "text_field" + query_embedding = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0] embedding_service = ConsistentFakeEmbeddings() store = VectorStore( index=index, @@ -834,7 +835,6 @@ def test_max_marginal_relevance_search( store.add_texts(texts) mmr_output = store.max_marginal_relevance_search( - embedding_service=embedding_service, query=texts[0], vector_field=vector_field, k=3, @@ -843,8 +843,17 @@ def test_max_marginal_relevance_search( sim_output = store.search(query=texts[0], k=3) assert mmr_output == sim_output + # search using query embeddings vector instead of query + mmr_output = store.max_marginal_relevance_search( + query_embedding=query_embedding, + vector_field=vector_field, + k=3, + num_candidates=3, + ) + sim_output = store.search(query_vector=query_embedding, k=3) + assert mmr_output == sim_output + mmr_output = store.max_marginal_relevance_search( - embedding_service=embedding_service, query=texts[0], vector_field=vector_field, k=2, @@ -855,7 +864,6 @@ def test_max_marginal_relevance_search( assert mmr_output[1]["_source"][text_field] == texts[1] mmr_output = store.max_marginal_relevance_search( - embedding_service=embedding_service, query=texts[0], vector_field=vector_field, k=2, @@ -868,7 +876,6 @@ def test_max_marginal_relevance_search( # if fetch_k < k, then the output will be less than k mmr_output = store.max_marginal_relevance_search( - embedding_service=embedding_service, query=texts[0], vector_field=vector_field, k=3, From a17a04c7fbd012cc78da32c72df79d2efbb24700 Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Mon, 12 Aug 2024 17:32:48 +0400 Subject: [PATCH 2/2] Use embedding service if provided --- .../helpers/vectorstore/_async/vectorstore.py | 11 +++-- .../helpers/vectorstore/_sync/vectorstore.py | 11 +++-- .../test_vectorstore/test_vectorstore.py | 44 ++++++++++++++++++- 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/elasticsearch/helpers/vectorstore/_async/vectorstore.py b/elasticsearch/helpers/vectorstore/_async/vectorstore.py index 16d0aece3..3b8c1e9e9 100644 --- a/elasticsearch/helpers/vectorstore/_async/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_async/vectorstore.py @@ -386,10 +386,15 @@ async def max_marginal_relevance_search( # Embed the query if query_embedding: query_vector = query_embedding - elif self.embedding_service: + else: if not query: - raise ValueError("specify a query or a query_embedding to search") - query_vector = await self.embedding_service.embed_query(query) + raise ValueError("specify either query or query_embedding to search") + elif embedding_service: + query_vector = await embedding_service.embed_query(query) + elif self.embedding_service: + query_vector = await self.embedding_service.embed_query(query) + else: + raise ValueError("specify embedding_service to search with query") # Fetch the initial documents got_hits = await self.search( diff --git a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py index 80f53fc12..3c4a0d51a 100644 --- a/elasticsearch/helpers/vectorstore/_sync/vectorstore.py +++ b/elasticsearch/helpers/vectorstore/_sync/vectorstore.py @@ -383,10 +383,15 @@ def max_marginal_relevance_search( # Embed the query if query_embedding: query_vector = query_embedding - elif self.embedding_service: + else: if not query: - raise ValueError("specify a query or a query_embedding to search") - query_vector = self.embedding_service.embed_query(query) + raise ValueError("specify either query or query_embedding to search") + elif embedding_service: + query_vector = embedding_service.embed_query(query) + elif self.embedding_service: + query_vector = self.embedding_service.embed_query(query) + else: + raise ValueError("specify embedding_service to search with query") # Fetch the initial documents got_hits = self.search( diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index af6562867..820746acd 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -815,6 +815,47 @@ def test_bulk_args(self, sync_client_request_saving: Any, index: str) -> None: # 1 for index exist, 1 for index create, 3 to index docs assert len(store.client.transport.requests) == 5 # type: ignore + def test_max_marginal_relevance_search_errors( + self, sync_client: Elasticsearch, index: str + ) -> None: + """Test max marginal relevance search error conditions.""" + texts = ["foo", "bar", "baz"] + vector_field = "vector_field" + embedding_service = ConsistentFakeEmbeddings() + store = VectorStore( + index=index, + retrieval_strategy=DenseVectorScriptScoreStrategy(), + embedding_service=embedding_service, + client=sync_client, + ) + store.add_texts(texts) + + # search without query embeddings vector or query + with pytest.raises( + ValueError, match="specify either query or query_embedding to search" + ): + store.max_marginal_relevance_search( + vector_field=vector_field, + k=3, + num_candidates=3, + ) + + # search without service + no_service_store = VectorStore( + index=index, + retrieval_strategy=DenseVectorScriptScoreStrategy(), + client=sync_client, + ) + with pytest.raises( + ValueError, match="specify embedding_service to search with query" + ): + no_service_store.max_marginal_relevance_search( + query=texts[0], + vector_field=vector_field, + k=3, + num_candidates=3, + ) + def test_max_marginal_relevance_search( self, sync_client: Elasticsearch, index: str ) -> None: @@ -834,6 +875,7 @@ def test_max_marginal_relevance_search( ) store.add_texts(texts) + # search with query mmr_output = store.max_marginal_relevance_search( query=texts[0], vector_field=vector_field, @@ -843,7 +885,7 @@ def test_max_marginal_relevance_search( sim_output = store.search(query=texts[0], k=3) assert mmr_output == sim_output - # search using query embeddings vector instead of query + # search with query embeddings mmr_output = store.max_marginal_relevance_search( query_embedding=query_embedding, vector_field=vector_field,