diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index c35507129..7763a0014 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -24,6 +24,8 @@ import pandas as pd import pymc as pm import xarray as xr +from pymc.backends import NDArray +from pymc.backends.base import MultiTrace from pymc.util import RandomState # If scikit-learn is available, use its data validator @@ -425,6 +427,7 @@ def fit( self, X: pd.DataFrame, y: Optional[pd.Series] = None, + fit_method="mcmc", progressbar: bool = True, predictor_names: List[str] = None, random_seed: RandomState = None, @@ -441,6 +444,8 @@ def fit( The training input samples. y : array-like if sklearn is available, otherwise array, shape (n_obs,) The target values (real numbers). + fit_method : str + Which method to use to infer model parameters. One of ["mcmc", "MAP"]. progressbar : bool Specifies whether the fit progressbar should be displayed predictor_names: List[str] = None, @@ -449,19 +454,14 @@ def fit( random_seed : RandomState Provides sampler with initial random seed for obtaining reproducible samples **kwargs : Any - Custom sampler settings can be provided in form of keyword arguments. - - Returns - ------- - self : az.InferenceData - returns inference data of the fitted model. - Examples - -------- - >>> model = MyModel() - >>> idata = model.fit(data) - Auto-assigning NUTS sampler... - Initializing NUTS using jitter+adapt_diag... + Parameters to pass to the inference method. See `_fit_mcmc` or `_fit_MAP` for + method-specific parameters. """ + available_methods = ["mcmc", "MAP"] + if fit_method not in available_methods: + raise ValueError( + f"Inference method {fit_method} not found. Choose one of {available_methods}." + ) if predictor_names is None: predictor_names = [] if y is None: @@ -474,7 +474,11 @@ def fit( sampler_config["progressbar"] = progressbar sampler_config["random_seed"] = random_seed sampler_config.update(**kwargs) - self.idata = self.sample_model(**sampler_config) + + if fit_method == "mcmc": + self.idata = self.sample_model(**sampler_config) + elif fit_method == "MAP": + self.idata = self._fit_MAP(**sampler_config) X_df = pd.DataFrame(X, columns=X.columns) combined_data = pd.concat([X_df, y], axis=1) @@ -482,6 +486,62 @@ def fit( self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore return self.idata # type: ignore + def _fit_MAP( + self, + **kwargs, + ): + """Find model maximum a posteriori using scipy optimizer""" + + model = self.model + find_MAP_args = {**self.sampler_config, **kwargs} + if "random_seed" in find_MAP_args: + # find_MAP takes a different argument name for seed than sample_* do. + find_MAP_args["seed"] = find_MAP_args["random_seed"] + # Extra unknown arguments cause problems for SciPy minimize + allowed_args = [ # find_MAP args + "start", + "vars", + "method", + # "return_raw", # probably causes a problem if set spuriously + # "include_transformed", # probably causes a problem if set spuriously + "progressbar", + "maxeval", + "seed", + ] + allowed_args += [ # scipy.optimize.minimize args + # "fun", # used by find_MAP + # "x0", # used by find_MAP + "args", + "method", + # "jac", # used by find_MAP + # "hess", # probably causes a problem if set spuriously + # "hessp", # probably causes a problem if set spuriously + "bounds", + "constraints", + "tol", + "callback", + "options", + ] + for arg in list(find_MAP_args): + if arg not in allowed_args: + del find_MAP_args[arg] + + map_res = pm.find_MAP(model=model, **find_MAP_args) + # Filter non-value variables + value_vars_names = {v.name for v in model.value_vars} + map_res = {k: v for k, v in map_res.items() if k in value_vars_names} + + # Convert map result to InferenceData + map_strace = NDArray(model=model) + map_strace.setup(draws=1, chain=0) + map_strace.record(map_res) + map_strace.close() + trace = MultiTrace([map_strace]) + idata = pm.to_inference_data(trace, model=model) + self.set_idata_attrs(idata) + + return idata + def predict( self, X_pred: Union[np.ndarray, pd.DataFrame, pd.Series], diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index dd4a88abb..18bfce7c2 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -40,8 +40,8 @@ def toy_y(toy_X): return y -@pytest.fixture(scope="module") -def fitted_model_instance(toy_X, toy_y): +@pytest.fixture(scope="module", params=["mcmc", "MAP"]) +def fitted_model_instance(request, toy_X, toy_y): sampler_config = { "draws": 500, "tune": 300, @@ -54,12 +54,11 @@ def fitted_model_instance(toy_X, toy_y): "obs_error": 2, } model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config) - model.fit(toy_X) + model.fit(toy_X, toy_y, fit_method=request.param) return model class test_ModelBuilder(ModelBuilder): - _model_type = "LinearModel" version = "0.1" @@ -151,9 +150,10 @@ def test_fit(fitted_model_instance): post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape -def test_fit_no_y(toy_X): +@pytest.mark.parametrize("fit_method", ["mcmc", "MAP"]) +def test_fit_no_y(toy_X, fit_method): model_builder = test_ModelBuilder() - model_builder.idata = model_builder.fit(X=toy_X) + model_builder.idata = model_builder.fit(X=toy_X, fit_method=fit_method) assert model_builder.model is not None assert model_builder.idata is not None assert "posterior" in model_builder.idata.groups() @@ -163,17 +163,16 @@ def test_fit_no_y(toy_X): sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." ) def test_save_load(fitted_model_instance): - temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) - fitted_model_instance.save(temp.name) - test_builder2 = test_ModelBuilder.load(temp.name) - assert fitted_model_instance.idata.groups() == test_builder2.idata.groups() + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) as temp: + fitted_model_instance.save(temp.name) + test_builder2 = test_ModelBuilder.load(temp.name) + assert sorted(fitted_model_instance.idata.groups()) == sorted(test_builder2.idata.groups()) x_pred = np.random.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) pred1 = fitted_model_instance.predict(prediction_data["input"]) pred2 = test_builder2.predict(prediction_data["input"]) assert pred1.shape == pred2.shape - temp.close() def test_predict(fitted_model_instance): @@ -193,8 +192,8 @@ def test_sample_posterior_predictive(fitted_model_instance, combined): pred = fitted_model_instance.sample_posterior_predictive( prediction_data["input"], combined=combined, extend_idata=True ) - chains = fitted_model_instance.idata.sample_stats.dims["chain"] - draws = fitted_model_instance.idata.sample_stats.dims["draw"] + chains = fitted_model_instance.idata.posterior.dims["chain"] + draws = fitted_model_instance.idata.posterior.dims["draw"] expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred) assert pred[fitted_model_instance.output_var].shape == expected_shape assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)