Skip to content

Commit afad6ac

Browse files
committed
Refactor logpt and raise more informative ValueErrors
1 parent c28b9c8 commit afad6ac

File tree

4 files changed

+100
-108
lines changed

4 files changed

+100
-108
lines changed

pymc/distributions/logprob.py

Lines changed: 50 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@
2424
from aeppl.logprob import logcdf as logcdf_aeppl
2525
from aeppl.logprob import logprob as logp_aeppl
2626
from aeppl.transforms import TransformValuesOpt
27-
from aesara import config
2827
from aesara.graph.basic import graph_inputs, io_toposort
29-
from aesara.graph.op import Op, compute_test_value
30-
from aesara.tensor.random.op import RandomVariable
28+
from aesara.graph.op import Op
3129
from aesara.tensor.subtensor import (
3230
AdvancedIncSubtensor,
3331
AdvancedIncSubtensor1,
@@ -164,100 +162,86 @@ def logpt(
164162
# joint_logprob directly.
165163

166164
# If var is not a list make it one.
167-
if not isinstance(var, list):
165+
if not isinstance(var, (list, tuple)):
168166
var = [var]
169167

170-
# If logpt isn't provided values and the variable (provided in var)
171-
# is an RV, it is assumed that the tagged value var or observation is
172-
# the value variable for that particular RV.
168+
# If logpt isn't provided values it is assumed that the tagged value var or
169+
# observation is the value variable for that particular RV.
173170
if rv_values is None:
174171
rv_values = {}
175-
for _var in var:
176-
if isinstance(_var.owner.op, RandomVariable):
177-
rv_value_var = getattr(
178-
_var.tag, "observations", getattr(_var.tag, "value_var", _var)
179-
)
180-
rv_values = {_var: rv_value_var}
172+
for rv in var:
173+
value_var = getattr(rv.tag, "observations", getattr(rv.tag, "value_var", None))
174+
if value_var is None:
175+
raise ValueError(f"No value variable found for var {rv}")
176+
rv_values[rv] = value_var
177+
# Else we assume we were given a single rv and respective value
181178
elif not isinstance(rv_values, Mapping):
182-
# Else if we're given a single value and a single variable we assume a mapping among them.
183-
rv_values = (
184-
{var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)} if len(var) == 1 else {}
185-
)
186-
187-
# Since the filtering of logp graph is based on value variables
188-
# provided to this function
189-
if not rv_values:
190-
warnings.warn("No value variables provided the logp will be an empty graph")
179+
if len(var) == 1:
180+
rv_values = {var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)}
181+
else:
182+
raise ValueError("rv_values must be a dict if more than one var is requested")
191183

