|
16 | 16 | # under the License.
|
17 | 17 |
|
18 | 18 | import collections.abc
|
19 |
| -from typing import Dict, Optional, ClassVar |
| 19 | +from copy import deepcopy |
| 20 | +from typing import Dict, Optional, ClassVar, Union, MutableMapping, Any |
20 | 21 |
|
21 | 22 | from .utils import DslBase, _JSONSafeTypes
|
22 | 23 |
|
23 | 24 |
|
24 |
| -# Incomplete annotation to not break query.py tests |
25 |
| -def SF(name_or_sf, **params) -> "ScoreFunction": |
| 25 | +def SF(name_or_sf: Union[str, "ScoreFunction", MutableMapping[str, Any]], **params: Any) -> "ScoreFunction": |
26 | 26 | # {"script_score": {"script": "_score"}, "filter": {}}
|
27 |
| - if isinstance(name_or_sf, collections.abc.Mapping): |
| 27 | + if isinstance(name_or_sf, collections.abc.MutableMapping): |
28 | 28 | if params:
|
29 | 29 | raise ValueError("SF() cannot accept parameters when passing in a dict.")
|
30 |
| - kwargs = {} |
31 |
| - sf = name_or_sf.copy() |
| 30 | + |
| 31 | + kwargs: Dict[str, Any] = {} |
| 32 | + sf = deepcopy(name_or_sf) |
32 | 33 | for k in ScoreFunction._param_defs:
|
33 | 34 | if k in name_or_sf:
|
34 | 35 | kwargs[k] = sf.pop(k)
|
35 | 36 |
|
36 | 37 | # not sf, so just filter+weight, which used to be boost factor
|
| 38 | + sf_params = params |
37 | 39 | if not sf:
|
38 | 40 | name = "boost_factor"
|
39 | 41 | # {'FUNCTION': {...}}
|
40 | 42 | elif len(sf) == 1:
|
41 |
| - name, params = sf.popitem() |
| 43 | + name, sf_params = sf.popitem() |
42 | 44 | else:
|
43 | 45 | raise ValueError(f"SF() got an unexpected fields in the dictionary: {sf!r}")
|
44 | 46 |
|
45 | 47 | # boost factor special case, see elasticsearch #6343
|
46 |
| - if not isinstance(params, collections.abc.Mapping): |
47 |
| - params = {"value": params} |
| 48 | + if not isinstance(sf_params, collections.abc.Mapping): |
| 49 | + sf_params = {"value": sf_params} |
48 | 50 |
|
49 | 51 | # mix known params (from _param_defs) and from inside the function
|
50 |
| - kwargs.update(params) |
| 52 | + kwargs.update(sf_params) |
51 | 53 | return ScoreFunction.get_dsl_class(name)(**kwargs)
|
52 | 54 |
|
53 | 55 | # ScriptScore(script="_score", filter=Q())
|
|
0 commit comments