Skip to content

refactor: refactor reference filter #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 100 additions & 71 deletions src/spaceone/core/model/mongo_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
from datetime import datetime, date
from dateutil.relativedelta import relativedelta
from functools import reduce
from functools import reduce, partial
from mongoengine import (
EmbeddedDocumentField,
EmbeddedDocument,
Expand Down Expand Up @@ -446,15 +446,24 @@ def _get_reference_model(cls, key):
return None, None, None, None

@classmethod
def _change_reference_condition(cls, key, value, operator):
def _change_reference_condition(cls, key, value, operator, reference_filter=None):
ref_model, ref_key, ref_query_key, foreign_key = cls._get_reference_model(key)
if ref_model:
if value is None:
return ref_key, value, operator
else:
ref_vos, total_count = ref_model.query(
filter=[{"k": ref_query_key, "v": value, "o": operator}]
)
if operator == "not":
_filter = [{"k": ref_query_key, "v": value, "o": "eq"}]
elif operator == "not_in":
_filter = [{"k": ref_query_key, "v": value, "o": "in"}]
else:
_filter = [{"k": ref_query_key, "v": value, "o": operator}]
if reference_filter:
for key, value in reference_filter.items():
if value:
_filter.append({"k": key, "v": value, "o": "eq"})

ref_vos, total_count = ref_model.query(filter=_filter)

if foreign_key:
ref_values = []
Expand All @@ -464,13 +473,17 @@ def _change_reference_condition(cls, key, value, operator):
ref_values.append(ref_value)
else:
ref_values = list(ref_vos)
return ref_key, ref_values, "in"

if operator in ["not", "not_in"]:
return ref_key, ref_values, "not_in"
else:
return ref_key, ref_values, "in"

else:
return key, value, operator

@classmethod
def _make_condition(cls, condition):
def _make_condition(cls, condition, reference_filter=None):
key = condition.get("key", condition.get("k"))
value = condition.get("value", condition.get("v"))
operator = condition.get("operator", condition.get("o"))
Expand All @@ -479,7 +492,7 @@ def _make_condition(cls, condition):
if operator not in FILTER_OPERATORS:
raise ERROR_DB_QUERY(
reason=f"Filter operator is not supported. (operator = "
f"{FILTER_OPERATORS.keys()})"
f"{FILTER_OPERATORS.keys()})"
)

resolver, mongo_operator, is_multiple = FILTER_OPERATORS.get(operator)
Expand All @@ -493,7 +506,7 @@ def _make_condition(cls, condition):
if operator not in ["regex", "regex_in"]:
if cls._check_reference_field(key):
key, value, operator = cls._change_reference_condition(
key, value, operator
key, value, operator, reference_filter
)

resolver, mongo_operator, is_multiple = FILTER_OPERATORS[operator]
Expand All @@ -507,15 +520,27 @@ def _make_condition(cls, condition):
)

@classmethod
def _make_filter(cls, filter, filter_or):
def _make_filter(cls, filter, filter_or, reference_filter):
_filter = None
_filter_or = None

if len(filter) > 0:
_filter = reduce(lambda x, y: x & y, map(cls._make_condition, filter))
_filter = reduce(
lambda x, y: x & y,
map(
partial(cls._make_condition, reference_filter=reference_filter),
filter,
),
)

if len(filter_or) > 0:
_filter_or = reduce(lambda x, y: x | y, map(cls._make_condition, filter_or))
_filter_or = reduce(
lambda x, y: x | y,
map(
partial(cls._make_condition, reference_filter=reference_filter),
filter_or,
),
)

if _filter and _filter_or:
_filter = _filter & _filter_or
Expand Down Expand Up @@ -566,14 +591,14 @@ def _make_unwind_project_stage(only: list):