192184
if scaling:
193185
rv_scalings = {}
194-
for _var in var:
195-
rv_value_var = getattr(_var.tag, "observations", getattr(_var.tag, "value_var", _var))
196-
rv_scalings[rv_value_var] = _get_scaling(
197-
getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim
186+
for rv, value_var in rv_values.items():
187+
rv_scalings[value_var] = _get_scaling(
188+
getattr(rv.tag, "total_size", None), value_var.shape, value_var.ndim
198189
)
199190

200191
# Aeppl needs all rv-values pairs, not just that of the requested var.
201192
# Hence we iterate through the graph to collect them.
202193
tmp_rvs_to_values = rv_values.copy()
203-
transform_map = {}
204194
for node in io_toposort(graph_inputs(var), var):
205195
try:
206196
curr_vars = [node.default_output()]
207197
except ValueError:
208198
curr_vars = node.outputs
209199
for curr_var in curr_vars:
210-
rv_value_var = getattr(
200+
if curr_var in tmp_rvs_to_values:
201+
continue
202+
# Check if variable has a value variable
203+
value_var = getattr(
211204
curr_var.tag, "observations", getattr(curr_var.tag, "value_var", None)
212205
)
213-
if rv_value_var is None:
214-
continue
215-
rv_value = rv_values.get(curr_var, rv_value_var)
216-
tmp_rvs_to_values[curr_var] = rv_value
217-
# Along with value variables we also check for transforms if any.
218-
if hasattr(rv_value_var.tag, "transform") and transformed:
219-
transform_map[rv_value] = rv_value_var.tag.transform
206+
if value_var is not None:
207+
tmp_rvs_to_values[curr_var] = value_var
208+
209+
# After collecting all necessary rvs and values, we check for any value transforms
210+
transform_map = {}
211+
if transformed:
212+
for rv, value_var in tmp_rvs_to_values.items():
213+
if hasattr(value_var.tag, "transform"):
214+
transform_map[value_var] = value_var.tag.transform
215+
# If the provided value_variable does not have transform information, we
216+
# check if the original `rv.tag.value_var` does.
217+
# TODO: This logic should be replaced by an explicit dict of
218+
# `{value_var: transform}` similar to `rv_values`.
219+
else:
220+
original_value_var = getattr(rv.tag, "value_var", None)
221+
if original_value_var is not None and hasattr(original_value_var.tag, "transform"):
222+
transform_map[value_var] = original_value_var.tag.transform
220223

221224
transform_opt = TransformValuesOpt(transform_map)
222225
temp_logp_var_dict = factorized_joint_logprob(
223226
tmp_rvs_to_values, extra_rewrites=transform_opt, use_jacobian=jacobian, **kwargs
224227
)
225228

226229
# aeppl returns the logpt for every single value term we provided to it. This includes
227-
# the extra values we plugged in above so we need to filter those out.
230+
# the extra values we plugged in above, so we filter those we actually wanted in the
231+
# same order they were given in.
228232
logp_var_dict = {}
229-
for value_var, _logp in temp_logp_var_dict.items():
230-
if value_var in rv_values.values():
231-
logp_var_dict[value_var] = _logp
233+
for value_var in rv_values.values():
234+
logp_var_dict[value_var] = temp_logp_var_dict[value_var]
232235

233-
# If it's an empty dictionary the logp is None
234-
if not logp_var_dict:
235-
logp_var = None
236-
else:
237-
# Otherwise apply appropriate scalings and at.add and/or at.sum the
238-
# graphs accordingly.
239-
if scaling:
240-
for _value in logp_var_dict.keys():
241-
if _value in rv_scalings:
242-
logp_var_dict[_value] *= rv_scalings[_value]
243-
244-
if len(logp_var_dict) == 1:
245-
logp_var_dict = tuple(logp_var_dict.values())[0]
246-
if sum:
247-
logp_var = at.sum(logp_var_dict)
248-
else:
249-
logp_var = logp_var_dict
250-
else:
251-
if sum:
252-
logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()])
253-
else:
254-
logp_var = at.add(*logp_var_dict.values())
236+
if scaling:
237+
for value_var in logp_var_dict.keys():
238+
if value_var in rv_scalings:
239+
logp_var_dict[value_var] *= rv_scalings[value_var]
255240

256-
# Recompute test values for the changes introduced by the replacements
257-
# above.
258-
if config.compute_test_value != "off":
259-
for node in io_toposort(graph_inputs((logp_var,)), (logp_var,)):
260-
compute_test_value(node)
241+
if sum:
242+
logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()])
243+
else:
244+
logp_var = at.add(*logp_var_dict.values())
261245

262246
return logp_var
263247

pymc/tests/test_distributions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2521,9 +2521,11 @@ def test_continuous(self):
25212521
assert logpt(InfBoundedNormal, 0).eval() != -np.inf
25222522
assert logpt(InfBoundedNormal, 11).eval() != -np.inf
25232523

2524-
value = at.dscalar("x")
2524+
value = model.rvs_to_values[LowerNormalTransform]
25252525
assert logpt(LowerNormalTransform, value).eval({value: -1}) != -np.inf
2526+
value = model.rvs_to_values[UpperNormalTransform]
25262527
assert logpt(UpperNormalTransform, value).eval({value: 1}) != -np.inf
2528+
value = model.rvs_to_values[BoundedNormalTransform]
25272529
assert logpt(BoundedNormalTransform, value).eval({value: 0}) != -np.inf
25282530
assert logpt(BoundedNormalTransform, value).eval({value: 11}) != -np.inf
25292531

