Skip to content

Commit 6c4d4eb

Browse files
Fix bug that does not correctly set the dtype of determinsitic variab… (#6425)
* Fix bug that does not correctly set the dtype of determinsitic variable after automatic imputation * Change `at.zeros` to `at.empty` when creating combined observed/missing vector Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
1 parent c3b8ff4 commit 6c4d4eb

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

pymc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ def make_obs_var(
14561456

14571457
# Create deterministic that combines observed and missing
14581458
# Note: This can widely increase memory consumption during sampling for large datasets
1459-
rv_var = at.zeros(data.shape)
1459+
rv_var = at.empty(data.shape, dtype=observed_rv_var.type.dtype)
14601460
rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
14611461
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
14621462
rv_var = Deterministic(name, rv_var, self, dims)

pymc/tests/test_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,9 @@ def test_missing_data(self):
356356

357357
assert m["x2_missing"].type == gf._extra_vars_shared["x2_missing"].type
358358

359+
# The dtype of the merged observed/missing deterministic should match the RV dtype
360+
assert m.deterministics[0].type.dtype == x2.type.dtype
361+
359362
pnt = m.initial_point(random_seed=None).copy()
360363
del pnt["x2_missing"]
361364

0 commit comments

Comments
 (0)