From cc401063b73045a0f1e963eb0226c9cc3f1faf2d Mon Sep 17 00:00:00 2001 From: Fangyi Zhou Date: Thu, 19 Sep 2024 10:51:37 +0100 Subject: [PATCH] feat: retrieve embeddings from database only when necessary When performing a similarity search without using maximal marginal relevance, the database query includes the embeddings by default, whereas the retrived embeddings are discarded without use. This can be very suboptimal when retrieve a large number of documents due to communication overhead. --- langchain_postgres/vectorstores.py | 40 +++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 044bece6..13fb8f4d 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1060,9 +1060,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa docs = [ ( Document( - id=str(result.EmbeddingStore.id), - page_content=result.EmbeddingStore.document, - metadata=result.EmbeddingStore.cmetadata, + id=str(result.id), + page_content=result.document, + metadata=result.cmetadata, ), result.distance if self.embeddings is not None else None, ) @@ -1395,8 +1395,16 @@ def __query_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + retrieve_embeddings: bool = False, ) -> Sequence[Any]: """Query the collection.""" + columns_to_select = [ + self.EmbeddingStore.id, + self.EmbeddingStore.document, + self.EmbeddingStore.cmetadata, + ] + if retrieve_embeddings: + columns_to_select.append(self.EmbeddingStore.embedding) with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: @@ -1417,7 +1425,7 @@ def __query_collection( results: List[Any] = ( session.query( - self.EmbeddingStore, + *columns_to_select, self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) @@ -1438,8 +1446,16 @@ async def __aquery_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + retrieve_embeddings: bool = False, ) -> Sequence[Any]: """Query the collection.""" + columns_to_select = [ + self.EmbeddingStore.id, + self.EmbeddingStore.document, + self.EmbeddingStore.cmetadata, + ] + if retrieve_embeddings: + columns_to_select.append(self.EmbeddingStore.embedding) async with self._make_async_session() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: @@ -1460,7 +1476,7 @@ async def __aquery_collection( stmt = ( select( - self.EmbeddingStore, + *columns_to_select, self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) @@ -1899,9 +1915,11 @@ def max_marginal_relevance_search_with_score_by_vector( relevance to the query and score for each. """ assert not self._async_engine, "This method must be called without async_mode" - results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) + results = self.__query_collection( + embedding=embedding, k=fetch_k, filter=filter, retrieve_embeddings=True + ) - embedding_list = [result.EmbeddingStore.embedding for result in results] + embedding_list = [result.embedding for result in results] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), @@ -1947,10 +1965,14 @@ async def amax_marginal_relevance_search_with_score_by_vector( await self.__apost_init__() # Lazy async init async with self._make_async_session() as session: results = await self.__aquery_collection( - session=session, embedding=embedding, k=fetch_k, filter=filter + session=session, + embedding=embedding, + k=fetch_k, + filter=filter, + retrieve_embeddings=True, ) - embedding_list = [result.EmbeddingStore.embedding for result in results] + embedding_list = [result.embedding for result in results] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32),