Skip to content

Commit 17fcf0a

Browse files
committed
- Fixed dataclass from pydantic not working (#100 comment).
- Added support for pydantic models and attr defines similar to dataclasses. - Support for python 3.6 will be removed in v5.0.0. New features added in future v4 releases are not guaranteed to work with python 3.6.
1 parent 8a54a31 commit 17fcf0a

File tree

8 files changed

+477
-285
lines changed

8 files changed

+477
-285
lines changed

CHANGELOG.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,18 @@ paths are considered internals and can change in minor and patch releases.
1515
v4.21.0 (2023-04-??)
1616
--------------------
1717

18+
Added
19+
^^^^^
20+
- Support for pydantic models and attr defines similar to dataclasses.
21+
1822
Fixed
1923
^^^^^
2024
- `str` parameter in subclass incorrectly parsed as dict with implicit `null`
2125
value (`#262 <https://github.com/omni-us/jsonargparse/issues/262>`__).
2226
- Wrong error indentation for subclass in union (`pytorch-lightning#17254
2327
<https://github.com/Lightning-AI/lightning/issues/17254>`__).
28+
- ``dataclass`` from pydantic not working (`#100 (comment)
29+
<https://github.com/omni-us/jsonargparse/issues/100#issuecomment-1408413796>`__).
2430

2531
Changed
2632
^^^^^^^
@@ -30,6 +36,11 @@ Changed
3036
(`pytorch-lightning#17247
3137
<https://github.com/Lightning-AI/lightning/issues/17247>`__).
3238

39+
Deprecated
40+
^^^^^^^^^^
41+
- Support for python 3.6 will be removed in v5.0.0. New features added in
42+
future v4 releases are not guaranteed to work with python 3.6.
43+
3344

3445
v4.20.1 (2023-03-30)
3546
--------------------

README.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,9 @@ Some notes about this support are:
433433
- ``dataclasses`` are supported as a type but only for pure data classes and not
434434
nested in a type. By pure it is meant that the class only inherits from data
435435
classes. Not a mixture of normal classes and data classes. Data classes as
436-
fields of other data classes is supported.
436+
fields of other data classes is supported. Pydantic's ``dataclass`` decorator
437+
and ``BaseModel`` classes, and attrs' ``define`` decorator are supported
438+
like standard dataclasses. Though, this support is currently experimental.
437439

438440
- To set a value to ``None`` it is required to use ``null`` since this is how
439441
json/yaml defines it. To avoid confusion in the help, ``NoneType`` is

jsonargparse/optionals.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
ruyaml_support = find_spec('ruyaml') is not None
2626
omegaconf_support = find_spec('omegaconf') is not None
2727
reconplogger_support = find_spec('reconplogger') is not None
28+
pydantic_support = find_spec('pydantic') is not None
29+
attrs_support = find_spec('attrs') is not None
2830

2931
_config_read_mode = 'fr'
3032
_docstring_parse_options = {
@@ -133,6 +135,18 @@ def import_reconplogger(importer):
133135
return reconplogger
134136

135137

138+
def import_pydantic(importer):
139+
with missing_package_raise('pydantic', importer):
140+
import pydantic
141+
return pydantic
142+
143+
144+
def import_attrs(importer):
145+
with missing_package_raise('attrs', importer):
146+
import attrs
147+
return attrs
148+
149+
136150
def set_config_read_mode(
137151
urls_enabled: bool = False,
138152
fsspec_enabled: bool = False,
@@ -207,15 +221,16 @@ def parse_docstring(component, params=False, logger=None):
207221

208222

209223
def parse_docs(component, parent, logger):
210-
docs = []
224+
docs = {}
211225
if docstring_parser_support:
212226
doc_sources = [component]
213227
if inspect.isclass(parent) and component.__name__ == '__init__':
214-
doc_sources = [parent] + doc_sources
228+
doc_sources += [parent]
215229
for src in doc_sources:
216230
doc = parse_docstring(src, params=True, logger=logger)
217231
if doc:
218-
docs.append(doc)
232+
for param in doc.params:
233+
docs[param.arg_name] = param.description
219234
return docs
220235

221236

jsonargparse/parameter_resolvers.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def is_staticmethod(attr) -> bool:
8080

8181

8282
def is_method(attr) -> bool:
83-
return inspect.isfunction(attr) and not is_staticmethod(attr)
83+
return (inspect.isfunction(attr) or attr.__class__.__name__ == 'cython_function_or_method') and not is_staticmethod(attr)
8484

8585

8686
def is_property(attr) -> bool:
@@ -265,10 +265,7 @@ def get_signature_parameters_and_indexes(component, parent, logger):
265265
params = params[1:]
266266
args_idx = get_arg_kind_index(params, kinds.VAR_POSITIONAL)
267267
kwargs_idx = get_arg_kind_index(params, kinds.VAR_KEYWORD)
268-
doc_params = {}
269-
for doc in parse_docs(component, parent, logger):
270-
for param in doc.params:
271-
doc_params[param.arg_name] = param.description
268+
doc_params = parse_docs(component, parent, logger)
272269
for num, param in enumerate(params):
273270
params[num] = ParamData(
274271
doc=doc_params.get(param.name),
@@ -767,6 +764,37 @@ def get_parameters_by_assumptions(
767764
return params
768765

769766

767+
def get_parameters_from_pydantic(
768+
function_or_class: Union[Callable, Type],
769+
method_or_property: Optional[str],
770+
logger: logging.Logger,
771+
) -> Optional[ParamList]:
772+
from .optionals import import_pydantic, pydantic_support
773+
if not pydantic_support or method_or_property:
774+
return None
775+
pydantic = import_pydantic('get_parameters_from_pydantic')
776+
if not is_subclass(function_or_class, pydantic.BaseModel):
777+
return None
778+
params = []
779+
doc_params = parse_docs(function_or_class, None, logger)
780+
for field in function_or_class.__fields__.values(): # type: ignore
781+
if field.required:
782+
default = inspect._empty
783+
elif field.default_factory:
784+
default = field.default_factory()
785+
else:
786+
default = field.default
787+
params.append(ParamData(
788+
name=field.name,
789+
annotation=field.annotation,
790+
default=default,
791+
kind=kinds.KEYWORD_ONLY,
792+
doc=field.field_info.description or doc_params.get(field.name),
793+
component=function_or_class,
794+
))
795+
return params
796+
797+
770798
def get_signature_parameters(
771799
function_or_class: Union[Callable, Type],
772800
method_or_property: Optional[str] = None,
@@ -787,6 +815,9 @@ def get_signature_parameters(
787815
"""
788816
logger = parse_logger(logger, 'get_signature_parameters')
789817
try:
818+
params = get_parameters_from_pydantic(function_or_class, method_or_property, logger)
819+
if params is not None:
820+
return params
790821
visitor = ParametersVisitor(function_or_class, method_or_property, logger=logger)
791822
return visitor.get_parameters()
792823
except Exception as ex:

jsonargparse/signatures.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union
99

1010
from .actions import _ActionConfigLoad
11-
from .optionals import get_doc_short_description
11+
from .optionals import (
12+
attrs_support,
13+
get_doc_short_description,
14+
import_attrs,
15+
import_pydantic,
16+
pydantic_support,
17+
)
1218
from .parameter_resolvers import (
1319
ParamData,
1420
get_parameter_origins,
@@ -411,20 +417,18 @@ def add_dataclass_arguments(
411417
default = theclass(**default)
412418
if not isinstance(default, theclass):
413419
raise ValueError(f'Expected "default" argument to be an instance of "{theclass.__name__}" or its kwargs dict, given {default}')
414-
defaults = dataclasses.asdict(default)
420+
defaults = dataclass_to_dict(default)
415421

416422
added_args: List[str] = []
417-
params = {p.name: p for p in get_signature_parameters(theclass, None, logger=self.logger)}
418-
for field in dataclasses.fields(theclass):
419-
if field.name in params:
420-
self._add_signature_parameter(
421-
group,
422-
nested_key,
423-
params[field.name],
424-
added_args,
425-
fail_untyped=fail_untyped,
426-
default=defaults.get(field.name, inspect_empty),
427-
)
423+
for param in get_signature_parameters(theclass, None, logger=self.logger):
424+
self._add_signature_parameter(
425+
group,
426+
nested_key,
427+
param,
428+
added_args,
429+
fail_untyped=fail_untyped,
430+
default=defaults.get(param.name, inspect_empty),
431+
)
428432

429433
return added_args
430434

@@ -550,7 +554,24 @@ def is_pure_dataclass(value):
550554
if not inspect.isclass(value):
551555
return False
552556
classes = [c for c in inspect.getmro(value) if c != object]
553-
return all(dataclasses.is_dataclass(c) for c in classes)
557+
all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes)
558+
if not all_dataclasses and pydantic_support:
559+
pydantic = import_pydantic('is_pure_dataclass')
560+
classes = [c for c in classes if c != pydantic.utils.Representation]
561+
all_dataclasses = all(is_subclass(c, pydantic.BaseModel) for c in classes)
562+
if not all_dataclasses and attrs_support:
563+
attrs = import_attrs('is_pure_dataclass')
564+
if attrs.has(value):
565+
return True
566+
return all_dataclasses
567+
568+
569+
def dataclass_to_dict(value):
570+
if pydantic_support:
571+
pydantic = import_pydantic('dataclass_to_dict')
572+
if isinstance(value, pydantic.BaseModel):
573+
return value.dict()
574+
return dataclasses.asdict(value)
554575

555576

556577
def compose_dataclasses(*args):

0 commit comments

Comments
 (0)