From e641575193134162dca4dac245865a20b91f7e94 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Mon, 19 May 2025 21:57:50 +0000 Subject: [PATCH] feat: adds hybrid search for async VS interface [3/N] --- langchain_postgres/v2/async_vectorstore.py | 171 +++++++- .../v2/test_async_pg_vectorstore_index.py | 77 +++- .../v2/test_async_pg_vectorstore_search.py | 364 +++++++++++++++++- 3 files changed, 581 insertions(+), 31 deletions(-) diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index 11e5ff9..9045da4 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -14,14 +14,10 @@ from sqlalchemy.ext.asyncio import AsyncEngine from .engine import PGEngine -from .indexes import ( - DEFAULT_DISTANCE_STRATEGY, - DEFAULT_INDEX_NAME_SUFFIX, - BaseIndex, - DistanceStrategy, - ExactNearestNeighbor, - QueryOptions, -) +from .hybrid_search_config import HybridSearchConfig +from .indexes import (DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX, + BaseIndex, DistanceStrategy, ExactNearestNeighbor, + QueryOptions) COMPARISONS_TO_NATIVE = { "$eq": "=", @@ -77,6 +73,8 @@ def __init__( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, + hybrid_search_column_exists: bool = False, ): """AsyncPGVectorStore constructor. Args: @@ -95,6 +93,8 @@ def __init__( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. + hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: @@ -119,6 +119,8 @@ def __init__( self.fetch_k = fetch_k self.lambda_mult = lambda_mult self.index_query_options = index_query_options + self.hybrid_search_config = hybrid_search_config + self.hybrid_search_column_exists = hybrid_search_column_exists @classmethod async def create( @@ -139,6 +141,7 @@ async def create( fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance. @@ -158,6 +161,7 @@ async def create( fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. Returns: AsyncPGVectorStore @@ -193,6 +197,17 @@ async def create( raise ValueError( f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." ) + hybrid_search_column_exists = False + if hybrid_search_config: + tsv_column_name = ( + hybrid_search_config.tsv_column + if hybrid_search_config.tsv_column + else content_column + "_tsv" + ) + hybrid_search_config.tsv_column = tsv_column_name + hybrid_search_column_exists = ( + tsv_column_name in columns and columns[tsv_column_name] == "tsvector" + ) if embedding_column not in columns: raise ValueError(f"Embedding column, {embedding_column}, does not exist.") if columns[embedding_column] != "USER-DEFINED": @@ -236,6 +251,8 @@ async def create( fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, + hybrid_search_column_exists=hybrid_search_column_exists, ) @property @@ -273,7 +290,12 @@ async def aadd_embeddings( if len(self.metadata_columns) > 0 else "" ) - insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}' + hybrid_search_column = ( + f', "{self.hybrid_search_config.tsv_column}"' + if self.hybrid_search_config and self.hybrid_search_column_exists + else "" + ) + insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{hybrid_search_column}{metadata_col_names}' values = { "id": id, "content": content, @@ -284,6 +306,14 @@ async def aadd_embeddings( if not embedding and can_inline_embed: values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore + if self.hybrid_search_config and self.hybrid_search_column_exists: + lang = ( + f"'{self.hybrid_search_config.tsv_lang}'," + if self.hybrid_search_config.tsv_lang + else "" + ) + values_stmt += f", to_tsvector({lang} :tsv_content)" + values["tsv_content"] = content # Add metadata extra = copy.deepcopy(metadata) for metadata_column in self.metadata_columns: @@ -308,6 +338,9 @@ async def aadd_embeddings( upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"' + if self.hybrid_search_config and self.hybrid_search_column_exists: + upsert_stmt += f', "{self.hybrid_search_config.tsv_column}" = EXCLUDED."{self.hybrid_search_config.tsv_column}"' + if self.metadata_json_column: upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"' @@ -408,6 +441,7 @@ async def afrom_texts( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance from texts. @@ -430,6 +464,7 @@ async def afrom_texts( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -453,6 +488,7 @@ async def afrom_texts( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) return vs @@ -478,6 +514,7 @@ async def afrom_documents( # type: ignore[override] fetch_k: int = 20, lambda_mult: float = 0.5, index_query_options: Optional[QueryOptions] = None, + hybrid_search_config: Optional[HybridSearchConfig] = None, **kwargs: Any, ) -> AsyncPGVectorStore: """Create an AsyncPGVectorStore instance from documents. @@ -500,6 +537,7 @@ async def afrom_documents( # type: ignore[override] fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. index_query_options (QueryOptions): Index query option. + hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. Raises: :class:`InvalidTextRepresentationError `: if the `ids` data type does not match that of the `id_column`. @@ -524,6 +562,7 @@ async def afrom_documents( # type: ignore[override] fetch_k=fetch_k, lambda_mult=lambda_mult, index_query_options=index_query_options, + hybrid_search_config=hybrid_search_config, ) texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] @@ -538,16 +577,30 @@ async def __query_collection( filter: Optional[dict] = None, **kwargs: Any, ) -> Sequence[RowMapping]: - """Perform similarity search query on database.""" - k = k if k else self.k + """ + Perform similarity search (or hybrid search) query on database. + Queries might be slow if the hybrid search column does not exist. + For best hybrid search performance, consider creating a TSV column + and adding GIN index. + """ + if not k: + k = ( + max( + self.k, + self.hybrid_search_config.primary_top_k, + self.hybrid_search_config.secondary_top_k, + ) + if self.hybrid_search_config + else self.k + ) operator = self.distance_strategy.operator search_function = self.distance_strategy.search_function - columns = self.metadata_columns + [ + columns = [ self.id_column, self.content_column, self.embedding_column, - ] + ] + self.metadata_columns if self.metadata_json_column: columns.append(self.metadata_json_column) @@ -557,7 +610,9 @@ async def __query_collection( filter_dict = None if filter and isinstance(filter, dict): safe_filter, filter_dict = self._create_filter_clause(filter) - param_filter = f"WHERE {safe_filter}" if safe_filter else "" + where_filters = f"WHERE {safe_filter}" if safe_filter else "" + and_filters = f"AND ({safe_filter})" if safe_filter else "" + inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) if not embedding and callable(inline_embed_func) and "query" in kwargs: query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore @@ -565,8 +620,8 @@ async def __query_collection( else: query_embedding = f"{[float(dimension) for dimension in embedding]}" embedding_data_string = ":query_embedding" - stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance - FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; + dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance + FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; """ param_dict = {"query_embedding": query_embedding, "k": k} if filter_dict: @@ -577,15 +632,50 @@ async def __query_collection( for query_option in self.index_query_options.to_parameter(): query_options_stmt = f"SET LOCAL {query_option};" await conn.execute(text(query_options_stmt)) - result = await conn.execute(text(stmt), param_dict) + result = await conn.execute(text(dense_query_stmt), param_dict) result_map = result.mappings() - results = result_map.fetchall() + dense_results = result_map.fetchall() else: async with self.engine.connect() as conn: - result = await conn.execute(text(stmt), param_dict) + result = await conn.execute(text(dense_query_stmt), param_dict) result_map = result.mappings() - results = result_map.fetchall() - return results + dense_results = result_map.fetchall() + + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + fts_query = ( + hybrid_search_config.fts_query + if hybrid_search_config and hybrid_search_config.fts_query + else kwargs.get("fts_query", "") + ) + if hybrid_search_config and fts_query: + hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k + # do the sparse query + lang = ( + f"'{hybrid_search_config.tsv_lang}'," + if hybrid_search_config.tsv_lang + else "" + ) + query_tsv = f"plainto_tsquery({lang} :fts_query)" + param_dict["fts_query"] = fts_query + if self.hybrid_search_column_exists: + content_tsv = f'"{hybrid_search_config.tsv_column}"' + else: + content_tsv = f'to_tsvector({lang} "{self.content_column}")' + sparse_query_stmt = f'SELECT {column_names}, ts_rank_cd({content_tsv}, {query_tsv}) as distance FROM "{self.schema_name}"."{self.table_name}" WHERE {content_tsv} @@ {query_tsv} {and_filters} ORDER BY distance desc LIMIT {hybrid_search_config.secondary_top_k};' + async with self.engine.connect() as conn: + result = await conn.execute(text(sparse_query_stmt), param_dict) + result_map = result.mappings() + sparse_results = result_map.fetchall() + + combined_results = hybrid_search_config.fusion_function( + dense_results, + sparse_results, + **hybrid_search_config.fusion_function_parameters, + ) + return combined_results + return dense_results async def asimilarity_search( self, @@ -603,6 +693,14 @@ async def asimilarity_search( ) kwargs["query"] = query + # add fts_query to hybrid_search_config + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + if hybrid_search_config and not hybrid_search_config.fts_query: + hybrid_search_config.fts_query = query + kwargs["hybrid_search_config"] = hybrid_search_config + return await self.asimilarity_search_by_vector( embedding=embedding, k=k, filter=filter, **kwargs ) @@ -634,6 +732,14 @@ async def asimilarity_search_with_score( ) kwargs["query"] = query + # add fts_query to hybrid_search_config + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + if hybrid_search_config and not hybrid_search_config.fts_query: + hybrid_search_config.fts_query = query + kwargs["hybrid_search_config"] = hybrid_search_config + docs = await self.asimilarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter, **kwargs ) @@ -806,15 +912,38 @@ async def aapply_vector_index( index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX name = index.name stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' + + if self.hybrid_search_config: + if self.hybrid_search_column_exists: + tsv_column_name = ( + self.hybrid_search_config.tsv_column + if self.hybrid_search_config.tsv_column + else self.content_column + "_tsv" + ) + tsv_column_name = f'"{tsv_column_name}"' + else: + lang = ( + f"'{self.hybrid_search_config.tsv_lang}'," + if self.hybrid_search_config.tsv_lang + else "" + ) + tsv_column_name = f"to_tsvector({lang} {self.content_column})" + tsv_index_name = self.table_name + self.hybrid_search_config.index_name + tsv_index_query = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {tsv_index_name} ON "{self.schema_name}"."{self.table_name}" USING {self.hybrid_search_config.index_type}({tsv_column_name});' + else: + tsv_index_query = "" + if concurrently: async with self.engine.connect() as conn: autocommit_conn = await conn.execution_options( isolation_level="AUTOCOMMIT" ) await autocommit_conn.execute(text(stmt)) + await conn.execute(text(tsv_index_query)) else: async with self.engine.connect() as conn: await conn.execute(text(stmt)) + await conn.execute(text(tsv_index_query)) await conn.commit() async def areindex(self, index_name: Optional[str] = None) -> None: diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_index.py b/tests/unit_tests/v2/test_async_pg_vectorstore_index.py index 3796ef5..dc8768d 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_index.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_index.py @@ -10,15 +10,14 @@ from langchain_postgres import PGEngine from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore -from langchain_postgres.v2.indexes import ( - DistanceStrategy, - HNSWIndex, - IVFFlatIndex, -) +from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig +from langchain_postgres.v2.indexes import (DistanceStrategy, HNSWIndex, + IVFFlatIndex) from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING uuid_str = str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE = "default" + uuid_str +DEFAULT_HYBRID_TABLE = "hybrid" + uuid_str DEFAULT_INDEX_NAME = "index" + uuid_str VECTOR_SIZE = 768 SIMPLE_TABLE = "default_table" @@ -55,8 +54,10 @@ class TestIndex: async def engine(self) -> AsyncIterator[PGEngine]: engine = PGEngine.from_connection_string(url=CONNECTION_STRING) yield engine - await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") - await aexecute(engine, f"DROP TABLE IF EXISTS {SIMPLE_TABLE}") + + engine._adrop_table(DEFAULT_TABLE) + engine._adrop_table(DEFAULT_HYBRID_TABLE) + engine._adrop_table(SIMPLE_TABLE) await engine.close() @pytest_asyncio.fixture(scope="class") @@ -92,6 +93,68 @@ async def test_aapply_vector_index(self, vs: AsyncPGVectorStore) -> None: assert await vs.is_valid_index(DEFAULT_INDEX_NAME) await vs.adrop_vector_index(DEFAULT_INDEX_NAME) + async def test_aapply_vector_index_non_hybrid_search_vs_without_tsv_column( + self, vs + ) -> None: + index = HNSWIndex(name="test_index_hybrid" + uuid_str) + + tsv_index_name = DEFAULT_TABLE + "langchain_tsv_index" + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.aapply_vector_index(index) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.adrop_vector_index(tsv_index_name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + + async def test_aapply_vector_index_hybrid_search_vs_without_tsv_column( + self, engine, vs + ) -> None: + # overwriting vs to get a hybrid vs + vs = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + hybrid_search_config=HybridSearchConfig(), + ) + index = HNSWIndex(name="test_index_hybrid" + uuid_str) + + tsv_index_name = DEFAULT_TABLE + "langchain_tsv_index" + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.adrop_vector_index(index.name) + await vs.aapply_vector_index(index) + assert await vs.is_valid_index(tsv_index_name) + await vs.areindex(tsv_index_name) + assert await vs.is_valid_index(tsv_index_name) + await vs.adrop_vector_index(tsv_index_name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + await vs.adrop_vector_index(index.name) + + async def test_aapply_vector_index_hybrid_search_with_tsv_column( + self, engine + ) -> None: + await engine._ainit_vectorstore_table( + DEFAULT_HYBRID_TABLE, VECTOR_SIZE, hybrid_search_config=HybridSearchConfig() + ) + vs = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_HYBRID_TABLE, + hybrid_search_config=HybridSearchConfig(), + ) + tsv_index_name = DEFAULT_HYBRID_TABLE + "langchain_tsv_index" + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + index = HNSWIndex(name=DEFAULT_INDEX_NAME) + await vs.aapply_vector_index(index) + await vs.adrop_vector_index(tsv_index_name) + await vs.adrop_vector_index(index.name) + is_valid_index = await vs.is_valid_index(tsv_index_name) + assert is_valid_index == False + async def test_areindex(self, vs: AsyncPGVectorStore) -> None: if not await vs.is_valid_index(DEFAULT_INDEX_NAME): index = HNSWIndex(name=DEFAULT_INDEX_NAME) diff --git a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py index 72f91d8..eb65da2 100644 --- a/tests/unit_tests/v2/test_async_pg_vectorstore_search.py +++ b/tests/unit_tests/v2/test_async_pg_vectorstore_search.py @@ -10,15 +10,18 @@ from langchain_postgres import Column, PGEngine from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore +from langchain_postgres.v2.hybrid_search_config import (HybridSearchConfig, + reciprocal_rank_fusion, + weighted_sum_ranking) from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions from tests.unit_tests.fixtures.metadata_filtering_data import ( - FILTERING_TEST_CASES, - METADATAS, -) + FILTERING_TEST_CASES, METADATAS) from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE1 = "test_table_hybrid1" + str(uuid.uuid4()).replace("-", "_") +HYBRID_SEARCH_TABLE2 = "test_table_hybrid2" + str(uuid.uuid4()).replace("-", "_") CUSTOM_FILTER_TABLE = "custom_filter" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 sync_method_exception_str = "Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead." @@ -41,6 +44,18 @@ filter_docs = [ Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) ] +# Documents designed for hybrid search testing +hybrid_docs_content = { + "hs_doc_apple_fruit": "An apple is a sweet and edible fruit produced by an apple tree. Apples are very common.", + "hs_doc_apple_tech": "Apple Inc. is a multinational technology company. Their latest tech is amazing.", + "hs_doc_orange_fruit": "The orange is the fruit of various citrus species. Oranges are tasty.", + "hs_doc_generic_tech": "Technology drives innovation in the modern world. Tech is evolving.", + "hs_doc_unrelated_cat": "A fluffy cat sat on a mat quietly observing a mouse.", +} +hybrid_docs = [ + Document(page_content=content, metadata={"doc_id_key": key}) + for key, content in hybrid_docs_content.items() +] def get_env_var(key: str, desc: str) -> str: @@ -69,6 +84,8 @@ async def engine(self) -> AsyncIterator[PGEngine]: await engine.adrop_table(DEFAULT_TABLE) await engine.adrop_table(CUSTOM_TABLE) await engine.adrop_table(CUSTOM_FILTER_TABLE) + await engine.adrop_table(HYBRID_SEARCH_TABLE1) + await engine.adrop_table(HYBRID_SEARCH_TABLE2) await engine.close() @pytest_asyncio.fixture(scope="class") @@ -111,6 +128,79 @@ async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore] await vs_custom.aadd_documents(docs, ids=ids) yield vs_custom + @pytest_asyncio.fixture(scope="class") + async def vs_hybrid_search_without_tsv_column(self, engine): + hybrid_search_config = HybridSearchConfig( + tsv_column="my_tsv_col", + tsv_lang="pg_catalog.english", + fts_query="my_fts_query", + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 60, + "fetch_top_k": 10, + }, + ) + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE1, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + metadata_json_column="mymetadata", # ignored + store_metadata=False, + ) + + vs_custom = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE1, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_json_column="mymetadata", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=hybrid_search_config, + ) + await vs_custom.aadd_documents(hybrid_docs) + yield vs_custom + + @pytest_asyncio.fixture(scope="class") + async def vs_hybrid_search_with_tsv_column(self, engine): + await engine._ainit_vectorstore_table( + HYBRID_SEARCH_TABLE2, + VECTOR_SIZE, + id_column=Column("myid", "TEXT"), + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[ + Column("page", "TEXT"), + Column("source", "TEXT"), + Column("doc_id_key", "TEXT"), + ], + store_metadata=False, + hybrid_search_config=HybridSearchConfig(), + ) + + vs_custom = await AsyncPGVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=HYBRID_SEARCH_TABLE2, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["doc_id_key"], + index_query_options=HNSWQueryOptions(ef_search=1), + hybrid_search_config=HybridSearchConfig(), + ) + await vs_custom.aadd_documents(hybrid_docs) + yield vs_custom + @pytest_asyncio.fixture(scope="class") async def vs_custom_filter( self, engine: PGEngine @@ -303,3 +393,271 @@ async def test_vectorstore_with_metadata_filters( "meow", k=5, filter=test_filter ) assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + + async def test_asimilarity_hybrid_search(self, vs): + results = await vs.asimilarity_search( + "foo", k=1, hybrid_search_config=HybridSearchConfig() + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + hybrid_search_config=HybridSearchConfig(), + ) + assert results[0] == Document(page_content="bar", id=ids[1]) + + results = await vs.asimilarity_search( + "foo", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.1, + "secondary_results_weight": 0.9, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="foo", id=ids[0])] + + async def test_asimilarity_hybrid_search_rrk(self, vs): + results = await vs.asimilarity_search( + "foo", + k=1, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion + ), + ) + assert len(results) == 1 + assert results == [Document(page_content="foo", id=ids[0])] + + results = await vs.asimilarity_search( + "bar", + k=1, + filter={"content": {"$ne": "baz"}}, + hybrid_search_config=HybridSearchConfig( + fusion_function=reciprocal_rank_fusion, + fusion_function_parameters={ + "rrf_k": 100, + "fetch_top_k": 10, + }, + primary_top_k=1, + secondary_top_k=1, + ), + ) + assert results == [Document(page_content="bar", id=ids[1])] + + async def test_hybrid_search_weighted_sum_default( + self, vs_hybrid_search_with_tsv_column + ): + """Test hybrid search with default weighted sum (0.5 vector, 0.5 FTS).""" + query = "apple" # Should match "apple" in FTS and vector + + # The vs_hybrid_search_with_tsv_column instance is already configured for hybrid search. + # Default fusion is weighted_sum_ranking with 0.5/0.5 weights. + # fts_query will default to the main query. + results_with_scores = ( + await vs_hybrid_search_with_tsv_column.asimilarity_search_with_score( + query, k=3 + ) + ) + + assert len(results_with_scores) > 1 + result_ids = [doc.metadata["doc_id_key"] for doc, score in results_with_scores] + + # Expect "hs_doc_apple_fruit" and "hs_doc_apple_tech" to be highly ranked. + assert "hs_doc_apple_fruit" in result_ids + + # Scores should be floats (fused scores) + for doc, score in results_with_scores: + assert isinstance(score, float) + + # Check if sorted by score (descending for weighted_sum_ranking with positive scores) + assert results_with_scores[0][1] >= results_with_scores[1][1] + + async def test_hybrid_search_weighted_sum_vector_bias( + self, + vs_hybrid_search_without_tsv_column, + ): + """Test weighted sum with higher weight for vector results.""" + query = "Apple Inc technology" # More specific for vector similarity + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", # Must match table setup + fusion_function_parameters={ + "primary_results_weight": 0.8, # Vector bias + "secondary_results_weight": 0.2, + }, + # fts_query will default to main query + ) + results = await vs_hybrid_search_without_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) > 0 + assert result_ids[0] == "hs_doc_orange_fruit" + + async def test_hybrid_search_weighted_sum_fts_bias( + self, + vs_hybrid_search_with_tsv_column, + ): + """Test weighted sum with higher weight for FTS results.""" + query = "fruit common tasty" # Strong FTS signal for fruit docs + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fusion_function=weighted_sum_ranking, + fusion_function_parameters={ + "primary_results_weight": 0.01, + "secondary_results_weight": 0.99, # FTS bias + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) == 2 + assert "hs_doc_apple_fruit" in result_ids + + async def test_hybrid_search_reciprocal_rank_fusion( + self, + vs_hybrid_search_with_tsv_column, + ): + """Test hybrid search with Reciprocal Rank Fusion.""" + query = "technology company" + + # Configure RRF. primary_top_k and secondary_top_k control inputs to fusion. + # fusion_function_parameters.fetch_top_k controls output count from RRF. + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fusion_function=reciprocal_rank_fusion, + primary_top_k=3, # How many dense results to consider + secondary_top_k=3, # How many sparse results to consider + fusion_function_parameters={ + "rrf_k": 60, + "fetch_top_k": 2, + }, # RRF specific params + ) + # The `k` in asimilarity_search here is the final desired number of results, + # which should align with fusion_function_parameters.fetch_top_k for RRF. + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(result_ids) == 2 + # "hs_doc_apple_tech" (FTS: technology, company; Vector: Apple Inc technology) + # "hs_doc_generic_tech" (FTS: technology; Vector: Technology drives innovation) + # RRF should combine these ranks. "hs_doc_apple_tech" is likely higher. + assert "hs_doc_apple_tech" in result_ids + assert result_ids[0] == "hs_doc_apple_tech" # Stronger combined signal + + async def test_hybrid_search_explicit_fts_query( + self, vs_hybrid_search_with_tsv_column + ): + """Test hybrid search when fts_query in HybridSearchConfig is different from main query.""" + main_vector_query = "Apple Inc." # For vector search + fts_specific_query = "fruit" # For FTS + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_specific_query, # Override FTS query + fusion_function_parameters={ # Using default weighted_sum_ranking + "primary_results_weight": 0.5, + "secondary_results_weight": 0.5, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + main_vector_query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Vector search for "Apple Inc.": hs_doc_apple_tech + # FTS search for "fruit": hs_doc_apple_fruit, hs_doc_orange_fruit + # Combined: hs_doc_apple_fruit (strong FTS) and hs_doc_apple_tech (strong vector) are candidates. + # "hs_doc_apple_fruit" might get a boost if "Apple Inc." vector has some similarity to "apple fruit" doc. + assert len(result_ids) > 0 + assert ( + "hs_doc_apple_fruit" in result_ids + or "hs_doc_apple_tech" in result_ids + or "hs_doc_orange_fruit" in result_ids + ) + + async def test_hybrid_search_with_filter(self, vs_hybrid_search_with_tsv_column): + """Test hybrid search with a metadata filter applied.""" + query = "apple" + # Filter to only include "tech" related apple docs using metadata + # Assuming metadata_columns=["doc_id_key"] was set up for vs_hybrid_search_with_tsv_column + doc_filter = {"doc_id_key": {"$eq": "hs_doc_apple_tech"}} + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + query, k=2, filter=doc_filter, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + assert len(results) == 1 + assert result_ids[0] == "hs_doc_apple_tech" + + async def test_hybrid_search_fts_empty_results( + self, vs_hybrid_search_with_tsv_column + ): + """Test when FTS query yields no results, should fall back to vector search.""" + vector_query = "apple" + no_match_fts_query = "zzyyxx_gibberish_term_for_fts_nomatch" + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=no_match_fts_query, + fusion_function_parameters={ + "primary_results_weight": 0.6, + "secondary_results_weight": 0.4, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query, k=2, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Expect results based purely on vector search for "apple" + assert len(result_ids) > 0 + assert "hs_doc_apple_fruit" in result_ids or "hs_doc_apple_tech" in result_ids + # The top result should be one of the apple documents based on vector search + assert results[0].metadata["doc_id_key"].startswith("hs_doc_unrelated_cat") + + async def test_hybrid_search_vector_empty_results_effectively( + self, + vs_hybrid_search_with_tsv_column, + ): + """Test when vector query is very dissimilar to docs, should rely on FTS.""" + # This is hard to guarantee with fake embeddings, but we try. + # A better way might be to use a filter that excludes all docs for the vector part, + # but filters are applied to both. + vector_query_far_off = "supercalifragilisticexpialidocious_vector_nomatch" + fts_query_match = "orange fruit" # Should match hs_doc_orange_fruit + + config = HybridSearchConfig( + tsv_column="mycontent_tsv", + fts_query=fts_query_match, + fusion_function_parameters={ + "primary_results_weight": 0.4, + "secondary_results_weight": 0.6, + }, + ) + results = await vs_hybrid_search_with_tsv_column.asimilarity_search( + vector_query_far_off, k=1, hybrid_search_config=config + ) + result_ids = [doc.metadata["doc_id_key"] for doc in results] + + # Expect results based purely on FTS search for "orange fruit" + assert len(result_ids) == 1 + assert result_ids[0] == "hs_doc_generic_tech"