Skip to content

Commit d9b7608

Browse files
Support Message.from_dict() as a class and an instance method (#476)
* Make Message.from_dict() a class method Signed-off-by: Marek Pikuła <marek.pikula@embevity.com> * Sync 1/2 of review comments * Sync other half * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Update src/betterproto/__init__.py * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Fix CI again * Fix failing formatting --------- Signed-off-by: Marek Pikuła <marek.pikula@embevity.com> Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
1 parent 02aa4e8 commit d9b7608

File tree

2 files changed

+164
-76
lines changed

2 files changed

+164
-76
lines changed

src/betterproto/__init__.py

Lines changed: 108 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
import enum as builtin_enum
35
import json
@@ -22,8 +24,8 @@
2224
from typing import (
2325
TYPE_CHECKING,
2426
Any,
25-
BinaryIO,
2627
Callable,
28+
ClassVar,
2729
Dict,
2830
Generator,
2931
Iterable,
@@ -37,6 +39,7 @@
3739
)
3840

3941
from dateutil.parser import isoparse
42+
from typing_extensions import Self
4043

4144
from ._types import T
4245
from ._version import __version__
@@ -47,6 +50,10 @@
4750
)
4851
from .enum import Enum as Enum
4952
from .grpc.grpclib_client import ServiceStub as ServiceStub
53+
from .utils import (
54+
classproperty,
55+
hybridmethod,
56+
)
5057

5158

5259
if TYPE_CHECKING:
@@ -729,6 +736,7 @@ class Message(ABC):
729736
_serialized_on_wire: bool
730737
_unknown_fields: bytes
731738
_group_current: Dict[str, str]
739+
_betterproto_meta: ClassVar[ProtoClassMetadata]
732740

733741
def __post_init__(self) -> None:
734742
# Keep track of whether every field was default
@@ -882,18 +890,18 @@ def __copy__(self: T, _: Any = {}) -> T:
882890
kwargs[name] = value
883891
return self.__class__(**kwargs) # type: ignore
884892

885-
@property
886-
def _betterproto(self) -> ProtoClassMetadata:
893+
@classproperty
894+
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
887895
"""
888896
Lazy initialize metadata for each protobuf class.
889897
It may be initialized multiple times in a multi-threaded environment,
890898
but that won't affect the correctness.
891899
"""
892-
meta = getattr(self.__class__, "_betterproto_meta", None)
893-
if not meta:
894-
meta = ProtoClassMetadata(self.__class__)
895-
self.__class__._betterproto_meta = meta # type: ignore
896-
return meta
900+
try:
901+
return cls._betterproto_meta
902+
except AttributeError:
903+
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
904+
return meta
897905

