diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index a3743e4..c1fc779 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1061,9 +1061,9 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa docs = [ ( Document( - id=str(result.EmbeddingStore.id), - page_content=result.EmbeddingStore.document, - metadata=result.EmbeddingStore.cmetadata, + id=str(result.id), + page_content=result.document, + metadata=result.cmetadata, ), result.distance if self.embeddings is not None else None, ) @@ -1396,8 +1396,16 @@ def __query_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + retrieve_embeddings: bool = False, ) -> Sequence[Any]: """Query the collection.""" + columns_to_select = [ + self.EmbeddingStore.id, + self.EmbeddingStore.document, + self.EmbeddingStore.cmetadata, + ] + if retrieve_embeddings: + columns_to_select.append(self.EmbeddingStore.embedding) with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: @@ -1418,7 +1426,7 @@ def __query_collection( results: List[Any] = ( session.query( - self.EmbeddingStore, + *columns_to_select, self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) @@ -1439,8 +1447,16 @@ async def __aquery_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, + retrieve_embeddings: bool = False, ) -> Sequence[Any]: """Query the collection.""" + columns_to_select = [ + self.EmbeddingStore.id, + self.EmbeddingStore.document, + self.EmbeddingStore.cmetadata, + ] + if retrieve_embeddings: + columns_to_select.append(self.EmbeddingStore.embedding) async with self._make_async_session() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: @@ -1461,7 +1477,7 @@ async def __aquery_collection( stmt = ( select( - self.EmbeddingStore, + *columns_to_select, self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) @@ -1900,9 +1916,11 @@ def max_marginal_relevance_search_with_score_by_vector( relevance to the query and score for each. """ assert not self._async_engine, "This method must be called without async_mode" - results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) + results = self.__query_collection( + embedding=embedding, k=fetch_k, filter=filter, retrieve_embeddings=True + ) - embedding_list = [result.EmbeddingStore.embedding for result in results] + embedding_list = [result.embedding for result in results] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), @@ -1948,10 +1966,14 @@ async def amax_marginal_relevance_search_with_score_by_vector( await self.__apost_init__() # Lazy async init async with self._make_async_session() as session: results = await self.__aquery_collection( - session=session, embedding=embedding, k=fetch_k, filter=filter + session=session, + embedding=embedding, + k=fetch_k, + filter=filter, + retrieve_embeddings=True, ) - embedding_list = [result.EmbeddingStore.embedding for result in results] + embedding_list = [result.embedding for result in results] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32),