Description
This is about using jsonargparse
within LightningCLI to generate a nice CLI easily, but the use case is likely more general.
pydantic
is a fantastic way to use structured types in a type-safe way with a lot of other benefits. So I am using it in a project to
structure the parameters of a complex ML model. In particular, the parameter dataset_params
of the LightningDataModule
is of type dataset.DatasetParams
, which happens to be a subclass of pydantic.BaseModel
to get all the niceties. Running the really
simple basic script for a trainer
from pytorch_lightning.utilities.cli import LightningCLI
from models import Model
from dataset import DataModule # This is the LightningDataModule.
cli = LightningCLI(Model, DataModule)
via
# naive way to provide the dict to be maybe parsed using pydantic
python train.py --data.dataset_params="{'filename': features.hdf5}"
yields the error message
train.py: error: Parser key "data.dataset_params": Type <class 'dataset.DatasetParams'> expects an str or a Dict/Namespace with a class_path entry but got "{'filename': 'features.hdf5'}"
I found an easy solution around this that is not documented (well). Although jsonargparse
does not directly support pydantic
, it does
support dataclasses
. Now pydantic
has a nice "compatibility mode for dataclasses
", that allows to have "standard" dataclass instead of a subclass of pydantic.BaseModel
. The nice thing that it is sufficient to apply this trick to dataset.DatasetParams
, not to the other nested types that are subclasses of pydantic.BaseModel
. So the following makes the above simple script work:
from pydantic.dataclasses import dataclass
# ModelParams is a subclass of pydantic.BaseModel (that has fields that are themselves
# subclasses of pydantic.BaseModel).
from models import ModelParams
# Use pydantic.dataclasses.dataclass here instead of deriving from pydantic.BaseModel to make the magic work.
@dataclass
class DatasetParams:
"""Parameters for accessing a dataset."""
filename: str
model_params: ModelParams = ModelParams()
where the call is now as simple as expected:
python train.py --data.dataset_params.filename=features.hdf5
I request that this is somehow mentioned explicitly in the docs since it is really useful.
Thank you for the excellent work that provided exactly what I was looking for.