diff --git a/pymc/model.py b/pymc/model.py index 1080543d1a..953ace0614 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1460,14 +1460,10 @@ def create_value_var( """ if value_var is None: - value_var = rv_var.type() - value_var.name = rv_var.name - - if aesara.config.compute_test_value != "off": - value_var.tag.test_value = rv_var.tag.test_value - - _add_future_warning_tag(value_var) - rv_var.tag.value_var = value_var + untransformed_value_var = rv_var.type() + untransformed_value_var.name = rv_var.name + else: + untransformed_value_var = value_var # Make the value variable a transformed value variable, # if there's an applicable transform @@ -1475,13 +1471,23 @@ def create_value_var( transform = _default_transform(rv_var.owner.op, rv_var) if transform is not None and transform is not UNSET: + value_var = transform.forward(untransformed_value_var, *rv_var.owner.inputs).type() + value_var.name = f"{untransformed_value_var.name}_{transform.name}__" value_var.tag.transform = transform - value_var.name = f"{value_var.name}_{transform.name}__" if aesara.config.compute_test_value != "off": value_var.tag.test_value = transform.forward( value_var, *rv_var.owner.inputs ).tag.test_value self.named_vars[value_var.name] = value_var + else: + value_var = untransformed_value_var + + if aesara.config.compute_test_value != "off": + value_var.tag.test_value = rv_var.tag.test_value + + _add_future_warning_tag(value_var) + rv_var.tag.value_var = value_var + self.rvs_to_transforms[rv_var] = transform self.rvs_to_values[rv_var] = value_var self.values_to_rvs[value_var] = rv_var