Skip to content

Commit cebe643

Browse files
authored
Feat/detect discriminator bnch 49688 (#143)
* add discriminator * initialize discriminator mappings * getter for discriminator python_name * remove unused method * fix discriminator.mapping null case * more inline discriminator properties * add discriminator_property.pyi * move discriminator property over to aurelia * handle unknowntype for discriminator properties * walk oneOf as well as anyOf references * add parsing of toplevel UnionProperty * support toplevel unionProperty * add nested parameter to construct macro
1 parent 46e4d85 commit cebe643

14 files changed

+75
-19
lines changed

openapi_python_client/.flake8

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[flake8]
2+
per-file-ignores =
3+
parser/properties/__init__.py: E402

openapi_python_client/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from .parser import GeneratorData, import_string_from_reference
1717
from .parser.errors import GeneratorError
18+
from .parser.properties import UnionProperty
1819
from .utils import snake_case
1920

2021
if sys.version_info.minor < 8: # version did not exist before 3.8, need to use a backport
@@ -46,7 +47,6 @@ class Project:
4647

4748
def __init__(self, *, openapi: GeneratorData, custom_template_path: Optional[Path] = None) -> None:
4849
self.openapi: GeneratorData = openapi
49-
5050
package_loader = PackageLoader(__package__)
5151
loader: BaseLoader
5252
if custom_template_path is not None:
@@ -174,10 +174,18 @@ def _build_models(self) -> None:
174174
imports = []
175175

176176
model_template = self.env.get_template("model.pyi")
177+
union_property_template = self.env.get_template("polymorphic_model.pyi")
178+
177179
for model in self.openapi.models.values():
178-
module_path = models_dir / f"{model.reference.module_name}.py"
179-
module_path.write_text(model_template.render(model=model))
180-
imports.append(import_string_from_reference(model.reference))
180+
if isinstance(model, UnionProperty):
181+
template = union_property_template
182+
else:
183+
template = model_template
184+
185+
module_path = models_dir / f"{model.module_name}.py"
186+
module_path.write_text(template.render(model=model))
187+
if not isinstance(model, UnionProperty):
188+
imports.append(import_string_from_reference(model.reference))
181189

182190
# Generate enums
183191
str_enum_template = self.env.get_template("str_enum.pyi")

openapi_python_client/parser/properties/__init__.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
_property = property # isort: skip
2+
13
from itertools import chain
24
from typing import Any, ClassVar, Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
35

@@ -160,10 +162,11 @@ class UnionProperty(Property):
160162
""" A property representing a Union (anyOf) of other properties """
161163

162164
inner_properties: List[Property]
165+
relative_imports: Set[str] = set()
163166
template: ClassVar[str] = "union_property.pyi"
164167
has_properties_without_templates: bool = attr.ib(init=False)
165-
discriminator_property: Optional[str]
166-
discriminator_mappings: Dict[str, Property]
168+
discriminator_property: Optional[str] = None
169+
discriminator_mappings: Dict[str, Property] = {}
167170

168171
def __attrs_post_init__(self) -> None:
169172
super().__attrs_post_init__()
@@ -181,6 +184,14 @@ def _get_inner_type_strings(self, json: bool = False) -> List[str]:
181184
def get_base_type_string(self, json: bool = False) -> str:
182185
return f"Union[{', '.join(self._get_inner_type_strings(json=json))}]"
183186

187+
def resolve_references(self, components, schemas):
188+
self.relative_imports.update(self.get_imports(prefix=".."))
189+
return schemas
190+
191+
@_property
192+
def module_name(self):
193+
return self.python_name
194+
184195
def get_type_strings_in_union(
185196
self, no_optional: bool = False, query_parameter: bool = False, json: bool = False
186197
) -> List[str]:
@@ -284,6 +295,14 @@ def build_model_property(
284295
Used to infer the type name if a `title` property is not available.
285296
schemas: Existing Schemas which have already been processed (to check name conflicts)
286297
"""
298+
if data.anyOf or data.oneOf:
299+
prop, schemas = build_union_property(
300+
data=data, name=name, required=required, schemas=schemas, parent_name=parent_name
301+
)
302+
if not isinstance(prop, PropertyError):
303+
schemas = attr.evolve(schemas, models={**schemas.models, prop.name: prop})
304+
return prop, schemas
305+
287306
required_set = set(data.required or [])
288307
required_properties: List[Property] = []
289308
optional_properties: List[Property] = []
@@ -317,6 +336,11 @@ def build_model_property(
317336
optional_properties.append(prop)
318337
relative_imports.update(prop.get_imports(prefix=".."))
319338

339+
discriminator_mappings: Dict[str, Property] = {}
340+
if data.discriminator is not None:
341+
for k, v in (data.discriminator.mapping or {}).items():
342+
discriminator_mappings[k] = Reference.from_ref(v)
343+
320344
additional_properties: Union[bool, Property, PropertyError]
321345
if data.additionalProperties is None:
322346
additional_properties = True
@@ -347,6 +371,8 @@ def build_model_property(
347371
description=data.description or "",
348372
default=None,
349373
nullable=data.nullable,
374+
discriminator_property=data.discriminator.propertyName if data.discriminator else None,
375+
discriminator_mappings=discriminator_mappings,
350376
required=required,
351377
name=name,
352378
additional_properties=additional_properties,
@@ -446,7 +472,7 @@ def build_union_property(
446472

447473
discriminator_mappings: Dict[str, Property] = {}
448474
if data.discriminator is not None:
449-
for k, v in (data.discriminator.mapping if data.discriminator else {}).items():
475+
for k, v in (data.discriminator.mapping or {}).items():
450476
ref_class_name = Reference.from_ref(v).class_name
451477
discriminator_mappings[k] = reference_name_to_subprop[ref_class_name]
452478

openapi_python_client/parser/properties/model_property.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Iterable
4-
from typing import TYPE_CHECKING, ClassVar, Dict, List, Set, Union
4+
from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Set, Union
55

66
import attr
77

@@ -22,6 +22,9 @@ class ModelProperty(Property):
2222
references: List[oai.Reference]
2323
required_properties: List[Property]
2424
optional_properties: List[Property]
25+
discriminator_property: Optional[str]
26+
discriminator_mappings: Dict[str, Property]
27+
2528
description: str
2629
relative_imports: Set[str]
2730
additional_properties: Union[bool, Property]
@@ -30,6 +33,10 @@ class ModelProperty(Property):
3033
template: ClassVar[str] = "model_property.pyi"
3134
json_is_dict: ClassVar[bool] = True
3235

36+
@property
37+
def module_name(self):
38+
return self.reference.module_name
39+
3340
def resolve_references(
3441
self, components: Dict[str, Union[oai.Reference, oai.Schema]], schemas: Schemas
3542
) -> Union[Schemas, PropertyError]:
@@ -44,7 +51,7 @@ def resolve_references(
4451
assert isinstance(referenced_prop, oai.Schema)
4552
for p, val in (referenced_prop.properties or {}).items():
4653
props[p] = (val, source_name)
47-
for sub_prop in referenced_prop.allOf or []:
54+
for sub_prop in referenced_prop.allOf or referenced_prop.anyOf or referenced_prop.oneOf or []:
4855
if isinstance(sub_prop, oai.Reference):
4956
self.references.append(sub_prop)
5057
else:
@@ -71,9 +78,17 @@ def resolve_references(
7178
self.optional_properties.append(prop)
7279
self.relative_imports.update(prop.get_imports(prefix=".."))
7380

81+
for _, value in self.discriminator_mappings.items():
82+
self.relative_imports.add(f"from ..models.{value.module_name} import {value.class_name}")
83+
7484
return schemas
7585

7686
def get_base_type_string(self) -> str:
87+
if getattr(self, "discriminator_mappings", None):
88+
discriminator_types = ", ".join(
89+
[ref.class_name for ref in self.discriminator_mappings.values()] + ["UnknownType"]
90+
)
91+
return f"Union[{discriminator_types}]"
7792
return self.reference.class_name
7893

7994
def get_imports(self, *, prefix: str) -> Set[str]:

openapi_python_client/templates/property_templates/date_property.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% macro construct(property, source, initial_value="None") %}
1+
{% macro construct(property, source, initial_value="None", nested=False) %}
22
{% if property.required and not property.nullable %}
33
{{ property.python_name }} = isoparse({{ source }}).date()
44
{% else %}

openapi_python_client/templates/property_templates/datetime_property.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% macro construct(property, source, initial_value="None") %}
1+
{% macro construct(property, source, initial_value="None", nested=False) %}
22
{% if property.required %}
33
{% if property.nullable %}
44
{{ property.python_name }} = {{ source }}

openapi_python_client/templates/property_templates/dict_property.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% macro construct(property, source, initial_value="None") %}
1+
{% macro construct(property, source, initial_value="None", nested=False) %}
22
{% if property.required %}
33
{{ property.python_name }} = {{ source }}
44
{% else %}

openapi_python_client/templates/property_templates/enum_property.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% macro construct(property, source, initial_value="None") %}
1+
{% macro construct(property, source, initial_value="None", nested=False) %}
22
{% if property.required %}
33
{{ property.python_name }} = {{ property.reference.class_name }}({{ source }})
44
{% else %}

openapi_python_client/templates/property_templates/file_property.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% macro construct(property, source, initial_value=None) %}
1+
{% macro construct(property, source, initial_value=None, nested=False) %}
22
{{ property.python_name }} = File(
33
payload = BytesIO({{ source }})
44
)

openapi_python_client/templates/property_templates/list_property.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% macro construct(property, source, initial_value="[]") %}
1+
{% macro construct(property, source, initial_value="[]", nested=False) %}
22
{% set inner_property = property.inner_property %}
33
{% if inner_property.template %}
44
{% set inner_source = inner_property.python_name + "_data" %}

openapi_python_client/templates/property_templates/model_property.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
{% macro construct(property, source, initial_value=None) %}
1+
{# This file is shadowed by the template with the same name
2+
# in aurelia/packages/api_client_generation/templates #}
3+
{% macro construct(property, source, initial_value=None, nested=False) %}
24
{% if property.required and not property.nullable %}
35
{% if source == "response.yaml" %}
46
yaml_dict = yaml.safe_load(response.text.encode("utf-8"))

openapi_python_client/templates/property_templates/none_property.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% macro construct(property, source, initial_value="None") %}
1+
{% macro construct(property, source, initial_value="None", nested=False) %}
22
{{ property.python_name }} = {{ initial_value }}
33
{% endmacro %}
44

openapi_python_client/templates/property_templates/union_property.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
{% macro construct(property, source, initial_value=None) %}
1+
{# This file is shadowed by the template with the same name
2+
# in aurelia/packages/api_client_generation/templates #}
3+
{% macro construct(property, source, initial_value=None, nested=False) %}
24
def _parse_{{ property.python_name }}(data: {{ property.get_type_string(json=True) }}) -> {{ property.get_type_string() }}:
35
{{ property.python_name }}: {{ property.get_type_string() }}
46
{% if "None" in property.get_type_strings_in_union(json=True) %}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "openapi-python-client"
33
# Our versions have diverged and have no relation to upstream code changes
44
# Henceforth, openapi-python-package will be maintained internally
5-
version = "1.0.3"
5+
version = "1.0.4"
66
description = "Generate modern Python clients from OpenAPI"
77
repository = "https://github.com/triaxtec/openapi-python-client"
88
license = "MIT"

0 commit comments

Comments
 (0)