Skip to content

Commit 3de6d6f

Browse files
committed
Fixed failure with dataclasses that have field with init=False #252.
1 parent 1fce989 commit 3de6d6f

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ v4.20.1 (2023-03-??)
1717

1818
Fixed
1919
^^^^^
20-
- Allow ``discard_init_args_on_class_path_change`` to handle more nested contexts `#247
21-
<https://github.com/omni-us/jsonargparse/issues/247>`__.
22-
2320
- Dump not working for partial callable with return instance
2421
`pytorch-lightning#15340 (comment)
2522
<https://github.com/Lightning-AI/lightning/issues/15340#issuecomment-1439203008>`__.
23+
- Allow ``discard_init_args_on_class_path_change`` to handle more nested
24+
contexts `#247 <https://github.com/omni-us/jsonargparse/issues/247>`__.
25+
- Failure with dataclasses that have field with ``init=False`` `#252
26+
<https://github.com/omni-us/jsonargparse/issues/252>`__.
2627

2728

2829
v4.20.0 (2023-02-20)

jsonargparse/signatures.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -416,14 +416,15 @@ def add_dataclass_arguments(
416416
added_args: List[str] = []
417417
params = {p.name: p for p in get_signature_parameters(theclass, None, logger=self.logger)}
418418
for field in dataclasses.fields(theclass):
419-
self._add_signature_parameter(
420-
group,
421-
nested_key,
422-
params[field.name],
423-
added_args,
424-
fail_untyped=fail_untyped,
425-
default=defaults.get(field.name, inspect_empty),
426-
)
419+
if field.name in params:
420+
self._add_signature_parameter(
421+
group,
422+
nested_key,
423+
params[field.name],
424+
added_args,
425+
fail_untyped=fail_untyped,
426+
default=defaults.get(field.name, inspect_empty),
427+
)
427428

428429
return added_args
429430

jsonargparse_tests/test_signatures.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,6 +1615,19 @@ class MyDataClass:
16151615
self.assertRaises(ArgumentError, lambda: parser.parse_args([]))
16161616

16171617

1618+
def test_dataclass_field_init_false(self):
1619+
1620+
@dataclasses.dataclass
1621+
class DataInitFalse:
1622+
p1: str = '-'
1623+
p2: str = dataclasses.field(init=False)
1624+
1625+
parser = ArgumentParser(exit_on_error=False)
1626+
added = parser.add_dataclass_arguments(DataInitFalse, 'd')
1627+
self.assertEqual(added, ['d.p1'])
1628+
self.assertEqual(parser.get_defaults(), Namespace(d=Namespace(p1='-')))
1629+
1630+
16181631
def test_dataclass_field_default_factory(self):
16191632

16201633
@dataclasses.dataclass

0 commit comments

Comments
 (0)