Skip to content

Commit 631a65c

Browse files
committed
feat: add typing for SF function
1 parent f3de1b4 commit 631a65c

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

elasticsearch_dsl/function.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,40 @@
1616
# under the License.
1717

1818
import collections.abc
19-
from typing import Dict, Optional, ClassVar
19+
from copy import deepcopy
20+
from typing import Dict, Optional, ClassVar, Union, MutableMapping, Any
2021

2122
from .utils import DslBase, _JSONSafeTypes
2223

2324

24-
# Incomplete annotation to not break query.py tests
25-
def SF(name_or_sf, **params) -> "ScoreFunction":
25+
def SF(name_or_sf: Union[str, "ScoreFunction", MutableMapping[str, Any]], **params: Any) -> "ScoreFunction":
2626
# {"script_score": {"script": "_score"}, "filter": {}}
27-
if isinstance(name_or_sf, collections.abc.Mapping):
27+
if isinstance(name_or_sf, collections.abc.MutableMapping):
2828
if params:
2929
raise ValueError("SF() cannot accept parameters when passing in a dict.")
30-
kwargs = {}
31-
sf = name_or_sf.copy()
30+
31+
kwargs: Dict[str, Any] = {}
32+
sf = deepcopy(name_or_sf)
3233
for k in ScoreFunction._param_defs:
3334
if k in name_or_sf:
3435
kwargs[k] = sf.pop(k)
3536

3637
# not sf, so just filter+weight, which used to be boost factor
38+
sf_params = params
3739
if not sf:
3840
name = "boost_factor"
3941
# {'FUNCTION': {...}}
4042
elif len(sf) == 1:
41-
name, params = sf.popitem()
43+
name, sf_params = sf.popitem()
4244
else:
4345
raise ValueError(f"SF() got an unexpected fields in the dictionary: {sf!r}")
4446

4547
# boost factor special case, see elasticsearch #6343
46-
if not isinstance(params, collections.abc.Mapping):
47-
params = {"value": params}
48+
if not isinstance(sf_params, collections.abc.Mapping):
49+
sf_params = {"value": sf_params}
4850

4951
# mix known params (from _param_defs) and from inside the function
50-
kwargs.update(params)
52+
kwargs.update(sf_params)
5153
return ScoreFunction.get_dsl_class(name)(**kwargs)
5254

5355
# ScriptScore(script="_score", filter=Q())

elasticsearch_dsl/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class DslBase(metaclass=DslMeta):
253253
all values in the `must` attribute into Query objects)
254254
"""
255255

256+
_type_name: ClassVar[str]
256257
_param_defs: ClassVar[Dict[str, Dict[str, Union[str, bool]]]] = {}
257258

258259
@classmethod

0 commit comments

Comments
 (0)