Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.

Fix __all__ definition #77

Merged
merged 2 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/betterproto2_compiler/plugin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,22 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
loader=jinja2.FileSystemLoader(templates_folder),
undefined=jinja2.StrictUndefined,
)
# Load the body first so we have a compleate list of imports needed.

# List of the symbols that should appear in the `__all__` variable of the file
all: list[str] = []

def add_to_all(name: str) -> str:
all.append(name)
return name

env.filters["add_to_all"] = add_to_all

body_template = env.get_template("template.py.j2")
header_template = env.get_template("header.py.j2")

# Load the body first do know the symbols defined in the file
code = body_template.render(output_file=output_file)
code = header_template.render(output_file=output_file, version=version) + "\n" + code
code = header_template.render(output_file=output_file, version=version, all=all) + "\n" + code

try:
# Sort imports, delete unused ones
Expand Down
11 changes: 2 additions & 9 deletions src/betterproto2_compiler/templates/header.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,8 @@
# This file has been @generated

__all__ = (
{% for _, enum in output_file.enums|dictsort(by="key") %}
"{{ enum.py_name }}",
{%- endfor -%}
{% for _, message in output_file.messages|dictsort(by="key") %}
"{{ message.py_name }}",
{%- endfor -%}
{% for _, service in output_file.services|dictsort(by="key") %}
"{{ service.py_name }}Stub",
"{{ service.py_name }}Base",
{%- for name in all -%}
"{{ name }}",
{%- endfor -%}
)

Expand Down
2 changes: 1 addition & 1 deletion src/betterproto2_compiler/templates/service_stub.py.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class {% block class_name %}{% endblock %}({% block inherit_from %}{% endblock %}):
class {% filter add_to_all %}{% block class_name %}{% endblock %}{% endfilter %}({% block inherit_from %}{% endblock %}):
{% block service_docstring scoped %}
{% if service.comment %}
"""
Expand Down
6 changes: 3 additions & 3 deletions src/betterproto2_compiler/templates/template.py.j2
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% for _, enum in output_file.enums|dictsort(by="key") %}
class {{ enum.py_name }}(betterproto2.Enum):
class {{ enum.py_name | add_to_all }}(betterproto2.Enum):
{% if enum.comment %}
"""
{{ enum.comment | indent(4) }}
Expand Down Expand Up @@ -31,7 +31,7 @@ class {{ enum.py_name }}(betterproto2.Enum):
{% else %}
@dataclass(eq=False, repr=False)
{% endif %}
class {{ message.py_name }}(betterproto2.Message):
class {{ message.py_name | add_to_all }}(betterproto2.Message):
{% if message.comment or message.oneofs %}
"""
{{ message.comment | indent(4) }}
Expand Down Expand Up @@ -104,7 +104,7 @@ default_message_pool.register_message("{{ output_file.package }}", "{{ message.p

{% if output_file.settings.server_generation == "async" %}
{% for _, service in output_file.services|dictsort(by="key") %}
class {{ service.py_name }}Base(ServiceBase):
class {{ (service.py_name + "Base") | add_to_all }}(ServiceBase):
{% if service.comment %}
"""
{{ service.comment | indent(4) }}
Expand Down
Loading