Skip to content

feat: adds hybrid search for sync VS interface [4/N] #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: hybrid_search_3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions langchain_postgres/v2/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .async_vectorstore import AsyncPGVectorStore
from .engine import PGEngine
from .hybrid_search_config import HybridSearchConfig
from .indexes import (
DEFAULT_DISTANCE_STRATEGY,
BaseIndex,
Expand Down Expand Up @@ -59,6 +60,7 @@ async def create(
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
) -> PGVectorStore:
"""Create an PGVectorStore instance.

Expand All @@ -78,6 +80,7 @@ async def create(
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Returns:
PGVectorStore
Expand All @@ -98,6 +101,7 @@ async def create(
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
)
vs = await engine._run_as_async(coro)
return cls(cls.__create_key, engine, vs)
Expand All @@ -120,6 +124,7 @@ def create_sync(
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
) -> PGVectorStore:
"""Create an PGVectorStore instance.

Expand All @@ -140,6 +145,7 @@ def create_sync(
fetch_k (int, optional): Number of Documents to fetch to pass to MMR algorithm. Defaults to 20.
lambda_mult (float, optional): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (Optional[QueryOptions], optional): Index query option. Defaults to None.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Returns:
PGVectorStore
Expand All @@ -160,6 +166,7 @@ def create_sync(
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
)
vs = engine._run_as_sync(coro)
return cls(cls.__create_key, engine, vs)
Expand Down Expand Up @@ -301,6 +308,7 @@ async def afrom_texts( # type: ignore[override]
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
**kwargs: Any,
) -> PGVectorStore:
"""Create an PGVectorStore instance from texts.
Expand All @@ -324,6 +332,7 @@ async def afrom_texts( # type: ignore[override]
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Raises:
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
Expand All @@ -347,6 +356,7 @@ async def afrom_texts( # type: ignore[override]
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
)
await vs.aadd_texts(texts, metadatas=metadatas, ids=ids)
return vs
Expand All @@ -371,6 +381,7 @@ async def afrom_documents( # type: ignore[override]
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
**kwargs: Any,
) -> PGVectorStore:
"""Create an PGVectorStore instance from documents.
Expand All @@ -393,6 +404,7 @@ async def afrom_documents( # type: ignore[override]
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Raises:
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
Expand All @@ -417,6 +429,7 @@ async def afrom_documents( # type: ignore[override]
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
)
await vs.aadd_documents(documents, ids=ids)
return vs
Expand All @@ -442,6 +455,7 @@ def from_texts( # type: ignore[override]
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
**kwargs: Any,
) -> PGVectorStore:
"""Create an PGVectorStore instance from texts.
Expand All @@ -465,6 +479,7 @@ def from_texts( # type: ignore[override]
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Raises:
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
Expand All @@ -488,6 +503,7 @@ def from_texts( # type: ignore[override]
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
**kwargs,
)
vs.add_texts(texts, metadatas=metadatas, ids=ids)
Expand All @@ -513,6 +529,7 @@ def from_documents( # type: ignore[override]
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
**kwargs: Any,
) -> PGVectorStore:
"""Create an PGVectorStore instance from documents.
Expand All @@ -535,6 +552,7 @@ def from_documents( # type: ignore[override]
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Raises:
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
Expand All @@ -558,6 +576,7 @@ def from_documents( # type: ignore[override]
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
**kwargs,
)
vs.add_documents(documents, ids=ids)
Expand Down
60 changes: 60 additions & 0 deletions tests/unit_tests/v2/test_pg_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from sqlalchemy import text

from langchain_postgres import Column, PGEngine, PGVectorStore
from langchain_postgres.v2.hybrid_search_config import (
HybridSearchConfig,
reciprocal_rank_fusion,
weighted_sum_ranking,
)
from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions
from tests.unit_tests.fixtures.metadata_filtering_data import (
FILTERING_TEST_CASES,
Expand Down Expand Up @@ -261,6 +266,37 @@ async def test_vectorstore_with_metadata_filters(
)
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter

async def test_asimilarity_hybrid_search(self, vs: PGVectorStore):
results = await vs.asimilarity_search(
"foo", k=1, hybrid_search_config=HybridSearchConfig()
)
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]

results = await vs.asimilarity_search(
"bar",
k=1,
hybrid_search_config=HybridSearchConfig(),
)
assert results[0] == Document(page_content="bar", id=ids[1])

results = await vs.asimilarity_search(
"foo",
k=1,
filter={"content": {"$ne": "baz"}},
hybrid_search_config=HybridSearchConfig(
fusion_function=weighted_sum_ranking,
fusion_function_parameters={
"primary_results_weight": 0.1,
"secondary_results_weight": 0.9,
"fetch_top_k": 10,
},
primary_top_k=1,
secondary_top_k=1,
),
)
assert results == [Document(page_content="foo", id=ids[0])]


@pytest.mark.enable_socket
class TestVectorStoreSearchSync:
Expand Down Expand Up @@ -401,3 +437,27 @@ def test_metadata_filter_negative_tests(
docs = vs_custom_filter_sync.similarity_search(
"meow", k=5, filter=test_filter
)

def test_similarity_hybrid_search(self, vs_custom):
results = vs_custom.similarity_search(
"foo", k=1, hybrid_search_config=HybridSearchConfig()
)
assert len(results) == 1
assert results == [Document(page_content="foo", id=ids[0])]

results = vs_custom.similarity_search(
"bar",
k=1,
hybrid_search_config=HybridSearchConfig(),
)
assert results == [Document(page_content="bar", id=ids[1])]

results = vs_custom.similarity_search(
"foo",
k=1,
filter={"mycontent": {"$ne": "baz"}},
hybrid_search_config=HybridSearchConfig(
fusion_function=reciprocal_rank_fusion
),
)
assert results == [Document(page_content="baz", id=ids[2])]