Skip to content

feature/hybrid-search #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,34 @@ async def __query_collection(
else:
query_embedding = f"{[float(dimension) for dimension in embedding]}"
embedding_data_string = ":query_embedding"
stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance
FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k;
"""

if "query" in kwargs and "full_text_weight" in kwargs:
if not isinstance(kwargs["full_text_weight"], (int, float)) or kwargs["full_text_weight"] < 0 or kwargs["full_text_weight"] > 1:
raise ValueError("full_text_weight must be a number between 0 and 1")

full_text_weight = float(kwargs["full_text_weight"])
full_text_query = " | ".join(kwargs["query"].split(" "))

stmt = f"""
WITH A AS (
SELECT {column_names},
{search_function}("{self.embedding_column}", {embedding_data_string}) as semantic_distance,
ts_rank(to_tsvector('english', "{self.content_column}"), to_tsquery('{full_text_query}')) as keyword_distance
FROM "{self.schema_name}"."{self.table_name}" {param_filter}
),
B AS (
SELECT *,
semantic_distance * {1.0 - full_text_weight} + keyword_distance * {full_text_weight} as distance
FROM A
)

SELECT * FROM B ORDER BY distance LIMIT :k;
"""
else:
stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance
FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k;
"""

param_dict = {"query_embedding": query_embedding, "k": k}
if filter_dict:
param_dict.update(filter_dict)
Expand Down
16 changes: 16 additions & 0 deletions langchain_postgres/v2/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ async def _ainit_vectorstore_table(
id_column: Union[str, Column, ColumnDict] = "langchain_id",
overwrite_existing: bool = False,
store_metadata: bool = True,
full_text_index: str | None = None

) -> None:
"""
Create a table for saving of vectors to be used with PGVectorStore.
Expand All @@ -178,6 +180,8 @@ async def _ainit_vectorstore_table(
overwrite_existing (bool): Whether to drop existing table. Default: False.
store_metadata (bool): Whether to store metadata in the table.
Default: True.
full_text_index (str): Language used to construct full text index. If None then no index will be used.
Default: None

Raises:
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
Expand Down Expand Up @@ -241,6 +245,10 @@ async def _ainit_vectorstore_table(
query += f""",\n"{metadata_json_column}" JSON"""
query += "\n);"

if full_text_index:
query += f"""CREATE INDEX ON \"{schema_name}\".\"{table_name}\"
USING GIN (to_tsvector('{full_text_index}', "{content_column}"))"""

async with self._pool.connect() as conn:
await conn.execute(text(query))
await conn.commit()
Expand All @@ -258,6 +266,7 @@ async def ainit_vectorstore_table(
id_column: Union[str, Column, ColumnDict] = "langchain_id",
overwrite_existing: bool = False,
store_metadata: bool = True,
full_text_index: str | None = None
) -> None:
"""
Create a table for saving of vectors to be used with PGVectorStore.
Expand All @@ -280,6 +289,8 @@ async def ainit_vectorstore_table(
overwrite_existing (bool): Whether to drop existing table. Default: False.
store_metadata (bool): Whether to store metadata in the table.
Default: True.
full_text_index (str): Language used to construct full text index. If None then no index will be used.
Default: None
"""
await self._run_as_async(
self._ainit_vectorstore_table(
Expand All @@ -293,6 +304,7 @@ async def ainit_vectorstore_table(
id_column=id_column,
overwrite_existing=overwrite_existing,
store_metadata=store_metadata,
full_text_index=full_text_index
)
)

Expand All @@ -309,6 +321,7 @@ def init_vectorstore_table(
id_column: Union[str, Column, ColumnDict] = "langchain_id",
overwrite_existing: bool = False,
store_metadata: bool = True,
full_text_index: str = None
) -> None:
"""
Create a table for saving of vectors to be used with PGVectorStore.
Expand All @@ -331,6 +344,8 @@ def init_vectorstore_table(
overwrite_existing (bool): Whether to drop existing table. Default: False.
store_metadata (bool): Whether to store metadata in the table.
Default: True.
full_text_index (str): Language used to construct full text index. If None then no index will be used.
Default: None
"""
self._run_as_sync(
self._ainit_vectorstore_table(
Expand All @@ -344,6 +359,7 @@ def init_vectorstore_table(
id_column=id_column,
overwrite_existing=overwrite_existing,
store_metadata=store_metadata,
full_text_index=full_text_index
)
)

Expand Down
15 changes: 14 additions & 1 deletion tests/unit_tests/v2/test_pg_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def engine(self) -> AsyncIterator[PGEngine]:
@pytest_asyncio.fixture(scope="class")
async def vs(self, engine: PGEngine) -> AsyncIterator[PGVectorStore]:
await engine.ainit_vectorstore_table(
DEFAULT_TABLE, VECTOR_SIZE, store_metadata=False
DEFAULT_TABLE, VECTOR_SIZE, store_metadata=False, full_text_index="english"
)
vs = await PGVectorStore.create(
engine,
Expand Down Expand Up @@ -197,6 +197,19 @@ async def test_similarity_search_with_relevance_scores_threshold_cosine(
assert len(results) == 1
assert results[0][0] == Document(page_content="foo", id=ids[0])

@pytest.mark.parametrize(("full_text_weight", "expected_result", "expected_score"), [(1.0, "foo", 1.0), (0.5, "foo", 1.4392072893679142), (0.0, "bar", 2.0)])
async def test_similarity_search_hybrid_1(
self, vs: PGVectorStore, full_text_weight, expected_result, expected_score
) -> None:
results = await vs.asimilarity_search_with_relevance_scores(
"foo", full_text_weight=full_text_weight
)
top_doc, top_score = results[0]
assert top_doc.page_content == expected_result
assert top_score == expected_score



async def test_similarity_search_with_relevance_scores_threshold_euclidean(
self, engine: PGEngine
) -> None:
Expand Down