Skip to content

refactor ModelBuilder.fit #196

Open
Open
@michaelraczycki

Description

@michaelraczycki

Fit function should be modified to allow for fit method specification (mcmc or MAP). To do this, and not break current API the following should be done:
Rename current fit to _fit_mcmc.
Create fit function, that takes X, y, fit_method: str and kwargs, and based on the fit_method calls either _fit_mcmc or _fit_MAP
Create _fit_MAP. Ensure that the idata preserving steps from current fit are included there, as it's crucial for the save/load functionality.

this can be used as a starting point for _fit_MAP:
def _fit_MAP(self, **kwargs):
"""Find model maximum a posteriori using scipy optimizer"""
model = self.model
map_res = pm.find_MAP(model=model, **kwargs)
# Filter non-value variables
value_vars_names = set(v.name for v in model.value_vars)
map_res = {k: v for k, v in map_res.items() if k in value_vars_names}
# Convert map result to InferenceData
map_strace = NDArray(model=model)
map_strace.setup(draws=1, chain=0)
map_strace.record(map_res)
map_strace.close()
trace = MultiTrace([map_strace])
return pm.to_inference_data(trace, model=model)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions