Skip to content

Commit 7ef9b2e

Browse files
ndim_supp cases in sumto1 transform
1 parent 0372528 commit 7ef9b2e

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pymc/distributions/transforms.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class SumTo1(RVTransform):
104104

105105
name = "sumto1"
106106

107+
def __init__(self, ndim_supp=0):
108+
self.ndim_supp = ndim_supp
109+
107110
def backward(self, value, *inputs):
108111
remaining = 1 - at.sum(value[..., :], axis=-1, keepdims=True)
109112
return at.concatenate([value[..., :], remaining], axis=-1)
@@ -113,7 +116,10 @@ def forward(self, value, *inputs):
113116

114117
def log_jac_det(self, value, *inputs):
115118
y = at.zeros(value.shape)
116-
return at.sum(y, axis=-1, keepdims=True)
119+
if self.ndim_supp == 0:
120+
return at.sum(y, axis=-1, keepdims=True)
121+
else:
122+
return at.sum(y, axis=-1)
117123

118124

119125
class CholeskyCovPacked(RVTransform):

0 commit comments

Comments
 (0)