@classmethod
def _stat_with_unwind(
cls,
unwind: list,
only: list = None,
filter: list = None,
filter_or: list = None,
sort: list = None,
page: dict = None,
target: str = None,
cls,
unwind: list,
only: list = None,
filter: list = None,
filter_or: list = None,
sort: list = None,
page: dict = None,
target: str = None,
):
if only is None:
raise ERROR_DB_QUERY(reason="unwind option requires only option.")
Expand Down Expand Up @@ -641,19 +666,20 @@ def _stat_with_unwind(

@classmethod
def query(
cls,
*args,
only=None,
exclude=None,
filter=None,
filter_or=None,
sort=None,
page=None,
minimal=False,
count_only=False,
unwind=None,
target=None,
**kwargs,
cls,
*args,
only=None,
exclude=None,
filter=None,
filter_or=None,
sort=None,
page=None,
minimal=False,
count_only=False,
unwind=None,
reference_filter=None,
target=None,
**kwargs,
):
filter = filter or []
filter_or = filter_or or []
Expand All @@ -669,7 +695,7 @@ def query(
_order_by = []
minimal_fields = cls._meta.get("minimal_fields")

_filter = cls._make_filter(filter, filter_or)
_filter = cls._make_filter(filter, filter_or, reference_filter)

for sort_option in sort:
if sort_option.get("desc", False):
Expand Down Expand Up @@ -715,7 +741,7 @@ def query(
if start < 1:
start = 1

vos = vos[start - 1: start + page["limit"] - 1]
vos = vos[start - 1 : start + page["limit"] - 1]

return vos, total_count

Expand Down Expand Up @@ -786,7 +812,7 @@ def _make_sub_conditions(cls, sub_conditions, _before_group_keys):
if operator not in _SUPPORTED_OPERATOR:
raise ERROR_DB_QUERY(
reason=f"'aggregate.group.fields.conditions.operator' condition's {operator} operator is not "
f"supported. (supported_operator = {_SUPPORTED_OPERATOR})"
f"supported. (supported_operator = {_SUPPORTED_OPERATOR})"
)

if key in _before_group_keys:
Expand All @@ -808,7 +834,7 @@ def _get_group_fields(cls, condition, _before_group_keys):
if operator not in STAT_GROUP_OPERATORS:
raise ERROR_DB_QUERY(
reason=f"'aggregate.group.fields' condition's {operator} operator is not supported. "
f"(supported_operator = {list(STAT_GROUP_OPERATORS.keys())})"
f"(supported_operator = {list(STAT_GROUP_OPERATORS.keys())})"
)

if name is None:
Expand Down Expand Up @@ -927,7 +953,7 @@ def _get_project_fields(cls, condition):
if operator and operator not in STAT_PROJECT_OPERATORS:
raise ERROR_DB_QUERY(
reason=f"'aggregate.project.fields' condition's {operator} operator is not supported. "
f"(supported_operator = {list(STAT_PROJECT_OPERATORS.keys())})"
f"(supported_operator = {list(STAT_PROJECT_OPERATORS.keys())})"
)

if name is None:
Expand Down Expand Up @@ -1085,9 +1111,9 @@ def _make_aggregate_rules(cls, aggregate):
else:
raise ERROR_REQUIRED_PARAMETER(
key="aggregate.unwind or aggregate.group or "
"aggregate.count or aggregate.sort or "
"aggregate.project or aggregate.limit or "
"aggregate.skip"
"aggregate.count or aggregate.sort or "
"aggregate.project or aggregate.limit or "
"aggregate.skip"
)

return _aggregate_rules
Expand Down Expand Up @@ -1141,23 +1167,24 @@ def _stat_distinct(cls, vos, distinct, page):
start = 1

result["total_count"] = len(values)
values = values[start - 1: start + page["limit"] - 1]
values = values[start - 1 : start + page["limit"] - 1]

result["results"] = cls._make_distinct_values(values)
return result

@classmethod
def stat(
cls,
*args,
aggregate=None,
distinct=None,
filter=None,
filter_or=None,
page=None,
target="SECONDARY_PREFERRED",
allow_disk_use=False,
**kwargs,
cls,
*args,
aggregate=None,
distinct=None,
filter=None,
filter_or=None,
page=None,
reference_filter=None,
target="SECONDARY_PREFERRED",
allow_disk_use=False,
**kwargs,
):
filter = filter or []
filter_or = filter_or or []
Expand All @@ -1166,7 +1193,7 @@ def stat(
if not (aggregate or distinct):
raise ERROR_REQUIRED_PARAMETER(key="aggregate")

_filter = cls._make_filter(filter, filter_or)
_filter = cls._make_filter(filter, filter_or, reference_filter)

try:
vos = cls._get_target_objects(target).filter(_filter)
Expand Down Expand Up @@ -1453,24 +1480,25 @@ def _convert_date_value(cls, date_value, date_field_format):

@classmethod
def analyze(
cls,
*args,
granularity=None,
fields=None,
select=None,
group_by=None,
field_group=None,
filter=None,
filter_or=None,
page=None,
sort=None,
start=None,
end=None,
date_field="date",
date_field_format="%Y-%m-%d",
target="SECONDARY_PREFERRED",
allow_disk_use=False,
**kwargs,
cls,
*args,
granularity=None,
fields=None,
select=None,
group_by=None,
field_group=None,
filter=None,
filter_or=None,
page=None,
sort=None,
start=None,
end=None,
date_field="date",
date_field_format="%Y-%m-%d",
reference_filter=None,
target="SECONDARY_PREFERRED",
allow_disk_use=False,
**kwargs,
):
if fields is None:
raise ERROR_REQUIRED_PARAMETER(key="fields")
Expand Down Expand Up @@ -1504,6 +1532,7 @@ def analyze(
"aggregate": [{"group": {"keys": group_keys, "fields": group_fields}}],
"target": target,
"allow_disk_use": allow_disk_use,
"reference_filter": reference_filter,
}

if select:
Expand Down
Loading