pymc/tests/test_distributions_random.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def random_polyagamma(*args, **kwargs):
4545
from pymc.distributions.continuous import get_tau_sigma, interpolated
4646
from pymc.distributions.discrete import _OrderedLogistic, _OrderedProbit
4747
from pymc.distributions.dist_math import clipped_beta_rvs
48-
from pymc.distributions.logprob import logpt
48+
from pymc.distributions.logprob import logp
4949
from pymc.distributions.multivariate import _OrderedMultinomial, quaddist_matrix
5050
from pymc.distributions.shape_utils import to_tuple
5151
from pymc.tests.helpers import SeededTest, select_by_precision
@@ -1626,8 +1626,8 @@ def test_errors(self):
16261626
rowcov=np.eye(3),
16271627
colcov=np.eye(3),
16281628
)
1629-
with pytest.raises(TypeError):
1630-
logpt(matrixnormal, aesara.tensor.ones((3, 3, 3)))
1629+
with pytest.raises(ValueError):
1630+
logp(matrixnormal, aesara.tensor.ones((3, 3, 3)))
16311631

16321632
with pm.Model():
16331633
with pytest.warns(FutureWarning):
@@ -1856,7 +1856,7 @@ def test_density_dist_without_random(self):
18561856
pm.DensityDist(
18571857
"density_dist",
18581858
mu,
1859-
logp=lambda value, mu: logpt(pm.Normal.dist(mu, 1, size=100), value),
1859+
logp=lambda value, mu: logp(pm.Normal.dist(mu, 1, size=100), value),
18601860
observed=np.random.randn(100),
18611861
initval=0,
18621862
)

pymc/tests/test_transforms.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pymc as pm
2424
import pymc.distributions.transforms as tr
2525

26-
from pymc.aesaraf import jacobian
26+
from pymc.aesaraf import floatX, jacobian
2727
from pymc.distributions import logpt
2828
from pymc.tests.checks import close_to, close_to_logical
2929
from pymc.tests.helpers import SeededTest
@@ -285,40 +285,46 @@ def build_model(self, distfam, params, size, transform, initval=None):
285285

286286
def check_transform_elementwise_logp(self, model):
287287
x = model.free_RVs[0]
288-
x0 = x.tag.value_var
289-
assert x.ndim == logpt(x, sum=False).ndim
288+
x_val_transf = x.tag.value_var
290289

291-
pt = model.initial_point
292-
array = np.random.randn(*pt[x0.name].shape)
293-
transform = x0.tag.transform
294-
logp_notrans = logpt(x, transform.backward(array, *x.owner.inputs), transformed=False)
290+
pt = model.recompute_initial_point(0)
291+
test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape))
292+
transform = x_val_transf.tag.transform
293+
test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval()
295294

296-
jacob_det = transform.log_jac_det(aesara.shared(array), *x.owner.inputs)
297-
assert logpt(x, sum=False).ndim == jacob_det.ndim
295+
# Create input variable with same dimensionality as untransformed test_array
296+
x_val_untransf = at.constant(test_array_untransf).type()
298297

299-
v1 = logpt(x, array, jacobian=False).eval()
300-
v2 = logp_notrans.eval()
298+
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
299+
assert logpt(x, sum=False).ndim == x.ndim == jacob_det.ndim
300+
301+
v1 = logpt(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf})
302+
v2 = logpt(x, x_val_untransf, transformed=False).eval({x_val_untransf: test_array_untransf})
301303
close_to(v1, v2, tol)
302304

303-
def check_vectortransform_elementwise_logp(self, model, vect_opt=0):
305+
def check_vectortransform_elementwise_logp(self, model):
304306
x = model.free_RVs[0]
305-
x0 = x.tag.value_var
306-
# TODO: For some reason the ndim relations
307-
# dont hold up here. But final log-probablity
308-
# values are what we expected.
309-
# assert (x.ndim - 1) == logpt(x, sum=False).ndim
310-
311-
pt = model.initial_point
312-
array = np.random.randn(*pt[x0.name].shape)
313-
transform = x0.tag.transform
314-
logp_nojac = logpt(x, transform.backward(array, *x.owner.inputs), transformed=False)
315-
316-
jacob_det = transform.log_jac_det(aesara.shared(array), *x.owner.inputs)
317-
# assert logpt(x).ndim == jacob_det.ndim
318-
307+
x_val_transf = x.tag.value_var
308+
309+
pt = model.recompute_initial_point(0)
310+
test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape))
311+
transform = x_val_transf.tag.transform
312+
test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval()
313+
314+
# Create input variable with same dimensionality as untransformed test_array
315+
x_val_untransf = at.constant(test_array_untransf).type()
316+
317+
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
318+
# Original distribution is univariate
319+
if x.owner.op.ndim_supp == 0:
320+
assert logpt(x, sum=False).ndim == x.ndim == (jacob_det.ndim + 1)
321+
# Original distribution is multivariate
322+
else:
323+
assert logpt(x, sum=False).ndim == (x.ndim - 1) == jacob_det.ndim
324+
325+
a = logpt(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf})
326+
b = logpt(x, x_val_untransf, transformed=False).eval({x_val_untransf: test_array_untransf})
319327
# Hack to get relative tolerance
320-
a = logpt(x, array.astype(aesara.config.floatX), jacobian=False).eval()
321-
b = logp_nojac.eval()
322328
close_to(a, b, np.abs(0.5 * (a + b) * tol))
323329

