diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index ed0cfd7960..3824af860f 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -961,7 +961,7 @@ def log_jac_det(self, value, *inputs): N = N.astype(value.dtype) sum_value = pt.sum(value, -1, keepdims=True) value_sum_expanded = value + sum_value - value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1) + value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros_like(sum_value)], -1) logsumexp_value_expanded = pt.logsumexp(value_sum_expanded, -1, keepdims=True) res = pt.log(N) + (N * sum_value) - (N * logsumexp_value_expanded) return pt.sum(res, -1) @@ -977,7 +977,7 @@ def forward(self, value, *inputs): return pt.as_tensor_variable(value) def log_jac_det(self, value, *inputs): - return pt.zeros(value.shape) + return pt.zeros_like(value) class ChainedTransform(Transform):