From d7a82b1ce9fac2f98d174831fbdd134a45383e85 Mon Sep 17 00:00:00 2001 From: butterman0 Date: Mon, 17 Feb 2025 14:12:27 +0100 Subject: [PATCH 01/12] Fix group selection for posterior predictive samples when predictions = True --- pymc_extras/model_builder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 6e712e5d7..98fbe7d5e 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -650,8 +650,11 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): if extend_idata: self.idata.extend(post_pred, join="right") + # Determine the correct group dynamically + group_name = "predictions" if kwargs.get("predictions", False) else "posterior_predictive" + posterior_predictive_samples = az.extract( - post_pred, "posterior_predictive", combined=combined + post_pred, group_name, combined=combined ) return posterior_predictive_samples From e86992b886f21cf8df40dc23827800004e940421 Mon Sep 17 00:00:00 2001 From: butterman0 Date: Tue, 18 Feb 2025 12:23:49 +0100 Subject: [PATCH 02/12] refactor: make predictions argument explicit --- pymc_extras/model_builder.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 98fbe7d5e..82fd07ece 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -530,6 +530,7 @@ def predict( self, X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, + predictions: bool = False, **kwargs, ) -> np.ndarray: """ @@ -559,7 +560,7 @@ def predict( """ posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined=False, **kwargs + X_pred, extend_idata, predictions, combined=False, **kwargs ) if self.output_var not in posterior_predictive_samples: @@ -624,7 +625,7 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): + def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combined, **kwargs): """ Sample from the model's posterior predictive distribution. @@ -646,12 +647,12 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs): self._data_setter(X_pred) with self.model: # sample with new input data - post_pred = pm.sample_posterior_predictive(self.idata, **kwargs) + post_pred = pm.sample_posterior_predictive(self.idata, predictions=predictions, **kwargs) if extend_idata: self.idata.extend(post_pred, join="right") - # Determine the correct group dynamically - group_name = "predictions" if kwargs.get("predictions", False) else "posterior_predictive" + # Determine the correct group + group_name = "predictions" if predictions else "posterior_predictive" posterior_predictive_samples = az.extract( post_pred, group_name, combined=combined @@ -703,6 +704,7 @@ def predict_posterior( X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, combined: bool = True, + predictions: bool = False, **kwargs, ) -> xr.DataArray: """ @@ -726,7 +728,7 @@ def predict_posterior( X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, combined, **kwargs + X_pred, extend_idata, predictions, combined, **kwargs ) if self.output_var not in posterior_predictive_samples: From fd1793160f3505abca257608b67954846b4a0c33 Mon Sep 17 00:00:00 2001 From: butterman0 Date: Tue, 18 Feb 2025 12:35:05 +0100 Subject: [PATCH 03/12] refactor: change default predictions to True --- pymc_extras/model_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 82fd07ece..2a4ab2abb 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -530,7 +530,7 @@ def predict( self, X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, - predictions: bool = False, + predictions: bool = True, **kwargs, ) -> np.ndarray: """ @@ -704,7 +704,7 @@ def predict_posterior( X_pred: np.ndarray | pd.DataFrame | pd.Series, extend_idata: bool = True, combined: bool = True, - predictions: bool = False, + predictions: bool = True, **kwargs, ) -> xr.DataArray: """ From b8174c94a480995d89276a7b347d2b96f499101f Mon Sep 17 00:00:00 2001 From: butterman0 Date: Tue, 18 Feb 2025 21:36:42 +0100 Subject: [PATCH 04/12] doc: update docstrings --- pymc_extras/model_builder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 2a4ab2abb..f8fc93de1 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -543,6 +543,9 @@ def predict( The input data used for prediction. extend_idata : Boolean determining whether the predictions should be added to inference data object. Defaults to True. + predictions : bool + Whether to use the predictions group for posterior predictive sampling. + Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns @@ -651,7 +654,6 @@ def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combine if extend_idata: self.idata.extend(post_pred, join="right") - # Determine the correct group group_name = "predictions" if predictions else "posterior_predictive" posterior_predictive_samples = az.extract( @@ -718,6 +720,8 @@ def predict_posterior( Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. + predictions : Boolean determing whether to use the predictions group for posterior predictive sampling. + Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive Returns From ded48e00be5536d11f6e6b1e2faed51a8591cdcc Mon Sep 17 00:00:00 2001 From: butterman0 Date: Thu, 6 Mar 2025 17:01:41 +0100 Subject: [PATCH 05/12] docs: update --- pymc_extras/model_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index f8fc93de1..dea962a31 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -638,6 +638,8 @@ def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combine The input data used for prediction using prior distribution.. extend_idata : Boolean determining whether the predictions should be added to inference data object. Defaults to False. + predictions : Boolean determing whether to use the predictions group for posterior predictive sampling. + Defaults to True. combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. Defaults to True. **kwargs: Additional arguments to pass to pymc.sample_posterior_predictive From ae2cafaf5c982e32633d151e909155cb9b0c2386 Mon Sep 17 00:00:00 2001 From: butterman0 Date: Thu, 6 Mar 2025 17:05:12 +0100 Subject: [PATCH 06/12] refactor: pass predictions by keyword --- pymc_extras/model_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index dea962a31..ccc29c57d 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -563,7 +563,7 @@ def predict( """ posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, predictions, combined=False, **kwargs + X_pred, extend_idata, combined=False, predictions=predictions, **kwargs ) if self.output_var not in posterior_predictive_samples: @@ -628,7 +628,7 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combined, **kwargs): + def sample_posterior_predictive(self, X_pred, extend_idata, combined, predictions = True, **kwargs): """ Sample from the model's posterior predictive distribution. @@ -734,7 +734,7 @@ def predict_posterior( X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( - X_pred, extend_idata, predictions, combined, **kwargs + X_pred, extend_idata, combined, predictions=predictions, **kwargs ) if self.output_var not in posterior_predictive_samples: From a8754ee3390513e405e478f2ef4ba786d1c5aa5a Mon Sep 17 00:00:00 2001 From: butterman0 Date: Wed, 9 Apr 2025 15:50:24 +0200 Subject: [PATCH 07/12] test: added test for predictions grouping --- tests/test_model_builder.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 9494bb10e..1a4ffa978 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -304,3 +304,35 @@ def test_id(): ).hexdigest()[:16] assert model_builder.id == expected_id + +@pytest.mark.parametrize("predictions", [True, False]) +def test_predict_respects_predictions_flag(fitted_model_instance, predictions): + x_pred = np.random.uniform(0, 1, 100) + prediction_data = pd.DataFrame({"input": x_pred}) + output_var = fitted_model_instance.output_var + + # Snapshot the original posterior_predictive values + pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy() + + # Ensure 'predictions' group is not present initially + assert "predictions" not in fitted_model_instance.idata.groups() + + # Run prediction with predictions=True or False + fitted_model_instance.predict( + prediction_data["input"], + extend_idata=True, + combined=False, + predictions=predictions, + ) + + pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values + + # Check predictions group presence + if predictions: + assert "predictions" in fitted_model_instance.idata.groups() + # Posterior predictive should remain unchanged + np.testing.assert_array_equal(pp_before, pp_after) + else: + assert "predictions" not in fitted_model_instance.idata.groups() + # Posterior predictive should be updated + np.testing.assert_array_not_equal(pp_before, pp_after) \ No newline at end of file From 7950f99ab303af18c7610334e5a24743244d6012 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 25 Apr 2025 07:41:41 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc_extras/model_builder.py | 12 +++++++----- tests/test_model_builder.py | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index ccc29c57d..10d34f29b 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -628,7 +628,9 @@ def sample_prior_predictive( return prior_predictive_samples - def sample_posterior_predictive(self, X_pred, extend_idata, combined, predictions = True, **kwargs): + def sample_posterior_predictive( + self, X_pred, extend_idata, combined, predictions=True, **kwargs + ): """ Sample from the model's posterior predictive distribution. @@ -652,15 +654,15 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, prediction self._data_setter(X_pred) with self.model: # sample with new input data - post_pred = pm.sample_posterior_predictive(self.idata, predictions=predictions, **kwargs) + post_pred = pm.sample_posterior_predictive( + self.idata, predictions=predictions, **kwargs + ) if extend_idata: self.idata.extend(post_pred, join="right") group_name = "predictions" if predictions else "posterior_predictive" - posterior_predictive_samples = az.extract( - post_pred, group_name, combined=combined - ) + posterior_predictive_samples = az.extract(post_pred, group_name, combined=combined) return posterior_predictive_samples diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 1a4ffa978..88dd971df 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -305,6 +305,7 @@ def test_id(): assert model_builder.id == expected_id + @pytest.mark.parametrize("predictions", [True, False]) def test_predict_respects_predictions_flag(fitted_model_instance, predictions): x_pred = np.random.uniform(0, 1, 100) @@ -324,7 +325,7 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions): combined=False, predictions=predictions, ) - + pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values # Check predictions group presence @@ -335,4 +336,4 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions): else: assert "predictions" not in fitted_model_instance.idata.groups() # Posterior predictive should be updated - np.testing.assert_array_not_equal(pp_before, pp_after) \ No newline at end of file + np.testing.assert_array_not_equal(pp_before, pp_after) From e7fe9a270a9dccbcb1d081c2a868e11baee1c387 Mon Sep 17 00:00:00 2001 From: butterman0 Date: Sat, 24 May 2025 17:20:57 +0200 Subject: [PATCH 09/12] add missing call to validate data --- pymc_extras/model_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 10d34f29b..c6836c9b6 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -561,6 +561,8 @@ def predict( >>> prediction_data = pd.DataFrame({'input':x_pred}) >>> pred_mean = model.predict(prediction_data) """ + + X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( X_pred, extend_idata, combined=False, predictions=predictions, **kwargs From e698292555b505ac0a79e47e1d9a2230e0a2751c Mon Sep 17 00:00:00 2001 From: butterman0 Date: Sat, 24 May 2025 17:22:23 +0200 Subject: [PATCH 10/12] update predict calls to handle validate data and predictions group --- tests/test_model_builder.py | 65 +++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 88dd971df..d846d704e 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -124,11 +124,18 @@ def _save_input_params(self, idata): def output_var(self): return "output" - def _data_setter(self, x: pd.Series, y: pd.Series = None): + def _data_setter(self, X: pd.Series | np.ndarray, y: pd.Series | np.ndarray = None): + with self.model: - pm.set_data({"x": x.values}) + + X = X.values if isinstance(X, pd.Series) else X.ravel() + + pm.set_data({"x": X}) + if y is not None: - pm.set_data({"y_data": y.values}) + y = y.values if isinstance(y, pd.Series) else y.ravel() + + pm.set_data({"y_data": y}) @property def _serializable_model_config(self): @@ -177,8 +184,8 @@ def test_save_load(fitted_model_instance): assert fitted_model_instance.id == test_builder2.id x_pred = np.random.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred1 = fitted_model_instance.predict(prediction_data["input"]) - pred2 = test_builder2.predict(prediction_data["input"]) + pred1 = fitted_model_instance.predict(prediction_data[["input"]]) + pred2 = test_builder2.predict(prediction_data[["input"]]) assert pred1.shape == pred2.shape temp.close() @@ -205,7 +212,7 @@ def test_empty_sampler_config_fit(toy_X, toy_y): def test_fit(fitted_model_instance): prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)}) - pred = fitted_model_instance.predict(prediction_data["input"]) + pred = fitted_model_instance.predict(prediction_data[["input"]]) post_pred = fitted_model_instance.sample_posterior_predictive( prediction_data["input"], extend_idata=True, combined=True ) @@ -223,7 +230,7 @@ def test_fit_no_y(toy_X): def test_predict(fitted_model_instance): x_pred = np.random.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred = fitted_model_instance.predict(prediction_data["input"]) + pred = fitted_model_instance.predict(prediction_data[["input"]]) # Perform elementwise comparison using numpy assert isinstance(pred, np.ndarray) assert len(pred) > 0 @@ -256,13 +263,12 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat prediction_data = pd.DataFrame({"input": x_pred}) if group == "prior_predictive": - prediction_method = fitted_model_instance.sample_prior_predictive + pred = fitted_model_instance.sample_prior_predictive(prediction_data["input"], combined=False, extend_idata=extend_idata) else: # group == "posterior_predictive": - prediction_method = fitted_model_instance.sample_posterior_predictive - - pred = prediction_method(prediction_data["input"], combined=False, extend_idata=extend_idata) + pred = fitted_model_instance.sample_posterior_predictive(prediction_data["input"], combined=False, predictions=False, extend_idata=extend_idata) pred_unstacked = pred[output_var].values + idata_now = fitted_model_instance.idata[group][output_var].values if extend_idata: @@ -320,9 +326,40 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions): # Run prediction with predictions=True or False fitted_model_instance.predict( - prediction_data["input"], + X_pred=prediction_data[["input"]], + extend_idata=True, + predictions=predictions, + ) + + pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values + + # Check predictions group presence + if predictions: + assert "predictions" in fitted_model_instance.idata.groups() + # Posterior predictive should remain unchanged + np.testing.assert_array_equal(pp_before, pp_after) + else: + assert "predictions" not in fitted_model_instance.idata.groups() + # Posterior predictive should be updated + assert not np.array_equal(pp_before, pp_after) + +@pytest.mark.parametrize("predictions", [True, False]) +def test_predict_posterior_respects_predictions_flag(fitted_model_instance, predictions): + x_pred = np.random.uniform(0, 1, 100) + prediction_data = pd.DataFrame({"input": x_pred}) + output_var = fitted_model_instance.output_var + + # Snapshot the original posterior_predictive values + pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy() + + # Ensure 'predictions' group is not present initially + assert "predictions" not in fitted_model_instance.idata.groups() + + # Run prediction with predictions=True or False + fitted_model_instance.predict_posterior( + X_pred=prediction_data[["input"]], extend_idata=True, - combined=False, + combined=True, predictions=predictions, ) @@ -336,4 +373,4 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions): else: assert "predictions" not in fitted_model_instance.idata.groups() # Posterior predictive should be updated - np.testing.assert_array_not_equal(pp_before, pp_after) + assert not np.array_equal(pp_before, pp_after) From b3f9a6c67289f430638133abd87b0dffc5fa500c Mon Sep 17 00:00:00 2001 From: butterman0 Date: Sat, 24 May 2025 17:26:23 +0200 Subject: [PATCH 11/12] Consolidate test with pytest paramterize --- tests/test_model_builder.py | 53 +++++++++++-------------------------- 1 file changed, 15 insertions(+), 38 deletions(-) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index d846d704e..e8355dad1 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -313,7 +313,8 @@ def test_id(): @pytest.mark.parametrize("predictions", [True, False]) -def test_predict_respects_predictions_flag(fitted_model_instance, predictions): +@pytest.mark.parametrize("predict_method", ["predict", "predict_posterior"]) +def test_predict_method_respects_predictions_flag(fitted_model_instance, predictions, predict_method): x_pred = np.random.uniform(0, 1, 100) prediction_data = pd.DataFrame({"input": x_pred}) output_var = fitted_model_instance.output_var @@ -325,43 +326,18 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions): assert "predictions" not in fitted_model_instance.idata.groups() # Run prediction with predictions=True or False - fitted_model_instance.predict( - X_pred=prediction_data[["input"]], - extend_idata=True, - predictions=predictions, - ) - - pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values - - # Check predictions group presence - if predictions: - assert "predictions" in fitted_model_instance.idata.groups() - # Posterior predictive should remain unchanged - np.testing.assert_array_equal(pp_before, pp_after) - else: - assert "predictions" not in fitted_model_instance.idata.groups() - # Posterior predictive should be updated - assert not np.array_equal(pp_before, pp_after) - -@pytest.mark.parametrize("predictions", [True, False]) -def test_predict_posterior_respects_predictions_flag(fitted_model_instance, predictions): - x_pred = np.random.uniform(0, 1, 100) - prediction_data = pd.DataFrame({"input": x_pred}) - output_var = fitted_model_instance.output_var - - # Snapshot the original posterior_predictive values - pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy() - - # Ensure 'predictions' group is not present initially - assert "predictions" not in fitted_model_instance.idata.groups() - - # Run prediction with predictions=True or False - fitted_model_instance.predict_posterior( - X_pred=prediction_data[["input"]], - extend_idata=True, - combined=True, - predictions=predictions, - ) + if predict_method == "predict": + fitted_model_instance.predict( + X_pred=prediction_data[["input"]], + extend_idata=True, + predictions=predictions, + ) + else:# predict_method == "predict_posterior": + fitted_model_instance.predict_posterior( + X_pred=prediction_data[["input"]], + extend_idata=True, + predictions=predictions, + ) pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values @@ -374,3 +350,4 @@ def test_predict_posterior_respects_predictions_flag(fitted_model_instance, pred assert "predictions" not in fitted_model_instance.idata.groups() # Posterior predictive should be updated assert not np.array_equal(pp_before, pp_after) + From f4d04c2af046506a90c798bad7a08da41d256041 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 May 2025 15:27:15 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc_extras/model_builder.py | 2 +- tests/test_model_builder.py | 25 ++++++++++++++----------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index c6836c9b6..260ec6b2c 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -561,7 +561,7 @@ def predict( >>> prediction_data = pd.DataFrame({'input':x_pred}) >>> pred_mean = model.predict(prediction_data) """ - + X_pred = self._validate_data(X_pred) posterior_predictive_samples = self.sample_posterior_predictive( diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index e8355dad1..8be76ad19 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -125,16 +125,14 @@ def output_var(self): return "output" def _data_setter(self, X: pd.Series | np.ndarray, y: pd.Series | np.ndarray = None): - with self.model: - X = X.values if isinstance(X, pd.Series) else X.ravel() - + pm.set_data({"x": X}) - + if y is not None: y = y.values if isinstance(y, pd.Series) else y.ravel() - + pm.set_data({"y_data": y}) @property @@ -263,12 +261,16 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat prediction_data = pd.DataFrame({"input": x_pred}) if group == "prior_predictive": - pred = fitted_model_instance.sample_prior_predictive(prediction_data["input"], combined=False, extend_idata=extend_idata) + pred = fitted_model_instance.sample_prior_predictive( + prediction_data["input"], combined=False, extend_idata=extend_idata + ) else: # group == "posterior_predictive": - pred = fitted_model_instance.sample_posterior_predictive(prediction_data["input"], combined=False, predictions=False, extend_idata=extend_idata) + pred = fitted_model_instance.sample_posterior_predictive( + prediction_data["input"], combined=False, predictions=False, extend_idata=extend_idata + ) pred_unstacked = pred[output_var].values - + idata_now = fitted_model_instance.idata[group][output_var].values if extend_idata: @@ -314,7 +316,9 @@ def test_id(): @pytest.mark.parametrize("predictions", [True, False]) @pytest.mark.parametrize("predict_method", ["predict", "predict_posterior"]) -def test_predict_method_respects_predictions_flag(fitted_model_instance, predictions, predict_method): +def test_predict_method_respects_predictions_flag( + fitted_model_instance, predictions, predict_method +): x_pred = np.random.uniform(0, 1, 100) prediction_data = pd.DataFrame({"input": x_pred}) output_var = fitted_model_instance.output_var @@ -332,7 +336,7 @@ def test_predict_method_respects_predictions_flag(fitted_model_instance, predict extend_idata=True, predictions=predictions, ) - else:# predict_method == "predict_posterior": + else: # predict_method == "predict_posterior": fitted_model_instance.predict_posterior( X_pred=prediction_data[["input"]], extend_idata=True, @@ -350,4 +354,3 @@ def test_predict_method_respects_predictions_flag(fitted_model_instance, predict assert "predictions" not in fitted_model_instance.idata.groups() # Posterior predictive should be updated assert not np.array_equal(pp_before, pp_after) -