diff --git a/openapi_python_client/__init__.py b/openapi_python_client/__init__.py index 44fc39cd4..3eeaa382c 100644 --- a/openapi_python_client/__init__.py +++ b/openapi_python_client/__init__.py @@ -104,6 +104,8 @@ def __init__( package_version=self.version, project_name=self.project_name, project_dir=self.project_dir, + openapi=self.openapi, + endpoint_collections_by_tag=self.openapi.endpoint_collections_by_tag, ) self.errors: List[GeneratorError] = [] @@ -264,18 +266,13 @@ def _build_api(self) -> None: client_path.write_text(client_template.render(), encoding=self.file_encoding) # Generate endpoints - endpoint_collections_by_tag = self.openapi.endpoint_collections_by_tag api_dir = self.package_dir / "api" api_dir.mkdir() api_init_path = api_dir / "__init__.py" api_init_template = self.env.get_template("api_init.py.jinja") - api_init_path.write_text( - api_init_template.render( - endpoint_collections_by_tag=endpoint_collections_by_tag, - ), - encoding=self.file_encoding, - ) + api_init_path.write_text(api_init_template.render(), encoding=self.file_encoding) + endpoint_collections_by_tag = self.openapi.endpoint_collections_by_tag endpoint_template = self.env.get_template( "endpoint_module.py.jinja", globals={"isbool": lambda obj: obj.get_base_type_string() == "bool"} ) diff --git a/openapi_python_client/templates/api_init.py.jinja b/openapi_python_client/templates/api_init.py.jinja index dc035f4ce..1464ff516 100644 --- a/openapi_python_client/templates/api_init.py.jinja +++ b/openapi_python_client/templates/api_init.py.jinja @@ -1 +1,38 @@ """ Contains methods for accessing the API """ + +from functools import cached_property +from typing import Awaitable, Callable, Union +import attr + +from ..client import Client +{% for tag in endpoint_collections_by_tag.keys() %} +from . import {{ tag }} as tag_{{ tag }} +{% endfor %} + +@attr.s(auto_attribs=True) +class SyncApi: + _client: Union[Client, Callable[[], Client]] + + {% for tag in endpoint_collections_by_tag.keys() %} + @cached_property + def {{ tag }}(self) -> tag_{{ tag }}.Endpoints: + """ Group of endpoints tagged with {{ tag }} """ + return tag_{{ tag }}.Endpoints(self._client) + {% endfor %} + +@attr.s(auto_attribs=True) +class AsyncApi: + _client: Union[Client, Callable[[], Client], Callable[[], Awaitable[Client]]] + + {% for tag in endpoint_collections_by_tag.keys() %} + @cached_property + def {{ tag }}(self) -> tag_{{ tag }}.AsyncEndpoints: + """ Group of endpoints tagged with {{ tag }} """ + return tag_{{ tag }}.AsyncEndpoints(self._client) + {% endfor %} + + +__all__ = ( + "Api", + "AsyncApi", +) diff --git a/openapi_python_client/templates/client.py.jinja b/openapi_python_client/templates/client.py.jinja index 418ea17b0..f4c54fc95 100644 --- a/openapi_python_client/templates/client.py.jinja +++ b/openapi_python_client/templates/client.py.jinja @@ -45,4 +45,9 @@ class AuthenticatedClient(Client): def get_headers(self) -> Dict[str, str]: auth_header_value = f"{self.prefix} {self.token}" if self.prefix else self.token """Get headers to be used in authenticated endpoints""" - return {self.auth_header_name: auth_header_value, **self.headers} \ No newline at end of file + return {self.auth_header_name: auth_header_value, **self.headers} + +__all__ = ( + "Client", + "AuthenticatedClient", +) diff --git a/openapi_python_client/templates/endpoint_init.py.jinja b/openapi_python_client/templates/endpoint_init.py.jinja index e69de29bb..189082be1 100644 --- a/openapi_python_client/templates/endpoint_init.py.jinja +++ b/openapi_python_client/templates/endpoint_init.py.jinja @@ -0,0 +1,100 @@ +{% from "endpoint_macros.py.jinja" import header_params, cookie_params, query_params, json_body, multipart_body, arguments, client, kwargs, parse_response, docstring %} + +from inspect import isawaitable +from typing import Any, Dict, List, Optional, Union, Callable, Awaitable +import attr + +from ...client import Client +from ...types import Response, UNSET + +{% for endpoint in endpoint_collection.endpoints %} +{% for relative in endpoint.relative_imports %} +{{ relative }} +{% endfor %} +{% endfor %} + +@attr.s(auto_attribs=True) +class Endpoints: + _client: Union[Client, Callable[[], Client]] + + def _get_client(self) -> Client: + client = self._client + if callable(client): + client = client() + assert isinstance(client, Client) + return client + + {% for endpoint in endpoint_collection.endpoints %} + {% set f_name = python_identifier(endpoint.name) %} + {% set return_string = endpoint.response_type() %} + {% set parsed_responses = (endpoint.responses | length > 0) and return_string != "Any" %} + def {{f_name}}_detailed( + self, + {{ arguments(endpoint, include_client=False) | indent(4) }} + ) -> Response[{{ return_string }}]: + {{ docstring(endpoint, return_string) | indent(4) }} + + from . import {{ f_name }} as _endpoint + return _endpoint.sync_detailed( + {{ kwargs(endpoint, include_client = False) }} + client = self._get_client(), + ) + + {% if parsed_responses %} + def {{f_name}}( + self, + {{ arguments(endpoint, include_client=False) | indent(4) }} + ) -> Optional[{{ return_string }}]: + {{ docstring(endpoint, return_string) | indent(4) }} + + from . import {{ f_name }} as _endpoint + return _endpoint.sync( + {{ kwargs(endpoint, include_client = False) }} + client = self._get_client(), + ) + {% endif %} + {% endfor %} + +@attr.s(auto_attribs=True) +class AsyncEndpoints: + _client: Union[Client, Callable[[], Client], Callable[[], Awaitable[Client]]] + + async def _get_client(self) -> Client: + client = self._client + if callable(client): + client = client() # type: ignore + if isawaitable(client): + client = await client + assert isinstance(client, Client) + return client + + {% for endpoint in endpoint_collection.endpoints %} + {% set f_name = python_identifier(endpoint.name) %} + {% set return_string = endpoint.response_type() %} + {% set parsed_responses = (endpoint.responses | length > 0) and return_string != "Any" %} + async def {{f_name}}_detailed( + self, + {{ arguments(endpoint, include_client = False) | indent(4) }} + ) -> Response[{{ return_string }}]: + {{ docstring(endpoint, return_string) | indent(4) }} + + from . import {{ f_name }} as _endpoint + return await _endpoint.asyncio_detailed( + {{ kwargs(endpoint, include_client = False) }} + client = await self._get_client(), + ) + + {% if parsed_responses %} + async def {{f_name}}( + self, + {{ arguments(endpoint, include_client=False) | indent(4) }} + ) -> Optional[{{ return_string }}]: + {{ docstring(endpoint, return_string) | indent(4) }} + + from . import {{ f_name }} as _endpoint + return await _endpoint.asyncio( + {{ kwargs(endpoint, include_client = False) }} + client = await self._get_client(), + ) + {% endif %} + {% endfor %} diff --git a/openapi_python_client/templates/endpoint_macros.py.jinja b/openapi_python_client/templates/endpoint_macros.py.jinja index 6eb4be11c..4b86fa173 100644 --- a/openapi_python_client/templates/endpoint_macros.py.jinja +++ b/openapi_python_client/templates/endpoint_macros.py.jinja @@ -77,18 +77,22 @@ params = {k: v for k, v in params.items() if v is not UNSET and v is not None} {% endmacro %} {# The all the kwargs passed into an endpoint (and variants thereof)) #} -{% macro arguments(endpoint) %} +{% macro arguments(endpoint, include_client=True) %} {# path parameters #} {% for parameter in endpoint.path_parameters.values() %} {{ parameter.to_string() }}, {% endfor %} +{% if include_client or endpoint.form_body or endpoint.multipart_body or endpoint.json_body or endpoint.query_parameters or endpoint.header_parameters or endpoint.cookie_parameters %} *, +{% endif %} +{% if include_client%} {# Proper client based on whether or not the endpoint requires authentication #} {% if endpoint.requires_security %} client: AuthenticatedClient, {% else %} client: Client, {% endif %} +{% endif %} {# Form data if any #} {% if endpoint.form_body %} form_data: {{ endpoint.form_body.get_type_string() }}, @@ -115,11 +119,13 @@ json_body: {{ endpoint.json_body.get_type_string() }}, {% endmacro %} {# Just lists all kwargs to endpoints as name=name for passing to other functions #} -{% macro kwargs(endpoint) %} +{% macro kwargs(endpoint, include_client=True) %} {% for parameter in endpoint.path_parameters.values() %} {{ parameter.python_name }}={{ parameter.python_name }}, {% endfor %} +{% if include_client %} client=client, +{% endif %} {% if endpoint.form_body %} form_data=form_data, {% endif %} diff --git a/openapi_python_client/templates/package_init.py.jinja b/openapi_python_client/templates/package_init.py.jinja index 366a7e508..2c61d0656 100644 --- a/openapi_python_client/templates/package_init.py.jinja +++ b/openapi_python_client/templates/package_init.py.jinja @@ -1,7 +1,10 @@ """ {{ package_description }} """ -from .client import AuthenticatedClient, Client +from .client import Client, AuthenticatedClient +from .api import SyncApi, AsyncApi __all__ = ( - "AuthenticatedClient", "Client", + "AuthenticatedClient", + "SyncApi", + "AsyncApi", )