diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 543b6f866..b31be19aa 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,7 +16,13 @@ ## Writing Code 1. Write some code and make sure it's covered by unit tests. All unit tests are in the `tests` directory and the file structure should mirror the structure of the source code in the `openapi_python_client` directory. + +### Run Checks and Tests + 2. When in a Poetry shell (`poetry shell`) run `task check` in order to run most of the same checks CI runs. This will auto-reformat the code, check type annotations, run unit tests, check code coverage, and lint the code. + +### Rework end to end tests + 3. If you're writing a new feature, try to add it to the end to end test. 1. If adding support for a new OpenAPI feature, add it somewhere in `end_to_end_tests/openapi.json` 2. Regenerate the "golden records" with `task regen`. This client is generated from the OpenAPI document used for end to end testing. diff --git a/openapi_python_client/__init__.py b/openapi_python_client/__init__.py index b34841cdc..85e83dbfd 100644 --- a/openapi_python_client/__init__.py +++ b/openapi_python_client/__init__.py @@ -1,5 +1,7 @@ """ Generate modern Python clients from OpenAPI """ +import json +import mimetypes import shutil import subprocess import sys @@ -361,21 +363,40 @@ def update_existing_client( return project.update() +def _load_yaml_or_json(data: bytes, content_type: Optional[str]) -> Union[Dict[str, Any], GeneratorError]: + if content_type == "application/json": + try: + return json.loads(data.decode()) + except ValueError as err: + return GeneratorError(header="Invalid JSON from provided source: {}".format(str(err))) + else: + try: + return yaml.safe_load(data) + except yaml.YAMLError as err: + return GeneratorError(header="Invalid YAML from provided source: {}".format(str(err))) + + def _get_document(*, url: Optional[str], path: Optional[Path]) -> Union[Dict[str, Any], GeneratorError]: yaml_bytes: bytes + content_type: Optional[str] if url is not None and path is not None: return GeneratorError(header="Provide URL or Path, not both.") if url is not None: try: response = httpx.get(url) yaml_bytes = response.content + if "content-type" in response.headers: + content_type = response.headers["content-type"].split(";")[0] + else: + content_type = mimetypes.guess_type(url, strict=True)[0] + except (httpx.HTTPError, httpcore.NetworkError): return GeneratorError(header="Could not get OpenAPI document from provided URL") elif path is not None: yaml_bytes = path.read_bytes() + content_type = mimetypes.guess_type(path.as_uri(), strict=True)[0] + else: return GeneratorError(header="No URL or Path provided") - try: - return yaml.safe_load(yaml_bytes) - except yaml.YAMLError: - return GeneratorError(header="Invalid YAML from provided source") + + return _load_yaml_or_json(yaml_bytes, content_type) diff --git a/openapi_python_client/config.py b/openapi_python_client/config.py index 2597d1aed..a246b3737 100644 --- a/openapi_python_client/config.py +++ b/openapi_python_client/config.py @@ -1,3 +1,5 @@ +import json +import mimetypes from pathlib import Path from typing import Dict, List, Optional @@ -35,6 +37,10 @@ class Config(BaseModel): @staticmethod def load_from_path(path: Path) -> "Config": """Creates a Config from provided JSON or YAML file and sets a bunch of globals from it""" - config_data = yaml.safe_load(path.read_text()) + mime = mimetypes.guess_type(path.as_uri(), strict=True)[0] + if mime == "application/json": + config_data = json.loads(path.read_text()) + else: + config_data = yaml.safe_load(path.read_text()) config = Config(**config_data) return config diff --git a/tests/test___init__.py b/tests/test___init__.py index 34b05b4c5..6d9d32852 100644 --- a/tests/test___init__.py +++ b/tests/test___init__.py @@ -218,7 +218,7 @@ def test__get_document_path_no_url(self, mocker): def test__get_document_bad_yaml(self, mocker): get = mocker.patch("httpx.get") - loads = mocker.patch("yaml.safe_load", side_effect=yaml.YAMLError) + loads = mocker.patch("yaml.safe_load", side_effect=yaml.YAMLError("error line 2")) from openapi_python_client import _get_document @@ -228,7 +228,44 @@ def test__get_document_bad_yaml(self, mocker): get.assert_not_called() path.read_bytes.assert_called_once() loads.assert_called_once_with(path.read_bytes()) - assert result == GeneratorError(header="Invalid YAML from provided source") + assert result == GeneratorError(header="Invalid YAML from provided source: error line 2") + + def test__get_document_json(self, mocker): + class FakeResponse: + content = b'{\n\t"foo": "bar"}' + headers = {"content-type": "application/json; encoding=utf8"} + + get = mocker.patch("httpx.get", return_value=FakeResponse()) + yaml_loads = mocker.patch("yaml.safe_load") + json_result = mocker.MagicMock() + json_loads = mocker.patch("json.loads", return_value=json_result) + + from openapi_python_client import _get_document + + url = mocker.MagicMock() + result = _get_document(url=url, path=None) + + get.assert_called_once() + json_loads.assert_called_once_with(FakeResponse.content.decode()) + assert result == json_result + + def test__get_document_bad_json(self, mocker): + class FakeResponse: + content = b'{"foo"}' + headers = {"content-type": "application/json; encoding=utf8"} + + get = mocker.patch("httpx.get", return_value=FakeResponse()) + json_result = mocker.MagicMock() + + from openapi_python_client import _get_document + + url = mocker.MagicMock() + result = _get_document(url=url, path=None) + + get.assert_called_once() + assert result == GeneratorError( + header="Invalid JSON from provided source: " "Expecting ':' delimiter: line 1 column 7 (char 6)" + ) def make_project(**kwargs): diff --git a/tests/test_config.py b/tests/test_config.py index eb2ba09ee..20c0bfee9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,9 @@ +import json import pathlib +import pytest +import yaml + from openapi_python_client.config import Config @@ -28,3 +32,28 @@ def test_load_from_path(mocker): assert config.project_name_override == "project-name" assert config.package_name_override == "package_name" assert config.package_version_override == "package_version" + + +DATA = {"class_overrides": {"Class1": {"class_name": "ExampleClass", "module_name": "example_module"}}} + + +def json_with_tabs(d): + return json.dumps(d, indent=4).replace(" ", "\t") + + +@pytest.mark.parametrize( + "filename,dump", + [ + ("example.yml", yaml.dump), + ("example.json", json.dumps), + ("example.yaml", yaml.dump), + ("example.json", json_with_tabs), + ], +) +def test_load_filenames(tmp_path, filename, dump): + yml_file = tmp_path.joinpath(filename) + with open(yml_file, "w") as f: + f.write(dump(DATA)) + + config = Config.load_from_path(yml_file) + assert config.class_overrides == DATA["class_overrides"]