5
5
import copy
6
6
from datetime import datetime , date
7
7
from dateutil .relativedelta import relativedelta
8
- from functools import reduce
8
+ from functools import reduce , partial
9
9
from mongoengine import (
10
10
EmbeddedDocumentField ,
11
11
EmbeddedDocument ,
@@ -446,15 +446,24 @@ def _get_reference_model(cls, key):
446
446
return None , None , None , None
447
447
448
448
@classmethod
449
- def _change_reference_condition (cls , key , value , operator ):
449
+ def _change_reference_condition (cls , key , value , operator , reference_filter = None ):
450
450
ref_model , ref_key , ref_query_key , foreign_key = cls ._get_reference_model (key )
451
451
if ref_model :
452
452
if value is None :
453
453
return ref_key , value , operator
454
454
else :
455
- ref_vos , total_count = ref_model .query (
456
- filter = [{"k" : ref_query_key , "v" : value , "o" : operator }]
457
- )
455
+ if operator == "not" :
456
+ _filter = [{"k" : ref_query_key , "v" : value , "o" : "eq" }]
457
+ elif operator == "not_in" :
458
+ _filter = [{"k" : ref_query_key , "v" : value , "o" : "in" }]
459
+ else :
460
+ _filter = [{"k" : ref_query_key , "v" : value , "o" : operator }]
461
+ if reference_filter :
462
+ for key , value in reference_filter .items ():
463
+ if value :
464
+ _filter .append ({"k" : key , "v" : value , "o" : "eq" })
465
+
466
+ ref_vos , total_count = ref_model .query (filter = _filter )
458
467
459
468
if foreign_key :
460
469
ref_values = []
@@ -464,13 +473,17 @@ def _change_reference_condition(cls, key, value, operator):
464
473
ref_values .append (ref_value )
465
474
else :
466
475
ref_values = list (ref_vos )
467
- return ref_key , ref_values , "in"
476
+
477
+ if operator in ["not" , "not_in" ]:
478
+ return ref_key , ref_values , "not_in"
479
+ else :
480
+ return ref_key , ref_values , "in"
468
481
469
482
else :
470
483
return key , value , operator
471
484
472
485
@classmethod
473
- def _make_condition (cls , condition ):
486
+ def _make_condition (cls , condition , reference_filter = None ):
474
487
key = condition .get ("key" , condition .get ("k" ))
475
488
value = condition .get ("value" , condition .get ("v" ))
476
489
operator = condition .get ("operator" , condition .get ("o" ))
@@ -479,7 +492,7 @@ def _make_condition(cls, condition):
479
492
if operator not in FILTER_OPERATORS :
480
493
raise ERROR_DB_QUERY (
481
494
reason = f"Filter operator is not supported. (operator = "
482
- f"{ FILTER_OPERATORS .keys ()} )"
495
+ f"{ FILTER_OPERATORS .keys ()} )"
483
496
)
484
497
485
498
resolver , mongo_operator , is_multiple = FILTER_OPERATORS .get (operator )
@@ -493,7 +506,7 @@ def _make_condition(cls, condition):
493
506
if operator not in ["regex" , "regex_in" ]:
494
507
if cls ._check_reference_field (key ):
495
508
key , value , operator = cls ._change_reference_condition (
496
- key , value , operator
509
+ key , value , operator , reference_filter
497
510
)
498
511
499
512
resolver , mongo_operator , is_multiple = FILTER_OPERATORS [operator ]
@@ -507,15 +520,27 @@ def _make_condition(cls, condition):
507
520
)
508
521
509
522
@classmethod
510
- def _make_filter (cls , filter , filter_or ):
523
+ def _make_filter (cls , filter , filter_or , reference_filter ):
511
524
_filter = None
512
525
_filter_or = None
513
526
514
527
if len (filter ) > 0 :
515
- _filter = reduce (lambda x , y : x & y , map (cls ._make_condition , filter ))
528
+ _filter = reduce (
529
+ lambda x , y : x & y ,
530
+ map (
531
+ partial (cls ._make_condition , reference_filter = reference_filter ),
532
+ filter ,
533
+ ),
534
+ )
516
535
517
536
if len (filter_or ) > 0 :
518
- _filter_or = reduce (lambda x , y : x | y , map (cls ._make_condition , filter_or ))
537
+ _filter_or = reduce (
538
+ lambda x , y : x | y ,
539
+ map (
540
+ partial (cls ._make_condition , reference_filter = reference_filter ),
541
+ filter_or ,
542
+ ),
543
+ )
519
544
520
545
if _filter and _filter_or :
521
546
_filter = _filter & _filter_or
@@ -566,14 +591,14 @@ def _make_unwind_project_stage(only: list):
566
591
567
592
@classmethod
568
593
def _stat_with_unwind (
569
- cls ,
570
- unwind : list ,
571
- only : list = None ,
572
- filter : list = None ,
573
- filter_or : list = None ,
574
- sort : list = None ,
575
- page : dict = None ,
576
- target : str = None ,
594
+ cls ,
595
+ unwind : list ,
596
+ only : list = None ,
597
+ filter : list = None ,
598
+ filter_or : list = None ,
599
+ sort : list = None ,
600
+ page : dict = None ,
601
+ target : str = None ,
577
602
):
578
603
if only is None :
579
604
raise ERROR_DB_QUERY (reason = "unwind option requires only option." )
@@ -641,19 +666,20 @@ def _stat_with_unwind(
641
666
642
667
@classmethod
643
668
def query (
644
- cls ,
645
- * args ,
646
- only = None ,
647
- exclude = None ,
648
- filter = None ,
649
- filter_or = None ,
650
- sort = None ,
651
- page = None ,
652
- minimal = False ,
653
- count_only = False ,
654
- unwind = None ,
655
- target = None ,
656
- ** kwargs ,
669
+ cls ,
670
+ * args ,
671
+ only = None ,
672
+ exclude = None ,
673
+ filter = None ,
674
+ filter_or = None ,
675
+ sort = None ,
676
+ page = None ,
677
+ minimal = False ,
678
+ count_only = False ,
679
+ unwind = None ,
680
+ reference_filter = None ,
681
+ target = None ,
682
+ ** kwargs ,
657
683
):
658
684
filter = filter or []
659
685
filter_or = filter_or or []
@@ -669,7 +695,7 @@ def query(
669
695
_order_by = []
670
696
minimal_fields = cls ._meta .get ("minimal_fields" )
671
697
672
- _filter = cls ._make_filter (filter , filter_or )
698
+ _filter = cls ._make_filter (filter , filter_or , reference_filter )
673
699
674
700
for sort_option in sort :
675
701
if sort_option .get ("desc" , False ):
@@ -715,7 +741,7 @@ def query(
715
741
if start < 1 :
716
742
start = 1
717
743
718
- vos = vos [start - 1 : start + page ["limit" ] - 1 ]
744
+ vos = vos [start - 1 : start + page ["limit" ] - 1 ]
719
745
720
746
return vos , total_count
721
747
@@ -786,7 +812,7 @@ def _make_sub_conditions(cls, sub_conditions, _before_group_keys):
786
812
if operator not in _SUPPORTED_OPERATOR :
787
813
raise ERROR_DB_QUERY (
788
814
reason = f"'aggregate.group.fields.conditions.operator' condition's { operator } operator is not "
789
- f"supported. (supported_operator = { _SUPPORTED_OPERATOR } )"
815
+ f"supported. (supported_operator = { _SUPPORTED_OPERATOR } )"
790
816
)
791
817
792
818
if key in _before_group_keys :
@@ -808,7 +834,7 @@ def _get_group_fields(cls, condition, _before_group_keys):
808
834
if operator not in STAT_GROUP_OPERATORS :
809
835
raise ERROR_DB_QUERY (
810
836
reason = f"'aggregate.group.fields' condition's { operator } operator is not supported. "
811
- f"(supported_operator = { list (STAT_GROUP_OPERATORS .keys ())} )"
837
+ f"(supported_operator = { list (STAT_GROUP_OPERATORS .keys ())} )"
812
838
)
813
839
814
840
if name is None :
@@ -927,7 +953,7 @@ def _get_project_fields(cls, condition):
927
953
if operator and operator not in STAT_PROJECT_OPERATORS :
928
954
raise ERROR_DB_QUERY (
929
955
reason = f"'aggregate.project.fields' condition's { operator } operator is not supported. "
930
- f"(supported_operator = { list (STAT_PROJECT_OPERATORS .keys ())} )"
956
+ f"(supported_operator = { list (STAT_PROJECT_OPERATORS .keys ())} )"
931
957
)
932
958
933
959
if name is None :
@@ -1085,9 +1111,9 @@ def _make_aggregate_rules(cls, aggregate):
1085
1111
else :
1086
1112
raise ERROR_REQUIRED_PARAMETER (
1087
1113
key = "aggregate.unwind or aggregate.group or "
1088
- "aggregate.count or aggregate.sort or "
1089
- "aggregate.project or aggregate.limit or "
1090
- "aggregate.skip"
1114
+ "aggregate.count or aggregate.sort or "
1115
+ "aggregate.project or aggregate.limit or "
1116
+ "aggregate.skip"
1091
1117
)
1092
1118
1093
1119
return _aggregate_rules
@@ -1141,23 +1167,24 @@ def _stat_distinct(cls, vos, distinct, page):
1141
1167
start = 1
1142
1168
1143
1169
result ["total_count" ] = len (values )
1144
- values = values [start - 1 : start + page ["limit" ] - 1 ]
1170
+ values = values [start - 1 : start + page ["limit" ] - 1 ]
1145
1171
1146
1172
result ["results" ] = cls ._make_distinct_values (values )
1147
1173
return result
1148
1174
1149
1175
@classmethod
1150
1176
def stat (
1151
- cls ,
1152
- * args ,
1153
- aggregate = None ,
1154
- distinct = None ,
1155
- filter = None ,
1156
- filter_or = None ,
1157
- page = None ,
1158
- target = "SECONDARY_PREFERRED" ,
1159
- allow_disk_use = False ,
1160
- ** kwargs ,
1177
+ cls ,
1178
+ * args ,
1179
+ aggregate = None ,
1180
+ distinct = None ,
1181
+ filter = None ,
1182
+ filter_or = None ,
1183
+ page = None ,
1184
+ reference_filter = None ,
1185
+ target = "SECONDARY_PREFERRED" ,
1186
+ allow_disk_use = False ,
1187
+ ** kwargs ,
1161
1188
):
1162
1189
filter = filter or []
1163
1190
filter_or = filter_or or []
@@ -1166,7 +1193,7 @@ def stat(
1166
1193
if not (aggregate or distinct ):
1167
1194
raise ERROR_REQUIRED_PARAMETER (key = "aggregate" )
1168
1195
1169
- _filter = cls ._make_filter (filter , filter_or )
1196
+ _filter = cls ._make_filter (filter , filter_or , reference_filter )
1170
1197
1171
1198
try :
1172
1199
vos = cls ._get_target_objects (target ).filter (_filter )
@@ -1453,24 +1480,25 @@ def _convert_date_value(cls, date_value, date_field_format):
1453
1480
1454
1481
@classmethod
1455
1482
def analyze (
1456
- cls ,
1457
- * args ,
1458
- granularity = None ,
1459
- fields = None ,
1460
- select = None ,
1461
- group_by = None ,
1462
- field_group = None ,
1463
- filter = None ,
1464
- filter_or = None ,
1465
- page = None ,
1466
- sort = None ,
1467
- start = None ,
1468
- end = None ,
1469
- date_field = "date" ,
1470
- date_field_format = "%Y-%m-%d" ,
1471
- target = "SECONDARY_PREFERRED" ,
1472
- allow_disk_use = False ,
1473
- ** kwargs ,
1483
+ cls ,
1484
+ * args ,
1485
+ granularity = None ,
1486
+ fields = None ,
1487
+ select = None ,
1488
+ group_by = None ,
1489
+ field_group = None ,
1490
+ filter = None ,
1491
+ filter_or = None ,
1492
+ page = None ,
1493
+ sort = None ,
1494
+ start = None ,
1495
+ end = None ,
1496
+ date_field = "date" ,
1497
+ date_field_format = "%Y-%m-%d" ,
1498
+ reference_filter = None ,
1499
+ target = "SECONDARY_PREFERRED" ,
1500
+ allow_disk_use = False ,
1501
+ ** kwargs ,
1474
1502
):
1475
1503
if fields is None :
1476
1504
raise ERROR_REQUIRED_PARAMETER (key = "fields" )
@@ -1504,6 +1532,7 @@ def analyze(
1504
1532
"aggregate" : [{"group" : {"keys" : group_keys , "fields" : group_fields }}],
1505
1533
"target" : target ,
1506
1534
"allow_disk_use" : allow_disk_use ,
1535
+ "reference_filter" : reference_filter ,
1507
1536
}
1508
1537
1509
1538
if select :
0 commit comments