Skip to content

Commit 714b4a0

Browse files
vandaltricardoV94
andauthored
Fix compute_test_value error when creating observed variables (#6982)
Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
1 parent 58217cc commit 714b4a0

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

pymc/model/core.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,15 +1347,6 @@ def make_obs_var(
13471347
"Dimensionality of data and RV don't match.", actual=data.ndim, expected=rv_var.ndim
13481348
)
13491349

1350-
if pytensor.config.compute_test_value != "off":
1351-
test_value = getattr(rv_var.tag, "test_value", None)
1352-
1353-
if test_value is not None:
1354-
# We try to reuse the old test value
1355-
rv_var.tag.test_value = np.broadcast_to(test_value, rv_var.shape)
1356-
else:
1357-
rv_var.tag.test_value = data
1358-
13591350
mask = getattr(data, "mask", None)
13601351
if mask is not None:
13611352
impute_message = (

tests/model/test_core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ def test_observed_type(self):
223223
assert x1.type.dtype == X.type.dtype
224224
assert x2.type.dtype == X.type.dtype
225225

226+
@pytensor.config.change_flags(compute_test_value="raise")
227+
def test_observed_compute_test_value(self):
228+
data = np.zeros(100)
229+
with pm.Model():
230+
obs = pm.Normal("obs", mu=pt.zeros_like(data), sigma=1, observed=data)
231+
assert obs.tag.test_value.shape == data.shape
232+
assert obs.tag.test_value.dtype == data.dtype
233+
226234

227235
def test_duplicate_vars():
228236
with pytest.raises(ValueError) as err:

0 commit comments

Comments
 (0)