3
3
4
4
import copy
5
5
import json
6
- import re
7
6
import uuid
8
7
from typing import Any , Callable , Iterable , Optional , Sequence
9
8
@@ -175,7 +174,8 @@ async def create(
175
174
stmt = "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = :table_name AND table_schema = :schema_name"
176
175
async with engine ._pool .connect () as conn :
177
176
result = await conn .execute (
178
- text (stmt ), {"table_name" : table_name , "schema_name" : schema_name }
177
+ text (stmt ),
178
+ {"table_name" : table_name , "schema_name" : schema_name },
179
179
)
180
180
result_map = result .mappings ()
181
181
results = result_map .fetchall ()
@@ -535,7 +535,7 @@ async def __query_collection(
535
535
embedding : list [float ],
536
536
* ,
537
537
k : Optional [int ] = None ,
538
- filter : Optional [dict ] | Optional [ str ] = None ,
538
+ filter : Optional [dict ] = None ,
539
539
** kwargs : Any ,
540
540
) -> Sequence [RowMapping ]:
541
541
"""Perform similarity search query on database."""
@@ -553,16 +553,22 @@ async def __query_collection(
553
553
554
554
column_names = ", " .join (f'"{ col } "' for col in columns )
555
555
556
+ safe_filter = None
557
+ filter_dict = None
556
558
if filter and isinstance (filter , dict ):
557
- filter = self ._create_filter_clause (filter )
558
- filter = f"WHERE { filter } " if filter else ""
559
+ safe_filter , filter_dict = self ._create_filter_clause (filter )
560
+ param_filter = f"WHERE { safe_filter } " if safe_filter else ""
559
561
inline_embed_func = getattr (self .embedding_service , "embed_query_inline" , None )
560
562
if not embedding and callable (inline_embed_func ) and "query" in kwargs :
561
563
query_embedding = self .embedding_service .embed_query_inline (kwargs ["query" ]) # type: ignore
562
564
else :
563
565
query_embedding = f"{ [float (dimension ) for dimension in embedding ]} "
564
- stmt = f'SELECT { column_names } , { search_function } ("{ self .embedding_column } ", :query_embedding) as distance FROM "{ self .schema_name } "."{ self .table_name } " { filter } ORDER BY "{ self .embedding_column } " { operator } :query_embedding LIMIT :k;'
566
+ stmt = f"""SELECT { column_names } , { search_function } ("{ self .embedding_column } ", :query_embedding) as distance
567
+ FROM "{ self .schema_name } "."{ self .table_name } " { param_filter } ORDER BY "{ self .embedding_column } " { operator } :query_embedding LIMIT :k;
568
+ """
565
569
param_dict = {"query_embedding" : query_embedding , "k" : k }
570
+ if filter_dict :
571
+ param_dict .update (filter_dict )
566
572
if self .index_query_options :
567
573
async with self .engine .connect () as conn :
568
574
# Set each query option individually
@@ -583,7 +589,7 @@ async def asimilarity_search(
583
589
self ,
584
590
query : str ,
585
591
k : Optional [int ] = None ,
586
- filter : Optional [dict ] | Optional [ str ] = None ,
592
+ filter : Optional [dict ] = None ,
587
593
** kwargs : Any ,
588
594
) -> list [Document ]:
589
595
"""Return docs selected by similarity search on query."""
@@ -614,7 +620,7 @@ async def asimilarity_search_with_score(
614
620
self ,
615
621
query : str ,
616
622
k : Optional [int ] = None ,
617
- filter : Optional [dict ] | Optional [ str ] = None ,
623
+ filter : Optional [dict ] = None ,
618
624
** kwargs : Any ,
619
625
) -> list [tuple [Document , float ]]:
620
626
"""Return docs and distance scores selected by similarity search on query."""
@@ -635,7 +641,7 @@ async def asimilarity_search_by_vector(
635
641
self ,
636
642
embedding : list [float ],
637
643
k : Optional [int ] = None ,
638
- filter : Optional [dict ] | Optional [ str ] = None ,
644
+ filter : Optional [dict ] = None ,
639
645
** kwargs : Any ,
640
646
) -> list [Document ]:
641
647
"""Return docs selected by vector similarity search."""
@@ -649,7 +655,7 @@ async def asimilarity_search_with_score_by_vector(
649
655
self ,
650
656
embedding : list [float ],
651
657
k : Optional [int ] = None ,
652
- filter : Optional [dict ] | Optional [ str ] = None ,
658
+ filter : Optional [dict ] = None ,
653
659
** kwargs : Any ,
654
660
) -> list [tuple [Document , float ]]:
655
661
"""Return docs and distance scores selected by vector similarity search."""
@@ -685,7 +691,7 @@ async def amax_marginal_relevance_search(
685
691
k : Optional [int ] = None ,
686
692
fetch_k : Optional [int ] = None ,
687
693
lambda_mult : Optional [float ] = None ,
688
- filter : Optional [dict ] | Optional [ str ] = None ,
694
+ filter : Optional [dict ] = None ,
689
695
** kwargs : Any ,
690
696
) -> list [Document ]:
691
697
"""Return docs selected using the maximal marginal relevance."""
@@ -706,7 +712,7 @@ async def amax_marginal_relevance_search_by_vector(
706
712
k : Optional [int ] = None ,
707
713
fetch_k : Optional [int ] = None ,
708
714
lambda_mult : Optional [float ] = None ,
709
- filter : Optional [dict ] | Optional [ str ] = None ,
715
+ filter : Optional [dict ] = None ,
710
716
** kwargs : Any ,
711
717
) -> list [Document ]:
712
718
"""Return docs selected using the maximal marginal relevance."""
@@ -729,7 +735,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
729
735
k : Optional [int ] = None ,
730
736
fetch_k : Optional [int ] = None ,
731
737
lambda_mult : Optional [float ] = None ,
732
- filter : Optional [dict ] | Optional [ str ] = None ,
738
+ filter : Optional [dict ] = None ,
733
739
** kwargs : Any ,
734
740
) -> list [tuple [Document , float ]]:
735
741
"""Return docs and distance scores selected using the maximal marginal relevance."""
@@ -834,7 +840,7 @@ async def is_valid_index(
834
840
) -> bool :
835
841
"""Check if index exists in the table."""
836
842
index_name = index_name or self .table_name + DEFAULT_INDEX_NAME_SUFFIX
837
- query = f """
843
+ query = """
838
844
SELECT tablename, indexname
839
845
FROM pg_indexes
840
846
WHERE tablename = :table_name AND schemaname = :schema_name AND indexname = :index_name;
@@ -898,7 +904,7 @@ def _handle_field_filter(
898
904
* ,
899
905
field : str ,
900
906
value : Any ,
901
- ) -> str :
907
+ ) -> tuple [ str , dict ] :
902
908
"""Create a filter for a specific field.
903
909
904
910
Args:
@@ -951,15 +957,17 @@ def _handle_field_filter(
951
957
if operator in COMPARISONS_TO_NATIVE :
952
958
# Then we implement an equality filter
953
959
# native is trusted input
954
- if isinstance (filter_value , str ):
955
- filter_value = f"'{ filter_value } '"
956
960
native = COMPARISONS_TO_NATIVE [operator ]
957
- return f"({ field } { native } { filter_value } )"
961
+ id = str (uuid .uuid4 ()).split ("-" )[0 ]
962
+ return f"{ field } { native } :{ field } _{ id } " , {f"{ field } _{ id } " : filter_value }
958
963
elif operator == "$between" :
959
964
# Use AND with two comparisons
960
965
low , high = filter_value
961
966
962
- return f"({ field } BETWEEN { low } AND { high } )"
967
+ return f"({ field } BETWEEN :{ field } _low AND :{ field } _high)" , {
968
+ f"{ field } _low" : low ,
969
+ f"{ field } _high" : high ,
970
+ }
963
971
elif operator in {"$in" , "$nin" , "$like" , "$ilike" }:
964
972
# We'll do force coercion to text
965
973
if operator in {"$in" , "$nin" }:
@@ -975,15 +983,15 @@ def _handle_field_filter(
975
983
)
976
984
977
985
if operator in {"$in" }:
978
- values = str (tuple (val for val in filter_value ))
979
- return f"({ field } IN { values } )"
986
+ return f"{ field } = ANY(:{ field } _in)" , {f"{ field } _in" : filter_value }
980
987
elif operator in {"$nin" }:
981
- values = str (tuple (val for val in filter_value ))
982
- return f"({ field } NOT IN { values } )"
988
+ return f"{ field } <> ALL (:{ field } _nin)" , {f"{ field } _nin" : filter_value }
983
989
elif operator in {"$like" }:
984
- return f"({ field } LIKE ' { filter_value } ')"
990
+ return f"({ field } LIKE : { field } _like)" , { f" { field } _like" : filter_value }
985
991
elif operator in {"$ilike" }:
986
- return f"({ field } ILIKE '{ filter_value } ')"
992
+ return f"({ field } ILIKE :{ field } _ilike)" , {
993
+ f"{ field } _ilike" : filter_value
994
+ }
987
995
else :
988
996
raise NotImplementedError ()
989
997
elif operator == "$exists" :
@@ -994,13 +1002,13 @@ def _handle_field_filter(
994
1002
)
995
1003
else :
996
1004
if filter_value :
997
- return f"({ field } IS NOT NULL)"
1005
+ return f"({ field } IS NOT NULL)" , {}
998
1006
else :
999
- return f"({ field } IS NULL)"
1007
+ return f"({ field } IS NULL)" , {}
1000
1008
else :
1001
1009
raise NotImplementedError ()
1002
1010
1003
- def _create_filter_clause (self , filters : Any ) -> str :
1011
+ def _create_filter_clause (self , filters : Any ) -> tuple [ str , dict ] :
1004
1012
"""Create LangChain filter representation to matching SQL where clauses
1005
1013
1006
1014
Args:
@@ -1037,7 +1045,11 @@ def _create_filter_clause(self, filters: Any) -> str:
1037
1045
op = key [1 :].upper () # Extract the operator
1038
1046
filter_clause = [self ._create_filter_clause (el ) for el in value ]
1039
1047
if len (filter_clause ) > 1 :
1040
- return f"({ f' { op } ' .join (filter_clause )} )"
1048
+ all_clauses = [clause [0 ] for clause in filter_clause ]
1049
+ params = {}
1050
+ for clause in filter_clause :
1051
+ params .update (clause [1 ])
1052
+ return f"({ f' { op } ' .join (all_clauses )} )" , params
1041
1053
elif len (filter_clause ) == 1 :
1042
1054
return filter_clause [0 ]
1043
1055
else :
@@ -1050,11 +1062,15 @@ def _create_filter_clause(self, filters: Any) -> str:
1050
1062
not_conditions = [
1051
1063
self ._create_filter_clause (item ) for item in value
1052
1064
]
1053
- not_stmts = [f"NOT { condition } " for condition in not_conditions ]
1054
- return f"({ ' AND ' .join (not_stmts )} )"
1065
+ all_clauses = [clause [0 ] for clause in not_conditions ]
1066
+ params = {}
1067
+ for clause in not_conditions :
1068
+ params .update (clause [1 ])
1069
+ not_stmts = [f"NOT { condition } " for condition in all_clauses ]
1070
+ return f"({ ' AND ' .join (not_stmts )} )" , params
1055
1071
elif isinstance (value , dict ):
1056
- not_ = self ._create_filter_clause (value )
1057
- return f"(NOT { not_ } )"
1072
+ not_ , params = self ._create_filter_clause (value )
1073
+ return f"(NOT { not_ } )" , params
1058
1074
else :
1059
1075
raise ValueError (
1060
1076
f"Invalid filter condition. Expected a dictionary "
@@ -1077,7 +1093,11 @@ def _create_filter_clause(self, filters: Any) -> str:
1077
1093
self ._handle_field_filter (field = k , value = v ) for k , v in filters .items ()
1078
1094
]
1079
1095
if len (and_ ) > 1 :
1080
- return f"({ ' AND ' .join (and_ )} )"
1096
+ all_clauses = [clause [0 ] for clause in and_ ]
1097
+ params = {}
1098
+ for clause in and_ :
1099
+ params .update (clause [1 ])
1100
+ return f"({ ' AND ' .join (all_clauses )} )" , params
1081
1101
elif len (and_ ) == 1 :
1082
1102
return and_ [0 ]
1083
1103
else :
@@ -1086,7 +1106,7 @@ def _create_filter_clause(self, filters: Any) -> str:
1086
1106
"but got an empty dictionary"
1087
1107
)
1088
1108
else :
1089
- return ""
1109
+ return "" , {}
1090
1110
1091
1111
def get_by_ids (self , ids : Sequence [str ]) -> list [Document ]:
1092
1112
raise NotImplementedError (
@@ -1168,7 +1188,7 @@ def similarity_search(
1168
1188
self ,
1169
1189
query : str ,
1170
1190
k : Optional [int ] = None ,
1171
- filter : Optional [dict ] | Optional [ str ] = None ,
1191
+ filter : Optional [dict ] = None ,
1172
1192
** kwargs : Any ,
1173
1193
) -> list [Document ]:
1174
1194
raise NotImplementedError (
@@ -1179,7 +1199,7 @@ def similarity_search_with_score(
1179
1199
self ,
1180
1200
query : str ,
1181
1201
k : Optional [int ] = None ,
1182
- filter : Optional [dict ] | Optional [ str ] = None ,
1202
+ filter : Optional [dict ] = None ,
1183
1203
** kwargs : Any ,
1184
1204
) -> list [tuple [Document , float ]]:
1185
1205
raise NotImplementedError (
@@ -1190,7 +1210,7 @@ def similarity_search_by_vector(
1190
1210
self ,
1191
1211
embedding : list [float ],
1192
1212
k : Optional [int ] = None ,
1193
- filter : Optional [dict ] | Optional [ str ] = None ,
1213
+ filter : Optional [dict ] = None ,
1194
1214
** kwargs : Any ,
1195
1215
) -> list [Document ]:
1196
1216
raise NotImplementedError (
@@ -1201,7 +1221,7 @@ def similarity_search_with_score_by_vector(
1201
1221
self ,
1202
1222
embedding : list [float ],
1203
1223
k : Optional [int ] = None ,
1204
- filter : Optional [dict ] | Optional [ str ] = None ,
1224
+ filter : Optional [dict ] = None ,
1205
1225
** kwargs : Any ,
1206
1226
) -> list [tuple [Document , float ]]:
1207
1227
raise NotImplementedError (
@@ -1214,7 +1234,7 @@ def max_marginal_relevance_search(
1214
1234
k : Optional [int ] = None ,
1215
1235
fetch_k : Optional [int ] = None ,
1216
1236
lambda_mult : Optional [float ] = None ,
1217
- filter : Optional [dict ] | Optional [ str ] = None ,
1237
+ filter : Optional [dict ] = None ,
1218
1238
** kwargs : Any ,
1219
1239
) -> list [Document ]:
1220
1240
raise NotImplementedError (
@@ -1227,7 +1247,7 @@ def max_marginal_relevance_search_by_vector(
1227
1247
k : Optional [int ] = None ,
1228
1248
fetch_k : Optional [int ] = None ,
1229
1249
lambda_mult : Optional [float ] = None ,
1230
- filter : Optional [dict ] | Optional [ str ] = None ,
1250
+ filter : Optional [dict ] = None ,
1231
1251
** kwargs : Any ,
1232
1252
) -> list [Document ]:
1233
1253
raise NotImplementedError (
@@ -1240,7 +1260,7 @@ def max_marginal_relevance_search_with_score_by_vector(
1240
1260
k : Optional [int ] = None ,
1241
1261
fetch_k : Optional [int ] = None ,
1242
1262
lambda_mult : Optional [float ] = None ,
1243
- filter : Optional [dict ] | Optional [ str ] = None ,
1263
+ filter : Optional [dict ] = None ,
1244
1264
** kwargs : Any ,
1245
1265
) -> list [tuple [Document , float ]]:
1246
1266
raise NotImplementedError (
0 commit comments