From 7c5bdbb789c88be6e686eae3b4b8effbcee83156 Mon Sep 17 00:00:00 2001 From: Jongmin Kim Date: Thu, 16 May 2024 08:53:08 +0900 Subject: [PATCH] refactor: refactor reference filter Signed-off-by: Jongmin Kim --- .../core/model/mongo_model/__init__.py | 171 ++++++++++-------- 1 file changed, 100 insertions(+), 71 deletions(-) diff --git a/src/spaceone/core/model/mongo_model/__init__.py b/src/spaceone/core/model/mongo_model/__init__.py index 52ec474..1da3a4f 100644 --- a/src/spaceone/core/model/mongo_model/__init__.py +++ b/src/spaceone/core/model/mongo_model/__init__.py @@ -5,7 +5,7 @@ import copy from datetime import datetime, date from dateutil.relativedelta import relativedelta -from functools import reduce +from functools import reduce, partial from mongoengine import ( EmbeddedDocumentField, EmbeddedDocument, @@ -446,15 +446,24 @@ def _get_reference_model(cls, key): return None, None, None, None @classmethod - def _change_reference_condition(cls, key, value, operator): + def _change_reference_condition(cls, key, value, operator, reference_filter=None): ref_model, ref_key, ref_query_key, foreign_key = cls._get_reference_model(key) if ref_model: if value is None: return ref_key, value, operator else: - ref_vos, total_count = ref_model.query( - filter=[{"k": ref_query_key, "v": value, "o": operator}] - ) + if operator == "not": + _filter = [{"k": ref_query_key, "v": value, "o": "eq"}] + elif operator == "not_in": + _filter = [{"k": ref_query_key, "v": value, "o": "in"}] + else: + _filter = [{"k": ref_query_key, "v": value, "o": operator}] + if reference_filter: + for key, value in reference_filter.items(): + if value: + _filter.append({"k": key, "v": value, "o": "eq"}) + + ref_vos, total_count = ref_model.query(filter=_filter) if foreign_key: ref_values = [] @@ -464,13 +473,17 @@ def _change_reference_condition(cls, key, value, operator): ref_values.append(ref_value) else: ref_values = list(ref_vos) - return ref_key, ref_values, "in" + + if operator in ["not", "not_in"]: + return ref_key, ref_values, "not_in" + else: + return ref_key, ref_values, "in" else: return key, value, operator @classmethod - def _make_condition(cls, condition): + def _make_condition(cls, condition, reference_filter=None): key = condition.get("key", condition.get("k")) value = condition.get("value", condition.get("v")) operator = condition.get("operator", condition.get("o")) @@ -479,7 +492,7 @@ def _make_condition(cls, condition): if operator not in FILTER_OPERATORS: raise ERROR_DB_QUERY( reason=f"Filter operator is not supported. (operator = " - f"{FILTER_OPERATORS.keys()})" + f"{FILTER_OPERATORS.keys()})" ) resolver, mongo_operator, is_multiple = FILTER_OPERATORS.get(operator) @@ -493,7 +506,7 @@ def _make_condition(cls, condition): if operator not in ["regex", "regex_in"]: if cls._check_reference_field(key): key, value, operator = cls._change_reference_condition( - key, value, operator + key, value, operator, reference_filter ) resolver, mongo_operator, is_multiple = FILTER_OPERATORS[operator] @@ -507,15 +520,27 @@ def _make_condition(cls, condition): ) @classmethod - def _make_filter(cls, filter, filter_or): + def _make_filter(cls, filter, filter_or, reference_filter): _filter = None _filter_or = None if len(filter) > 0: - _filter = reduce(lambda x, y: x & y, map(cls._make_condition, filter)) + _filter = reduce( + lambda x, y: x & y, + map( + partial(cls._make_condition, reference_filter=reference_filter), + filter, + ), + ) if len(filter_or) > 0: - _filter_or = reduce(lambda x, y: x | y, map(cls._make_condition, filter_or)) + _filter_or = reduce( + lambda x, y: x | y, + map( + partial(cls._make_condition, reference_filter=reference_filter), + filter_or, + ), + ) if _filter and _filter_or: _filter = _filter & _filter_or @@ -566,14 +591,14 @@ def _make_unwind_project_stage(only: list): @classmethod def _stat_with_unwind( - cls, - unwind: list, - only: list = None, - filter: list = None, - filter_or: list = None, - sort: list = None, - page: dict = None, - target: str = None, + cls, + unwind: list, + only: list = None, + filter: list = None, + filter_or: list = None, + sort: list = None, + page: dict = None, + target: str = None, ): if only is None: raise ERROR_DB_QUERY(reason="unwind option requires only option.") @@ -641,19 +666,20 @@ def _stat_with_unwind( @classmethod def query( - cls, - *args, - only=None, - exclude=None, - filter=None, - filter_or=None, - sort=None, - page=None, - minimal=False, - count_only=False, - unwind=None, - target=None, - **kwargs, + cls, + *args, + only=None, + exclude=None, + filter=None, + filter_or=None, + sort=None, + page=None, + minimal=False, + count_only=False, + unwind=None, + reference_filter=None, + target=None, + **kwargs, ): filter = filter or [] filter_or = filter_or or [] @@ -669,7 +695,7 @@ def query( _order_by = [] minimal_fields = cls._meta.get("minimal_fields") - _filter = cls._make_filter(filter, filter_or) + _filter = cls._make_filter(filter, filter_or, reference_filter) for sort_option in sort: if sort_option.get("desc", False): @@ -715,7 +741,7 @@ def query( if start < 1: start = 1 - vos = vos[start - 1: start + page["limit"] - 1] + vos = vos[start - 1 : start + page["limit"] - 1] return vos, total_count @@ -786,7 +812,7 @@ def _make_sub_conditions(cls, sub_conditions, _before_group_keys): if operator not in _SUPPORTED_OPERATOR: raise ERROR_DB_QUERY( reason=f"'aggregate.group.fields.conditions.operator' condition's {operator} operator is not " - f"supported. (supported_operator = {_SUPPORTED_OPERATOR})" + f"supported. (supported_operator = {_SUPPORTED_OPERATOR})" ) if key in _before_group_keys: @@ -808,7 +834,7 @@ def _get_group_fields(cls, condition, _before_group_keys): if operator not in STAT_GROUP_OPERATORS: raise ERROR_DB_QUERY( reason=f"'aggregate.group.fields' condition's {operator} operator is not supported. " - f"(supported_operator = {list(STAT_GROUP_OPERATORS.keys())})" + f"(supported_operator = {list(STAT_GROUP_OPERATORS.keys())})" ) if name is None: @@ -927,7 +953,7 @@ def _get_project_fields(cls, condition): if operator and operator not in STAT_PROJECT_OPERATORS: raise ERROR_DB_QUERY( reason=f"'aggregate.project.fields' condition's {operator} operator is not supported. " - f"(supported_operator = {list(STAT_PROJECT_OPERATORS.keys())})" + f"(supported_operator = {list(STAT_PROJECT_OPERATORS.keys())})" ) if name is None: @@ -1085,9 +1111,9 @@ def _make_aggregate_rules(cls, aggregate): else: raise ERROR_REQUIRED_PARAMETER( key="aggregate.unwind or aggregate.group or " - "aggregate.count or aggregate.sort or " - "aggregate.project or aggregate.limit or " - "aggregate.skip" + "aggregate.count or aggregate.sort or " + "aggregate.project or aggregate.limit or " + "aggregate.skip" ) return _aggregate_rules @@ -1141,23 +1167,24 @@ def _stat_distinct(cls, vos, distinct, page): start = 1 result["total_count"] = len(values) - values = values[start - 1: start + page["limit"] - 1] + values = values[start - 1 : start + page["limit"] - 1] result["results"] = cls._make_distinct_values(values) return result @classmethod def stat( - cls, - *args, - aggregate=None, - distinct=None, - filter=None, - filter_or=None, - page=None, - target="SECONDARY_PREFERRED", - allow_disk_use=False, - **kwargs, + cls, + *args, + aggregate=None, + distinct=None, + filter=None, + filter_or=None, + page=None, + reference_filter=None, + target="SECONDARY_PREFERRED", + allow_disk_use=False, + **kwargs, ): filter = filter or [] filter_or = filter_or or [] @@ -1166,7 +1193,7 @@ def stat( if not (aggregate or distinct): raise ERROR_REQUIRED_PARAMETER(key="aggregate") - _filter = cls._make_filter(filter, filter_or) + _filter = cls._make_filter(filter, filter_or, reference_filter) try: vos = cls._get_target_objects(target).filter(_filter) @@ -1453,24 +1480,25 @@ def _convert_date_value(cls, date_value, date_field_format): @classmethod def analyze( - cls, - *args, - granularity=None, - fields=None, - select=None, - group_by=None, - field_group=None, - filter=None, - filter_or=None, - page=None, - sort=None, - start=None, - end=None, - date_field="date", - date_field_format="%Y-%m-%d", - target="SECONDARY_PREFERRED", - allow_disk_use=False, - **kwargs, + cls, + *args, + granularity=None, + fields=None, + select=None, + group_by=None, + field_group=None, + filter=None, + filter_or=None, + page=None, + sort=None, + start=None, + end=None, + date_field="date", + date_field_format="%Y-%m-%d", + reference_filter=None, + target="SECONDARY_PREFERRED", + allow_disk_use=False, + **kwargs, ): if fields is None: raise ERROR_REQUIRED_PARAMETER(key="fields") @@ -1504,6 +1532,7 @@ def analyze( "aggregate": [{"group": {"keys": group_keys, "fields": group_fields}}], "target": target, "allow_disk_use": allow_disk_use, + "reference_filter": reference_filter, } if select: