-
Notifications
You must be signed in to change notification settings - Fork 0
feat: adds hybrid search for async VS interface [3/N] #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hybrid_search_2
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,14 +14,10 @@ | |
from sqlalchemy.ext.asyncio import AsyncEngine | ||
|
||
from .engine import PGEngine | ||
from .indexes import ( | ||
DEFAULT_DISTANCE_STRATEGY, | ||
DEFAULT_INDEX_NAME_SUFFIX, | ||
BaseIndex, | ||
DistanceStrategy, | ||
ExactNearestNeighbor, | ||
QueryOptions, | ||
) | ||
from .hybrid_search_config import HybridSearchConfig | ||
from .indexes import (DEFAULT_DISTANCE_STRATEGY, DEFAULT_INDEX_NAME_SUFFIX, | ||
BaseIndex, DistanceStrategy, ExactNearestNeighbor, | ||
QueryOptions) | ||
|
||
COMPARISONS_TO_NATIVE = { | ||
"$eq": "=", | ||
|
@@ -77,6 +73,8 @@ def __init__( | |
fetch_k: int = 20, | ||
lambda_mult: float = 0.5, | ||
index_query_options: Optional[QueryOptions] = None, | ||
hybrid_search_config: Optional[HybridSearchConfig] = None, | ||
hybrid_search_column_exists: bool = False, | ||
): | ||
"""AsyncPGVectorStore constructor. | ||
Args: | ||
|
@@ -95,6 +93,8 @@ def __init__( | |
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. | ||
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. | ||
index_query_options (QueryOptions): Index query option. | ||
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. | ||
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. | ||
|
||
|
||
Raises: | ||
|
@@ -119,6 +119,8 @@ def __init__( | |
self.fetch_k = fetch_k | ||
self.lambda_mult = lambda_mult | ||
self.index_query_options = index_query_options | ||
self.hybrid_search_config = hybrid_search_config | ||
self.hybrid_search_column_exists = hybrid_search_column_exists | ||
|
||
@classmethod | ||
async def create( | ||
|
@@ -139,6 +141,7 @@ async def create( | |
fetch_k: int = 20, | ||
lambda_mult: float = 0.5, | ||
index_query_options: Optional[QueryOptions] = None, | ||
hybrid_search_config: Optional[HybridSearchConfig] = None, | ||
) -> AsyncPGVectorStore: | ||
"""Create an AsyncPGVectorStore instance. | ||
|
||
|
@@ -158,6 +161,7 @@ async def create( | |
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. | ||
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. | ||
index_query_options (QueryOptions): Index query option. | ||
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None. | ||
|
||
Returns: | ||
AsyncPGVectorStore | ||
|
@@ -193,6 +197,17 @@ async def create( | |
raise ValueError( | ||
f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." | ||
) | ||
hybrid_search_column_exists = False | ||
if hybrid_search_config: | ||
tsv_column_name = ( | ||
hybrid_search_config.tsv_column | ||
if hybrid_search_config.tsv_column | ||
else content_column + "_tsv" | ||
) | ||
hybrid_search_config.tsv_column = tsv_column_name | ||
hybrid_search_column_exists = ( | ||
tsv_column_name in columns and columns[tsv_column_name] == "tsvector" | ||
) | ||
if embedding_column not in columns: | ||
raise ValueError(f"Embedding column, {embedding_column}, does not exist.") | ||
if columns[embedding_column] != "USER-DEFINED": | ||
|
@@ -236,6 +251,8 @@ async def create( | |
fetch_k=fetch_k, | ||
lambda_mult=lambda_mult, | ||
index_query_options=index_query_options, | ||
hybrid_search_config=hybrid_search_config, | ||
hybrid_search_column_exists=hybrid_search_column_exists, | ||
) | ||
|
||
@property | ||
|
@@ -273,7 +290,12 @@ async def aadd_embeddings( | |
if len(self.metadata_columns) > 0 | ||
else "" | ||
) | ||
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}' | ||
hybrid_search_column = ( | ||
f', "{self.hybrid_search_config.tsv_column}"' | ||
if self.hybrid_search_config and self.hybrid_search_column_exists | ||
else "" | ||
) | ||
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{hybrid_search_column}{metadata_col_names}' | ||
values = { | ||
"id": id, | ||
"content": content, | ||
|
@@ -284,6 +306,14 @@ async def aadd_embeddings( | |
if not embedding and can_inline_embed: | ||
values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore | ||
|
||
if self.hybrid_search_config and self.hybrid_search_column_exists: | ||
lang = ( | ||
f"'{self.hybrid_search_config.tsv_lang}'," | ||
if self.hybrid_search_config.tsv_lang | ||
else "" | ||
) | ||
values_stmt += f", to_tsvector({lang} :tsv_content)" | ||
values["tsv_content"] = content | ||
# Add metadata | ||
extra = copy.deepcopy(metadata) | ||
for metadata_column in self.metadata_columns: | ||
|
@@ -308,6 +338,9 @@ async def aadd_embeddings( | |
|
||
upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"' | ||
|
||
if self.hybrid_search_config and self.hybrid_search_column_exists: | ||
upsert_stmt += f', "{self.hybrid_search_config.tsv_column}" = EXCLUDED."{self.hybrid_search_config.tsv_column}"' | ||
|
||
if self.metadata_json_column: | ||
upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"' | ||
|
||
|
@@ -408,6 +441,7 @@ async def afrom_texts( # type: ignore[override] | |
fetch_k: int = 20, | ||
lambda_mult: float = 0.5, | ||
index_query_options: Optional[QueryOptions] = None, | ||
hybrid_search_config: Optional[HybridSearchConfig] = None, | ||
**kwargs: Any, | ||
) -> AsyncPGVectorStore: | ||
"""Create an AsyncPGVectorStore instance from texts. | ||
|
@@ -430,6 +464,7 @@ async def afrom_texts( # type: ignore[override] | |
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. | ||
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. | ||
index_query_options (QueryOptions): Index query option. | ||
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. | ||
|
||
Raises: | ||
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`. | ||
|
@@ -453,6 +488,7 @@ async def afrom_texts( # type: ignore[override] | |
fetch_k=fetch_k, | ||
lambda_mult=lambda_mult, | ||
index_query_options=index_query_options, | ||
hybrid_search_config=hybrid_search_config, | ||
) | ||
await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) | ||
return vs | ||
|
@@ -478,6 +514,7 @@ async def afrom_documents( # type: ignore[override] | |
fetch_k: int = 20, | ||
lambda_mult: float = 0.5, | ||
index_query_options: Optional[QueryOptions] = None, | ||
hybrid_search_config: Optional[HybridSearchConfig] = None, | ||
**kwargs: Any, | ||
) -> AsyncPGVectorStore: | ||
"""Create an AsyncPGVectorStore instance from documents. | ||
|
@@ -500,6 +537,7 @@ async def afrom_documents( # type: ignore[override] | |
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. | ||
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. | ||
index_query_options (QueryOptions): Index query option. | ||
hybrid_search_column_exists (bool): Defines whether the existing table has the hybrid search column. | ||
|
||
Raises: | ||
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`. | ||
|
@@ -524,6 +562,7 @@ async def afrom_documents( # type: ignore[override] | |
fetch_k=fetch_k, | ||
lambda_mult=lambda_mult, | ||
index_query_options=index_query_options, | ||
hybrid_search_config=hybrid_search_config, | ||
) | ||
texts = [doc.page_content for doc in documents] | ||
metadatas = [doc.metadata for doc in documents] | ||
|
@@ -538,16 +577,30 @@ async def __query_collection( | |
filter: Optional[dict] = None, | ||
**kwargs: Any, | ||
) -> Sequence[RowMapping]: | ||
"""Perform similarity search query on database.""" | ||
k = k if k else self.k | ||
""" | ||
Perform similarity search (or hybrid search) query on database. | ||
Queries might be slow if the hybrid search column does not exist. | ||
For best hybrid search performance, consider creating a TSV column | ||
and adding GIN index. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Since this is a private function, this docstring will not be visible to users. Consider moving it elsewhere There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also please clarify that only the hybrid search will be slow (not the normal similarity searches) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may want to do this for the user when creating the new table. At a minimum, we should add documentation for using hybrid search and call this out. |
||
""" | ||
if not k: | ||
k = ( | ||
max( | ||
self.k, | ||
self.hybrid_search_config.primary_top_k, | ||
self.hybrid_search_config.secondary_top_k, | ||
) | ||
if self.hybrid_search_config | ||
else self.k | ||
) | ||
operator = self.distance_strategy.operator | ||
search_function = self.distance_strategy.search_function | ||
|
||
columns = self.metadata_columns + [ | ||
columns = [ | ||
self.id_column, | ||
self.content_column, | ||
self.embedding_column, | ||
] | ||
] + self.metadata_columns | ||
if self.metadata_json_column: | ||
columns.append(self.metadata_json_column) | ||
|
||
|
@@ -557,16 +610,18 @@ async def __query_collection( | |
filter_dict = None | ||
if filter and isinstance(filter, dict): | ||
safe_filter, filter_dict = self._create_filter_clause(filter) | ||
param_filter = f"WHERE {safe_filter}" if safe_filter else "" | ||
where_filters = f"WHERE {safe_filter}" if safe_filter else "" | ||
and_filters = f"AND ({safe_filter})" if safe_filter else "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add a comment on what this is or move this closer to where the code is used. |
||
|
||
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None) | ||
if not embedding and callable(inline_embed_func) and "query" in kwargs: | ||
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore | ||
embedding_data_string = f"{query_embedding}" | ||
else: | ||
query_embedding = f"{[float(dimension) for dimension in embedding]}" | ||
embedding_data_string = ":query_embedding" | ||
stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance | ||
FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; | ||
dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance | ||
FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; | ||
""" | ||
param_dict = {"query_embedding": query_embedding, "k": k} | ||
if filter_dict: | ||
|
@@ -577,15 +632,50 @@ async def __query_collection( | |
for query_option in self.index_query_options.to_parameter(): | ||
query_options_stmt = f"SET LOCAL {query_option};" | ||
await conn.execute(text(query_options_stmt)) | ||
result = await conn.execute(text(stmt), param_dict) | ||
result = await conn.execute(text(dense_query_stmt), param_dict) | ||
result_map = result.mappings() | ||
results = result_map.fetchall() | ||
dense_results = result_map.fetchall() | ||
else: | ||
async with self.engine.connect() as conn: | ||
result = await conn.execute(text(stmt), param_dict) | ||
result = await conn.execute(text(dense_query_stmt), param_dict) | ||
result_map = result.mappings() | ||
results = result_map.fetchall() | ||
return results | ||
dense_results = result_map.fetchall() | ||
|
||
hybrid_search_config = kwargs.get( | ||
"hybrid_search_config", self.hybrid_search_config | ||
) | ||
fts_query = ( | ||
hybrid_search_config.fts_query | ||
if hybrid_search_config and hybrid_search_config.fts_query | ||
else kwargs.get("fts_query", "") | ||
) | ||
if hybrid_search_config and fts_query: | ||
hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k | ||
# do the sparse query | ||
lang = ( | ||
f"'{hybrid_search_config.tsv_lang}'," | ||
if hybrid_search_config.tsv_lang | ||
else "" | ||
) | ||
query_tsv = f"plainto_tsquery({lang} :fts_query)" | ||
param_dict["fts_query"] = fts_query | ||
if self.hybrid_search_column_exists: | ||
content_tsv = f'"{hybrid_search_config.tsv_column}"' | ||
else: | ||
content_tsv = f'to_tsvector({lang} "{self.content_column}")' | ||
sparse_query_stmt = f'SELECT {column_names}, ts_rank_cd({content_tsv}, {query_tsv}) as distance FROM "{self.schema_name}"."{self.table_name}" WHERE {content_tsv} @@ {query_tsv} {and_filters} ORDER BY distance desc LIMIT {hybrid_search_config.secondary_top_k};' | ||
async with self.engine.connect() as conn: | ||
result = await conn.execute(text(sparse_query_stmt), param_dict) | ||
result_map = result.mappings() | ||
sparse_results = result_map.fetchall() | ||
|
||
combined_results = hybrid_search_config.fusion_function( | ||
dense_results, | ||
sparse_results, | ||
**hybrid_search_config.fusion_function_parameters, | ||
) | ||
return combined_results | ||
return dense_results | ||
|
||
async def asimilarity_search( | ||
self, | ||
|
@@ -603,6 +693,14 @@ async def asimilarity_search( | |
) | ||
kwargs["query"] = query | ||
|
||
# add fts_query to hybrid_search_config | ||
hybrid_search_config = kwargs.get( | ||
"hybrid_search_config", self.hybrid_search_config | ||
) | ||
if hybrid_search_config and not hybrid_search_config.fts_query: | ||
hybrid_search_config.fts_query = query | ||
kwargs["hybrid_search_config"] = hybrid_search_config | ||
|
||
return await self.asimilarity_search_by_vector( | ||
embedding=embedding, k=k, filter=filter, **kwargs | ||
) | ||
|
@@ -634,6 +732,14 @@ async def asimilarity_search_with_score( | |
) | ||
kwargs["query"] = query | ||
|
||
# add fts_query to hybrid_search_config | ||
hybrid_search_config = kwargs.get( | ||
"hybrid_search_config", self.hybrid_search_config | ||
) | ||
if hybrid_search_config and not hybrid_search_config.fts_query: | ||
hybrid_search_config.fts_query = query | ||
kwargs["hybrid_search_config"] = hybrid_search_config | ||
|
||
docs = await self.asimilarity_search_with_score_by_vector( | ||
embedding=embedding, k=k, filter=filter, **kwargs | ||
) | ||
|
@@ -806,15 +912,38 @@ async def aapply_vector_index( | |
index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX | ||
name = index.name | ||
stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};' | ||
|
||
if self.hybrid_search_config: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like this in the vector index method |
||
if self.hybrid_search_column_exists: | ||
tsv_column_name = ( | ||
self.hybrid_search_config.tsv_column | ||
if self.hybrid_search_config.tsv_column | ||
else self.content_column + "_tsv" | ||
) | ||
tsv_column_name = f'"{tsv_column_name}"' | ||
else: | ||
lang = ( | ||
f"'{self.hybrid_search_config.tsv_lang}'," | ||
if self.hybrid_search_config.tsv_lang | ||
else "" | ||
) | ||
tsv_column_name = f"to_tsvector({lang} {self.content_column})" | ||
tsv_index_name = self.table_name + self.hybrid_search_config.index_name | ||
tsv_index_query = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {tsv_index_name} ON "{self.schema_name}"."{self.table_name}" USING {self.hybrid_search_config.index_type}({tsv_column_name});' | ||
else: | ||
tsv_index_query = "" | ||
|
||
if concurrently: | ||
async with self.engine.connect() as conn: | ||
autocommit_conn = await conn.execution_options( | ||
isolation_level="AUTOCOMMIT" | ||
) | ||
await autocommit_conn.execute(text(stmt)) | ||
await conn.execute(text(tsv_index_query)) | ||
else: | ||
async with self.engine.connect() as conn: | ||
await conn.execute(text(stmt)) | ||
await conn.execute(text(tsv_index_query)) | ||
await conn.commit() | ||
|
||
async def areindex(self, index_name: Optional[str] = None) -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we only update the name if the column exists, then we don't have to manage the separate boolean?