diff --git a/.gitignore b/.gitignore index e1c52fb4c..bba4f232a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,9 +19,11 @@ dmypy.json # JetBrains .idea/ +test-reports/ + /coverage.xml /.coverage htmlcov/ # Generated end to end test data -my-test-api-client +my-test-api-client \ No newline at end of file diff --git a/README.md b/README.md index 04cbe06c9..bcf7fb049 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,19 @@ get an error. > For more usage details run `openapi-python-client --help` or read [usage](usage.md) + +### Using custom templates + +This feature leverages Jinja2's [ChoiceLoader](https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.ChoiceLoader) and [FileSystemLoader](https://jinja.palletsprojects.com/en/2.11.x/api/#jinja2.FileSystemLoader). This means you do _not_ need to customize every template. Simply copy the template(s) you want to customize from [the default template directory](openapi_python_client/templates) to your own custom template directory (file names _must_ match exactly) and pass the template directory through the `custom_template_path` flag to the `generate` and `update` commands. For instance, + +``` +openapi-python-client update \ + --url https://my.api.com/openapi.json \ + --custom-template-path=relative/path/to/mytemplates +``` + +_Be forewarned, this is a beta-level feature in the sense that the API exposed in the templates is undocumented and unstable._ + ## What You Get 1. A `pyproject.toml` file with some basic metadata intended to be used with [Poetry]. diff --git a/end_to_end_tests/golden-record-custom/.gitignore b/end_to_end_tests/golden-record-custom/.gitignore new file mode 100644 index 000000000..ed29cb977 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/.gitignore @@ -0,0 +1,23 @@ +__pycache__/ +build/ +dist/ +*.egg-info/ +.pytest_cache/ + +# pyenv +.python-version + +# Environments +.env +.venv + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# JetBrains +.idea/ + +/coverage.xml +/.coverage \ No newline at end of file diff --git a/end_to_end_tests/golden-record-custom/README.md b/end_to_end_tests/golden-record-custom/README.md new file mode 100644 index 000000000..fbbe00c2f --- /dev/null +++ b/end_to_end_tests/golden-record-custom/README.md @@ -0,0 +1,61 @@ +# my-test-api-client +A client library for accessing My Test API + +## Usage +First, create a client: + +```python +from my_test_api_client import Client + +client = Client(base_url="https://api.example.com") +``` + +If the endpoints you're going to hit require authentication, use `AuthenticatedClient` instead: + +```python +from my_test_api_client import AuthenticatedClient + +client = AuthenticatedClient(base_url="https://api.example.com", token="SuperSecretToken") +``` + +Now call your endpoint and use your models: + +```python +from my_test_api_client.models import MyDataModel +from my_test_api_client.api.my_tag import get_my_data_model + +my_data: MyDataModel = get_my_data_model(client=client) +``` + +Or do the same thing with an async version: + +```python +from my_test_api_client.models import MyDataModel +from my_test_api_client.async_api.my_tag import get_my_data_model + +my_data: MyDataModel = await get_my_data_model(client=client) +``` + +Things to know: +1. Every path/method combo becomes a Python function with type annotations. +1. All path/query params, and bodies become method arguments. +1. If your endpoint had any tags on it, the first tag will be used as a module name for the function (my_tag above) +1. Any endpoint which did not have a tag will be in `my_test_api_client.api.default` +1. If the API returns a response code that was not declared in the OpenAPI document, a + `my_test_api_client.api.errors.ApiResponseError` wil be raised + with the `response` attribute set to the `httpx.Response` that was received. + + +## Building / publishing this Client +This project uses [Poetry](https://python-poetry.org/) to manage dependencies and packaging. Here are the basics: +1. Update the metadata in pyproject.toml (e.g. authors, version) +1. If you're using a private repository, configure it with Poetry + 1. `poetry config repositories. ` + 1. `poetry config http-basic. ` +1. Publish the client with `poetry publish --build -r ` or, if for public PyPI, just `poetry publish --build` + +If you want to install this client into another project without publishing it (e.g. for development) then: +1. If that project **is using Poetry**, you can simply do `poetry add ` from that project +1. If that project is not using Poetry: + 1. Build a wheel with `poetry build -f wheel` + 1. Install that wheel from the other project `pip install ` \ No newline at end of file diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/__init__.py b/end_to_end_tests/golden-record-custom/my_test_api_client/__init__.py new file mode 100644 index 000000000..c8d0f1760 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/__init__.py @@ -0,0 +1 @@ +""" A client library for accessing My Test API """ diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/__init__.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/__init__.py new file mode 100644 index 000000000..dc035f4ce --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/__init__.py @@ -0,0 +1 @@ +""" Contains methods for accessing the API """ diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/default/__init__.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/default/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/default/ping_ping_get.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/default/ping_ping_get.py new file mode 100644 index 000000000..5e4cc9a2d --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/default/ping_ping_get.py @@ -0,0 +1,33 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + + +def _parse_response(*, response: httpx.Response) -> Optional[bool]: + if response.status_code == 200: + return bool(response.text) + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[bool]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, +) -> httpx.Response[bool]: + + response = client.request( + "get", + "/ping", + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/__init__.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/defaults_tests_defaults_post.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/defaults_tests_defaults_post.py new file mode 100644 index 000000000..84347423e --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/defaults_tests_defaults_post.py @@ -0,0 +1,99 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + +import datetime +from typing import Dict, List, Optional, Union, cast + +from dateutil.parser import isoparse + +from ...models.an_enum import AnEnum +from ...models.http_validation_error import HTTPValidationError + + +def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPValidationError]]: + if response.status_code == 200: + return None + if response.status_code == 422: + return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json())) + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[Union[None, HTTPValidationError]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, + json_body: Dict[Any, Any], + string_prop: Optional[str] = "the default string", + datetime_prop: Optional[datetime.datetime] = isoparse("1010-10-10T00:00:00"), + date_prop: Optional[datetime.date] = isoparse("1010-10-10").date(), + float_prop: Optional[float] = 3.14, + int_prop: Optional[int] = 7, + boolean_prop: Optional[bool] = False, + list_prop: Optional[List[AnEnum]] = None, + union_prop: Optional[Union[Optional[float], Optional[str]]] = "not a float", + enum_prop: Optional[AnEnum] = None, +) -> httpx.Response[Union[None, HTTPValidationError]]: + + json_datetime_prop = datetime_prop.isoformat() if datetime_prop else None + + json_date_prop = date_prop.isoformat() if date_prop else None + + if list_prop is None: + json_list_prop = None + else: + json_list_prop = [] + for list_prop_item_data in list_prop: + list_prop_item = list_prop_item_data.value + + json_list_prop.append(list_prop_item) + + if union_prop is None: + json_union_prop: Optional[Union[Optional[float], Optional[str]]] = None + elif isinstance(union_prop, float): + json_union_prop = union_prop + else: + json_union_prop = union_prop + + json_enum_prop = enum_prop.value if enum_prop else None + + params: Dict[str, Any] = {} + if string_prop is not None: + params["string_prop"] = string_prop + if datetime_prop is not None: + params["datetime_prop"] = json_datetime_prop + if date_prop is not None: + params["date_prop"] = json_date_prop + if float_prop is not None: + params["float_prop"] = float_prop + if int_prop is not None: + params["int_prop"] = int_prop + if boolean_prop is not None: + params["boolean_prop"] = boolean_prop + if list_prop is not None: + params["list_prop"] = json_list_prop + if union_prop is not None: + params["union_prop"] = json_union_prop + if enum_prop is not None: + params["enum_prop"] = json_enum_prop + + json_json_body = json_body + + response = client.request( + "post", + "/tests/defaults", + json=json_json_body, + params=params, + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_booleans.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_booleans.py new file mode 100644 index 000000000..20d7c43ab --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_booleans.py @@ -0,0 +1,33 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + + +def _parse_response(*, response: httpx.Response) -> Optional[List[bool]]: + if response.status_code == 200: + return [bool(item) for item in cast(List[bool], response.json())] + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[List[bool]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, +) -> httpx.Response[List[bool]]: + + response = client.request( + "get", + "/tests/basic_lists/booleans", + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_floats.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_floats.py new file mode 100644 index 000000000..e3fc2e4d4 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_floats.py @@ -0,0 +1,33 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + + +def _parse_response(*, response: httpx.Response) -> Optional[List[float]]: + if response.status_code == 200: + return [float(item) for item in cast(List[float], response.json())] + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[List[float]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, +) -> httpx.Response[List[float]]: + + response = client.request( + "get", + "/tests/basic_lists/floats", + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_integers.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_integers.py new file mode 100644 index 000000000..28ec4963c --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_integers.py @@ -0,0 +1,33 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + + +def _parse_response(*, response: httpx.Response) -> Optional[List[int]]: + if response.status_code == 200: + return [int(item) for item in cast(List[int], response.json())] + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[List[int]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, +) -> httpx.Response[List[int]]: + + response = client.request( + "get", + "/tests/basic_lists/integers", + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_strings.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_strings.py new file mode 100644 index 000000000..1acdf6a40 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_basic_list_of_strings.py @@ -0,0 +1,33 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + + +def _parse_response(*, response: httpx.Response) -> Optional[List[str]]: + if response.status_code == 200: + return [str(item) for item in cast(List[str], response.json())] + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[List[str]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, +) -> httpx.Response[List[str]]: + + response = client.request( + "get", + "/tests/basic_lists/strings", + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_user_list.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_user_list.py new file mode 100644 index 000000000..d3a8591bc --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/get_user_list.py @@ -0,0 +1,62 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + +import datetime +from typing import Dict, List, Union, cast + +from ...models.a_model import AModel +from ...models.an_enum import AnEnum +from ...models.http_validation_error import HTTPValidationError + + +def _parse_response(*, response: httpx.Response) -> Optional[Union[List[AModel], HTTPValidationError]]: + if response.status_code == 200: + return [AModel.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())] + if response.status_code == 422: + return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json())) + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[Union[List[AModel], HTTPValidationError]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, + an_enum_value: List[AnEnum], + some_date: Union[datetime.date, datetime.datetime], +) -> httpx.Response[Union[List[AModel], HTTPValidationError]]: + + json_an_enum_value = [] + for an_enum_value_item_data in an_enum_value: + an_enum_value_item = an_enum_value_item_data.value + + json_an_enum_value.append(an_enum_value_item) + + if isinstance(some_date, datetime.date): + json_some_date = some_date.isoformat() + + else: + json_some_date = some_date.isoformat() + + params: Dict[str, Any] = { + "an_enum_value": json_an_enum_value, + "some_date": json_some_date, + } + + response = client.request( + "get", + "/tests/", + params=params, + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py new file mode 100644 index 000000000..068e6e9c7 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/int_enum_tests_int_enum_post.py @@ -0,0 +1,48 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + +from typing import Dict, cast + +from ...models.an_int_enum import AnIntEnum +from ...models.http_validation_error import HTTPValidationError + + +def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPValidationError]]: + if response.status_code == 200: + return None + if response.status_code == 422: + return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json())) + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[Union[None, HTTPValidationError]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, + int_enum: AnIntEnum, +) -> httpx.Response[Union[None, HTTPValidationError]]: + + json_int_enum = int_enum.value + + params: Dict[str, Any] = { + "int_enum": json_int_enum, + } + + response = client.request( + "post", + "/tests/int_enum", + params=params, + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/json_body_tests_json_body_post.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/json_body_tests_json_body_post.py new file mode 100644 index 000000000..76f4ffe84 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/json_body_tests_json_body_post.py @@ -0,0 +1,44 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + +from typing import Dict, cast + +from ...models.a_model import AModel +from ...models.http_validation_error import HTTPValidationError + + +def _parse_response(*, response: httpx.Response) -> Optional[Union[None, HTTPValidationError]]: + if response.status_code == 200: + return None + if response.status_code == 422: + return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json())) + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[Union[None, HTTPValidationError]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, + json_body: AModel, +) -> httpx.Response[Union[None, HTTPValidationError]]: + + json_json_body = json_body.to_dict() + + response = client.request( + "post", + "/tests/json_body", + json=json_json_body, + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/no_response_tests_no_response_get.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/no_response_tests_no_response_get.py new file mode 100644 index 000000000..4f4d89cbd --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/no_response_tests_no_response_get.py @@ -0,0 +1,25 @@ +import httpx + +Client = httpx.Client + + +def _build_response(*, response: httpx.Response) -> httpx.Response[None]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=None, + ) + + +def httpx_request( + *, + client: Client, +) -> httpx.Response[None]: + + response = client.request( + "get", + "/tests/no_response", + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/octet_stream_tests_octet_stream_get.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/octet_stream_tests_octet_stream_get.py new file mode 100644 index 000000000..8f1b83adb --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/octet_stream_tests_octet_stream_get.py @@ -0,0 +1,33 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + + +def _parse_response(*, response: httpx.Response) -> Optional[bytes]: + if response.status_code == 200: + return bytes(response.content) + return None + + +def _build_response(*, response: httpx.Response) -> httpx.Response[bytes]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request( + *, + client: Client, +) -> httpx.Response[bytes]: + + response = client.request( + "get", + "/tests/octet_stream", + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/unsupported_content_tests_unsupported_content_get.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/unsupported_content_tests_unsupported_content_get.py new file mode 100644 index 000000000..c1019e884 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/unsupported_content_tests_unsupported_content_get.py @@ -0,0 +1,25 @@ +import httpx + +Client = httpx.Client + + +def _build_response(*, response: httpx.Response) -> httpx.Response[None]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=None, + ) + + +def httpx_request( + *, + client: Client, +) -> httpx.Response[None]: + + response = client.request( + "get", + "/tests/unsupported_content", + ) + + return _build_response(response=response) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/upload_file_tests_upload_post.py b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/upload_file_tests_upload_post.py new file mode 100644 index 000000000..1ef04185b --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/api/tests/upload_file_tests_upload_post.py @@ -0,0 +1,57 @@ +from typing import Optional + +import httpx + +Client = httpx.Client + +from typing import Optional + +from ...models.body_upload_file_tests_upload_post import BodyUploadFileTestsUploadPost +from ...models.http_validation_error import HTTPValidationError + + +def _parse_response(*, response: httpx.Response) -> Optional[Union[ + None, + HTTPValidationError +]]: + if response.status_code == 200: + return None + if response.status_code == 422: + return HTTPValidationError.from_dict(cast(Dict[str, Any], response.json())) + return None + + + +def _build_response(*, response: httpx.Response) -> httpx.Response[Union[ + None, + HTTPValidationError +]]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + parsed=_parse_response(response=response), + ) + + +def httpx_request(*, + client: Client, + multipart_data: BodyUploadFileTestsUploadPost, + keep_alive: Optional[bool] = None, +) -> httpx.Response[Union[ + None, + HTTPValidationError +]]: + if keep_alive is not None: + headers["keep-alive"] = keep_alive + + + + + response = client.request( + "post", + "/tests/upload", + "files": multipart_data.to_dict(), + ) + + return _build_response(response=response) \ No newline at end of file diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/client.py b/end_to_end_tests/golden-record-custom/my_test_api_client/client.py new file mode 100644 index 000000000..c3074040c --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/client.py @@ -0,0 +1,46 @@ +from typing import Dict + +import attr + + +@attr.s(auto_attribs=True) +class Client: + """ A class for keeping track of data related to the API """ + + base_url: str + cookies: Dict[str, str] = attr.ib(factory=dict, kw_only=True) + headers: Dict[str, str] = attr.ib(factory=dict, kw_only=True) + timeout: float = attr.ib(5.0, kw_only=True) + + def get_headers(self) -> Dict[str, str]: + """ Get headers to be used in all endpoints """ + return {**self.headers} + + def with_headers(self, headers: Dict[str, str]) -> "Client": + """ Get a new client matching this one with additional headers """ + return attr.evolve(self, headers={**self.headers, **headers}) + + def get_cookies(self) -> Dict[str, str]: + return {**self.cookies} + + def with_cookies(self, cookies: Dict[str, str]) -> "Client": + """ Get a new client matching this one with additional cookies """ + return attr.evolve(self, cookies={**self.cookies, **cookies}) + + def get_timeout(self) -> float: + return self.timeout + + def with_timeout(self, timeout: float) -> "Client": + """ Get a new client matching this one with a new timeout (in seconds) """ + return attr.evolve(self, timeout=timeout) + + +@attr.s(auto_attribs=True) +class AuthenticatedClient(Client): + """ A Client which has been authenticated for use on secured endpoints """ + + token: str + + def get_headers(self) -> Dict[str, str]: + """ Get headers to be used in authenticated endpoints """ + return {"Authorization": f"Bearer {self.token}", **self.headers} diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/models/__init__.py b/end_to_end_tests/golden-record-custom/my_test_api_client/models/__init__.py new file mode 100644 index 000000000..9f38702e3 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/models/__init__.py @@ -0,0 +1 @@ +""" Contains all the data models used in inputs/outputs """ diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/models/a_model.py b/end_to_end_tests/golden-record-custom/my_test_api_client/models/a_model.py new file mode 100644 index 000000000..a1a0ace0c --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/models/a_model.py @@ -0,0 +1,100 @@ +import datetime +from typing import Any, Dict, List, Optional, Union + +import attr +from dateutil.parser import isoparse + +from ..models.an_enum import AnEnum +from ..models.different_enum import DifferentEnum + + +@attr.s(auto_attribs=True) +class AModel: + """ A Model for testing all the ways custom objects can be used """ + + an_enum_value: AnEnum + some_dict: Optional[Dict[Any, Any]] + a_camel_date_time: Union[datetime.datetime, datetime.date] + a_date: datetime.date + nested_list_of_enums: Optional[List[List[DifferentEnum]]] = None + attr_1_leading_digit: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + an_enum_value = self.an_enum_value.value + + some_dict = self.some_dict + + if isinstance(self.a_camel_date_time, datetime.datetime): + a_camel_date_time = self.a_camel_date_time.isoformat() + + else: + a_camel_date_time = self.a_camel_date_time.isoformat() + + a_date = self.a_date.isoformat() + + if self.nested_list_of_enums is None: + nested_list_of_enums = None + else: + nested_list_of_enums = [] + for nested_list_of_enums_item_data in self.nested_list_of_enums: + nested_list_of_enums_item = [] + for nested_list_of_enums_item_item_data in nested_list_of_enums_item_data: + nested_list_of_enums_item_item = nested_list_of_enums_item_item_data.value + + nested_list_of_enums_item.append(nested_list_of_enums_item_item) + + nested_list_of_enums.append(nested_list_of_enums_item) + + attr_1_leading_digit = self.attr_1_leading_digit + + return { + "an_enum_value": an_enum_value, + "some_dict": some_dict, + "aCamelDateTime": a_camel_date_time, + "a_date": a_date, + "nested_list_of_enums": nested_list_of_enums, + "1_leading_digit": attr_1_leading_digit, + } + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "AModel": + an_enum_value = AnEnum(d["an_enum_value"]) + + some_dict = d["some_dict"] + + def _parse_a_camel_date_time(data: Dict[str, Any]) -> Union[datetime.datetime, datetime.date]: + a_camel_date_time: Union[datetime.datetime, datetime.date] + try: + a_camel_date_time = isoparse(d["aCamelDateTime"]) + + return a_camel_date_time + except: # noqa: E722 + pass + a_camel_date_time = isoparse(d["aCamelDateTime"]).date() + + return a_camel_date_time + + a_camel_date_time = _parse_a_camel_date_time(d["aCamelDateTime"]) + + a_date = isoparse(d["a_date"]).date() + + nested_list_of_enums = [] + for nested_list_of_enums_item_data in d.get("nested_list_of_enums") or []: + nested_list_of_enums_item = [] + for nested_list_of_enums_item_item_data in nested_list_of_enums_item_data: + nested_list_of_enums_item_item = DifferentEnum(nested_list_of_enums_item_item_data) + + nested_list_of_enums_item.append(nested_list_of_enums_item_item) + + nested_list_of_enums.append(nested_list_of_enums_item) + + attr_1_leading_digit = d.get("1_leading_digit") + + return AModel( + an_enum_value=an_enum_value, + some_dict=some_dict, + a_camel_date_time=a_camel_date_time, + a_date=a_date, + nested_list_of_enums=nested_list_of_enums, + attr_1_leading_digit=attr_1_leading_digit, + ) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/models/an_enum.py b/end_to_end_tests/golden-record-custom/my_test_api_client/models/an_enum.py new file mode 100644 index 000000000..9616ca82e --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/models/an_enum.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class AnEnum(str, Enum): + FIRST_VALUE = "FIRST_VALUE" + SECOND_VALUE = "SECOND_VALUE" diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/models/an_int_enum.py b/end_to_end_tests/golden-record-custom/my_test_api_client/models/an_int_enum.py new file mode 100644 index 000000000..6048add0f --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/models/an_int_enum.py @@ -0,0 +1,7 @@ +from enum import IntEnum + + +class AnIntEnum(IntEnum): + VALUE_NEGATIVE_1 = -1 + VALUE_1 = 1 + VALUE_2 = 2 diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/models/body_upload_file_tests_upload_post.py b/end_to_end_tests/golden-record-custom/my_test_api_client/models/body_upload_file_tests_upload_post.py new file mode 100644 index 000000000..4fe7f8476 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/models/body_upload_file_tests_upload_post.py @@ -0,0 +1,27 @@ +from typing import Any, Dict + +import attr + +from ..types import File + + +@attr.s(auto_attribs=True) +class BodyUploadFileTestsUploadPost: + """ """ + + some_file: File + + def to_dict(self) -> Dict[str, Any]: + some_file = self.some_file.to_tuple() + + return { + "some_file": some_file, + } + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "BodyUploadFileTestsUploadPost": + some_file = d["some_file"] + + return BodyUploadFileTestsUploadPost( + some_file=some_file, + ) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/models/different_enum.py b/end_to_end_tests/golden-record-custom/my_test_api_client/models/different_enum.py new file mode 100644 index 000000000..00357ab7a --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/models/different_enum.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class DifferentEnum(str, Enum): + DIFFERENT = "DIFFERENT" + OTHER = "OTHER" diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/models/http_validation_error.py b/end_to_end_tests/golden-record-custom/my_test_api_client/models/http_validation_error.py new file mode 100644 index 000000000..90cd71e8c --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/models/http_validation_error.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, List, Optional + +import attr + +from ..models.validation_error import ValidationError + + +@attr.s(auto_attribs=True) +class HTTPValidationError: + """ """ + + detail: Optional[List[ValidationError]] = None + + def to_dict(self) -> Dict[str, Any]: + if self.detail is None: + detail = None + else: + detail = [] + for detail_item_data in self.detail: + detail_item = detail_item_data.to_dict() + + detail.append(detail_item) + + return { + "detail": detail, + } + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "HTTPValidationError": + detail = [] + for detail_item_data in d.get("detail") or []: + detail_item = ValidationError.from_dict(detail_item_data) + + detail.append(detail_item) + + return HTTPValidationError( + detail=detail, + ) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/models/validation_error.py b/end_to_end_tests/golden-record-custom/my_test_api_client/models/validation_error.py new file mode 100644 index 000000000..1e415c476 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/models/validation_error.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, List + +import attr + + +@attr.s(auto_attribs=True) +class ValidationError: + """ """ + + loc: List[str] + msg: str + type: str + + def to_dict(self) -> Dict[str, Any]: + loc = self.loc + + msg = self.msg + type = self.type + + return { + "loc": loc, + "msg": msg, + "type": type, + } + + @staticmethod + def from_dict(d: Dict[str, Any]) -> "ValidationError": + loc = d["loc"] + + msg = d["msg"] + + type = d["type"] + + return ValidationError( + loc=loc, + msg=msg, + type=type, + ) diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/py.typed b/end_to_end_tests/golden-record-custom/my_test_api_client/py.typed new file mode 100644 index 000000000..1aad32711 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 \ No newline at end of file diff --git a/end_to_end_tests/golden-record-custom/my_test_api_client/types.py b/end_to_end_tests/golden-record-custom/my_test_api_client/types.py new file mode 100644 index 000000000..951227435 --- /dev/null +++ b/end_to_end_tests/golden-record-custom/my_test_api_client/types.py @@ -0,0 +1,33 @@ +""" Contains some shared types for properties """ +from typing import BinaryIO, Generic, MutableMapping, Optional, TextIO, Tuple, TypeVar, Union + +import attr + + +@attr.s(auto_attribs=True) +class File: + """ Contains information for file uploads """ + + payload: Union[BinaryIO, TextIO] + file_name: Optional[str] = None + mime_type: Optional[str] = None + + def to_tuple(self) -> Tuple[Optional[str], Union[BinaryIO, TextIO], Optional[str]]: + """ Return a tuple representation that httpx will accept for multipart/form-data """ + return self.file_name, self.payload, self.mime_type + + +T = TypeVar("T") + + +@attr.s(auto_attribs=True) +class Response(Generic[T]): + """ A response from an endpoint """ + + status_code: int + content: bytes + headers: MutableMapping[str, str] + parsed: Optional[T] + + +__all__ = ["File", "Response"] diff --git a/end_to_end_tests/golden-record-custom/pyproject.toml b/end_to_end_tests/golden-record-custom/pyproject.toml new file mode 100644 index 000000000..eeb1a9e4e --- /dev/null +++ b/end_to_end_tests/golden-record-custom/pyproject.toml @@ -0,0 +1,41 @@ +[tool.poetry] +name = "my-test-api-client" +version = "0.1.0" +description = "A client library for accessing My Test API" + +authors = [] + +readme = "README.md" +packages = [ + {include = "my_test_api_client"}, +] +include = ["CHANGELOG.md", "my_test_api_client/py.typed"] + + +[tool.poetry.dependencies] +python = "^3.8" +httpx = "^0.15.0" +attrs = "^20.1.0" +python-dateutil = "^2.8.1" + +[tool.black] +line-length = 120 +target_version = ['py38'] +exclude = ''' +( + /( + | \.git + | \.venv + | \.mypy_cache + )/ +) +''' + +[tool.isort] +line_length = 120 +multi_line_output = 3 +include_trailing_comma = true + +[build-system] +requires = ["poetry>=1.0"] +build-backend = "poetry.masonry.api" \ No newline at end of file diff --git a/end_to_end_tests/regen_golden_record.py b/end_to_end_tests/regen_golden_record.py index 4fe06cb7b..37400defc 100644 --- a/end_to_end_tests/regen_golden_record.py +++ b/end_to_end_tests/regen_golden_record.py @@ -1,5 +1,6 @@ """ Regenerate golden-record """ import shutil +import sys from pathlib import Path from typer.testing import CliRunner @@ -9,13 +10,30 @@ if __name__ == "__main__": runner = CliRunner() openapi_path = Path(__file__).parent / "openapi.json" - gr_path = Path(__file__).parent / "golden-record" - shutil.rmtree(gr_path, ignore_errors=True) + output_path = Path.cwd() / "my-test-api-client" + if sys.argv[1] == "custom": + gr_path = Path(__file__).parent / "golden-record-custom" + else: + gr_path = Path(__file__).parent / "golden-record" + + shutil.rmtree(gr_path, ignore_errors=True) shutil.rmtree(output_path, ignore_errors=True) config_path = Path(__file__).parent / "config.yml" - result = runner.invoke(app, [f"--config={config_path}", "generate", f"--path={openapi_path}"]) + if sys.argv[1] == "custom": + result = runner.invoke( + app, + [ + f"--config={config_path}", + "generate", + f"--path={openapi_path}", + "--custom-template-path=end_to_end_tests/test_custom_templates", + ], + ) + else: + result = runner.invoke(app, [f"--config={config_path}", "generate", f"--path={openapi_path}"]) + if result.stdout: print(result.stdout) if result.exception: diff --git a/end_to_end_tests/test_custom_templates/endpoint_module.pyi b/end_to_end_tests/test_custom_templates/endpoint_module.pyi new file mode 100644 index 000000000..a00c442de --- /dev/null +++ b/end_to_end_tests/test_custom_templates/endpoint_module.pyi @@ -0,0 +1,63 @@ +import httpx +from typing import Optional + + +Client = httpx.Client + +{% for relative in endpoint.relative_imports %} +{{ relative }} +{% endfor %} + +{% from "endpoint_macros.pyi" import header_params, query_params, json_body, return_type, arguments, client, kwargs, parse_response %} + +{% set return_string = return_type(endpoint) %} +{% set parsed_responses = (endpoint.responses | length > 0) and return_string != "None" %} + + +{% if parsed_responses %} +def _parse_response(*, response: httpx.Response) -> Optional[{{ return_string }}]: + {% for response in endpoint.responses %} + if response.status_code == {{ response.status_code }}: + return {{ response.constructor() }} + {% endfor %} + return None +{% endif %} + + + +def _build_response(*, response: httpx.Response) -> httpx.Response[{{ return_string }}]: + return httpx.Response( + status_code=response.status_code, + content=response.content, + headers=response.headers, + {% if parsed_responses %} + parsed=_parse_response(response=response), + {% else %} + parsed=None, + {% endif %} + ) + + +def httpx_request({{ arguments(endpoint) | indent(4) }}) -> httpx.Response[{{ return_string }}]: + {{ header_params(endpoint) | indent(4) }} + {{ query_params(endpoint) | indent(4) }} + {{ json_body(endpoint) | indent(4) }} + + response = client.request( + "{{ endpoint.method }}", + "{{ endpoint.path }}", + {% if endpoint.json_body %} + json={{ "json_" + endpoint.json_body.python_name }}, + {% endif %} + {% if endpoint.query_parameters %} + params=params, + {% endif %} + {% if endpoint.form_body_reference %} + "data": asdict(form_data), + {% endif %} + {% if endpoint.multipart_body_reference %} + "files": multipart_data.to_dict(), + {% endif %} + ) + + return _build_response(response=response) \ No newline at end of file diff --git a/end_to_end_tests/test_end_to_end.py b/end_to_end_tests/test_end_to_end.py index 51e895a7b..ce4942f7f 100644 --- a/end_to_end_tests/test_end_to_end.py +++ b/end_to_end_tests/test_end_to_end.py @@ -47,3 +47,28 @@ def test_end_to_end(): assert status == 0, f"Type checking client failed: {out}" shutil.rmtree(output_path) + + +def test_end_to_end_w_custom_templates(): + runner = CliRunner() + openapi_path = Path(__file__).parent / "openapi.json" + config_path = Path(__file__).parent / "config.yml" + gr_path = Path(__file__).parent / "golden-record-custom" + output_path = Path.cwd() / "my-test-api-client" + shutil.rmtree(output_path, ignore_errors=True) + + result = runner.invoke( + app, + [ + f"--config={config_path}", + "generate", + f"--path={openapi_path}", + "--custom-template-path=end_to_end_tests/test_custom_templates", + ], + ) + + if result.exit_code != 0: + raise result.exception + _compare_directories(gr_path, output_path) + + shutil.rmtree(output_path) diff --git a/openapi_python_client/__init__.py b/openapi_python_client/__init__.py index 2fcca23b3..3a4d48a2b 100644 --- a/openapi_python_client/__init__.py +++ b/openapi_python_client/__init__.py @@ -9,7 +9,7 @@ import httpcore import httpx import yaml -from jinja2 import Environment, PackageLoader +from jinja2 import Environment, PackageLoader, ChoiceLoader, FileSystemLoader from openapi_python_client import utils @@ -30,9 +30,20 @@ class Project: project_name_override: Optional[str] = None package_name_override: Optional[str] = None - def __init__(self, *, openapi: GeneratorData) -> None: + def __init__(self, *, openapi: GeneratorData, custom_template_path: Optional[Path] = None) -> None: self.openapi: GeneratorData = openapi - self.env: Environment = Environment(loader=PackageLoader(__package__), trim_blocks=True, lstrip_blocks=True) + + package_loader = PackageLoader(__package__) + if custom_template_path is not None: + loader = ChoiceLoader( + [ + FileSystemLoader(str(custom_template_path)), + package_loader, + ] + ) + else: + loader = package_loader + self.env: Environment = Environment(loader=loader, trim_blocks=True, lstrip_blocks=True) self.project_name: str = self.project_name_override or f"{utils.kebab_case(openapi.title).lower()}-client" self.project_dir: Path = Path.cwd() / self.project_name @@ -191,37 +202,43 @@ def _build_api(self) -> None: module_path.write_text(endpoint_template.render(endpoint=endpoint)) -def _get_project_for_url_or_path(url: Optional[str], path: Optional[Path]) -> Union[Project, GeneratorError]: +def _get_project_for_url_or_path( + url: Optional[str], path: Optional[Path], custom_template_path: Optional[Path] = None +) -> Union[Project, GeneratorError]: data_dict = _get_document(url=url, path=path) if isinstance(data_dict, GeneratorError): return data_dict openapi = GeneratorData.from_dict(data_dict) if isinstance(openapi, GeneratorError): return openapi - return Project(openapi=openapi) + return Project(openapi=openapi, custom_template_path=custom_template_path) -def create_new_client(*, url: Optional[str], path: Optional[Path]) -> Sequence[GeneratorError]: +def create_new_client( + *, url: Optional[str], path: Optional[Path], custom_template_path: Optional[Path] = None +) -> Sequence[GeneratorError]: """ Generate the client library Returns: A list containing any errors encountered when generating. """ - project = _get_project_for_url_or_path(url=url, path=path) + project = _get_project_for_url_or_path(url=url, path=path, custom_template_path=custom_template_path) if isinstance(project, GeneratorError): return [project] return project.build() -def update_existing_client(*, url: Optional[str], path: Optional[Path]) -> Sequence[GeneratorError]: +def update_existing_client( + *, url: Optional[str], path: Optional[Path], custom_template_path: Optional[Path] = None +) -> Sequence[GeneratorError]: """ Update an existing client library Returns: A list containing any errors encountered when generating. """ - project = _get_project_for_url_or_path(url=url, path=path) + project = _get_project_for_url_or_path(url=url, path=path, custom_template_path=custom_template_path) if isinstance(project, GeneratorError): return [project] return project.update() diff --git a/openapi_python_client/cli.py b/openapi_python_client/cli.py index 9bfbe3366..2643ce85d 100644 --- a/openapi_python_client/cli.py +++ b/openapi_python_client/cli.py @@ -96,10 +96,20 @@ def handle_errors(errors: Sequence[GeneratorError]) -> None: raise typer.Exit(code=1) +custom_template_path_options = { + "help": "A path to a directory containing custom template(s)", + "file_okay": False, + "dir_okay": True, + "readable": True, + "resolve_path": True, +} + + @app.command() def generate( url: Optional[str] = typer.Option(None, help="A URL to read the JSON from"), path: Optional[pathlib.Path] = typer.Option(None, help="A path to the JSON file"), + custom_template_path: Optional[pathlib.Path] = typer.Option(None, **custom_template_path_options), ) -> None: """ Generate a new OpenAPI Client library """ from . import create_new_client @@ -110,7 +120,7 @@ def generate( if url and path: typer.secho("Provide either --url or --path, not both", fg=typer.colors.RED) raise typer.Exit(code=1) - errors = create_new_client(url=url, path=path) + errors = create_new_client(url=url, path=path, custom_template_path=custom_template_path) handle_errors(errors) @@ -118,6 +128,7 @@ def generate( def update( url: Optional[str] = typer.Option(None, help="A URL to read the JSON from"), path: Optional[pathlib.Path] = typer.Option(None, help="A path to the JSON file"), + custom_template_path: Optional[pathlib.Path] = typer.Option(None, **custom_template_path_options), ) -> None: """ Update an existing OpenAPI Client library """ from . import update_existing_client @@ -129,5 +140,5 @@ def update( typer.secho("Provide either --url or --path, not both", fg=typer.colors.RED) raise typer.Exit(code=1) - errors = update_existing_client(url=url, path=path) + errors = update_existing_client(url=url, path=path, custom_template_path=custom_template_path) handle_errors(errors) diff --git a/pyproject.toml b/pyproject.toml index 0253d3f3b..ffbfab4ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,8 @@ exclude = ''' | \.mypy_cache | openapi_python_client/templates | tests/test_templates + | end_to_end_tests/test_custom_templates + | end_to_end_tests/golden-record-custom )/ ) ''' diff --git a/tests/test___init__.py b/tests/test___init__.py index d7970e054..3a42b4442 100644 --- a/tests/test___init__.py +++ b/tests/test___init__.py @@ -23,7 +23,7 @@ def test__get_project_for_url_or_path(mocker): _get_document.assert_called_once_with(url=url, path=path) from_dict.assert_called_once_with(data_dict) - _Project.assert_called_once_with(openapi=openapi) + _Project.assert_called_once_with(openapi=openapi, custom_template_path=None) assert project == _Project() @@ -75,7 +75,7 @@ def test_create_new_client(mocker): result = create_new_client(url=url, path=path) - _get_project_for_url_or_path.assert_called_once_with(url=url, path=path) + _get_project_for_url_or_path.assert_called_once_with(url=url, path=path, custom_template_path=None) project.build.assert_called_once() assert result == project.build.return_value @@ -92,7 +92,7 @@ def test_create_new_client_project_error(mocker): result = create_new_client(url=url, path=path) - _get_project_for_url_or_path.assert_called_once_with(url=url, path=path) + _get_project_for_url_or_path.assert_called_once_with(url=url, path=path, custom_template_path=None) assert result == [error] @@ -108,7 +108,7 @@ def test_update_existing_client(mocker): result = update_existing_client(url=url, path=path) - _get_project_for_url_or_path.assert_called_once_with(url=url, path=path) + _get_project_for_url_or_path.assert_called_once_with(url=url, path=path, custom_template_path=None) project.update.assert_called_once() assert result == project.update.return_value @@ -125,7 +125,7 @@ def test_update_existing_client_project_error(mocker): result = update_existing_client(url=url, path=path) - _get_project_for_url_or_path.assert_called_once_with(url=url, path=path) + _get_project_for_url_or_path.assert_called_once_with(url=url, path=path, custom_template_path=None) assert result == [error] @@ -416,3 +416,27 @@ def test__get_errors(mocker): project = Project(openapi=openapi) assert project._get_errors() == [1, 2, 3] + + +def test__custom_templates(mocker): + from openapi_python_client import GeneratorData, Project + from openapi_python_client.parser.openapi import EndpointCollection, Schemas + + openapi = mocker.MagicMock( + autospec=GeneratorData, + title="My Test API", + endpoint_collections_by_tag={ + "default": mocker.MagicMock(autospec=EndpointCollection, parse_errors=[1]), + "other": mocker.MagicMock(autospec=EndpointCollection, parse_errors=[2]), + }, + schemas=mocker.MagicMock(autospec=Schemas, errors=[3]), + ) + + project = Project(openapi=openapi) + assert isinstance(project.env.loader, jinja2.PackageLoader) + + project = Project(openapi=openapi, custom_template_path="../end_to_end_tests/test_custom_templates") + assert isinstance(project.env.loader, jinja2.ChoiceLoader) + assert len(project.env.loader.loaders) == 2 + assert isinstance(project.env.loader.loaders[0], jinja2.FileSystemLoader) + assert isinstance(project.env.loader.loaders[1], jinja2.PackageLoader) diff --git a/tests/test_cli.py b/tests/test_cli.py index f4fd9b9cb..1995b7228 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -36,7 +36,7 @@ def test_config_arg(mocker, _create_new_client): assert result.exit_code == 0 load_config.assert_called_once_with(path=Path(config_path)) - _create_new_client.assert_called_once_with(url=None, path=Path(path)) + _create_new_client.assert_called_once_with(url=None, path=Path(path), custom_template_path=None) def test_bad_config(mocker, _create_new_client): @@ -80,7 +80,7 @@ def test_generate_url(self, _create_new_client): result = runner.invoke(app, ["generate", f"--url={url}"]) assert result.exit_code == 0 - _create_new_client.assert_called_once_with(url=url, path=None) + _create_new_client.assert_called_once_with(url=url, path=None, custom_template_path=None) def test_generate_path(self, _create_new_client): path = "cool/path" @@ -89,7 +89,7 @@ def test_generate_path(self, _create_new_client): result = runner.invoke(app, ["generate", f"--path={path}"]) assert result.exit_code == 0 - _create_new_client.assert_called_once_with(url=None, path=Path(path)) + _create_new_client.assert_called_once_with(url=None, path=Path(path), custom_template_path=None) def test_generate_handle_errors(self, _create_new_client): _create_new_client.return_value = [GeneratorError(detail="this is a message")] @@ -159,7 +159,7 @@ def test_update_url(self, _update_existing_client): result = runner.invoke(app, ["update", f"--url={url}"]) assert result.exit_code == 0 - _update_existing_client.assert_called_once_with(url=url, path=None) + _update_existing_client.assert_called_once_with(url=url, path=None, custom_template_path=None) def test_update_path(self, _update_existing_client): path = "cool/path" @@ -168,4 +168,4 @@ def test_update_path(self, _update_existing_client): result = runner.invoke(app, ["update", f"--path={path}"]) assert result.exit_code == 0 - _update_existing_client.assert_called_once_with(url=None, path=Path(path)) + _update_existing_client.assert_called_once_with(url=None, path=Path(path), custom_template_path=None)