Skip to content

Refactor modelbuilder fit #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import pandas as pd
import pymc as pm
import xarray as xr
from pymc.backends import NDArray
from pymc.backends.base import MultiTrace
from pymc.util import RandomState

# If scikit-learn is available, use its data validator
Expand Down Expand Up @@ -425,6 +427,7 @@ def fit(
self,
X: pd.DataFrame,
y: Optional[pd.Series] = None,
fit_method="mcmc",
progressbar: bool = True,
predictor_names: List[str] = None,
random_seed: RandomState = None,
Expand All @@ -441,6 +444,8 @@ def fit(
The training input samples.
y : array-like if sklearn is available, otherwise array, shape (n_obs,)
The target values (real numbers).
fit_method : str
Which method to use to infer model parameters. One of ["mcmc", "MAP"].
progressbar : bool
Specifies whether the fit progressbar should be displayed
predictor_names: List[str] = None,
Expand All @@ -449,19 +454,14 @@ def fit(
random_seed : RandomState
Provides sampler with initial random seed for obtaining reproducible samples
**kwargs : Any
Custom sampler settings can be provided in form of keyword arguments.

Returns
-------
self : az.InferenceData
returns inference data of the fitted model.
Examples
--------
>>> model = MyModel()
>>> idata = model.fit(data)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Parameters to pass to the inference method. See `_fit_mcmc` or `_fit_MAP` for
method-specific parameters.
"""
available_methods = ["mcmc", "MAP"]
if fit_method not in available_methods:
raise ValueError(
f"Inference method {fit_method} not found. Choose one of {available_methods}."
)
if predictor_names is None:
predictor_names = []
if y is None:
Expand All @@ -474,14 +474,74 @@ def fit(
sampler_config["progressbar"] = progressbar
sampler_config["random_seed"] = random_seed
sampler_config.update(**kwargs)
self.idata = self.sample_model(**sampler_config)

if fit_method == "mcmc":
self.idata = self.sample_model(**sampler_config)
elif fit_method == "MAP":
self.idata = self._fit_MAP(**sampler_config)

X_df = pd.DataFrame(X, columns=X.columns)
combined_data = pd.concat([X_df, y], axis=1)
assert all(combined_data.columns), "All columns must have non-empty names"
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
return self.idata # type: ignore

def _fit_MAP(
self,
**kwargs,
):
"""Find model maximum a posteriori using scipy optimizer"""

model = self.model
find_MAP_args = {**self.sampler_config, **kwargs}
if "random_seed" in find_MAP_args:
# find_MAP takes a different argument name for seed than sample_* do.
find_MAP_args["seed"] = find_MAP_args["random_seed"]
# Extra unknown arguments cause problems for SciPy minimize
allowed_args = [ # find_MAP args
"start",
"vars",
"method",
# "return_raw", # probably causes a problem if set spuriously
# "include_transformed", # probably causes a problem if set spuriously
"progressbar",
"maxeval",
"seed",
]
allowed_args += [ # scipy.optimize.minimize args
# "fun", # used by find_MAP
# "x0", # used by find_MAP
"args",
"method",
# "jac", # used by find_MAP
# "hess", # probably causes a problem if set spuriously
# "hessp", # probably causes a problem if set spuriously
"bounds",
"constraints",
"tol",
"callback",
"options",
]
for arg in list(find_MAP_args):
if arg not in allowed_args:
del find_MAP_args[arg]

map_res = pm.find_MAP(model=model, **find_MAP_args)
# Filter non-value variables
value_vars_names = {v.name for v in model.value_vars}
map_res = {k: v for k, v in map_res.items() if k in value_vars_names}

# Convert map result to InferenceData
map_strace = NDArray(model=model)
map_strace.setup(draws=1, chain=0)
map_strace.record(map_res)
map_strace.close()
trace = MultiTrace([map_strace])
idata = pm.to_inference_data(trace, model=model)
self.set_idata_attrs(idata)

return idata

def predict(
self,
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series],
Expand Down
25 changes: 12 additions & 13 deletions pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def toy_y(toy_X):
return y


@pytest.fixture(scope="module")
def fitted_model_instance(toy_X, toy_y):
@pytest.fixture(scope="module", params=["mcmc", "MAP"])
def fitted_model_instance(request, toy_X, toy_y):
sampler_config = {
"draws": 500,
"tune": 300,
Expand All @@ -54,12 +54,11 @@ def fitted_model_instance(toy_X, toy_y):
"obs_error": 2,
}
model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config)
model.fit(toy_X)
model.fit(toy_X, toy_y, fit_method=request.param)
return model


class test_ModelBuilder(ModelBuilder):

_model_type = "LinearModel"
version = "0.1"

Expand Down Expand Up @@ -151,9 +150,10 @@ def test_fit(fitted_model_instance):
post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape


def test_fit_no_y(toy_X):
@pytest.mark.parametrize("fit_method", ["mcmc", "MAP"])
def test_fit_no_y(toy_X, fit_method):
model_builder = test_ModelBuilder()
model_builder.idata = model_builder.fit(X=toy_X)
model_builder.idata = model_builder.fit(X=toy_X, fit_method=fit_method)
assert model_builder.model is not None
assert model_builder.idata is not None
assert "posterior" in model_builder.idata.groups()
Expand All @@ -163,17 +163,16 @@ def test_fit_no_y(toy_X):
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
)
def test_save_load(fitted_model_instance):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
fitted_model_instance.save(temp.name)
test_builder2 = test_ModelBuilder.load(temp.name)
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) as temp:
fitted_model_instance.save(temp.name)
test_builder2 = test_ModelBuilder.load(temp.name)
assert sorted(fitted_model_instance.idata.groups()) == sorted(test_builder2.idata.groups())

x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred1 = fitted_model_instance.predict(prediction_data["input"])
pred2 = test_builder2.predict(prediction_data["input"])
assert pred1.shape == pred2.shape
temp.close()


def test_predict(fitted_model_instance):
Expand All @@ -193,8 +192,8 @@ def test_sample_posterior_predictive(fitted_model_instance, combined):
pred = fitted_model_instance.sample_posterior_predictive(
prediction_data["input"], combined=combined, extend_idata=True
)
chains = fitted_model_instance.idata.sample_stats.dims["chain"]
draws = fitted_model_instance.idata.sample_stats.dims["draw"]
chains = fitted_model_instance.idata.posterior.dims["chain"]
draws = fitted_model_instance.idata.posterior.dims["draw"]
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
assert pred[fitted_model_instance.output_var].shape == expected_shape
assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)
Expand Down