diff --git a/src/betterproto2_compiler/plugin/compiler.py b/src/betterproto2_compiler/plugin/compiler.py index c35697e7..62f1658e 100644 --- a/src/betterproto2_compiler/plugin/compiler.py +++ b/src/betterproto2_compiler/plugin/compiler.py @@ -33,12 +33,22 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: loader=jinja2.FileSystemLoader(templates_folder), undefined=jinja2.StrictUndefined, ) - # Load the body first so we have a compleate list of imports needed. + + # List of the symbols that should appear in the `__all__` variable of the file + all: list[str] = [] + + def add_to_all(name: str) -> str: + all.append(name) + return name + + env.filters["add_to_all"] = add_to_all + body_template = env.get_template("template.py.j2") header_template = env.get_template("header.py.j2") + # Load the body first do know the symbols defined in the file code = body_template.render(output_file=output_file) - code = header_template.render(output_file=output_file, version=version) + "\n" + code + code = header_template.render(output_file=output_file, version=version, all=all) + "\n" + code try: # Sort imports, delete unused ones diff --git a/src/betterproto2_compiler/templates/header.py.j2 b/src/betterproto2_compiler/templates/header.py.j2 index d6bb310f..64a2de57 100644 --- a/src/betterproto2_compiler/templates/header.py.j2 +++ b/src/betterproto2_compiler/templates/header.py.j2 @@ -6,15 +6,8 @@ # This file has been @generated __all__ = ( - {% for _, enum in output_file.enums|dictsort(by="key") %} - "{{ enum.py_name }}", - {%- endfor -%} - {% for _, message in output_file.messages|dictsort(by="key") %} - "{{ message.py_name }}", - {%- endfor -%} - {% for _, service in output_file.services|dictsort(by="key") %} - "{{ service.py_name }}Stub", - "{{ service.py_name }}Base", + {%- for name in all -%} + "{{ name }}", {%- endfor -%} ) diff --git a/src/betterproto2_compiler/templates/service_stub.py.j2 b/src/betterproto2_compiler/templates/service_stub.py.j2 index f0958055..54cec014 100644 --- a/src/betterproto2_compiler/templates/service_stub.py.j2 +++ b/src/betterproto2_compiler/templates/service_stub.py.j2 @@ -1,4 +1,4 @@ -class {% block class_name %}{% endblock %}({% block inherit_from %}{% endblock %}): +class {% filter add_to_all %}{% block class_name %}{% endblock %}{% endfilter %}({% block inherit_from %}{% endblock %}): {% block service_docstring scoped %} {% if service.comment %} """ diff --git a/src/betterproto2_compiler/templates/template.py.j2 b/src/betterproto2_compiler/templates/template.py.j2 index 4f89d832..f51a3b39 100644 --- a/src/betterproto2_compiler/templates/template.py.j2 +++ b/src/betterproto2_compiler/templates/template.py.j2 @@ -1,5 +1,5 @@ {% for _, enum in output_file.enums|dictsort(by="key") %} -class {{ enum.py_name }}(betterproto2.Enum): +class {{ enum.py_name | add_to_all }}(betterproto2.Enum): {% if enum.comment %} """ {{ enum.comment | indent(4) }} @@ -31,7 +31,7 @@ class {{ enum.py_name }}(betterproto2.Enum): {% else %} @dataclass(eq=False, repr=False) {% endif %} -class {{ message.py_name }}(betterproto2.Message): +class {{ message.py_name | add_to_all }}(betterproto2.Message): {% if message.comment or message.oneofs %} """ {{ message.comment | indent(4) }} @@ -104,7 +104,7 @@ default_message_pool.register_message("{{ output_file.package }}", "{{ message.p {% if output_file.settings.server_generation == "async" %} {% for _, service in output_file.services|dictsort(by="key") %} -class {{ service.py_name }}Base(ServiceBase): +class {{ (service.py_name + "Base") | add_to_all }}(ServiceBase): {% if service.comment %} """ {{ service.comment | indent(4) }}