diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index c07683555a..d997e00466 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -345,10 +345,13 @@ def draw( return [np.stack(v) for v in drawn_values] -def observed_dependent_deterministics(model: Model): +def observed_dependent_deterministics(model: Model, extra_observeds=None): """Find deterministics that depend directly on observed variables.""" + if extra_observeds is None: + extra_observeds = [] + deterministics = model.deterministics - observed_rvs = set(model.observed_RVs) + observed_rvs = set(model.observed_RVs + extra_observeds) blockers = model.basic_RVs return [ deterministic @@ -767,6 +770,7 @@ def sample_posterior_predictive( if "coords" not in idata_kwargs: idata_kwargs["coords"] = {} idata: InferenceData | None = None + observed_data = None stacked_dims = None if isinstance(trace, InferenceData): _constant_data = getattr(trace, "constant_data", None) @@ -774,6 +778,7 @@ def sample_posterior_predictive( trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) constant_data.update({str(k): v.data for k, v in _constant_data.items()}) idata = trace + observed_data = trace.get("observed_data", None) trace = trace["posterior"] if isinstance(trace, xarray.Dataset): trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) @@ -816,7 +821,12 @@ def sample_posterior_predictive( if var_names is not None: vars_ = [model[x] for x in var_names] else: - vars_ = model.observed_RVs + observed_dependent_deterministics(model) + observed_vars = model.observed_RVs + if observed_data is not None: + observed_vars += [ + model[x] for x in observed_data if x in model and x not in observed_vars + ] + vars_ = observed_vars + observed_dependent_deterministics(model, observed_vars) vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) diff --git a/tests/sampling/test_forward.py b/tests/sampling/test_forward.py index 404f74a961..9348296297 100644 --- a/tests/sampling/test_forward.py +++ b/tests/sampling/test_forward.py @@ -540,6 +540,24 @@ def test_normal_scalar_idata(self): ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False) assert ppc["a"].shape == (nchains, ndraws) + def test_external_trace_det(self): + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1, observed=0.0) + b = pm.Deterministic("b", a + 1) + trace = pm.sample(tune=50, draws=50, chains=1, compute_convergence_checks=False) + + # test that trace is used in ppc + with pm.Model() as model_ppc: + mu = pm.Normal("mu", 0.0, 1.0) + a = pm.Normal("a", mu=mu, sigma=1) + c = pm.Deterministic("c", a + 1) + + ppc = pm.sample_posterior_predictive( + trace=trace, model=model_ppc, return_inferencedata=False + ) + assert list(ppc.keys()) == ["a", "c"] + def test_normal_vector(self): with pm.Model() as model: mu = pm.Normal("mu", 0.0, 1.0)