Skip to content

Commit 66dfc63

Browse files
committed
Handle mutable default arguments cleanly
When generating code, ensure that default list/dict arguments are initialised in local scope if unspecified or `None`.
1 parent 3273ae4 commit 66dfc63

File tree

6 files changed

+131
-70
lines changed

6 files changed

+131
-70
lines changed

betterproto/plugin.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,24 @@ def generate_code(request, response):
314314
output["typing_imports"].add("Optional")
315315
break
316316

317+
mutable_default_args = []
318+
if (
319+
not method.client_streaming
320+
and input_message
321+
and input_message.get("properties")
322+
):
323+
properties = []
324+
for f in input_message.get("properties"):
325+
if f["zero"] != "None" and (
326+
f["type"].startswith("List[")
327+
or f["type"].startswith("Dict[")
328+
):
329+
output["typing_imports"].add("Optional")
330+
mutable_default_args.append((f["py_name"], f["zero"]))
331+
f["zero"] = "None"
332+
properties.append(f)
333+
input_message["properties"] = properties
334+
317335
data["methods"].append(
318336
{
319337
"name": method.name,
@@ -332,6 +350,7 @@ def generate_code(request, response):
332350
),
333351
"client_streaming": method.client_streaming,
334352
"server_streaming": method.server_streaming,
353+
"mutable_default_args": mutable_default_args,
335354
}
336355
)
337356

betterproto/templates/template.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
8080
{{ method.comment }}
8181

8282
{% endif %}
83+
{%- for py_name, zero in method.mutable_default_args %}
84+
{{ py_name }} = {{ py_name }} or {{ zero }}
85+
{% endfor %}
86+
8387
{% if not method.client_streaming %}
8488
request = {{ method.input }}()
8589
{% for field in method.input_message.properties %}

betterproto/tests/grpc/test_grpclib_client.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import asyncio
2+
import sys
3+
4+
import grpclib
5+
import grpclib.metadata
6+
import pytest
7+
from grpclib.testing import ChannelFor
8+
9+
from betterproto.grpc.util.async_channel import AsyncChannel
210
from betterproto.tests.output_betterproto.service.service import (
3-
DoThingResponse,
411
DoThingRequest,
12+
DoThingResponse,
513
GetThingRequest,
614
TestStub as ThingServiceClient,
715
)
8-
import grpclib
9-
from grpclib.testing import ChannelFor
10-
import pytest
11-
from betterproto.grpc.util.async_channel import AsyncChannel
1216
from .thing_service import ThingService
1317

1418

@@ -35,6 +39,20 @@ async def test_simple_service_call():
3539
await _test_client(ThingServiceClient(channel))
3640

3741

42+
@pytest.mark.asyncio
43+
@pytest.mark.skipif(
44+
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
45+
)
46+
async def test_service_call_mutable_defaults(mocker):
47+
async with ChannelFor([ThingService()]) as channel:
48+
client = ThingServiceClient(channel)
49+
spy = mocker.spy(client, "_unary_unary")
50+
await _test_client(client)
51+
comments = spy.call_args_list[-1].args[1].comments
52+
await _test_client(client)
53+
assert spy.call_args_list[-1].args[1].comments is not comments
54+
55+
3856
@pytest.mark.asyncio
3957
async def test_service_call_with_upfront_request_params():
4058
# Setting deadline

betterproto/tests/inputs/service/service.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package service;
44

55
message DoThingRequest {
66
string name = 1;
7+
repeated string comments = 2;
78
}
89

910
message DoThingResponse {

0 commit comments

Comments
 (0)