898906
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
899907
"""
@@ -1512,10 +1520,74 @@ def to_dict(
15121520
output[cased_name] = value
15131521
return output
15141522

1515-
def from_dict(self: T, value: Mapping[str, Any]) -> T:
1523+
@classmethod
1524+
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
1525+
init_kwargs: Dict[str, Any] = {}
1526+
for key, value in mapping.items():
1527+
field_name = safe_snake_case(key)
1528+
try:
1529+
meta = cls._betterproto.meta_by_field_name[field_name]
1530+
except KeyError:
1531+
continue
1532+
if value is None:
1533+
continue
1534+
1535+
if meta.proto_type == TYPE_MESSAGE:
1536+
sub_cls = cls._betterproto.cls_by_field[field_name]
1537+
if sub_cls == datetime:
1538+
value = (
1539+
[isoparse(item) for item in value]
1540+
if isinstance(value, list)
1541+
else isoparse(value)
1542+
)
1543+
elif sub_cls == timedelta:
1544+
value = (
1545+
[timedelta(seconds=float(item[:-1])) for item in value]
1546+
if isinstance(value, list)
1547+
else timedelta(seconds=float(value[:-1]))
1548+
)
1549+
elif not meta.wraps:
1550+
value = (
1551+
[sub_cls.from_dict(item) for item in value]
1552+
if isinstance(value, list)
1553+
else sub_cls.from_dict(value)
1554+
)
1555+
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1556+
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
1557+
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
1558+
else:
1559+
if meta.proto_type in INT_64_TYPES:
1560+
value = (
1561+
[int(n) for n in value]
1562+
if isinstance(value, list)
1563+
else int(value)
1564+
)
1565+
elif meta.proto_type == TYPE_BYTES:
1566+
value = (
1567+
[b64decode(n) for n in value]
1568+
if isinstance(value, list)
1569+
else b64decode(value)
1570+
)
1571+
elif meta.proto_type == TYPE_ENUM:
1572+
enum_cls = cls._betterproto.cls_by_field[field_name]
1573+
if isinstance(value, list):
1574+
value = [enum_cls.from_string(e) for e in value]
1575+
elif isinstance(value, str):
1576+
value = enum_cls.from_string(value)
1577+
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1578+
value = (
1579+
[_parse_float(n) for n in value]
1580+
if isinstance(value, list)
1581+
else _parse_float(value)
1582+
)
1583+
1584+
init_kwargs[field_name] = value
1585+
return init_kwargs
1586+
1587+
@hybridmethod
1588+
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
15161589
"""
1517-
Parse the key/value pairs into the current message instance. This returns the
1518-
instance itself and is therefore assignable and chainable.
1590+
Parse the key/value pairs into the a new message instance.
15191591
15201592
Parameters
15211593
-----------
@@ -1527,72 +1599,29 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T:
15271599
:class:`Message`
15281600
The initialized message.
15291601
"""
1602+
self = cls(**cls._from_dict_init(value))
15301603
self._serialized_on_wire = True
1531-
for key in value:
1532-
field_name = safe_snake_case(key)
1533-
meta = self._betterproto.meta_by_field_name.get(field_name)
1534-
if not meta:
1535-
continue
1604+
return self
15361605

1537-
if value[key] is not None:
1538-
if meta.proto_type == TYPE_MESSAGE:
1539-
v = self._get_field_default(field_name)
1540-
cls = self._betterproto.cls_by_field[field_name]
1541-
if isinstance(v, list):
1542-
if cls == datetime:
1543-
v = [isoparse(item) for item in value[key]]
1544-
elif cls == timedelta:
1545-
v = [
1546-
timedelta(seconds=float(item[:-1]))
1547-
for item in value[key]
1548-
]
1549-
else:
1550-
v = [cls().from_dict(item) for item in value[key]]
1551-
elif cls == datetime:
1552-
v = isoparse(value[key])
1553-
setattr(self, field_name, v)
1554-
elif cls == timedelta:
1555-
v = timedelta(seconds=float(value[key][:-1]))
1556-
setattr(self, field_name, v)
1557-
elif meta.wraps:
1558-
setattr(self, field_name, value[key])
1559-
elif v is None:
1560-
setattr(self, field_name, cls().from_dict(value[key]))
1561-
else:
1562-
# NOTE: `from_dict` mutates the underlying message, so no
1563-
# assignment here is necessary.
1564-
v.from_dict(value[key])
1565-
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
1566-
v = getattr(self, field_name)
1567-
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
1568-
for k in value[key]:
1569-
v[k] = cls().from_dict(value[key][k])
1570-
else:
1571-
v = value[key]
1572-
if meta.proto_type in INT_64_TYPES:
1573-
if isinstance(value[key], list):
1574-
v = [int(n) for n in value[key]]
1575-
else:
1576-
v = int(value[key])
1577-
elif meta.proto_type == TYPE_BYTES:
1578-
if isinstance(value[key], list):
1579-
v = [b64decode(n) for n in value[key]]
1580-
else:
1581-
v = b64decode(value[key])
1582-
elif meta.proto_type == TYPE_ENUM:
1583-
enum_cls = self._betterproto.cls_by_field[field_name]
1584-
if isinstance(v, list):
1585-
v = [enum_cls.from_string(e) for e in v]
1586-
elif isinstance(v, str):
1587-
v = enum_cls.from_string(v)
1588-
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
1589-
if isinstance(value[key], list):
1590-
v = [_parse_float(n) for n in value[key]]
1591-
else:
1592-
v = _parse_float(value[key])
1606+
@from_dict.instancemethod
1607+
def from_dict(self, value: Mapping[str, Any]) -> Self:
1608+
"""
1609+
Parse the key/value pairs into the current message instance. This returns the
1610+
instance itself and is therefore assignable and chainable.
15931611
1594-
if v is not None:
1595-
setattr(self, field_name, v)
1612+
Parameters
1613+
-----------
1614+
value: Dict[:class:`str`, Any]
1615+
The dictionary to parse from.
1616+
1617+
Returns
1618+
--------
1619+
:class:`Message`
1620+
The initialized message.
1621+
"""
1622+
self._serialized_on_wire = True
1623+
for field, value in self._from_dict_init(value).items():
1624+
setattr(self, field, value)
15961625
return self
15971626

15981627
def to_json(
@@ -1809,8 +1838,8 @@ def is_set(self, name: str) -> bool:
18091838

18101839
@classmethod
18111840
def _validate_field_groups(cls, values):
1812-
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
1813-
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
1841+
group_to_one_ofs = cls._betterproto.oneof_field_by_group
1842+
field_name_to_meta = cls._betterproto.meta_by_field_name
18141843

18151844
for group, field_set in group_to_one_ofs.items():
18161845
if len(field_set) == 1:
@@ -1837,6 +1866,9 @@ def _validate_field_groups(cls, values):
18371866
return values
18381867

18391868

1869+
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
1870+
1871+
18401872
def serialized_on_wire(message: Message) -> bool:
18411873
"""
18421874
If this message was or should be serialized on the wire. This can be used to detect

src/betterproto/utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from typing import (
4+
Any,
5+
Callable,
6+
Generic,
7+
Optional,
8+
Type,
9+
TypeVar,
10+
)
11+
12+
from typing_extensions import (
13+
Concatenate,
14+
ParamSpec,
15+
Self,
16+
)
17+
18+
19+
SelfT = TypeVar("SelfT")
20+
P = ParamSpec("P")
21+
HybridT = TypeVar("HybridT", covariant=True)
22+
23+
24+
class hybridmethod(Generic[SelfT, P, HybridT]):
25+
def __init__(
26+
self,
27+
func: Callable[
28+
Concatenate[type[SelfT], P], HybridT
29+
], # Must be the classmethod version
30+
):
31+
self.cls_func = func
32+
self.__doc__ = func.__doc__
33+
34+
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
35+
self.instance_func = func
36+
return self
37+
38+
def __get__(
39+
self, instance: Optional[SelfT], owner: Type[SelfT]
40+
) -> Callable[P, HybridT]:
41+
if instance is None or self.instance_func is None:
42+
# either bound to the class, or no instance method available
43+
return self.cls_func.__get__(owner, None)
44+
return self.instance_func.__get__(instance, owner)
45+
46+
47+
T_co = TypeVar("T_co")
48+
TT_co = TypeVar("TT_co", bound="type[Any]")
49+
50+
51+
class classproperty(Generic[TT_co, T_co]):
52+
def __init__(self, func: Callable[[TT_co], T_co]):
53+
self.__func__ = func
54+
55+
def __get__(self, instance: Any, type: TT_co) -> T_co:
56+
return self.__func__(type)

0 commit comments

Comments
 (0)