diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index 11e5ff9..bff36cb 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -565,9 +565,34 @@ async def __query_collection( else: query_embedding = f"{[float(dimension) for dimension in embedding]}" embedding_data_string = ":query_embedding" - stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance - FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; - """ + + if "query" in kwargs and "full_text_weight" in kwargs: + if not isinstance(kwargs["full_text_weight"], (int, float)) or kwargs["full_text_weight"] < 0 or kwargs["full_text_weight"] > 1: + raise ValueError("full_text_weight must be a number between 0 and 1") + + full_text_weight = float(kwargs["full_text_weight"]) + full_text_query = " | ".join(kwargs["query"].split(" ")) + + stmt = f""" + WITH A AS ( + SELECT {column_names}, + {search_function}("{self.embedding_column}", {embedding_data_string}) as semantic_distance, + ts_rank(to_tsvector('english', "{self.content_column}"), to_tsquery('{full_text_query}')) as keyword_distance + FROM "{self.schema_name}"."{self.table_name}" {param_filter} + ), + B AS ( + SELECT *, + semantic_distance * {1.0 - full_text_weight} + keyword_distance * {full_text_weight} as distance + FROM A + ) + + SELECT * FROM B ORDER BY distance LIMIT :k; + """ + else: + stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance + FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; + """ + param_dict = {"query_embedding": query_embedding, "k": k} if filter_dict: param_dict.update(filter_dict) diff --git a/langchain_postgres/v2/engine.py b/langchain_postgres/v2/engine.py index c2a0d93..0e95b9e 100644 --- a/langchain_postgres/v2/engine.py +++ b/langchain_postgres/v2/engine.py @@ -156,6 +156,8 @@ async def _ainit_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + full_text_index: str | None = None + ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -178,6 +180,8 @@ async def _ainit_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + full_text_index (str): Language used to construct full text index. If None then no index will be used. + Default: None Raises: :class:`DuplicateTableError `: if table already exists. @@ -241,6 +245,10 @@ async def _ainit_vectorstore_table( query += f""",\n"{metadata_json_column}" JSON""" query += "\n);" + if full_text_index: + query += f"""CREATE INDEX ON \"{schema_name}\".\"{table_name}\" + USING GIN (to_tsvector('{full_text_index}', "{content_column}"))""" + async with self._pool.connect() as conn: await conn.execute(text(query)) await conn.commit() @@ -258,6 +266,7 @@ async def ainit_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + full_text_index: str | None = None ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -280,6 +289,8 @@ async def ainit_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + full_text_index (str): Language used to construct full text index. If None then no index will be used. + Default: None """ await self._run_as_async( self._ainit_vectorstore_table( @@ -293,6 +304,7 @@ async def ainit_vectorstore_table( id_column=id_column, overwrite_existing=overwrite_existing, store_metadata=store_metadata, + full_text_index=full_text_index ) ) @@ -309,6 +321,7 @@ def init_vectorstore_table( id_column: Union[str, Column, ColumnDict] = "langchain_id", overwrite_existing: bool = False, store_metadata: bool = True, + full_text_index: str = None ) -> None: """ Create a table for saving of vectors to be used with PGVectorStore. @@ -331,6 +344,8 @@ def init_vectorstore_table( overwrite_existing (bool): Whether to drop existing table. Default: False. store_metadata (bool): Whether to store metadata in the table. Default: True. + full_text_index (str): Language used to construct full text index. If None then no index will be used. + Default: None """ self._run_as_sync( self._ainit_vectorstore_table( @@ -344,6 +359,7 @@ def init_vectorstore_table( id_column=id_column, overwrite_existing=overwrite_existing, store_metadata=store_metadata, + full_text_index=full_text_index ) ) diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index 379f529..9c6d826 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -76,7 +76,7 @@ async def engine(self) -> AsyncIterator[PGEngine]: @pytest_asyncio.fixture(scope="class") async def vs(self, engine: PGEngine) -> AsyncIterator[PGVectorStore]: await engine.ainit_vectorstore_table( - DEFAULT_TABLE, VECTOR_SIZE, store_metadata=False + DEFAULT_TABLE, VECTOR_SIZE, store_metadata=False, full_text_index="english" ) vs = await PGVectorStore.create( engine, @@ -197,6 +197,19 @@ async def test_similarity_search_with_relevance_scores_threshold_cosine( assert len(results) == 1 assert results[0][0] == Document(page_content="foo", id=ids[0]) + @pytest.mark.parametrize(("full_text_weight", "expected_result", "expected_score"), [(1.0, "foo", 1.0), (0.5, "foo", 1.4392072893679142), (0.0, "bar", 2.0)]) + async def test_similarity_search_hybrid_1( + self, vs: PGVectorStore, full_text_weight, expected_result, expected_score + ) -> None: + results = await vs.asimilarity_search_with_relevance_scores( + "foo", full_text_weight=full_text_weight + ) + top_doc, top_score = results[0] + assert top_doc.page_content == expected_result + assert top_score == expected_score + + + async def test_similarity_search_with_relevance_scores_threshold_euclidean( self, engine: PGEngine ) -> None: