Skip to content

Commit 034e2e7

Browse files
chris-chambersnat-nGobot1234
authored
Add support for recursive messages (#130)
Changes message initialization (`__post_init__`) so that default values are no longer eagerly created to prevent infinite recursion when initializing recursive messages. As a result, `PLACEHOLDER` will be present in the message for any uninitialized fields. So, an implementation of `__get_attribute__` is added that checks for `PLACEHOLDER` and lazily creates and stores default field values. And, because `PLACEHOLDER` values don't compare equal with zero values, a custom implementation of `__eq__` is provided, and the code generation template is updated so that messages generate with `@dataclass(eq=False)`. Also add new Message __repr__ implementation that skips PLACEHOLDER values and orders keys by number from the proto. Co-authored-by: Christopher Chambers <chris@peanutcode.com> Co-authored-by: nat <n@natn.me> Co-authored-by: James <50501825+Gobot1234@users.noreply.github.com>
1 parent ca16b6e commit 034e2e7

File tree

5 files changed

+125
-10
lines changed

5 files changed

+125
-10
lines changed

src/betterproto/__init__.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ class ProtoClassMetadata:
428428
"cls_by_field",
429429
"field_name_by_number",
430430
"meta_by_field_name",
431+
"sorted_field_names",
431432
)
432433

433434
def __init__(self, cls: Type["Message"]):
@@ -453,6 +454,9 @@ def __init__(self, cls: Type["Message"]):
453454
self.oneof_field_by_group = by_group
454455
self.field_name_by_number = by_field_number
455456
self.meta_by_field_name = by_field_name
457+
self.sorted_field_names = tuple(
458+
by_field_number[number] for number in sorted(by_field_number.keys())
459+
)
456460

457461
self.default_gen = self._get_default_gen(cls, fields)
458462
self.cls_by_field = self._get_cls_by_field(cls, fields)
@@ -513,23 +517,63 @@ def __post_init__(self) -> None:
513517
if meta.group:
514518
group_current.setdefault(meta.group)
515519

516-
if getattr(self, field_name) != PLACEHOLDER:
517-
# Skip anything not set to the sentinel value
520+
if self.__raw_get(field_name) != PLACEHOLDER:
521+
# Found a non-sentinel value
518522
all_sentinel = False
519523

520524
if meta.group:
521525
# This was set, so make it the selected value of the one-of.
522526
group_current[meta.group] = field_name
523527

524-
continue
525-
526-
setattr(self, field_name, self._get_field_default(field_name))
527-
528528
# Now that all the defaults are set, reset it!
529529
self.__dict__["_serialized_on_wire"] = not all_sentinel
530530
self.__dict__["_unknown_fields"] = b""
531531
self.__dict__["_group_current"] = group_current
532532

533+
def __raw_get(self, name: str) -> Any:
534+
return super().__getattribute__(name)
535+
536+
def __eq__(self, other) -> bool:
537+
if type(self) is not type(other):
538+
return False
539+
540+
for field_name in self._betterproto.meta_by_field_name:
541+
self_val = self.__raw_get(field_name)
542+
other_val = other.__raw_get(field_name)
543+
if self_val is PLACEHOLDER:
544+
if other_val is PLACEHOLDER:
545+
continue
546+
self_val = self._get_field_default(field_name)
547+
elif other_val is PLACEHOLDER:
548+
other_val = other._get_field_default(field_name)
549+
550+
if self_val != other_val:
551+
return False
552+
553+
return True
554+
555+
def __repr__(self) -> str:
556+
parts = [
557+
f"{field_name}={value!r}"
558+
for field_name in self._betterproto.sorted_field_names
559+
for value in (self.__raw_get(field_name),)
560+
if value is not PLACEHOLDER
561+
]
562+
return f"{self.__class__.__name__}({', '.join(parts)})"
563+
564+
def __getattribute__(self, name: str) -> Any:
565+
"""
566+
Lazily initialize default values to avoid infinite recursion for recursive
567+
message types
568+
"""
569+
value = super().__getattribute__(name)
570+
if value is not PLACEHOLDER:
571+
return value
572+
573+
value = self._get_field_default(name)
574+
super().__setattr__(name, value)
575+
return value
576+
533577
def __setattr__(self, attr: str, value: Any) -> None:
534578
if attr != "_serialized_on_wire":
535579
# Track when a field has been set.
@@ -542,9 +586,7 @@ def __setattr__(self, attr: str, value: Any) -> None:
542586
if field.name == attr:
543587
self._group_current[group] = field.name
544588
else:
545-
super().__setattr__(
546-
field.name, self._get_field_default(field.name)
547-
)
589+
super().__setattr__(field.name, PLACEHOLDER)
548590

549591
super().__setattr__(attr, value)
550592

src/betterproto/templates/template.py.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum):
3737
{% endfor %}
3838
{% endif %}
3939
{% for message in output_file.messages %}
40-
@dataclass
40+
@dataclass(eq=False, repr=False)
4141
class {{ message.py_name }}(betterproto.Message):
4242
{% if message.comment %}
4343
{{ message.comment }}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"name": "Zues",
3+
"child": {
4+
"name": "Hercules"
5+
},
6+
"intermediate": {
7+
"child": {
8+
"name": "Douglas Adams"
9+
},
10+
"number": 42
11+
}
12+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
message Test {
4+
string name = 1;
5+
Test child = 2;
6+
Intermediate intermediate = 3;
7+
}
8+
9+
10+
message Intermediate {
11+
int32 number = 1;
12+
Test child = 2;
13+
}

tests/test_features.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,51 @@ def _round_trip_serialization(foo: Foo) -> Foo:
317317
== betterproto.which_one_of(_round_trip_serialization(foo3), "group1")
318318
== ("", None)
319319
)
320+
321+
322+
def test_recursive_message():
323+
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
324+
325+
msg = RecursiveMessage()
326+
327+
assert msg.child == RecursiveMessage()
328+
329+
# Lazily-created zero-value children must not affect equality.
330+
assert msg == RecursiveMessage()
331+
332+
# Lazily-created zero-value children must not affect serialization.
333+
assert bytes(msg) == b""
334+
335+
336+
def test_recursive_message_defaults():
337+
from tests.output_betterproto.recursivemessage import (
338+
Test as RecursiveMessage,
339+
Intermediate,
340+
)
341+
342+
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
343+
344+
# set values are as expected
345+
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
346+
347+
# lazy initialized works modifies the message
348+
assert msg != RecursiveMessage(
349+
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
350+
)
351+
msg.child.child.name = "jude"
352+
assert msg == RecursiveMessage(
353+
name="bob",
354+
intermediate=Intermediate(42),
355+
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
356+
)
357+
358+
# lazily initialization recurses as needed
359+
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
360+
assert msg.intermediate.child.intermediate == Intermediate()
361+
362+
363+
def test_message_repr():
364+
from tests.output_betterproto.recursivemessage import Test
365+
366+
assert repr(Test(name="Loki")) == "Test(name='Loki')"
367+
assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())"

0 commit comments

Comments
 (0)