Skip to content

Commit f3de1b4

Browse files
committed
feat: add _JSONSafeTypes annotation to to_dict methods
1 parent 1a8e62c commit f3de1b4

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

elasticsearch_dsl/function.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import collections.abc
1919
from typing import Dict, Optional, ClassVar
2020

21-
from .utils import DslBase
21+
from .utils import DslBase, _JSONSafeTypes
2222

2323

2424
# Incomplete annotation to not break query.py tests
@@ -72,12 +72,14 @@ class ScoreFunction(DslBase):
7272
}
7373
name: ClassVar[Optional[str]] = None
7474

75-
def to_dict(self):
75+
def to_dict(self) -> Dict[str, _JSONSafeTypes]:
7676
d = super().to_dict()
7777
# filter and query dicts should be at the same level as us
7878
for k in self._param_defs:
79-
if k in d[self.name]:
80-
d[k] = d[self.name].pop(k)
79+
if self.name is not None:
80+
val = d[self.name]
81+
if isinstance(val, dict) and k in val:
82+
d[k] = val.pop(k)
8183
return d
8284

8385

@@ -88,12 +90,15 @@ class ScriptScore(ScoreFunction):
8890
class BoostFactor(ScoreFunction):
8991
name = "boost_factor"
9092

91-
def to_dict(self) -> Dict[str, int]:
93+
def to_dict(self) -> Dict[str, _JSONSafeTypes]:
9294
d = super().to_dict()
93-
if "value" in d[self.name]:
94-
d[self.name] = d[self.name].pop("value")
95-
else:
96-
del d[self.name]
95+
if self.name is not None:
96+
val = d[self.name]
97+
if isinstance(val, dict):
98+
if "value" in val:
99+
d[self.name] = val.pop("value")
100+
else:
101+
del d[self.name]
97102
return d
98103

99104

elasticsearch_dsl/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919
import collections.abc
2020
from copy import copy
21-
from typing import Any, Dict, Optional, Type, ClassVar, Union
21+
from typing import Any, Dict, Optional, Type, ClassVar, Union, List
2222

2323
from typing_extensions import Self
2424

2525
from .exceptions import UnknownDslObject, ValidationException
2626

27+
_JSONSafeTypes = Union[int, bool, str, float, List["_JSONSafeTypes"], Dict[str, "_JSONSafeTypes"]]
28+
2729
SKIP_VALUES = ("", None)
2830
EXPAND__TO_DOT = True
2931

@@ -356,8 +358,7 @@ def __getattr__(self, name):
356358
return AttrDict(value)
357359
return value
358360

359-
# TODO: This type annotation can probably be made tighter
360-
def to_dict(self) -> Dict[str, Dict[str, Any]]:
361+
def to_dict(self) -> Dict[str, _JSONSafeTypes]:
361362
"""
362363
Serialize the DSL object to plain dict
363364
"""

0 commit comments

Comments
 (0)