From af323f7efa50c4540c698486a0a9b3648fc5ccc5 Mon Sep 17 00:00:00 2001 From: Christopher Chambers Date: Sun, 26 Jul 2020 15:08:07 -0400 Subject: [PATCH 1/9] Add a failing test for recursive messages --- tests/test_features.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_features.py b/tests/test_features.py index b5b381126..044795076 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -317,3 +317,20 @@ def _round_trip_serialization(foo: Foo) -> Foo: == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") == ("", None) ) + + +@dataclass(eq=False) +class RecursiveMessage(betterproto.Message): + child: "RecursiveMessage" = betterproto.message_field(1) + + +def test_recursive_message(): + msg = RecursiveMessage() + + assert msg.child == RecursiveMessage() + + # Lazily-created zero-value children must not affect equality. + assert msg == RecursiveMessage() + + # Lazily-created zero-value children must not affect serialization. + assert bytes(msg) == b"" From c0bfdb5a9d88b56062e56f5280a882570622607a Mon Sep 17 00:00:00 2001 From: Christopher Chambers Date: Sun, 26 Jul 2020 15:09:59 -0400 Subject: [PATCH 2/9] Add support for recursive messages Changes message initialization (`__post_init__`) so that default values are no longer eagerly created to prevent infinite recursion when initializing recursive messages. As a result, `PLACEHOLDER` will be present in the message for any uninitialized fields. So, an implementation of `__get_attribute__` is added that checks for `PLACEHOLDER` and lazily creates and stores default field values. And, because `PLACEHOLDER` values don't compare equal with zero values, a custom implementation of `__eq__` is provided, and the code generation template is updated so that messages generate with `@dataclass(eq=False)`. --- src/betterproto/__init__.py | 45 +++++++++++++++++++----- src/betterproto/templates/template.py.j2 | 4 +-- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 483551663..842d6276b 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -513,23 +513,52 @@ def __post_init__(self) -> None: if meta.group: group_current.setdefault(meta.group) - if getattr(self, field_name) != PLACEHOLDER: - # Skip anything not set to the sentinel value + if self.__raw_get(field_name) != PLACEHOLDER: + # Found a non-sentinel value all_sentinel = False if meta.group: # This was set, so make it the selected value of the one-of. group_current[meta.group] = field_name - continue - - setattr(self, field_name, self._get_field_default(field_name)) - # Now that all the defaults are set, reset it! self.__dict__["_serialized_on_wire"] = not all_sentinel self.__dict__["_unknown_fields"] = b"" self.__dict__["_group_current"] = group_current + def __raw_get(self, name: str) -> Any: + return super().__getattribute__(name) + + def __eq__(self, other): + if type(self) != type(other): + return False + + equal = True + for field_name in self._betterproto.meta_by_field_name: + self_val = self.__raw_get(field_name) + other_val = other.__raw_get(field_name) + if self_val is PLACEHOLDER and other_val is PLACEHOLDER: + continue + elif self_val is PLACEHOLDER: + self_val = self._get_field_default(field_name) + elif other_val is PLACEHOLDER: + other_val = other._get_field_default(field_name) + + if self_val != other_val: + equal = False + break + + return equal + + def __getattribute__(self, name: str) -> Any: + value = super().__getattribute__(name) + if value is not PLACEHOLDER: + return value + + value = self._get_field_default(name) + super().__setattr__(name, value) + return value + def __setattr__(self, attr: str, value: Any) -> None: if attr != "_serialized_on_wire": # Track when a field has been set. @@ -542,9 +571,7 @@ def __setattr__(self, attr: str, value: Any) -> None: if field.name == attr: self._group_current[group] = field.name else: - super().__setattr__( - field.name, self._get_field_default(field.name) - ) + super().__setattr__(field.name, PLACEHOLDER) super().__setattr__(attr, value) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 7fd046307..f6cddbc56 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -37,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum): {% endfor %} {% endif %} {% for message in output_file.messages %} -@dataclass +@dataclass(eq=False) class {{ message.py_name }}(betterproto.Message): {% if message.comment %} {{ message.comment }} @@ -82,7 +82,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): Optional[{{ field.annotation }}] {%- else -%} {{ field.annotation }} - {%- endif -%} = + {%- endif -%} = {%- if field.py_name not in method.mutable_default_args -%} {{ field.default_value_string }} {%- else -%} From 87bd56542f3939667d81748ebe49f1f6afa82e55 Mon Sep 17 00:00:00 2001 From: Christopher Chambers Date: Sun, 26 Jul 2020 16:08:17 -0400 Subject: [PATCH 3/9] Add a failing test for recursive message repr --- tests/test_features.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_features.py b/tests/test_features.py index 044795076..e4f908418 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -319,9 +319,10 @@ def _round_trip_serialization(foo: Foo) -> Foo: ) -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class RecursiveMessage(betterproto.Message): child: "RecursiveMessage" = betterproto.message_field(1) + foo: int = betterproto.int32_field(2) def test_recursive_message(): @@ -334,3 +335,11 @@ def test_recursive_message(): # Lazily-created zero-value children must not affect serialization. assert bytes(msg) == b"" + + +def test_message_repr(): + assert repr(RecursiveMessage(foo=1)) == "RecursiveMessage(foo=1)" + assert ( + repr(RecursiveMessage(child=RecursiveMessage(), foo=1)) + == "RecursiveMessage(child=RecursiveMessage(),foo=1)" + ) From c2e9d794c7ba6a69566bf49c2deb4f260cf45475 Mon Sep 17 00:00:00 2001 From: Christopher Chambers Date: Sun, 26 Jul 2020 16:09:21 -0400 Subject: [PATCH 4/9] Add a repr implementation to Message that supports recursive messages --- src/betterproto/__init__.py | 15 +++++++++++++++ src/betterproto/templates/template.py.j2 | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 842d6276b..f78eaed06 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -550,6 +550,21 @@ def __eq__(self, other): return equal + def __repr__(self): + parts = [self.__class__.__name__, "("] + found = False + for field_name in self._betterproto.meta_by_field_name: + value = self.__raw_get(field_name) + if value is PLACEHOLDER: + continue + found = True + parts.extend([field_name, "=", repr(value), ","]) + + if found: + parts.pop() + parts.append(")") + return "".join(parts) + def __getattribute__(self, name: str) -> Any: value = super().__getattribute__(name) if value is not PLACEHOLDER: diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index f6cddbc56..fb10c5f43 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -37,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum): {% endfor %} {% endif %} {% for message in output_file.messages %} -@dataclass(eq=False) +@dataclass(eq=False, repr=False) class {{ message.py_name }}(betterproto.Message): {% if message.comment %} {{ message.comment }} From 79623f6fc8b3692c770eea2f65c65e57817dbc80 Mon Sep 17 00:00:00 2001 From: Christopher Chambers Date: Mon, 27 Jul 2020 15:59:52 -0400 Subject: [PATCH 5/9] Address recursive message PR notes - Uses `is not` to compare types in `Message.__eq__` - Adds a space after each comma in `Message.__repr__` --- src/betterproto/__init__.py | 4 ++-- tests/test_features.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f78eaed06..f75bb45c6 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -530,7 +530,7 @@ def __raw_get(self, name: str) -> Any: return super().__getattribute__(name) def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False equal = True @@ -558,7 +558,7 @@ def __repr__(self): if value is PLACEHOLDER: continue found = True - parts.extend([field_name, "=", repr(value), ","]) + parts.extend([field_name, "=", repr(value), ", "]) if found: parts.pop() diff --git a/tests/test_features.py b/tests/test_features.py index e4f908418..e3fd73b2b 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -341,5 +341,5 @@ def test_message_repr(): assert repr(RecursiveMessage(foo=1)) == "RecursiveMessage(foo=1)" assert ( repr(RecursiveMessage(child=RecursiveMessage(), foo=1)) - == "RecursiveMessage(child=RecursiveMessage(),foo=1)" + == "RecursiveMessage(child=RecursiveMessage(), foo=1)" ) From 34a243ec3e213708d6a5cc896fe867f0a6943ac5 Mon Sep 17 00:00:00 2001 From: Christopher Chambers Date: Mon, 27 Jul 2020 17:29:54 -0400 Subject: [PATCH 6/9] Revert accidental template whitespace change --- src/betterproto/templates/template.py.j2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index fb10c5f43..753d340c7 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -82,7 +82,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): Optional[{{ field.annotation }}] {%- else -%} {{ field.annotation }} - {%- endif -%} = + {%- endif -%} = {%- if field.py_name not in method.mutable_default_args -%} {{ field.default_value_string }} {%- else -%} From 8c821c229d7ac4665306110b7de684b1dcae5784 Mon Sep 17 00:00:00 2001 From: nat Date: Sun, 30 Aug 2020 19:45:40 +0300 Subject: [PATCH 7/9] Update src/betterproto/__init__.py Co-authored-by: James <50501825+Gobot1234@users.noreply.github.com> --- src/betterproto/__init__.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f75bb45c6..c7032808f 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -550,20 +550,15 @@ def __eq__(self, other): return equal - def __repr__(self): - parts = [self.__class__.__name__, "("] - found = False + def __repr__(self) -> str: + parts = [] for field_name in self._betterproto.meta_by_field_name: value = self.__raw_get(field_name) if value is PLACEHOLDER: continue - found = True - parts.extend([field_name, "=", repr(value), ", "]) + parts.append(f"{field_name}={value!r}") - if found: - parts.pop() - parts.append(")") - return "".join(parts) + return f"{self.__class__.__name__}({', '.join(parts)})" def __getattribute__(self, name: str) -> Any: value = super().__getattribute__(name) From 9b3b45190afd1fe2a7c2ec570ad862f30c0b4d3c Mon Sep 17 00:00:00 2001 From: nat Date: Sun, 30 Aug 2020 19:53:12 +0300 Subject: [PATCH 8/9] Update src/betterproto/__init__.py Co-authored-by: James <50501825+Gobot1234@users.noreply.github.com> --- src/betterproto/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index c7032808f..8cab4b0f0 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -529,26 +529,25 @@ def __post_init__(self) -> None: def __raw_get(self, name: str) -> Any: return super().__getattribute__(name) - def __eq__(self, other): + def __eq__(self, other: T) -> bool: if type(self) is not type(other): return False - equal = True for field_name in self._betterproto.meta_by_field_name: self_val = self.__raw_get(field_name) other_val = other.__raw_get(field_name) if self_val is PLACEHOLDER and other_val is PLACEHOLDER: continue - elif self_val is PLACEHOLDER: + + if self_val is PLACEHOLDER: self_val = self._get_field_default(field_name) elif other_val is PLACEHOLDER: other_val = other._get_field_default(field_name) if self_val != other_val: - equal = False - break + return False - return equal + return True def __repr__(self) -> str: parts = [] From 2af6f634046323c68ec2758b016def13a4afa561 Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Sun, 30 Aug 2020 20:54:12 +0200 Subject: [PATCH 9/9] Finish off #130 - Add more test cases using the standard test pattern - Add sorted_field_names to message subclass metadata to support stable ordering of keys in repr. - Tweak code in new message methods --- src/betterproto/__init__.py | 28 +++++++----- .../recursivemessage/recursivemessage.json | 12 +++++ .../recursivemessage/recursivemessage.proto | 13 ++++++ tests/test_features.py | 44 ++++++++++++++----- 4 files changed, 75 insertions(+), 22 deletions(-) create mode 100644 tests/inputs/recursivemessage/recursivemessage.json create mode 100644 tests/inputs/recursivemessage/recursivemessage.proto diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 8cab4b0f0..598579878 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -428,6 +428,7 @@ class ProtoClassMetadata: "cls_by_field", "field_name_by_number", "meta_by_field_name", + "sorted_field_names", ) def __init__(self, cls: Type["Message"]): @@ -453,6 +454,9 @@ def __init__(self, cls: Type["Message"]): self.oneof_field_by_group = by_group self.field_name_by_number = by_field_number self.meta_by_field_name = by_field_name + self.sorted_field_names = tuple( + by_field_number[number] for number in sorted(by_field_number.keys()) + ) self.default_gen = self._get_default_gen(cls, fields) self.cls_by_field = self._get_cls_by_field(cls, fields) @@ -529,17 +533,16 @@ def __post_init__(self) -> None: def __raw_get(self, name: str) -> Any: return super().__getattribute__(name) - def __eq__(self, other: T) -> bool: + def __eq__(self, other) -> bool: if type(self) is not type(other): return False for field_name in self._betterproto.meta_by_field_name: self_val = self.__raw_get(field_name) other_val = other.__raw_get(field_name) - if self_val is PLACEHOLDER and other_val is PLACEHOLDER: - continue - if self_val is PLACEHOLDER: + if other_val is PLACEHOLDER: + continue self_val = self._get_field_default(field_name) elif other_val is PLACEHOLDER: other_val = other._get_field_default(field_name) @@ -550,16 +553,19 @@ def __eq__(self, other: T) -> bool: return True def __repr__(self) -> str: - parts = [] - for field_name in self._betterproto.meta_by_field_name: - value = self.__raw_get(field_name) - if value is PLACEHOLDER: - continue - parts.append(f"{field_name}={value!r}") - + parts = [ + f"{field_name}={value!r}" + for field_name in self._betterproto.sorted_field_names + for value in (self.__raw_get(field_name),) + if value is not PLACEHOLDER + ] return f"{self.__class__.__name__}({', '.join(parts)})" def __getattribute__(self, name: str) -> Any: + """ + Lazily initialize default values to avoid infinite recursion for recursive + message types + """ value = super().__getattribute__(name) if value is not PLACEHOLDER: return value diff --git a/tests/inputs/recursivemessage/recursivemessage.json b/tests/inputs/recursivemessage/recursivemessage.json new file mode 100644 index 000000000..e92c3fbfa --- /dev/null +++ b/tests/inputs/recursivemessage/recursivemessage.json @@ -0,0 +1,12 @@ +{ + "name": "Zues", + "child": { + "name": "Hercules" + }, + "intermediate": { + "child": { + "name": "Douglas Adams" + }, + "number": 42 + } +} diff --git a/tests/inputs/recursivemessage/recursivemessage.proto b/tests/inputs/recursivemessage/recursivemessage.proto new file mode 100644 index 000000000..f988316e3 --- /dev/null +++ b/tests/inputs/recursivemessage/recursivemessage.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +message Test { + string name = 1; + Test child = 2; + Intermediate intermediate = 3; +} + + +message Intermediate { + int32 number = 1; + Test child = 2; +} diff --git a/tests/test_features.py b/tests/test_features.py index e3fd73b2b..f5482643c 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -319,13 +319,9 @@ def _round_trip_serialization(foo: Foo) -> Foo: ) -@dataclass(eq=False, repr=False) -class RecursiveMessage(betterproto.Message): - child: "RecursiveMessage" = betterproto.message_field(1) - foo: int = betterproto.int32_field(2) - - def test_recursive_message(): + from tests.output_betterproto.recursivemessage import Test as RecursiveMessage + msg = RecursiveMessage() assert msg.child == RecursiveMessage() @@ -337,9 +333,35 @@ def test_recursive_message(): assert bytes(msg) == b"" -def test_message_repr(): - assert repr(RecursiveMessage(foo=1)) == "RecursiveMessage(foo=1)" - assert ( - repr(RecursiveMessage(child=RecursiveMessage(), foo=1)) - == "RecursiveMessage(child=RecursiveMessage(), foo=1)" +def test_recursive_message_defaults(): + from tests.output_betterproto.recursivemessage import ( + Test as RecursiveMessage, + Intermediate, ) + + msg = RecursiveMessage(name="bob", intermediate=Intermediate(42)) + + # set values are as expected + assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42)) + + # lazy initialized works modifies the message + assert msg != RecursiveMessage( + name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude") + ) + msg.child.child.name = "jude" + assert msg == RecursiveMessage( + name="bob", + intermediate=Intermediate(42), + child=RecursiveMessage(child=RecursiveMessage(name="jude")), + ) + + # lazily initialization recurses as needed + assert msg.child.child.child.child.child.child.child == RecursiveMessage() + assert msg.intermediate.child.intermediate == Intermediate() + + +def test_message_repr(): + from tests.output_betterproto.recursivemessage import Test + + assert repr(Test(name="Loki")) == "Test(name='Loki')" + assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())"