324330
@pytest.mark.parametrize(
@@ -406,7 +412,7 @@ def test_vonmises(self, mu, kappa, size):
406412
)
407413
def test_dirichlet(self, a, size):
408414
model = self.build_model(pm.Dirichlet, {"a": a}, size=size, transform=tr.simplex)
409-
self.check_vectortransform_elementwise_logp(model, vect_opt=1)
415+
self.check_vectortransform_elementwise_logp(model)
410416

411417
def test_normal_ordered(self):
412418
model = self.build_model(
@@ -416,7 +422,7 @@ def test_normal_ordered(self):
416422
initval=np.asarray([-1.0, 1.0, 4.0]),
417423
transform=tr.ordered,
418424
)
419-
self.check_vectortransform_elementwise_logp(model, vect_opt=0)
425+
self.check_vectortransform_elementwise_logp(model)
420426

421427
@pytest.mark.parametrize(
422428
"sd,size",
@@ -434,7 +440,7 @@ def test_half_normal_ordered(self, sd, size):
434440
initval=initval,
435441
transform=tr.Chain([tr.log, tr.ordered]),
436442
)
437-
self.check_vectortransform_elementwise_logp(model, vect_opt=0)
443+
self.check_vectortransform_elementwise_logp(model)
438444

439445
@pytest.mark.parametrize("lam,size", [(2.5, (2,)), (np.ones(3), (4, 3))])
440446
def test_exponential_ordered(self, lam, size):
@@ -446,7 +452,7 @@ def test_exponential_ordered(self, lam, size):
446452
initval=initval,
447453
transform=tr.Chain([tr.log, tr.ordered]),
448454
)
449-
self.check_vectortransform_elementwise_logp(model, vect_opt=0)
455+
self.check_vectortransform_elementwise_logp(model)
450456

451457
@pytest.mark.parametrize(
452458
"a,b,size",
@@ -468,7 +474,7 @@ def test_beta_ordered(self, a, b, size):
468474
initval=initval,
469475
transform=tr.Chain([tr.logodds, tr.ordered]),
470476
)
471-
self.check_vectortransform_elementwise_logp(model, vect_opt=0)
477+
self.check_vectortransform_elementwise_logp(model)
472478

473479
@pytest.mark.parametrize(
474480
"lower,upper,size",
@@ -491,7 +497,7 @@ def transform_params(*inputs):
491497
initval=initval,
492498
transform=tr.Chain([interval, tr.ordered]),
493499
)
494-
self.check_vectortransform_elementwise_logp(model, vect_opt=1)
500+
self.check_vectortransform_elementwise_logp(model)
495501

496502
@pytest.mark.parametrize("mu,kappa,size", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))])
497503
def test_vonmises_ordered(self, mu, kappa, size):
@@ -503,7 +509,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
503509
initval=initval,
504510
transform=tr.Chain([tr.circular, tr.ordered]),
505511
)
506-
self.check_vectortransform_elementwise_logp(model, vect_opt=0)
512+
self.check_vectortransform_elementwise_logp(model)
507513

508514
@pytest.mark.parametrize(
509515
"lower,upper,size,transform",
@@ -522,7 +528,7 @@ def test_uniform_other(self, lower, upper, size, transform):
522528
initval=initval,
523529
transform=transform,
524530
)
525-
self.check_vectortransform_elementwise_logp(model, vect_opt=1)
531+
self.check_vectortransform_elementwise_logp(model)
526532

527533
@pytest.mark.parametrize(
528534
"mu,cov,size,shape",
@@ -536,7 +542,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape):
536542
model = self.build_model(
537543
pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered
538544
)
539-
self.check_vectortransform_elementwise_logp(model, vect_opt=1)
545+
self.check_vectortransform_elementwise_logp(model)
540546

541547

542548
def test_triangular_transform():

0 commit comments

Comments
 (0)