Skip to content

Add a more fluent API wrapper #690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions openapi_python_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(
package_version=self.version,
project_name=self.project_name,
project_dir=self.project_dir,
openapi=self.openapi,
endpoint_collections_by_tag=self.openapi.endpoint_collections_by_tag,
)
self.errors: List[GeneratorError] = []

Expand Down Expand Up @@ -264,18 +266,13 @@ def _build_api(self) -> None:
client_path.write_text(client_template.render(), encoding=self.file_encoding)

# Generate endpoints
endpoint_collections_by_tag = self.openapi.endpoint_collections_by_tag
api_dir = self.package_dir / "api"
api_dir.mkdir()
api_init_path = api_dir / "__init__.py"
api_init_template = self.env.get_template("api_init.py.jinja")
api_init_path.write_text(
api_init_template.render(
endpoint_collections_by_tag=endpoint_collections_by_tag,
),
encoding=self.file_encoding,
)
api_init_path.write_text(api_init_template.render(), encoding=self.file_encoding)

endpoint_collections_by_tag = self.openapi.endpoint_collections_by_tag
endpoint_template = self.env.get_template(
"endpoint_module.py.jinja", globals={"isbool": lambda obj: obj.get_base_type_string() == "bool"}
)
Expand Down
37 changes: 37 additions & 0 deletions openapi_python_client/templates/api_init.py.jinja
Original file line number Diff line number Diff line change
@@ -1 +1,38 @@
""" Contains methods for accessing the API """

from functools import cached_property
from typing import Awaitable, Callable, Union
import attr

from ..client import Client
{% for tag in endpoint_collections_by_tag.keys() %}
from . import {{ tag }} as tag_{{ tag }}
{% endfor %}

@attr.s(auto_attribs=True)
class SyncApi:
_client: Union[Client, Callable[[], Client]]

{% for tag in endpoint_collections_by_tag.keys() %}
@cached_property
def {{ tag }}(self) -> tag_{{ tag }}.Endpoints:
""" Group of endpoints tagged with {{ tag }} """
return tag_{{ tag }}.Endpoints(self._client)
{% endfor %}

@attr.s(auto_attribs=True)
class AsyncApi:
_client: Union[Client, Callable[[], Client], Callable[[], Awaitable[Client]]]

{% for tag in endpoint_collections_by_tag.keys() %}
@cached_property
def {{ tag }}(self) -> tag_{{ tag }}.AsyncEndpoints:
""" Group of endpoints tagged with {{ tag }} """
return tag_{{ tag }}.AsyncEndpoints(self._client)
{% endfor %}


__all__ = (
"Api",
"AsyncApi",
)
7 changes: 6 additions & 1 deletion openapi_python_client/templates/client.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,9 @@ class AuthenticatedClient(Client):
def get_headers(self) -> Dict[str, str]:
auth_header_value = f"{self.prefix} {self.token}" if self.prefix else self.token
"""Get headers to be used in authenticated endpoints"""
return {self.auth_header_name: auth_header_value, **self.headers}
return {self.auth_header_name: auth_header_value, **self.headers}

__all__ = (
"Client",
"AuthenticatedClient",
)
100 changes: 100 additions & 0 deletions openapi_python_client/templates/endpoint_init.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{% from "endpoint_macros.py.jinja" import header_params, cookie_params, query_params, json_body, multipart_body, arguments, client, kwargs, parse_response, docstring %}

from inspect import isawaitable
from typing import Any, Dict, List, Optional, Union, Callable, Awaitable
import attr

from ...client import Client
from ...types import Response, UNSET

{% for endpoint in endpoint_collection.endpoints %}
{% for relative in endpoint.relative_imports %}
{{ relative }}
{% endfor %}
{% endfor %}

@attr.s(auto_attribs=True)
class Endpoints:
_client: Union[Client, Callable[[], Client]]

def _get_client(self) -> Client:
client = self._client
if callable(client):
client = client()
assert isinstance(client, Client)
return client

{% for endpoint in endpoint_collection.endpoints %}
{% set f_name = python_identifier(endpoint.name) %}
{% set return_string = endpoint.response_type() %}
{% set parsed_responses = (endpoint.responses | length > 0) and return_string != "Any" %}
def {{f_name}}_detailed(
self,
{{ arguments(endpoint, include_client=False) | indent(4) }}
) -> Response[{{ return_string }}]:
{{ docstring(endpoint, return_string) | indent(4) }}

