Skip to content

Commit c9832a7

Browse files
aseyboldtricardoV94
authored andcommitted
Create value variable on transformed space
1 parent 90be5bc commit c9832a7

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

pymc/model.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,28 +1464,37 @@ def create_value_var(
14641464
this branch of the conditional.
14651465
14661466
"""
1467-
if value_var is None:
1468-
value_var = rv_var.type()
1469-
value_var.name = rv_var.name
1470-
1471-
if aesara.config.compute_test_value != "off":
1472-
value_var.tag.test_value = rv_var.tag.test_value
1473-
1474-
_add_future_warning_tag(value_var)
1475-
rv_var.tag.value_var = value_var
14761467

14771468
# Make the value variable a transformed value variable,
14781469
# if there's an applicable transform
1479-
if transform is UNSET and rv_var.owner:
1480-
transform = _default_transform(rv_var.owner.op, rv_var)
1470+
if transform is UNSET:
1471+
if rv_var.owner is None:
1472+
transform = None
1473+
else:
1474+
transform = _default_transform(rv_var.owner.op, rv_var)
14811475

1482-
if transform is not None and transform is not UNSET:
1476+
if value_var is not None:
1477+
if transform is not None:
1478+
raise ValueError("Cannot use transform when providing a pre-defined value_var")
1479+
elif transform is None:
1480+
# Create value variable with the same type as the RV
1481+
value_var = rv_var.type()
1482+
value_var.name = rv_var.name
1483+
if aesara.config.compute_test_value != "off":
1484+
value_var.tag.test_value = rv_var.tag.test_value
1485+
else:
1486+
# Create value variable with the same type as the transformed RV
1487+
value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
1488+
value_var.name = f"{rv_var.name}_{transform.name}__"
14831489
value_var.tag.transform = transform
1484-
value_var.name = f"{value_var.name}_{transform.name}__"
14851490
if aesara.config.compute_test_value != "off":
14861491
value_var.tag.test_value = transform.forward(
1487-
value_var, *rv_var.owner.inputs
1492+
rv_var, *rv_var.owner.inputs
14881493
).tag.test_value
1494+
1495+
_add_future_warning_tag(value_var)
1496+
rv_var.tag.value_var = value_var
1497+
14891498
self.rvs_to_transforms[rv_var] = transform
14901499
self.rvs_to_values[rv_var] = value_var
14911500
self.values_to_rvs[value_var] = rv_var

0 commit comments

Comments
 (0)