diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 7280105d03..d5b20c5390 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1073,11 +1073,11 @@ def symbolic_normalizing_constant(self): for v in self.group if isinstance(v.owner.op, MinibatchRandomVariable) ] - + [1.0] # To avoid empty max + + [pm.floatX(1.0)] # To avoid empty max ) ) t = self.symbolic_single_sample(t) - return pm.floatX(t) + return t @node_property def symbolic_logq_not_scaled(self):