from . import {{ f_name }} as _endpoint
return _endpoint.sync_detailed(
{{ kwargs(endpoint, include_client = False) }}
client = self._get_client(),
)

{% if parsed_responses %}
def {{f_name}}(
self,
{{ arguments(endpoint, include_client=False) | indent(4) }}
) -> Optional[{{ return_string }}]:
{{ docstring(endpoint, return_string) | indent(4) }}

from . import {{ f_name }} as _endpoint
return _endpoint.sync(
{{ kwargs(endpoint, include_client = False) }}
client = self._get_client(),
)
{% endif %}
{% endfor %}

@attr.s(auto_attribs=True)
class AsyncEndpoints:
_client: Union[Client, Callable[[], Client], Callable[[], Awaitable[Client]]]

async def _get_client(self) -> Client:
client = self._client
if callable(client):
client = client() # type: ignore
if isawaitable(client):
client = await client
assert isinstance(client, Client)
return client

{% for endpoint in endpoint_collection.endpoints %}
{% set f_name = python_identifier(endpoint.name) %}
{% set return_string = endpoint.response_type() %}
{% set parsed_responses = (endpoint.responses | length > 0) and return_string != "Any" %}
async def {{f_name}}_detailed(
self,
{{ arguments(endpoint, include_client = False) | indent(4) }}
) -> Response[{{ return_string }}]:
{{ docstring(endpoint, return_string) | indent(4) }}

from . import {{ f_name }} as _endpoint
return await _endpoint.asyncio_detailed(
{{ kwargs(endpoint, include_client = False) }}
client = await self._get_client(),
)

{% if parsed_responses %}
async def {{f_name}}(
self,
{{ arguments(endpoint, include_client=False) | indent(4) }}
) -> Optional[{{ return_string }}]:
{{ docstring(endpoint, return_string) | indent(4) }}

from . import {{ f_name }} as _endpoint
return await _endpoint.asyncio(
{{ kwargs(endpoint, include_client = False) }}
client = await self._get_client(),
)
{% endif %}
{% endfor %}
10 changes: 8 additions & 2 deletions openapi_python_client/templates/endpoint_macros.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,22 @@ params = {k: v for k, v in params.items() if v is not UNSET and v is not None}
{% endmacro %}

{# The all the kwargs passed into an endpoint (and variants thereof)) #}
{% macro arguments(endpoint) %}
{% macro arguments(endpoint, include_client=True) %}
{# path parameters #}
{% for parameter in endpoint.path_parameters.values() %}
{{ parameter.to_string() }},
{% endfor %}
{% if include_client or endpoint.form_body or endpoint.multipart_body or endpoint.json_body or endpoint.query_parameters or endpoint.header_parameters or endpoint.cookie_parameters %}
*,
{% endif %}
{% if include_client%}
{# Proper client based on whether or not the endpoint requires authentication #}
{% if endpoint.requires_security %}
client: AuthenticatedClient,
{% else %}
client: Client,
{% endif %}
{% endif %}
{# Form data if any #}
{% if endpoint.form_body %}
form_data: {{ endpoint.form_body.get_type_string() }},
Expand All @@ -115,11 +119,13 @@ json_body: {{ endpoint.json_body.get_type_string() }},
{% endmacro %}

{# Just lists all kwargs to endpoints as name=name for passing to other functions #}
{% macro kwargs(endpoint) %}
{% macro kwargs(endpoint, include_client=True) %}
{% for parameter in endpoint.path_parameters.values() %}
{{ parameter.python_name }}={{ parameter.python_name }},
{% endfor %}
{% if include_client %}
client=client,
{% endif %}
{% if endpoint.form_body %}
form_data=form_data,
{% endif %}
Expand Down
7 changes: 5 additions & 2 deletions openapi_python_client/templates/package_init.py.jinja
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
""" {{ package_description }} """
from .client import AuthenticatedClient, Client
from .client import Client, AuthenticatedClient
from .api import SyncApi, AsyncApi

__all__ = (
"AuthenticatedClient",
"Client",
"AuthenticatedClient",
"SyncApi",
"AsyncApi",
)