Skip to content

Commit 447eb00

Browse files
allow embeddings vector to be used for mmr searching (#2620)
Signed-off-by: rishabh208gupta <rishabhgupta.52pp@gmail.com>
1 parent 6521b55 commit 447eb00

File tree

3 files changed

+31
-12
lines changed

3 files changed

+31
-12
lines changed

elasticsearch/helpers/vectorstore/_async/vectorstore.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ async def delete( # type: ignore[no-untyped-def]
232232
async def search(
233233
self,
234234
*,
235-
query: Optional[str],
235+
query: Optional[str] = None,
236236
query_vector: Optional[List[float]] = None,
237237
k: int = 4,
238238
num_candidates: int = 50,
@@ -344,8 +344,9 @@ async def _create_index_if_not_exists(self) -> None:
344344
async def max_marginal_relevance_search(
345345
self,
346346
*,
347-
embedding_service: AsyncEmbeddingService,
348-
query: str,
347+
query: Optional[str] = None,
348+
query_embedding: Optional[List[float]] = None,
349+
embedding_service: Optional[AsyncEmbeddingService] = None,
349350
vector_field: str,
350351
k: int = 4,
351352
num_candidates: int = 20,
@@ -361,6 +362,8 @@ async def max_marginal_relevance_search(
361362
among selected documents.
362363
363364
:param query (str): Text to look up documents similar to.
365+
:param query_embedding: Input embedding vector. If given, input query string is
366+
ignored.
364367
:param k (int): Number of Documents to return. Defaults to 4.
365368
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
366369
:param lambda_mult (float): Number between 0 and 1 that determines the degree
@@ -381,7 +384,10 @@ async def max_marginal_relevance_search(
381384
remove_vector_query_field_from_metadata = False
382385

383386
# Embed the query
384-
query_embedding = await embedding_service.embed_query(query)
387+
if self.embedding_service and not query_embedding:
388+
if not query:
389+
raise ValueError("specify a query or a query_embedding to search")
390+
query_embedding = await self.embedding_service.embed_query(query)
385391

386392
# Fetch the initial documents
387393
got_hits = await self.search(

elasticsearch/helpers/vectorstore/_sync/vectorstore.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def delete( # type: ignore[no-untyped-def]
229229
def search(
230230
self,
231231
*,
232-
query: Optional[str],
232+
query: Optional[str] = None,
233233
query_vector: Optional[List[float]] = None,
234234
k: int = 4,
235235
num_candidates: int = 50,
@@ -341,8 +341,9 @@ def _create_index_if_not_exists(self) -> None:
341341
def max_marginal_relevance_search(
342342
self,
343343
*,
344-
embedding_service: EmbeddingService,
345-
query: str,
344+
query: Optional[str] = None,
345+
query_embedding: Optional[List[float]] = None,
346+
embedding_service: Optional[EmbeddingService] = None,
346347
vector_field: str,
347348
k: int = 4,
348349
num_candidates: int = 20,
@@ -358,6 +359,8 @@ def max_marginal_relevance_search(
358359
among selected documents.
359360
360361
:param query (str): Text to look up documents similar to.
362+
:param query_embedding: Input embedding vector. If given, input query string is
363+
ignored.
361364
:param k (int): Number of Documents to return. Defaults to 4.
362365
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
363366
:param lambda_mult (float): Number between 0 and 1 that determines the degree
@@ -378,7 +381,10 @@ def max_marginal_relevance_search(
378381
remove_vector_query_field_from_metadata = False
379382

380383
# Embed the query
381-
query_embedding = embedding_service.embed_query(query)
384+
if self.embedding_service and not query_embedding:
385+
if not query:
386+
raise ValueError("specify a query or a query_embedding to search")
387+
query_embedding = self.embedding_service.embed_query(query)
382388

383389
# Fetch the initial documents
384390
got_hits = self.search(

test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ def test_max_marginal_relevance_search(
822822
texts = ["foo", "bar", "baz"]
823823
vector_field = "vector_field"
824824
text_field = "text_field"
825+
query_embedding = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0]
825826
embedding_service = ConsistentFakeEmbeddings()
826827
store = VectorStore(
827828
index=index,
@@ -834,7 +835,6 @@ def test_max_marginal_relevance_search(
834835
store.add_texts(texts)
835836

836837
mmr_output = store.max_marginal_relevance_search(
837-
embedding_service=embedding_service,
838838
query=texts[0],
839839
vector_field=vector_field,
840840
k=3,
@@ -843,8 +843,17 @@ def test_max_marginal_relevance_search(
843843
sim_output = store.search(query=texts[0], k=3)
844844
assert mmr_output == sim_output
845845

846+
# search using query embeddings vector instead of query
847+
mmr_output = store.max_marginal_relevance_search(
848+
query_embedding=query_embedding,
849+
vector_field=vector_field,
850+
k=3,
851+
num_candidates=3,
852+
)
853+
sim_output = store.search(query_vector=query_embedding, k=3)
854+
assert mmr_output == sim_output
855+
846856
mmr_output = store.max_marginal_relevance_search(
847-
embedding_service=embedding_service,
848857
query=texts[0],
849858
vector_field=vector_field,
850859
k=2,
@@ -855,7 +864,6 @@ def test_max_marginal_relevance_search(
855864
assert mmr_output[1]["_source"][text_field] == texts[1]
856865

857866
mmr_output = store.max_marginal_relevance_search(
858-
embedding_service=embedding_service,
859867
query=texts[0],
860868
vector_field=vector_field,
861869
k=2,
@@ -868,7 +876,6 @@ def test_max_marginal_relevance_search(
868876

869877
# if fetch_k < k, then the output will be less than k
870878
mmr_output = store.max_marginal_relevance_search(
871-
embedding_service=embedding_service,
872879
query=texts[0],
873880
vector_field=vector_field,
874881
k=3,

0 commit comments

Comments
 (0)