diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 7763a001..c3550712 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -24,8 +24,6 @@ import pandas as pd import pymc as pm import xarray as xr -from pymc.backends import NDArray -from pymc.backends.base import MultiTrace from pymc.util import RandomState # If scikit-learn is available, use its data validator @@ -427,7 +425,6 @@ def fit( self, X: pd.DataFrame, y: Optional[pd.Series] = None, - fit_method="mcmc", progressbar: bool = True, predictor_names: List[str] = None, random_seed: RandomState = None, @@ -444,8 +441,6 @@ def fit( The training input samples. y : array-like if sklearn is available, otherwise array, shape (n_obs,) The target values (real numbers). - fit_method : str - Which method to use to infer model parameters. One of ["mcmc", "MAP"]. progressbar : bool Specifies whether the fit progressbar should be displayed predictor_names: List[str] = None, @@ -454,14 +449,19 @@ def fit( random_seed : RandomState Provides sampler with initial random seed for obtaining reproducible samples **kwargs : Any - Parameters to pass to the inference method. See `_fit_mcmc` or `_fit_MAP` for - method-specific parameters. + Custom sampler settings can be provided in form of keyword arguments. + + Returns + ------- + self : az.InferenceData + returns inference data of the fitted model. + Examples + -------- + >>> model = MyModel() + >>> idata = model.fit(data) + Auto-assigning NUTS sampler... + Initializing NUTS using jitter+adapt_diag... """ - available_methods = ["mcmc", "MAP"] - if fit_method not in available_methods: - raise ValueError( - f"Inference method {fit_method} not found. Choose one of {available_methods}." - ) if predictor_names is None: predictor_names = [] if y is None: @@ -474,11 +474,7 @@ def fit( sampler_config["progressbar"] = progressbar sampler_config["random_seed"] = random_seed sampler_config.update(**kwargs) - - if fit_method == "mcmc": - self.idata = self.sample_model(**sampler_config) - elif fit_method == "MAP": - self.idata = self._fit_MAP(**sampler_config) + self.idata = self.sample_model(**sampler_config) X_df = pd.DataFrame(X, columns=X.columns) combined_data = pd.concat([X_df, y], axis=1) @@ -486,62 +482,6 @@ def fit( self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore return self.idata # type: ignore - def _fit_MAP( - self, - **kwargs, - ): - """Find model maximum a posteriori using scipy optimizer""" - - model = self.model - find_MAP_args = {**self.sampler_config, **kwargs} - if "random_seed" in find_MAP_args: - # find_MAP takes a different argument name for seed than sample_* do. - find_MAP_args["seed"] = find_MAP_args["random_seed"] - # Extra unknown arguments cause problems for SciPy minimize - allowed_args = [ # find_MAP args - "start", - "vars", - "method", - # "return_raw", # probably causes a problem if set spuriously - # "include_transformed", # probably causes a problem if set spuriously - "progressbar", - "maxeval", - "seed", - ] - allowed_args += [ # scipy.optimize.minimize args - # "fun", # used by find_MAP - # "x0", # used by find_MAP - "args", - "method", - # "jac", # used by find_MAP - # "hess", # probably causes a problem if set spuriously - # "hessp", # probably causes a problem if set spuriously - "bounds", - "constraints", - "tol", - "callback", - "options", - ] - for arg in list(find_MAP_args): - if arg not in allowed_args: - del find_MAP_args[arg] - - map_res = pm.find_MAP(model=model, **find_MAP_args) - # Filter non-value variables - value_vars_names = {v.name for v in model.value_vars} - map_res = {k: v for k, v in map_res.items() if k in value_vars_names} - - # Convert map result to InferenceData - map_strace = NDArray(model=model) - map_strace.setup(draws=1, chain=0) - map_strace.record(map_res) - map_strace.close() - trace = MultiTrace([map_strace]) - idata = pm.to_inference_data(trace, model=model) - self.set_idata_attrs(idata) - - return idata - def predict( self, X_pred: Union[np.ndarray, pd.DataFrame, pd.Series], diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 18bfce7c..dd4a88ab 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -40,8 +40,8 @@ def toy_y(toy_X): return y -@pytest.fixture(scope="module", params=["mcmc", "MAP"]) -def fitted_model_instance(request, toy_X, toy_y): +@pytest.fixture(scope="module") +def fitted_model_instance(toy_X, toy_y): sampler_config = { "draws": 500, "tune": 300, @@ -54,11 +54,12 @@ def fitted_model_instance(request, toy_X, toy_y): "obs_error": 2, } model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config) - model.fit(toy_X, toy_y, fit_method=request.param) + model.fit(toy_X) return model class test_ModelBuilder(ModelBuilder): + _model_type = "LinearModel" version = "0.1" @@ -150,10 +151,9 @@ def test_fit(fitted_model_instance): post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape -@pytest.mark.parametrize("fit_method", ["mcmc", "MAP"]) -def test_fit_no_y(toy_X, fit_method): +def test_fit_no_y(toy_X): model_builder = test_ModelBuilder() - model_builder.idata = model_builder.fit(X=toy_X, fit_method=fit_method) + model_builder.idata = model_builder.fit(X=toy_X) assert model_builder.model is not None assert model_builder.idata is not None assert "posterior" in model_builder.idata.groups() @@ -163,16 +163,17 @@ def test_fit_no_y(toy_X, fit_method): sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." ) def test_save_load(fitted_model_instance): - with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) as temp: - fitted_model_instance.save(temp.name) - test_builder2 = test_ModelBuilder.load(temp.name) - assert sorted(fitted_model_instance.idata.groups()) == sorted(test_builder2.idata.groups()) + temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) + fitted_model_instance.save(temp.name) + test_builder2 = test_ModelBuilder.load(temp.name) + assert fitted_model_instance.idata.groups() == test_builder2.idata.groups() x_pred = np.random.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) pred1 = fitted_model_instance.predict(prediction_data["input"]) pred2 = test_builder2.predict(prediction_data["input"]) assert pred1.shape == pred2.shape + temp.close() def test_predict(fitted_model_instance): @@ -192,8 +193,8 @@ def test_sample_posterior_predictive(fitted_model_instance, combined): pred = fitted_model_instance.sample_posterior_predictive( prediction_data["input"], combined=combined, extend_idata=True ) - chains = fitted_model_instance.idata.posterior.dims["chain"] - draws = fitted_model_instance.idata.posterior.dims["draw"] + chains = fitted_model_instance.idata.sample_stats.dims["chain"] + draws = fitted_model_instance.idata.sample_stats.dims["draw"] expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred) assert pred[fitted_model_instance.output_var].shape == expected_shape assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)