diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 6e712e5d..260ec6b2 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -530,6 +530,7 @@ def predict( self, X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, + predictions: bool = True, **kwargs, ) -> np.ndarray: """ @@ -542,6 +543,9 @@ def predict( The input data used for prediction. extend_idata : Boolean determining whether the predictions should be added to inference data object. Defaults to True. + predictions : bool + Whether to use the predictions group for posterior predictive sampling. + Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns @@ -558,8 +562,10 @@ def predict( >>> pred_mean = model.predict(prediction_data) """ + X_pred = self._validate_data(X_pred) + posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined=False, **kwargs + X_pred, extend_idata, combined=False, predictions=predictions, **kwargs ) if self.output_var not in posterior_predictive_samples: @@ -624,7 +630,9 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): + def sample_posterior_predictive( + self, X_pred, extend_idata, combined, predictions=True, **kwargs + ): """ Sample from the model's posterior predictive distribution. @@ -634,6 +642,8 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): The input data used for prediction using prior distribution.. extend_idata : Boolean determining whether the predictions should be added to inference data object. Defaults to False. + predictions : Boolean determing whether to use the predictions group for posterior predictive sampling. + Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive @@ -646,13 +656,15 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): self._data_setter(X_pred) with self.model: # sample with new input data - post_pred = pm.sample_posterior_predictive(self.idata, **kwargs) + post_pred = pm.sample_posterior_predictive( + self.idata, predictions=predictions, **kwargs + ) if extend_idata: self.idata.extend(post_pred, join="right") - posterior_predictive_samples = az.extract( - post_pred, "posterior_predictive", combined=combined - ) + group_name = "predictions" if predictions else "posterior_predictive" + + posterior_predictive_samples = az.extract(post_pred, group_name, combined=combined) return posterior_predictive_samples @@ -700,6 +712,7 @@ def predict_posterior( X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, combined: bool = True, + predictions: bool = True, **kwargs, ) -> xr.DataArray: """ @@ -713,6 +726,8 @@ def predict_posterior( Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. + predictions : Boolean determing whether to use the predictions group for posterior predictive sampling. + Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns @@ -723,7 +738,7 @@ def predict_posterior( X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined, **kwargs + X_pred, extend_idata, combined, predictions=predictions, **kwargs ) if self.output_var not in posterior_predictive_samples: diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 9494bb10..8be76ad1 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -124,11 +124,16 @@ def _save_input_params(self, idata): def output_var(self): return "output" - def _data_setter(self, x: pd.Series, y: pd.Series = None): + def _data_setter(self, X: pd.Series | np.ndarray, y: pd.Series | np.ndarray = None): with self.model: - pm.set_data({"x": x.values}) + X = X.values if isinstance(X, pd.Series) else X.ravel() + + pm.set_data({"x": X}) + if y is not None: - pm.set_data({"y_data": y.values}) + y = y.values if isinstance(y, pd.Series) else y.ravel() + + pm.set_data({"y_data": y}) @property def _serializable_model_config(self): @@ -177,8 +182,8 @@ def test_save_load(fitted_model_instance): assert fitted_model_instance.id == test_builder2.id x_pred = np.random.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred1 = fitted_model_instance.predict(prediction_data["input"]) - pred2 = test_builder2.predict(prediction_data["input"]) + pred1 = fitted_model_instance.predict(prediction_data[["input"]]) + pred2 = test_builder2.predict(prediction_data[["input"]]) assert pred1.shape == pred2.shape temp.close() @@ -205,7 +210,7 @@ def test_empty_sampler_config_fit(toy_X, toy_y): def test_fit(fitted_model_instance): prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)}) - pred = fitted_model_instance.predict(prediction_data["input"]) + pred = fitted_model_instance.predict(prediction_data[["input"]]) post_pred = fitted_model_instance.sample_posterior_predictive( prediction_data["input"], extend_idata=True, combined=True ) @@ -223,7 +228,7 @@ def test_fit_no_y(toy_X): def test_predict(fitted_model_instance): x_pred = np.random.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred = fitted_model_instance.predict(prediction_data["input"]) + pred = fitted_model_instance.predict(prediction_data[["input"]]) # Perform elementwise comparison using numpy assert isinstance(pred, np.ndarray) assert len(pred) > 0 @@ -256,13 +261,16 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat prediction_data = pd.DataFrame({"input": x_pred}) if group == "prior_predictive": - prediction_method = fitted_model_instance.sample_prior_predictive + pred = fitted_model_instance.sample_prior_predictive( + prediction_data["input"], combined=False, extend_idata=extend_idata + ) else: # group == "posterior_predictive": - prediction_method = fitted_model_instance.sample_posterior_predictive - - pred = prediction_method(prediction_data["input"], combined=False, extend_idata=extend_idata) + pred = fitted_model_instance.sample_posterior_predictive( + prediction_data["input"], combined=False, predictions=False, extend_idata=extend_idata + ) pred_unstacked = pred[output_var].values + idata_now = fitted_model_instance.idata[group][output_var].values if extend_idata: @@ -304,3 +312,45 @@ def test_id(): ).hexdigest()[:16] assert model_builder.id == expected_id + + +@pytest.mark.parametrize("predictions", [True, False]) +@pytest.mark.parametrize("predict_method", ["predict", "predict_posterior"]) +def test_predict_method_respects_predictions_flag( + fitted_model_instance, predictions, predict_method +): + x_pred = np.random.uniform(0, 1, 100) + prediction_data = pd.DataFrame({"input": x_pred}) + output_var = fitted_model_instance.output_var + + # Snapshot the original posterior_predictive values + pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy() + + # Ensure 'predictions' group is not present initially + assert "predictions" not in fitted_model_instance.idata.groups() + + # Run prediction with predictions=True or False + if predict_method == "predict": + fitted_model_instance.predict( + X_pred=prediction_data[["input"]], + extend_idata=True, + predictions=predictions, + ) + else: # predict_method == "predict_posterior": + fitted_model_instance.predict_posterior( + X_pred=prediction_data[["input"]], + extend_idata=True, + predictions=predictions, + ) + + pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values + + # Check predictions group presence + if predictions: + assert "predictions" in fitted_model_instance.idata.groups() + # Posterior predictive should remain unchanged + np.testing.assert_array_equal(pp_before, pp_after) + else: + assert "predictions" not in fitted_model_instance.idata.groups() + # Posterior predictive should be updated + assert not np.array_equal(pp_before, pp_after)