Skip to content

Commit 596ac5f

Browse files
authored
fix: remove string filters and parameterize filters (#185)
1 parent 50b567b commit 596ac5f

File tree

4 files changed

+76
-120
lines changed

4 files changed

+76
-120
lines changed

langchain_postgres/v2/async_vectorstore.py

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import copy
55
import json
6-
import re
76
import uuid
87
from typing import Any, Callable, Iterable, Optional, Sequence
98

@@ -175,7 +174,8 @@ async def create(
175174
stmt = "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = :table_name AND table_schema = :schema_name"
176175
async with engine._pool.connect() as conn:
177176
result = await conn.execute(
178-
text(stmt), {"table_name": table_name, "schema_name": schema_name}
177+
text(stmt),
178+
{"table_name": table_name, "schema_name": schema_name},
179179
)
180180
result_map = result.mappings()
181181
results = result_map.fetchall()
@@ -535,7 +535,7 @@ async def __query_collection(
535535
embedding: list[float],
536536
*,
537537
k: Optional[int] = None,
538-
filter: Optional[dict] | Optional[str] = None,
538+
filter: Optional[dict] = None,
539539
**kwargs: Any,
540540
) -> Sequence[RowMapping]:
541541
"""Perform similarity search query on database."""
@@ -553,16 +553,22 @@ async def __query_collection(
553553

554554
column_names = ", ".join(f'"{col}"' for col in columns)
555555

556+
safe_filter = None
557+
filter_dict = None
556558
if filter and isinstance(filter, dict):
557-
filter = self._create_filter_clause(filter)
558-
filter = f"WHERE {filter}" if filter else ""
559+
safe_filter, filter_dict = self._create_filter_clause(filter)
560+
param_filter = f"WHERE {safe_filter}" if safe_filter else ""
559561
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
560562
if not embedding and callable(inline_embed_func) and "query" in kwargs:
561563
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore
562564
else:
563565
query_embedding = f"{[float(dimension) for dimension in embedding]}"
564-
stmt = f'SELECT {column_names}, {search_function}("{self.embedding_column}", :query_embedding) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k;'
566+
stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", :query_embedding) as distance
567+
FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} :query_embedding LIMIT :k;
568+
"""
565569
param_dict = {"query_embedding": query_embedding, "k": k}
570+
if filter_dict:
571+
param_dict.update(filter_dict)
566572
if self.index_query_options:
567573
async with self.engine.connect() as conn:
568574
# Set each query option individually
@@ -583,7 +589,7 @@ async def asimilarity_search(
583589
self,
584590
query: str,
585591
k: Optional[int] = None,
586-
filter: Optional[dict] | Optional[str] = None,
592+
filter: Optional[dict] = None,
587593
**kwargs: Any,
588594
) -> list[Document]:
589595
"""Return docs selected by similarity search on query."""
@@ -614,7 +620,7 @@ async def asimilarity_search_with_score(
614620
self,
615621
query: str,
616622
k: Optional[int] = None,
617-
filter: Optional[dict] | Optional[str] = None,
623+
filter: Optional[dict] = None,
618624
**kwargs: Any,
619625
) -> list[tuple[Document, float]]:
620626
"""Return docs and distance scores selected by similarity search on query."""
@@ -635,7 +641,7 @@ async def asimilarity_search_by_vector(
635641
self,
636642
embedding: list[float],
637643
k: Optional[int] = None,
638-
filter: Optional[dict] | Optional[str] = None,
644+
filter: Optional[dict] = None,
639645
**kwargs: Any,
640646
) -> list[Document]:
641647
"""Return docs selected by vector similarity search."""
@@ -649,7 +655,7 @@ async def asimilarity_search_with_score_by_vector(
649655
self,
650656
embedding: list[float],
651657
k: Optional[int] = None,
652-
filter: Optional[dict] | Optional[str] = None,
658+
filter: Optional[dict] = None,
653659
**kwargs: Any,
654660
) -> list[tuple[Document, float]]:
655661
"""Return docs and distance scores selected by vector similarity search."""
@@ -685,7 +691,7 @@ async def amax_marginal_relevance_search(
685691
k: Optional[int] = None,
686692
fetch_k: Optional[int] = None,
687693
lambda_mult: Optional[float] = None,
688-
filter: Optional[dict] | Optional[str] = None,
694+
filter: Optional[dict] = None,
689695
**kwargs: Any,
690696
) -> list[Document]:
691697
"""Return docs selected using the maximal marginal relevance."""
@@ -706,7 +712,7 @@ async def amax_marginal_relevance_search_by_vector(
706712
k: Optional[int] = None,
707713
fetch_k: Optional[int] = None,
708714
lambda_mult: Optional[float] = None,
709-
filter: Optional[dict] | Optional[str] = None,
715+
filter: Optional[dict] = None,
710716
**kwargs: Any,
711717
) -> list[Document]:
712718
"""Return docs selected using the maximal marginal relevance."""
@@ -729,7 +735,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
729735
k: Optional[int] = None,
730736
fetch_k: Optional[int] = None,
731737
lambda_mult: Optional[float] = None,
732-
filter: Optional[dict] | Optional[str] = None,
738+
filter: Optional[dict] = None,
733739
**kwargs: Any,
734740
) -> list[tuple[Document, float]]:
735741
"""Return docs and distance scores selected using the maximal marginal relevance."""
@@ -834,7 +840,7 @@ async def is_valid_index(
834840
) -> bool:
835841
"""Check if index exists in the table."""
836842
index_name = index_name or self.table_name + DEFAULT_INDEX_NAME_SUFFIX
837-
query = f"""
843+
query = """
838844
SELECT tablename, indexname
839845
FROM pg_indexes
840846
WHERE tablename = :table_name AND schemaname = :schema_name AND indexname = :index_name;
@@ -898,7 +904,7 @@ def _handle_field_filter(
898904
*,
899905
field: str,
900906
value: Any,
901-
) -> str:
907+
) -> tuple[str, dict]:
902908
"""Create a filter for a specific field.
903909
904910
Args:
@@ -951,15 +957,17 @@ def _handle_field_filter(
951957
if operator in COMPARISONS_TO_NATIVE:
952958
# Then we implement an equality filter
953959
# native is trusted input
954-
if isinstance(filter_value, str):
955-
filter_value = f"'{filter_value}'"
956960
native = COMPARISONS_TO_NATIVE[operator]
957-
return f"({field} {native} {filter_value})"
961+
id = str(uuid.uuid4()).split("-")[0]
962+
return f"{field} {native} :{field}_{id}", {f"{field}_{id}": filter_value}
958963
elif operator == "$between":
959964
# Use AND with two comparisons
960965
low, high = filter_value
961966

962-
return f"({field} BETWEEN {low} AND {high})"
967+
return f"({field} BETWEEN :{field}_low AND :{field}_high)", {
968+
f"{field}_low": low,
969+
f"{field}_high": high,
970+
}
963971
elif operator in {"$in", "$nin", "$like", "$ilike"}:
964972
# We'll do force coercion to text
965973
if operator in {"$in", "$nin"}:
@@ -975,15 +983,15 @@ def _handle_field_filter(
975983
)
976984

977985
if operator in {"$in"}:
978-
values = str(tuple(val for val in filter_value))
979-
return f"({field} IN {values})"
986+
return f"{field} = ANY(:{field}_in)", {f"{field}_in": filter_value}
980987
elif operator in {"$nin"}:
981-
values = str(tuple(val for val in filter_value))
982-
return f"({field} NOT IN {values})"
988+
return f"{field} <> ALL (:{field}_nin)", {f"{field}_nin": filter_value}
983989
elif operator in {"$like"}:
984-
return f"({field} LIKE '{filter_value}')"
990+
return f"({field} LIKE :{field}_like)", {f"{field}_like": filter_value}
985991
elif operator in {"$ilike"}:
986-
return f"({field} ILIKE '{filter_value}')"
992+
return f"({field} ILIKE :{field}_ilike)", {
993+
f"{field}_ilike": filter_value
994+
}
987995
else:
988996
raise NotImplementedError()
989997
elif operator == "$exists":
@@ -994,13 +1002,13 @@ def _handle_field_filter(
9941002
)
9951003
else:
9961004
if filter_value:
997-
return f"({field} IS NOT NULL)"
1005+
return f"({field} IS NOT NULL)", {}
9981006
else:
999-
return f"({field} IS NULL)"
1007+
return f"({field} IS NULL)", {}
10001008
else:
10011009
raise NotImplementedError()
10021010

1003-
def _create_filter_clause(self, filters: Any) -> str:
1011+
def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
10041012
"""Create LangChain filter representation to matching SQL where clauses
10051013
10061014
Args:
@@ -1037,7 +1045,11 @@ def _create_filter_clause(self, filters: Any) -> str:
10371045
op = key[1:].upper() # Extract the operator
10381046
filter_clause = [self._create_filter_clause(el) for el in value]
10391047
if len(filter_clause) > 1:
1040-
return f"({f' {op} '.join(filter_clause)})"
1048+
all_clauses = [clause[0] for clause in filter_clause]
1049+
params = {}
1050+
for clause in filter_clause:
1051+
params.update(clause[1])
1052+
return f"({f' {op} '.join(all_clauses)})", params
10411053
elif len(filter_clause) == 1:
10421054
return filter_clause[0]
10431055
else:
@@ -1050,11 +1062,15 @@ def _create_filter_clause(self, filters: Any) -> str:
10501062
not_conditions = [
10511063
self._create_filter_clause(item) for item in value
10521064
]
1053-
not_stmts = [f"NOT {condition}" for condition in not_conditions]
1054-
return f"({' AND '.join(not_stmts)})"
1065+
all_clauses = [clause[0] for clause in not_conditions]
1066+
params = {}
1067+
for clause in not_conditions:
1068+
params.update(clause[1])
1069+
not_stmts = [f"NOT {condition}" for condition in all_clauses]
1070+
return f"({' AND '.join(not_stmts)})", params
10551071
elif isinstance(value, dict):
1056-
not_ = self._create_filter_clause(value)
1057-
return f"(NOT {not_})"
1072+
not_, params = self._create_filter_clause(value)
1073+
return f"(NOT {not_})", params
10581074
else:
10591075
raise ValueError(
10601076
f"Invalid filter condition. Expected a dictionary "
@@ -1077,7 +1093,11 @@ def _create_filter_clause(self, filters: Any) -> str:
10771093
self._handle_field_filter(field=k, value=v) for k, v in filters.items()
10781094
]
10791095
if len(and_) > 1:
1080-
return f"({' AND '.join(and_)})"
1096+
all_clauses = [clause[0] for clause in and_]
1097+
params = {}
1098+
for clause in and_:
1099+
params.update(clause[1])
1100+
return f"({' AND '.join(all_clauses)})", params
10811101
elif len(and_) == 1:
10821102
return and_[0]
10831103
else:
@@ -1086,7 +1106,7 @@ def _create_filter_clause(self, filters: Any) -> str:
10861106
"but got an empty dictionary"
10871107
)
10881108
else:
1089-
return ""
1109+
return "", {}
10901110

10911111
def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
10921112
raise NotImplementedError(
@@ -1168,7 +1188,7 @@ def similarity_search(
11681188
self,
11691189
query: str,
11701190
k: Optional[int] = None,
1171-
filter: Optional[dict] | Optional[str] = None,
1191+
filter: Optional[dict] = None,
11721192
**kwargs: Any,
11731193
) -> list[Document]:
11741194
raise NotImplementedError(
@@ -1179,7 +1199,7 @@ def similarity_search_with_score(
11791199
self,
11801200
query: str,
11811201
k: Optional[int] = None,
1182-
filter: Optional[dict] | Optional[str] = None,
1202+
filter: Optional[dict] = None,
11831203
**kwargs: Any,
11841204
) -> list[tuple[Document, float]]:
11851205
raise NotImplementedError(
@@ -1190,7 +1210,7 @@ def similarity_search_by_vector(
11901210
self,
11911211
embedding: list[float],
11921212
k: Optional[int] = None,
1193-
filter: Optional[dict] | Optional[str] = None,
1213+
filter: Optional[dict] = None,
11941214
**kwargs: Any,
11951215
) -> list[Document]:
11961216
raise NotImplementedError(
@@ -1201,7 +1221,7 @@ def similarity_search_with_score_by_vector(
12011221
self,
12021222
embedding: list[float],
12031223
k: Optional[int] = None,
1204-
filter: Optional[dict] | Optional[str] = None,
1224+
filter: Optional[dict] = None,
12051225
**kwargs: Any,
12061226
) -> list[tuple[Document, float]]:
12071227
raise NotImplementedError(
@@ -1214,7 +1234,7 @@ def max_marginal_relevance_search(
12141234
k: Optional[int] = None,
12151235
fetch_k: Optional[int] = None,
12161236
lambda_mult: Optional[float] = None,
1217-
filter: Optional[dict] | Optional[str] = None,
1237+
filter: Optional[dict] = None,
12181238
**kwargs: Any,
12191239
) -> list[Document]:
12201240
raise NotImplementedError(
@@ -1227,7 +1247,7 @@ def max_marginal_relevance_search_by_vector(
12271247
k: Optional[int] = None,
12281248
fetch_k: Optional[int] = None,
12291249
lambda_mult: Optional[float] = None,
1230-
filter: Optional[dict] | Optional[str] = None,
1250+
filter: Optional[dict] = None,
12311251
**kwargs: Any,
12321252
) -> list[Document]:
12331253
raise NotImplementedError(
@@ -1240,7 +1260,7 @@ def max_marginal_relevance_search_with_score_by_vector(
12401260
k: Optional[int] = None,
12411261
fetch_k: Optional[int] = None,
12421262
lambda_mult: Optional[float] = None,
1243-
filter: Optional[dict] | Optional[str] = None,
1263+
filter: Optional[dict] = None,
12441264
**kwargs: Any,
12451265
) -> list[tuple[Document, float]]:
12461266
raise NotImplementedError(

0 commit comments

Comments
 (0)