diff --git a/langchain_postgres/v2/vectorstores.py b/langchain_postgres/v2/vectorstores.py index 1dc1be9..52224db 100644 --- a/langchain_postgres/v2/vectorstores.py +++ b/langchain_postgres/v2/vectorstores.py @@ -9,6 +9,7 @@ from .async_vectorstore import AsyncPGVectorStore from .engine import PGEngine +from .hybrid_search_config import HybridSearchConfig from .indexes import ( DEFAULT_DISTANCE_STRATEGY, BaseIndex, @@ -59,6 +60,7 @@ async def create( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> PGVectorStore: """Create an PGVectorStore instance. @@ -78,6 +80,7 @@ async def create( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: PGVectorStore @@ -98,6 +101,7 @@ async def create( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) vs = await engine._run_as_async(coro) return cls(cls.__create_key, engine, vs) @@ -120,6 +124,7 @@ def create_sync( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> PGVectorStore: """Create an PGVectorStore instance. @@ -140,6 +145,7 @@ def create_sync( fetch_k (int, optional): Number of Documents to fetch to pass to MMR algorithm. Defaults to 20. lambda_mult (float, optional): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (Optional[QueryOptions], optional): Index query option. Defaults to None. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: PGVectorStore @@ -160,6 +166,7 @@ def create_sync( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) vs = engine._run_as_sync(coro) return cls(cls.__create_key, engine, vs) @@ -301,6 +308,7 @@ async def afrom_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from texts. @@ -324,6 +332,7 @@ async def afrom_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -347,6 +356,7 @@ async def afrom_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_texts(texts, metadatas=metadatas, ids=ids) return vs @@ -371,6 +381,7 @@ async def afrom_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from documents. @@ -393,6 +404,7 @@ async def afrom_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -417,6 +429,7 @@ async def afrom_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_documents(documents, ids=ids) return vs @@ -442,6 +455,7 @@ def from_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from texts. @@ -465,6 +479,7 @@ def from_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -488,6 +503,7 @@ def from_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, **kwargs, ) vs.add_texts(texts, metadatas=metadatas, ids=ids) @@ -513,6 +529,7 @@ def from_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> PGVectorStore: """Create an PGVectorStore instance from documents. @@ -535,6 +552,7 @@ def from_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -558,6 +576,7 @@ def from_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, **kwargs, ) vs.add_documents(documents, ids=ids) diff --git a/tests/unit_tests/v2/test_pg_vectorstore_search.py b/tests/unit_tests/v2/test_pg_vectorstore_search.py index 379f529..d783a85 100644 --- a/tests/unit_tests/v2/test_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_pg_vectorstore_search.py @@ -9,6 +9,11 @@ from sqlalchemy import text from langchain_postgres import Column, PGEngine, PGVectorStore +from langchain_postgres.v2.hybrid_search_config import ( + HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking, +) from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from tests.unit_tests.fixtures.metadata_filtering_data import ( FILTERING_TEST_CASES, @@ -261,6 +266,37 @@ async def test_vectorstore_with_metadata_filters( ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + async def test_asimilarity_hybrid_search(self, vs: PGVectorStore): + results = await vs.asimilarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results[0] == Document(page_content="bar", id=ids[1]) + + results = await vs.asimilarity_search( + "foo", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="foo", id=ids[0])] + @pytest.mark.enable_socket class TestVectorStoreSearchSync: @@ -401,3 +437,27 @@ def test_metadata_filter_negative_tests( docs = vs_custom_filter_sync.similarity_search( "meow", k=5, filter=test_filter ) + + def test_similarity_hybrid_search(self, vs_custom): + results = vs_custom.similarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = vs_custom.similarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results == [Document(page_content="bar", id=ids[1])] + + results = vs_custom.similarity_search( + "foo", + k=1, + filter={"mycontent": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), + ) + assert results == [Document(page_content="baz", id=ids[2])]