Skip to content

feat: adds hybrid search for async VS interface [3/N] #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: hybrid_search_2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 150 additions & 21 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,10 @@
from sqlalchemy.ext.asyncio import AsyncEngine

from .engine import PGEngine
from .indexes import (
DEFAULT_DISTANCE_STRATEGY,
DEFAULT_INDEX_NAME_SUFFIX,
BaseIndex,
DistanceStrategy,
ExactNearestNeighbor,
QueryOptions,
)
from .hybrid_search_config import HybridSearchConfig
from .indexes import (DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX,
BaseIndex, DistanceStrategy, ExactNearestNeighbor,
QueryOptions)

COMPARISONS_TO_NATIVE = {
"$eq": "=",
Expand Down Expand Up @@ -77,6 +73,8 @@ def __init__(
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
hybrid_search_column_exists: bool = False,
):
"""AsyncPGVectorStore constructor.
Args:
Expand All @@ -95,6 +93,8 @@ def __init__(
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column.


Raises:
Expand All @@ -119,6 +119,8 @@ def __init__(
self.fetch_k = fetch_k
self.lambda_mult = lambda_mult
self.index_query_options = index_query_options
self.hybrid_search_config = hybrid_search_config
self.hybrid_search_column_exists = hybrid_search_column_exists

@classmethod
async def create(
Expand All @@ -139,6 +141,7 @@ async def create(
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
) -> AsyncPGVectorStore:
"""Create an AsyncPGVectorStore instance.

Expand All @@ -158,6 +161,7 @@ async def create(
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.

Returns:
AsyncPGVectorStore
Expand Down Expand Up @@ -193,6 +197,17 @@ async def create(
raise ValueError(
f"Content column, {content_column}, is type, {content_type}. It must be a type of character string."
)
hybrid_search_column_exists = False
if hybrid_search_config:
tsv_column_name = (
hybrid_search_config.tsv_column
if hybrid_search_config.tsv_column
else content_column + "_tsv"
)
hybrid_search_config.tsv_column = tsv_column_name

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we only update the name if the column exists, then we don't have to manage the separate boolean?

hybrid_search_column_exists = (
tsv_column_name in columns and columns[tsv_column_name] == "tsvector"
)
if embedding_column not in columns:
raise ValueError(f"Embedding column, {embedding_column}, does not exist.")
if columns[embedding_column] != "USER-DEFINED":
Expand Down Expand Up @@ -236,6 +251,8 @@ async def create(
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
hybrid_search_column_exists=hybrid_search_column_exists,
)

@property
Expand Down Expand Up @@ -273,7 +290,12 @@ async def aadd_embeddings(
if len(self.metadata_columns) > 0
else ""
)
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}'
hybrid_search_column = (
f', "{self.hybrid_search_config.tsv_column}"'
if self.hybrid_search_config and self.hybrid_search_column_exists
else ""
)
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{hybrid_search_column}{metadata_col_names}'
values = {
"id": id,
"content": content,
Expand All @@ -284,6 +306,14 @@ async def aadd_embeddings(
if not embedding and can_inline_embed:
values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore

if self.hybrid_search_config and self.hybrid_search_column_exists:
lang = (
f"'{self.hybrid_search_config.tsv_lang}',"
if self.hybrid_search_config.tsv_lang
else ""
)
values_stmt += f", to_tsvector({lang} :tsv_content)"
values["tsv_content"] = content
# Add metadata
extra = copy.deepcopy(metadata)
for metadata_column in self.metadata_columns:
Expand All @@ -308,6 +338,9 @@ async def aadd_embeddings(

upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"'

if self.hybrid_search_config and self.hybrid_search_column_exists:
upsert_stmt += f', "{self.hybrid_search_config.tsv_column}" = EXCLUDED."{self.hybrid_search_config.tsv_column}"'

if self.metadata_json_column:
upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"'

Expand Down Expand Up @@ -408,6 +441,7 @@ async def afrom_texts( # type: ignore[override]
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
**kwargs: Any,
) -> AsyncPGVectorStore:
"""Create an AsyncPGVectorStore instance from texts.
Expand All @@ -430,6 +464,7 @@ async def afrom_texts( # type: ignore[override]
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column.

Raises:
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
Expand All @@ -453,6 +488,7 @@ async def afrom_texts( # type: ignore[override]
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
)
await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
return vs
Expand All @@ -478,6 +514,7 @@ async def afrom_documents( # type: ignore[override]
fetch_k: int = 20,
lambda_mult: float = 0.5,
index_query_options: Optional[QueryOptions] = None,
hybrid_search_config: Optional[HybridSearchConfig] = None,
**kwargs: Any,
) -> AsyncPGVectorStore:
"""Create an AsyncPGVectorStore instance from documents.
Expand All @@ -500,6 +537,7 @@ async def afrom_documents( # type: ignore[override]
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
index_query_options (QueryOptions): Index query option.
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column.

Raises:
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
Expand All @@ -524,6 +562,7 @@ async def afrom_documents( # type: ignore[override]
fetch_k=fetch_k,
lambda_mult=lambda_mult,
index_query_options=index_query_options,
hybrid_search_config=hybrid_search_config,
)
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
Expand All @@ -538,16 +577,30 @@ async def __query_collection(
filter: Optional[dict] = None,
**kwargs: Any,
) -> Sequence[RowMapping]:
"""Perform similarity search query on database."""
k = k if k else self.k
"""
Perform similarity search (or hybrid search) query on database.
Queries might be slow if the hybrid search column does not exist.
For best hybrid search performance, consider creating a TSV column
and adding GIN index.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Since this is a private function, this docstring will not be visible to users. Consider moving it elsewhere

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please clarify that only the hybrid search will be slow (not the normal similarity searches)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to do this for the user when creating the new table. At a minimum, we should add documentation for using hybrid search and call this out.

"""
if not k:
k = (
max(
self.k,
self.hybrid_search_config.primary_top_k,
self.hybrid_search_config.secondary_top_k,
)
if self.hybrid_search_config
else self.k
)
operator = self.distance_strategy.operator
search_function = self.distance_strategy.search_function

columns = self.metadata_columns + [
columns = [
self.id_column,
self.content_column,
self.embedding_column,
]
] + self.metadata_columns
if self.metadata_json_column:
columns.append(self.metadata_json_column)

Expand All @@ -557,16 +610,18 @@ async def __query_collection(
filter_dict = None
if filter and isinstance(filter, dict):
safe_filter, filter_dict = self._create_filter_clause(filter)
param_filter = f"WHERE {safe_filter}" if safe_filter else ""
where_filters = f"WHERE {safe_filter}" if safe_filter else ""
and_filters = f"AND ({safe_filter})" if safe_filter else ""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment on what this is or move this closer to where the code is used.


inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
if not embedding and callable(inline_embed_func) and "query" in kwargs:
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore
embedding_data_string = f"{query_embedding}"
else:
query_embedding = f"{[float(dimension) for dimension in embedding]}"
embedding_data_string = ":query_embedding"
stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance
FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k;
dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance
FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k;
"""
param_dict = {"query_embedding": query_embedding, "k": k}
if filter_dict:
Expand All @@ -577,15 +632,50 @@ async def __query_collection(
for query_option in self.index_query_options.to_parameter():
query_options_stmt = f"SET LOCAL {query_option};"
await conn.execute(text(query_options_stmt))
result = await conn.execute(text(stmt), param_dict)
result = await conn.execute(text(dense_query_stmt), param_dict)
result_map = result.mappings()
results = result_map.fetchall()
dense_results = result_map.fetchall()
else:
async with self.engine.connect() as conn:
result = await conn.execute(text(stmt), param_dict)
result = await conn.execute(text(dense_query_stmt), param_dict)
result_map = result.mappings()
results = result_map.fetchall()
return results
dense_results = result_map.fetchall()

hybrid_search_config = kwargs.get(
"hybrid_search_config", self.hybrid_search_config
)
fts_query = (
hybrid_search_config.fts_query
if hybrid_search_config and hybrid_search_config.fts_query
else kwargs.get("fts_query", "")
)
if hybrid_search_config and fts_query:
hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k
# do the sparse query
lang = (
f"'{hybrid_search_config.tsv_lang}',"
if hybrid_search_config.tsv_lang
else ""
)
query_tsv = f"plainto_tsquery({lang} :fts_query)"
param_dict["fts_query"] = fts_query
if self.hybrid_search_column_exists:
content_tsv = f'"{hybrid_search_config.tsv_column}"'
else:
content_tsv = f'to_tsvector({lang} "{self.content_column}")'
sparse_query_stmt = f'SELECT {column_names}, ts_rank_cd({content_tsv}, {query_tsv}) as distance FROM "{self.schema_name}"."{self.table_name}" WHERE {content_tsv} @@ {query_tsv} {and_filters} ORDER BY distance desc LIMIT {hybrid_search_config.secondary_top_k};'
async with self.engine.connect() as conn:
result = await conn.execute(text(sparse_query_stmt), param_dict)
result_map = result.mappings()
sparse_results = result_map.fetchall()

combined_results = hybrid_search_config.fusion_function(
dense_results,
sparse_results,
**hybrid_search_config.fusion_function_parameters,
)
return combined_results
return dense_results

async def asimilarity_search(
self,
Expand All @@ -603,6 +693,14 @@ async def asimilarity_search(
)
kwargs["query"] = query

# add fts_query to hybrid_search_config
hybrid_search_config = kwargs.get(
"hybrid_search_config", self.hybrid_search_config
)
if hybrid_search_config and not hybrid_search_config.fts_query:
hybrid_search_config.fts_query = query
kwargs["hybrid_search_config"] = hybrid_search_config

return await self.asimilarity_search_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs
)
Expand Down Expand Up @@ -634,6 +732,14 @@ async def asimilarity_search_with_score(
)
kwargs["query"] = query

# add fts_query to hybrid_search_config
hybrid_search_config = kwargs.get(
"hybrid_search_config", self.hybrid_search_config
)
if hybrid_search_config and not hybrid_search_config.fts_query:
hybrid_search_config.fts_query = query
kwargs["hybrid_search_config"] = hybrid_search_config

docs = await self.asimilarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs
)
Expand Down Expand Up @@ -806,15 +912,38 @@ async def aapply_vector_index(
index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX
name = index.name
stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};'

if self.hybrid_search_config:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this in the vector index method

if self.hybrid_search_column_exists:
tsv_column_name = (
self.hybrid_search_config.tsv_column
if self.hybrid_search_config.tsv_column
else self.content_column + "_tsv"
)
tsv_column_name = f'"{tsv_column_name}"'
else:
lang = (
f"'{self.hybrid_search_config.tsv_lang}',"
if self.hybrid_search_config.tsv_lang
else ""
)
tsv_column_name = f"to_tsvector({lang} {self.content_column})"
tsv_index_name = self.table_name + self.hybrid_search_config.index_name
tsv_index_query = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {tsv_index_name} ON "{self.schema_name}"."{self.table_name}" USING {self.hybrid_search_config.index_type}({tsv_column_name});'
else:
tsv_index_query = ""

if concurrently:
async with self.engine.connect() as conn:
autocommit_conn = await conn.execution_options(
isolation_level="AUTOCOMMIT"
)
await autocommit_conn.execute(text(stmt))
await conn.execute(text(tsv_index_query))
else:
async with self.engine.connect() as conn:
await conn.execute(text(stmt))
await conn.execute(text(tsv_index_query))
await conn.commit()

async def areindex(self, index_name: Optional[str] = None) -> None:
Expand Down
Loading