Skip to content

Commit ca150d4

Browse files
committed
Sum observations in model.datalogpt
Fixes #4803 and #4804
1 parent e8396bf commit ca150d4

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

pymc3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def varlogpt(self):
809809
@property
810810
def datalogpt(self):
811811
with self:
812-
factors = [logpt(obs, obs.tag.observations) for obs in self.observed_RVs]
812+
factors = [logpt_sum(obs, obs.tag.observations) for obs in self.observed_RVs]
813813

814814
# Convert random variables into their log-likelihood inputs and
815815
# apply their transforms, if any

pymc3/tests/test_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,3 +651,17 @@ def test_set_initval():
651651
y = pm.Normal("y", x, 1)
652652

653653
assert model.rvs_to_values[y] in model.initial_values
654+
655+
656+
def test_datalogpt_multiple_shapes():
657+
with pm.Model() as m:
658+
x = pm.Normal("x", 0, 1)
659+
z1 = pm.Potential("z1", x)
660+
z2 = pm.Potential("z2", at.full((1, 3), x))
661+
y1 = pm.Normal("y1", x, 1, observed=np.array([1]))
662+
y2 = pm.Normal("y2", x, 1, observed=np.array([1, 2]))
663+
y3 = pm.Normal("y3", x, 1, observed=np.array([1, 2, 3]))
664+
665+
# This would raise a TypeError, see #4803 and #4804
666+
x_val = m.rvs_to_values[x]
667+
m.datalogpt.eval({x_val: 0})

0 commit comments

Comments
 (0)