Skip to content

Commit 2af6f63

Browse files
committed
Finish off #130
- Add more test cases using the standard test pattern - Add sorted_field_names to message subclass metadata to support stable ordering of keys in repr. - Tweak code in new message methods
1 parent 9b3b451 commit 2af6f63

File tree

4 files changed

+75
-22
lines changed

4 files changed

+75
-22
lines changed

src/betterproto/__init__.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ class ProtoClassMetadata:
428428
"cls_by_field",
429429
"field_name_by_number",
430430
"meta_by_field_name",
431+
"sorted_field_names",
431432
)
432433

433434
def __init__(self, cls: Type["Message"]):
@@ -453,6 +454,9 @@ def __init__(self, cls: Type["Message"]):
453454
self.oneof_field_by_group = by_group
454455
self.field_name_by_number = by_field_number
455456
self.meta_by_field_name = by_field_name
457+
self.sorted_field_names = tuple(
458+
by_field_number[number] for number in sorted(by_field_number.keys())
459+
)
456460

457461
self.default_gen = self._get_default_gen(cls, fields)
458462
self.cls_by_field = self._get_cls_by_field(cls, fields)
@@ -529,17 +533,16 @@ def __post_init__(self) -> None:
529533
def __raw_get(self, name: str) -> Any:
530534
return super().__getattribute__(name)
531535

532-
def __eq__(self, other: T) -> bool:
536+
def __eq__(self, other) -> bool:
533537
if type(self) is not type(other):
534538
return False
535539

536540
for field_name in self._betterproto.meta_by_field_name:
537541
self_val = self.__raw_get(field_name)
538542
other_val = other.__raw_get(field_name)
539-
if self_val is PLACEHOLDER and other_val is PLACEHOLDER:
540-
continue
541-
542543
if self_val is PLACEHOLDER:
544+
if other_val is PLACEHOLDER:
545+
continue
543546
self_val = self._get_field_default(field_name)
544547
elif other_val is PLACEHOLDER:
545548
other_val = other._get_field_default(field_name)
@@ -550,16 +553,19 @@ def __eq__(self, other: T) -> bool:
550553
return True
551554

552555
def __repr__(self) -> str:
553-
parts = []
554-
for field_name in self._betterproto.meta_by_field_name:
555-
value = self.__raw_get(field_name)
556-
if value is PLACEHOLDER:
557-
continue
558-
parts.append(f"{field_name}={value!r}")
559-
556+
parts = [
557+
f"{field_name}={value!r}"
558+
for field_name in self._betterproto.sorted_field_names
559+
for value in (self.__raw_get(field_name),)
560+
if value is not PLACEHOLDER
561+
]
560562
return f"{self.__class__.__name__}({', '.join(parts)})"
561563

562564
def __getattribute__(self, name: str) -> Any:
565+
"""
566+
Lazily initialize default values to avoid infinite recursion for recursive
567+
message types
568+
"""
563569
value = super().__getattribute__(name)
564570
if value is not PLACEHOLDER:
565571
return value
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"name": "Zues",
3+
"child": {
4+
"name": "Hercules"
5+
},
6+
"intermediate": {
7+
"child": {
8+
"name": "Douglas Adams"
9+
},
10+
"number": 42
11+
}
12+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
message Test {
4+
string name = 1;
5+
Test child = 2;
6+
Intermediate intermediate = 3;
7+
}
8+
9+
10+
message Intermediate {
11+
int32 number = 1;
12+
Test child = 2;
13+
}

tests/test_features.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,9 @@ def _round_trip_serialization(foo: Foo) -> Foo:
319319
)
320320

321321

322-
@dataclass(eq=False, repr=False)
323-
class RecursiveMessage(betterproto.Message):
324-
child: "RecursiveMessage" = betterproto.message_field(1)
325-
foo: int = betterproto.int32_field(2)
326-
327-
328322
def test_recursive_message():
323+
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
324+
329325
msg = RecursiveMessage()
330326

331327
assert msg.child == RecursiveMessage()
@@ -337,9 +333,35 @@ def test_recursive_message():
337333
assert bytes(msg) == b""
338334

339335

340-
def test_message_repr():
341-
assert repr(RecursiveMessage(foo=1)) == "RecursiveMessage(foo=1)"
342-
assert (
343-
repr(RecursiveMessage(child=RecursiveMessage(), foo=1))
344-
== "RecursiveMessage(child=RecursiveMessage(), foo=1)"
336+
def test_recursive_message_defaults():
337+
from tests.output_betterproto.recursivemessage import (
338+
Test as RecursiveMessage,
339+
Intermediate,
345340
)
341+
342+
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
343+
344+
# set values are as expected
345+
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
346+
347+
# lazy initialized works modifies the message
348+
assert msg != RecursiveMessage(
349+
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
350+
)
351+
msg.child.child.name = "jude"
352+
assert msg == RecursiveMessage(
353+
name="bob",
354+
intermediate=Intermediate(42),
355+
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
356+
)
357+
358+
# lazily initialization recurses as needed
359+
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
360+
assert msg.intermediate.child.intermediate == Intermediate()
361+
362+
363+
def test_message_repr():
364+
from tests.output_betterproto.recursivemessage import Test
365+
366+
assert repr(Test(name="Loki")) == "Test(name='Loki')"
367+
assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())"

0 commit comments

Comments
 (0)