Skip to content

Extend docs to explain how to make jsonargparse work with pydantic models #100

Closed
@bzfhille

Description

@bzfhille

This is about using jsonargparse within LightningCLI to generate a nice CLI easily, but the use case is likely more general.

pydantic is a fantastic way to use structured types in a type-safe way with a lot of other benefits. So I am using it in a project to
structure the parameters of a complex ML model. In particular, the parameter dataset_params of the LightningDataModule is of type dataset.DatasetParams, which happens to be a subclass of pydantic.BaseModel to get all the niceties. Running the really
simple basic script for a trainer

from pytorch_lightning.utilities.cli import LightningCLI
from models import Model
from dataset import DataModule  # This is the LightningDataModule.

cli = LightningCLI(Model, DataModule)

via

# naive way to provide the dict to be maybe parsed using pydantic
python train.py --data.dataset_params="{'filename': features.hdf5}"

yields the error message

train.py: error: Parser key "data.dataset_params": Type <class 'dataset.DatasetParams'> expects an str or a Dict/Namespace with a class_path entry but got "{'filename': 'features.hdf5'}"

I found an easy solution around this that is not documented (well). Although jsonargparse does not directly support pydantic, it does
support dataclasses. Now pydantic has a nice "compatibility mode for dataclasses", that allows to have "standard" dataclass instead of a subclass of pydantic.BaseModel. The nice thing that it is sufficient to apply this trick to dataset.DatasetParams, not to the other nested types that are subclasses of pydantic.BaseModel. So the following makes the above simple script work:

from pydantic.dataclasses import dataclass
# ModelParams is a subclass of pydantic.BaseModel (that has fields that are themselves
# subclasses of pydantic.BaseModel).
from models import ModelParams

# Use pydantic.dataclasses.dataclass here instead of deriving from pydantic.BaseModel to make the magic work.
@dataclass
class DatasetParams:
    """Parameters for accessing a dataset."""
	filename: str
    model_params: ModelParams = ModelParams()

where the call is now as simple as expected:

python train.py --data.dataset_params.filename=features.hdf5

I request that this is somehow mentioned explicitly in the docs since it is really useful.

Thank you for the excellent work that provided exactly what I was looking for.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdocumentationImprovements to the documentation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions