Skip to content

Commit 06a626d

Browse files
committed
Handle mutable default arguments cleanly
When generating code, ensure that default list/dict arguments are initialised in local scope if unspecified or `None`.
1 parent 8864f4f commit 06a626d

File tree

6 files changed

+68
-6
lines changed

6 files changed

+68
-6
lines changed

betterproto/plugin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ def lookup_method_input_type(method, types):
369369
return known_type
370370

371371

372+
def is_mutable_field_type(field_type: str) -> bool:
373+
return field_type.startswith("List[") or field_type.startswith("Dict[")
374+
375+
372376
def read_protobuf_service(
373377
service: ServiceDescriptorProto, index, proto_file, content, output_types
374378
):
@@ -384,8 +388,23 @@ def read_protobuf_service(
384388
for j, method in enumerate(service.method):
385389
method_input_message = lookup_method_input_type(method, output_types)
386390

391+
# This section ensures that method arguments having a default
392+
# value that is initialised as a List/Dict (mutable) is replaced
393+
# with None and initialisation is deferred to the beginning of the
394+
# method definition. This is done so to avoid any side-effects.
395+
# Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
396+
mutable_default_args = []
397+
387398
if method_input_message:
388399
for field in method_input_message["properties"]:
400+
if (
401+
not method.client_streaming
402+
and field["zero"] != "None"
403+
and is_mutable_field_type(field["type"])
404+
):
405+
mutable_default_args.append((field["py_name"], field["zero"]))
406+
field["zero"] = "None"
407+
389408
if field["zero"] == "None":
390409
template_data["typing_imports"].add("Optional")
391410

@@ -407,6 +426,7 @@ def read_protobuf_service(
407426
),
408427
"client_streaming": method.client_streaming,
409428
"server_streaming": method.server_streaming,
429+
"mutable_default_args": mutable_default_args,
410430
}
411431
)
412432

betterproto/templates/template.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
8080
{{ method.comment }}
8181

8282
{% endif %}
83+
{%- for py_name, zero in method.mutable_default_args %}
84+
{{ py_name }} = {{ py_name }} or {{ zero }}
85+
{% endfor %}
86+
8387
{% if not method.client_streaming %}
8488
request = {{ method.input }}()
8589
{% for field in method.input_message.properties %}

betterproto/tests/grpc/test_grpclib_client.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import asyncio
2+
import sys
3+
4+
import grpclib
5+
import grpclib.metadata
6+
import pytest
7+
from grpclib.testing import ChannelFor
8+
9+
from betterproto.grpc.util.async_channel import AsyncChannel
210
from betterproto.tests.output_betterproto.service.service import (
3-
DoThingResponse,
411
DoThingRequest,
12+
DoThingResponse,
513
GetThingRequest,
614
TestStub as ThingServiceClient,
715
)
8-
import grpclib
9-
from grpclib.testing import ChannelFor
10-
import pytest
11-
from betterproto.grpc.util.async_channel import AsyncChannel
1216
from .thing_service import ThingService
1317

1418

@@ -35,6 +39,20 @@ async def test_simple_service_call():
3539
await _test_client(ThingServiceClient(channel))
3640

3741

42+
@pytest.mark.asyncio
43+
@pytest.mark.skipif(
44+
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
45+
)
46+
async def test_service_call_mutable_defaults(mocker):
47+
async with ChannelFor([ThingService()]) as channel:
48+
client = ThingServiceClient(channel)
49+
spy = mocker.spy(client, "_unary_unary")
50+
await _test_client(client)
51+
comments = spy.call_args_list[-1].args[1].comments
52+
await _test_client(client)
53+
assert spy.call_args_list[-1].args[1].comments is not comments
54+
55+
3856
@pytest.mark.asyncio
3957
async def test_service_call_with_upfront_request_params():
4058
# Setting deadline

betterproto/tests/inputs/service/service.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package service;
44

55
message DoThingRequest {
66
string name = 1;
7+
repeated string comments = 2;
78
}
89

910
message DoThingResponse {

poetry.lock

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ protobuf = "^3.12.2"
2929
pytest = "^5.4.2"
3030
pytest-asyncio = "^0.12.0"
3131
pytest-cov = "^2.9.0"
32+
pytest-mock = "^3.1.1"
3233
tox = "^3.15.1"
3334

3435
[tool.poetry.scripts]

0 commit comments

Comments
 (0)