From ccb9cfb32f9d4165b855dc12ee6c19e44191fe0f Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Tue, 4 Apr 2023 20:27:42 +0530 Subject: [PATCH 1/3] Remove joint_logprob function from tests.logprob.utils --- tests/logprob/test_censoring.py | 62 ++++++++------ tests/logprob/test_composite_logprob.py | 34 +++++--- tests/logprob/test_cumsum.py | 27 +++--- tests/logprob/test_mixture.py | 28 +++--- tests/logprob/test_rewriting.py | 7 +- tests/logprob/test_scan.py | 10 +-- tests/logprob/test_tensor.py | 66 +++++++++------ tests/logprob/test_transforms.py | 108 +++++++++++++----------- tests/logprob/utils.py | 27 ------ 9 files changed, 195 insertions(+), 174 deletions(-) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index cf92eaf0f8..4607dca45f 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -41,10 +41,10 @@ import scipy as sp import scipy.stats as st +from pymc import logp from pymc.logprob import factorized_joint_logprob from pymc.logprob.transforms import LogTransform, TransformValuesRewrite from pymc.testing import assert_no_rvs -from tests.logprob.utils import joint_logprob @pytensor.config.change_flags(compute_test_value="raise") @@ -55,10 +55,10 @@ def test_continuous_rv_clip(): cens_x_vv = cens_x_rv.clone() cens_x_vv.tag.test_value = 0 - logp = joint_logprob({cens_x_rv: cens_x_vv}) - assert_no_rvs(logp) + logprob = pt.sum(logp(cens_x_rv, cens_x_vv)) + assert_no_rvs(logprob) - logp_fn = pytensor.function([cens_x_vv], logp) + logp_fn = pytensor.function([cens_x_vv], logprob) ref_scipy = st.norm(0.5, 1) assert logp_fn(-3) == -np.inf @@ -75,10 +75,10 @@ def test_discrete_rv_clip(): cens_x_vv = cens_x_rv.clone() - logp = joint_logprob({cens_x_rv: cens_x_vv}) - assert_no_rvs(logp) + logprob = pt.sum(logp(cens_x_rv, cens_x_vv)) + assert_no_rvs(logprob) - logp_fn = pytensor.function([cens_x_vv], logp) + logp_fn = pytensor.function([cens_x_vv], logprob) ref_scipy = st.poisson(2) assert logp_fn(0) == -np.inf @@ -97,8 +97,8 @@ def test_one_sided_clip(): lb_cens_x_vv = lb_cens_x_rv.clone() ub_cens_x_vv = ub_cens_x_rv.clone() - lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv}) - ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv}) + lb_logp = pt.sum(logp(lb_cens_x_rv, lb_cens_x_vv)) + ub_logp = pt.sum(logp(ub_cens_x_rv, ub_cens_x_vv)) assert_no_rvs(lb_logp) assert_no_rvs(ub_logp) @@ -117,10 +117,10 @@ def test_useless_clip(): cens_x_vv = cens_x_rv.clone() - logp = joint_logprob({cens_x_rv: cens_x_vv}, sum=False) - assert_no_rvs(logp) + logprob = logp(cens_x_rv, cens_x_vv) + assert_no_rvs(logprob) - logp_fn = pytensor.function([cens_x_vv], logp) + logp_fn = pytensor.function([cens_x_vv], logprob) ref_scipy = st.norm(0.5, 1) np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2])) @@ -133,10 +133,12 @@ def test_random_clip(): lb_vv = lb_rv.clone() cens_x_vv = cens_x_rv.clone() - logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}, sum=False) - assert_no_rvs(logp) + logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) + logp_combined = pt.add(*logp.values()) - logp_fn = pytensor.function([lb_vv, cens_x_vv], logp) + assert_no_rvs(logp_combined) + + logp_fn = pytensor.function([lb_vv, cens_x_vv], logp_combined) res = logp_fn([0, -1], [-1, -1]) assert res[0] == -np.inf assert res[1] != -np.inf @@ -150,8 +152,10 @@ def test_broadcasted_clip_constant(): lb_vv = lb_rv.clone() cens_x_vv = cens_x_rv.clone() - logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) - assert_no_rvs(logp) + logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) + + assert_no_rvs(logp_combined) def test_broadcasted_clip_random(): @@ -162,8 +166,10 @@ def test_broadcasted_clip_random(): lb_vv = lb_rv.clone() cens_x_vv = cens_x_rv.clone() - logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) - assert_no_rvs(logp) + logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) + + assert_no_rvs(logp_combined) def test_fail_base_and_clip_have_values(): @@ -199,10 +205,11 @@ def test_deterministic_clipping(): x_vv = x_rv.clone() y_vv = y_rv.clone() - logp = joint_logprob({x_rv: x_vv, y_rv: y_vv}) - assert_no_rvs(logp) + logp = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) + assert_no_rvs(logp_combined) - logp_fn = pytensor.function([x_vv, y_vv], logp) + logp_fn = pytensor.function([x_vv, y_vv], logp_combined) assert np.isclose( logp_fn(-1, 1), st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1), @@ -216,10 +223,11 @@ def test_clip_transform(): cens_x_vv = cens_x_rv.clone() transform = TransformValuesRewrite({cens_x_vv: LogTransform()}) - logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform) + logp = factorized_joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) cens_x_vv_testval = -1 - obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval}) + obs_logp = logp_combined.eval({cens_x_vv: cens_x_vv_testval}) exp_logp = sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_vv_testval)) + cens_x_vv_testval assert np.isclose(obs_logp, exp_logp) @@ -236,8 +244,8 @@ def test_rounding(rounding_op): xr.name = "xr" xr_vv = xr.clone() - logp = joint_logprob({xr: xr_vv}, sum=False) - assert logp is not None + logprob = logp(xr, xr_vv) + assert logprob is not None x_sp = st.norm(loc, scale) if rounding_op == pt.round: @@ -250,6 +258,6 @@ def test_rounding(rounding_op): raise NotImplementedError() assert np.allclose( - logp.eval({xr_vv: test_value}), + logprob.eval({xr_vv: test_value}), expected_logp, ) diff --git a/tests/logprob/test_composite_logprob.py b/tests/logprob/test_composite_logprob.py index f5de69ac26..f6419608ce 100644 --- a/tests/logprob/test_composite_logprob.py +++ b/tests/logprob/test_composite_logprob.py @@ -39,10 +39,11 @@ import pytensor.tensor as pt import scipy.stats as st +from pymc import logp +from pymc.logprob.basic import factorized_joint_logprob from pymc.logprob.censoring import MeasurableClip from pymc.logprob.rewriting import construct_ir_fgraph from pymc.testing import assert_no_rvs -from tests.logprob.utils import joint_logprob def test_scalar_clipped_mixture(): @@ -61,9 +62,10 @@ def test_scalar_clipped_mixture(): idxs_vv = idxs.clone() idxs_vv.name = "idxs_val" - logp = joint_logprob({idxs: idxs_vv, mix: mix_vv}) + logp = factorized_joint_logprob({idxs: idxs_vv, mix: mix_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) - logp_fn = pytensor.function([idxs_vv, mix_vv], logp) + logp_fn = pytensor.function([idxs_vv, mix_vv], logp_combined) assert logp_fn(0, 0.4) == -np.inf assert np.isclose(logp_fn(0, 0.5), st.norm.logcdf(0.5, 1) + np.log(0.6)) assert np.isclose(logp_fn(0, 1.3), st.norm.logpdf(1.3, 1) + np.log(0.6)) @@ -98,8 +100,12 @@ def test_nested_scalar_mixtures(): idxs12_vv = idxs12.clone() mix12_vv = mix12.clone() - logp = joint_logprob({idxs1: idxs1_vv, idxs2: idxs2_vv, idxs12: idxs12_vv, mix12: mix12_vv}) - logp_fn = pytensor.function([idxs1_vv, idxs2_vv, idxs12_vv, mix12_vv], logp) + logp = factorized_joint_logprob( + {idxs1: idxs1_vv, idxs2: idxs2_vv, idxs12: idxs12_vv, mix12: mix12_vv} + ) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) + + logp_fn = pytensor.function([idxs1_vv, idxs2_vv, idxs12_vv, mix12_vv], logp_combined) expected_mu_logpdf = st.norm.logpdf(0) + np.log(0.5) * 3 assert np.isclose(logp_fn(0, 0, 0, -50), expected_mu_logpdf) @@ -144,9 +150,9 @@ def test_shifted_cumsum(): y.name = "y" y_vv = y.clone() - logp = joint_logprob({y: y_vv}) + logprob = logp(y, y_vv) assert np.isclose( - logp.eval({y_vv: np.arange(5) + 1 + 5}), + logprob.eval({y_vv: np.arange(5) + 1 + 5}).sum(), st.norm.logpdf(1) * 5, ) @@ -157,8 +163,8 @@ def test_double_log_transform_rv(): y_rv.name = "y" y_vv = y_rv.clone() - logp = joint_logprob({y_rv: y_vv}, sum=False) - logp_fn = pytensor.function([y_vv], logp) + logprob = logp(y_rv, y_vv) + logp_fn = pytensor.function([y_vv], logprob) log_log_y_val = np.asarray(0.5) log_y_val = np.exp(log_log_y_val) @@ -178,9 +184,9 @@ def test_affine_transform_rv(): y_rv.name = "y" y_vv = y_rv.clone() - logp = joint_logprob({y_rv: y_vv}, sum=False) - assert_no_rvs(logp) - logp_fn = pytensor.function([loc, scale, y_vv], logp) + logprob = logp(y_rv, y_vv) + assert_no_rvs(logprob) + logp_fn = pytensor.function([loc, scale, y_vv], logprob) loc_test_val = 4.0 scale_test_val = np.full(rv_size, 0.5) @@ -200,8 +206,8 @@ def test_affine_log_transform_rv(): y_vv = y_rv.clone() - logp = joint_logprob({y_rv: y_vv}, sum=False) - logp_fn = pytensor.function([a, b, y_vv], logp) + logprob = logp(y_rv, y_vv) + logp_fn = pytensor.function([a, b, y_vv], logprob) a_val = -1.5 b_val = 3.0 diff --git a/tests/logprob/test_cumsum.py b/tests/logprob/test_cumsum.py index cdc43f162c..94ea39a5fd 100644 --- a/tests/logprob/test_cumsum.py +++ b/tests/logprob/test_cumsum.py @@ -40,8 +40,9 @@ import pytest import scipy.stats as st +from pymc import logp +from pymc.logprob.basic import factorized_joint_logprob from pymc.testing import assert_no_rvs -from tests.logprob.utils import joint_logprob @pytest.mark.parametrize( @@ -59,12 +60,12 @@ def test_normal_cumsum(size, axis): rv = pt.random.normal(0, 1, size=size).cumsum(axis) vv = rv.clone() - logp = joint_logprob({rv: vv}) - assert_no_rvs(logp) + logprob = logp(rv, vv) + assert_no_rvs(logprob) assert np.isclose( st.norm(0, 1).logpdf(np.ones(size)).sum(), - logp.eval({vv: np.ones(size).cumsum(axis)}), + logprob.eval({vv: np.ones(size).cumsum(axis)}).sum(), ) @@ -83,12 +84,12 @@ def test_normal_cumsum(size, axis): def test_bernoulli_cumsum(size, axis): rv = pt.random.bernoulli(0.9, size=size).cumsum(axis) vv = rv.clone() - logp = joint_logprob({rv: vv}) - assert_no_rvs(logp) + logprob = logp(rv, vv) + assert_no_rvs(logprob) assert np.isclose( st.bernoulli(0.9).logpmf(np.ones(size)).sum(), - logp.eval({vv: np.ones(size, int).cumsum(axis)}), + logprob.eval({vv: np.ones(size, int).cumsum(axis)}).sum(), ) @@ -97,7 +98,7 @@ def test_destructive_cumsum_fails(): x_rv = pt.random.normal(size=(2, 2, 2)).cumsum() x_vv = x_rv.clone() with pytest.raises(RuntimeError, match="could not be derived"): - joint_logprob({x_rv: x_vv}) + factorized_joint_logprob({x_rv: x_vv}) def test_deterministic_cumsum(): @@ -108,11 +109,13 @@ def test_deterministic_cumsum(): x_vv = x_rv.clone() y_vv = y_rv.clone() - logp = joint_logprob({x_rv: x_vv, y_rv: y_vv}) - assert_no_rvs(logp) - logp_fn = pytensor.function([x_vv, y_vv], logp) + logp = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) + assert_no_rvs(logp_combined) + + logp_fn = pytensor.function([x_vv, y_vv], logp_combined) assert np.isclose( - logp_fn(np.ones(5), np.arange(5) + 1), + logp_fn(np.ones(5), np.arange(5) + 1).sum(), st.norm(1, 1).logpdf(1) * 10, ) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 6534b22d85..8767e97d66 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,7 +52,7 @@ from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.utils import dirac_delta from pymc.testing import assert_no_rvs -from tests.logprob.utils import joint_logprob, scipy_logprob +from tests.logprob.utils import scipy_logprob def test_mixture_basics(): @@ -101,7 +101,7 @@ def create_mix_model(size, axis): i_vv = env["i_vv"] M_rv = env["M_rv"] m_vv = env["m_vv"] - joint_logprob({M_rv: m_vv, I_rv: i_vv}) + factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}) @pytensor.config.change_flags(compute_test_value="warn") @@ -134,9 +134,10 @@ def test_compute_test_value(op_constructor): del M_rv.tag.test_value - M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + M_logp = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}) + M_logp_combined = pt.add(*M_logp.values()) - assert isinstance(M_logp.tag.test_value, np.ndarray) + assert isinstance(M_logp_combined.tag.test_value, np.ndarray) @pytest.mark.parametrize( @@ -179,13 +180,14 @@ def test_hetero_mixture_binomial(p_val, size, supported): m_vv.name = "m" if supported: - M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + M_logp = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}) + M_logp_combined = pt.add(*M_logp.values()) else: with pytest.raises(RuntimeError, match="could not be derived: {m}"): - joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False) + factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv}) return - M_logp_fn = pytensor.function([p_at, m_vv, i_vv], M_logp) + M_logp_fn = pytensor.function([p_at, m_vv, i_vv], M_logp_combined) assert_no_rvs(M_logp_fn.maker.fgraph.outputs[0]) @@ -936,14 +938,16 @@ def test_switch_mixture(): assert equal_computations(fgraph.outputs, fgraph2.outputs) - z1_logp = joint_logprob({Z1_rv: z_vv, I_rv: i_vv}) - z2_logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv}) + z1_logp = factorized_joint_logprob({Z1_rv: z_vv, I_rv: i_vv}) + z2_logp = factorized_joint_logprob({Z2_rv: z_vv, I_rv: i_vv}) + z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()]) + z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()]) # below should follow immediately from the equal_computations assertion above - assert equal_computations([z1_logp], [z2_logp]) + assert equal_computations([z1_logp_combined], [z2_logp_combined]) - np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0})) - np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0})) + np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 0})) + np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 0})) def test_ifelse_mixture_one_component(): diff --git a/tests/logprob/test_rewriting.py b/tests/logprob/test_rewriting.py index 7107b18d3f..b8836bbce5 100644 --- a/tests/logprob/test_rewriting.py +++ b/tests/logprob/test_rewriting.py @@ -50,9 +50,9 @@ Subtensor, ) +from pymc.logprob.basic import factorized_joint_logprob from pymc.logprob.rewriting import local_lift_DiracDelta from pymc.logprob.utils import DiracDelta, dirac_delta -from tests.logprob.utils import joint_logprob def test_local_lift_DiracDelta(): @@ -120,9 +120,10 @@ def test_joint_logprob_incsubtensor(indices, size): assert isinstance(Y_rv.owner.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)) - Y_rv_logp = joint_logprob({Y_rv: y_value_var}, sum=False) + Y_rv_logp = factorized_joint_logprob({Y_rv: y_value_var}) + Y_rv_logp_combined = pt.add(*Y_rv_logp.values()) - obs_logps = Y_rv_logp.eval({y_value_var: y_val}) + obs_logps = Y_rv_logp_combined.eval({y_value_var: y_val}) y_val_idx = y_val.copy() y_val_idx[indices] = data diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index b617edb49a..748a4405fc 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -52,7 +52,6 @@ get_random_outer_outputs, ) from pymc.testing import assert_no_rvs -from tests.logprob.utils import joint_logprob def create_inner_out_logp(value_map): @@ -336,7 +335,8 @@ def scan_fn(mus_t, sigma_t, Gamma_t): s_vv = S_rv.clone() s_vv.name = "s" - y_logp = joint_logprob({Y_rv: y_vv, S_rv: s_vv, Gamma_rv: Gamma_vv}) + y_logp = factorized_joint_logprob({Y_rv: y_vv, S_rv: s_vv, Gamma_rv: Gamma_vv}) + y_logp_combined = pt.sum([pt.sum(factor) for factor in y_logp.values()]) y_val = np.arange(10) s_val = np.array([0, 1, 0, 1, 1, 0, 0, 0, 1, 1]) @@ -350,7 +350,7 @@ def scan_fn(mus_t, sigma_t, Gamma_t): Gamma_vv: Gamma_val, } - y_logp_fn = pytensor.function(list(test_point.keys()), y_logp) + y_logp_fn = pytensor.function(list(test_point.keys()), y_logp_combined) assert_no_rvs(y_logp_fn.maker.fgraph.outputs[0]) @@ -381,7 +381,7 @@ def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t): assert_no_rvs(y_logp_ref) - y_logp_val = y_logp.eval(test_point) + y_logp_val = y_logp_combined.eval(test_point) y_logp_ref_val = y_logp_ref.eval(test_point) @@ -451,7 +451,7 @@ def test_mode_is_kept(remove_asserts): ) x.name = "x" x_vv = x.clone() - x_logp = pytensor.function([x_vv], joint_logprob({x: x_vv})) + x_logp = pytensor.function([x_vv], pt.sum(logp(x, x_vv))) x_test_val = np.full((10,), -1) if remove_asserts: diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index e44d9809fc..64b4cdf6e1 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -45,11 +45,10 @@ from pytensor.tensor.extra_ops import BroadcastTo from scipy import stats as st -from pymc.logprob import factorized_joint_logprob +from pymc.logprob.basic import factorized_joint_logprob, logp from pymc.logprob.rewriting import logprob_rewrites_db from pymc.logprob.tensor import naive_bcast_rv_lift from pymc.testing import assert_no_rvs -from tests.logprob.utils import joint_logprob def test_naive_bcast_rv_lift(): @@ -91,14 +90,16 @@ def test_bcast_rv_logp(): broadcasted_x_rv.name = "broadcasted_x" broadcasted_x_vv = broadcasted_x_rv.clone() - logp = joint_logprob({broadcasted_x_rv: broadcasted_x_vv}, sum=False) - valid_logp = logp.eval({broadcasted_x_vv: [0, 0]}) + logp = factorized_joint_logprob({broadcasted_x_rv: broadcasted_x_vv}) + logp_combined = pt.add(*logp.values()) + valid_logp = logp_combined.eval({broadcasted_x_vv: [0, 0]}) + assert valid_logp.shape == () assert np.isclose(valid_logp, st.norm.logpdf(0)) # It's not possible for broadcasted dimensions to have different values # This should either raise or return -inf - invalid_logp = logp.eval({broadcasted_x_vv: [0, 1]}) + invalid_logp = logp_combined.eval({broadcasted_x_vv: [0, 1]}) assert invalid_logp == -np.inf @@ -114,15 +115,19 @@ def test_measurable_make_vector(): base3_vv = base3_rv.clone() y_vv = y_rv.clone() - ref_logp = joint_logprob({base1_rv: base1_vv, base2_rv: base2_vv, base3_rv: base3_vv}) - make_vector_logp = joint_logprob({y_rv: y_vv}, sum=False) + ref_logp = factorized_joint_logprob( + {base1_rv: base1_vv, base2_rv: base2_vv, base3_rv: base3_vv} + ) + ref_logp_combined = pt.sum([pt.sum(factor) for factor in ref_logp.values()]) + + make_vector_logp = logp(y_rv, y_vv) base1_testval = base1_rv.eval() base2_testval = base2_rv.eval() base3_testval = base3_rv.eval() y_testval = np.stack((base1_testval, base2_testval, base3_testval)) - ref_logp_eval_eval = ref_logp.eval( + ref_logp_eval_eval = ref_logp_combined.eval( {base1_vv: base1_testval, base2_vv: base2_testval, base3_vv: base3_testval} ) make_vector_logp_eval = make_vector_logp.eval({y_vv: y_testval}) @@ -151,21 +156,25 @@ def test_measurable_make_vector_interdependent(reverse): x_vv = x.clone() ys_vv = ys.clone() - logp = joint_logprob({x: x_vv, ys: ys_vv}) - assert_no_rvs(logp) + logp = factorized_joint_logprob({x: x_vv, ys: ys_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) + assert_no_rvs(logp_combined) y0_vv = y_rvs[0].clone() y1_vv = y_rvs[1].clone() y2_vv = y_rvs[2].clone() - ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv}) + ref_logp = factorized_joint_logprob( + {x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv} + ) + ref_logp_combined = pt.sum([pt.sum(factor) for factor in ref_logp.values()]) rng = np.random.default_rng() x_vv_test = rng.normal() ys_vv_test = rng.normal(size=3) np.testing.assert_allclose( - logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}), - ref_logp.eval( + logp_combined.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}).sum(), + ref_logp_combined.eval( {x_vv: x_vv_test, y0_vv: ys_vv_test[0], y1_vv: ys_vv_test[1], y2_vv: ys_vv_test[2]} ), ) @@ -191,21 +200,25 @@ def test_measurable_join_interdependent(reverse): x_vv = x.clone() ys_vv = ys.clone() - logp = joint_logprob({x: x_vv, ys: ys_vv}) - assert_no_rvs(logp) + logp = factorized_joint_logprob({x: x_vv, ys: ys_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) + assert_no_rvs(logp_combined) y0_vv = y_rvs[0].clone() y1_vv = y_rvs[1].clone() y2_vv = y_rvs[2].clone() - ref_logp = joint_logprob({x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv}) + ref_logp = factorized_joint_logprob( + {x: x_vv, y_rvs[0]: y0_vv, y_rvs[1]: y1_vv, y_rvs[2]: y2_vv} + ) + ref_logp_combined = pt.sum([pt.sum(factor) for factor in ref_logp.values()]) rng = np.random.default_rng() x_vv_test = rng.normal() ys_vv_test = rng.normal(size=(3, 2)) np.testing.assert_allclose( - logp.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}), - ref_logp.eval( + logp_combined.eval({x_vv: x_vv_test, ys_vv: ys_vv_test}), + ref_logp_combined.eval( { x_vv: x_vv_test, y0_vv: ys_vv_test[0:1], @@ -246,7 +259,7 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate): base_logps = pt.concatenate(base_logps, axis=axis) else: base_logps = pt.stack(base_logps, axis=axis) - y_logp = joint_logprob({y_rv: y_vv}, sum=False) + y_logp = logp(y_rv, y_vv) assert_no_rvs(y_logp) base1_testval = base1_rv.eval() @@ -314,7 +327,7 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis else: axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1) base_logps = pt.stack(base_logps, axis=axis_norm - 1) - y_logp = joint_logprob({y_rv: y_vv}, sum=False) + y_logp = y_logp = logp(y_rv, y_vv) assert_no_rvs(y_logp) base1_testval = base1_rv.eval() @@ -336,7 +349,7 @@ def test_join_mixed_ndim_supp(): y_vv = y_rv.clone() with pytest.raises(ValueError, match="Joined logps have different number of dimensions"): - joint_logprob({y_rv: y_vv}) + logp(y_rv, y_vv) @pytensor.config.change_flags(cxx="") @@ -375,17 +388,18 @@ def test_measurable_dimshuffle(ds_order, multivariate): else: logp_ds_order = ds_order - ref_logp = joint_logprob({base_rv: base_vv}, sum=False).dimshuffle(logp_ds_order) + ref_logp = logp(base_rv, base_vv).dimshuffle(logp_ds_order) # Disable local_dimshuffle_rv_lift to test fallback Aeppl rewrite ir_rewriter = logprob_rewrites_db.query( RewriteDatabaseQuery(include=["basic"]).excluding("dimshuffle_lift") ) - ds_logp = joint_logprob({ds_rv: ds_vv}, sum=False, ir_rewriter=ir_rewriter) - assert ds_logp is not None + ds_logp = factorized_joint_logprob({ds_rv: ds_vv}, ir_rewriter=ir_rewriter) + ds_logp_combined = pt.add(*ds_logp.values()) + assert ds_logp_combined is not None ref_logp_fn = pytensor.function([base_vv], ref_logp) - ds_logp_fn = pytensor.function([ds_vv], ds_logp) + ds_logp_fn = pytensor.function([ds_vv], ds_logp_combined) base_test_value = base_rv.eval() ds_test_value = pt.constant(base_test_value).dimshuffle(ds_order).eval() @@ -412,4 +426,4 @@ def test_unmeargeable_dimshuffles(): w_vv = w.clone() # TODO: Check that logp is correct if this type of graphs is ever supported with pytest.raises(RuntimeError, match="could not be derived"): - joint_logprob({w: w_vv}) + factorized_joint_logprob({w: w_vv}) diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index ba24fadee1..22912d0928 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -49,7 +49,7 @@ from pymc.distributions.transforms import _default_transform, log, logodds from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob -from pymc.logprob.basic import factorized_joint_logprob +from pymc.logprob.basic import factorized_joint_logprob, logp from pymc.logprob.transforms import ( ChainedTransform, ExpTransform, @@ -64,7 +64,6 @@ transformed_variable, ) from pymc.testing import assert_no_rvs -from tests.logprob.utils import joint_logprob class DirichletScipyDist: @@ -235,11 +234,14 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size): transform = _default_transform(a.owner.op, a) transform_rewrite = TransformValuesRewrite({a_value_var: transform}) - res = joint_logprob({a: a_value_var, b: b_value_var}, extra_rewrites=transform_rewrite) + res = factorized_joint_logprob( + {a: a_value_var, b: b_value_var}, extra_rewrites=transform_rewrite + ) + res_combined = pt.sum([pt.sum(factor) for factor in res.values()]) test_val_rng = np.random.RandomState(3238) - logp_vals_fn = pytensor.function([a_value_var, b_value_var], res) + logp_vals_fn = pytensor.function([a_value_var, b_value_var], res_combined) a_forward_fn = pytensor.function([a_value_var], transform.forward(a_value_var, *a.owner.inputs)) a_backward_fn = pytensor.function( @@ -308,12 +310,13 @@ def test_simple_transformed_logprob_nojac(use_jacobian): x_vv.name = "x" transform_rewrite = TransformValuesRewrite({x_vv: log}) - tr_logp = joint_logprob( + tr_logp = factorized_joint_logprob( {X_rv: x_vv}, extra_rewrites=transform_rewrite, use_jacobian=use_jacobian ) + tr_logp_combined = pt.sum([pt.sum(factor) for factor in tr_logp.values()]) assert np.isclose( - tr_logp.eval({x_vv: np.log(2.5)}), + tr_logp_combined.eval({x_vv: np.log(2.5)}), sp.stats.halfnorm(0, 3).logpdf(2.5) + (np.log(2.5) if use_jacobian else 0.0), ) @@ -366,13 +369,14 @@ def test_hierarchical_uniform_transform(): x: _default_transform(x_rv.owner.op, x_rv), } ) - logp = joint_logprob( + logp = factorized_joint_logprob( {lower_rv: lower, upper_rv: upper, x_rv: x}, extra_rewrites=transform_rewrite, ) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) - assert_no_rvs(logp) - assert not np.isinf(logp.eval({lower: -10, upper: 20, x: -20})) + assert_no_rvs(logp_combined) + assert not np.isinf(logp_combined.eval({lower: -10, upper: 20, x: -20})) def test_nondefault_transforms(): @@ -392,10 +396,11 @@ def test_nondefault_transforms(): } ) - logp = joint_logprob( + logp = factorized_joint_logprob( {loc_rv: loc, scale_rv: scale, x_rv: x}, extra_rewrites=transform_rewrite, ) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) # Check numerical evaluation matches with expected transforms loc_val = 0 @@ -413,7 +418,7 @@ def test_nondefault_transforms(): exp_logp += x_val_tr # log log_jac_det assert np.isclose( - logp.eval({loc: loc_val, scale: scale_val_tr, x: x_val_tr}), + logp_combined.eval({loc: loc_val, scale: scale_val_tr, x: x_val_tr}), exp_logp, ) @@ -429,13 +434,14 @@ def test_default_transform_multiout(): transform_rewrite = TransformValuesRewrite({x: None}) - logp = joint_logprob( + logp = factorized_joint_logprob( {x_rv: x}, extra_rewrites=transform_rewrite, ) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) assert np.isclose( - logp.eval({x: 1}), + logp_combined.eval({x: 1}), sp.stats.norm(0, 1).logpdf(1), ) @@ -481,7 +487,8 @@ def test_nondefault_transform_multiout(transform_x, transform_y, multiout_measur } ) - logp = joint_logprob({x: x_vv, y: y_vv}, extra_rewrites=transform_rewrite) + logp = factorized_joint_logprob({x: x_vv, y: y_vv}, extra_rewrites=transform_rewrite) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) x_vv_test = np.random.normal() y_vv_test = np.abs(np.random.normal()) @@ -497,7 +504,9 @@ def test_nondefault_transform_multiout(transform_x, transform_y, multiout_measur else: expected_logp += np.log(y_vv_test) + 2 - np.log(y_vv_test) - np.testing.assert_almost_equal(logp.eval({x_vv: x_vv_test, y_vv: y_vv_test}), expected_logp) + np.testing.assert_almost_equal( + logp_combined.eval({x_vv: x_vv_test, y_vv: y_vv_test}), expected_logp + ) def test_TransformValuesMapping(): @@ -547,31 +556,33 @@ def test_mixture_transform(): y_vv = Y_rv.clone() y_vv.name = "y" - logp_no_trans = joint_logprob( + logp_no_trans = factorized_joint_logprob( {Y_rv: y_vv, I_rv: i_vv}, ) + logp_no_trans_comb = pt.sum([pt.sum(factor) for factor in logp_no_trans.values()]) transform_rewrite = TransformValuesRewrite({y_vv: LogTransform()}) with pytest.warns(None) as record: # This shouldn't raise any warnings - logp_trans = joint_logprob( + logp_trans = factorized_joint_logprob( {Y_rv: y_vv, I_rv: i_vv}, extra_rewrites=transform_rewrite, use_jacobian=False, ) + logp_trans_combined = pt.sum([pt.sum(factor) for factor in logp_trans.values()]) assert not record.list # The untransformed graph should be the same as the transformed graph after # replacing the `Y_rv` value variable with a transformed version of itself - logp_nt_fg = FunctionGraph(outputs=[logp_no_trans], clone=False) + logp_nt_fg = FunctionGraph(outputs=[logp_no_trans_comb], clone=False) y_trans = transformed_variable(pt.exp(y_vv), y_vv) y_trans.name = "y_log" logp_nt_fg.replace(y_vv, y_trans) logp_nt = logp_nt_fg.outputs[0] - assert equal_computations([logp_nt], [logp_trans]) + assert equal_computations([logp_nt], [logp_trans_combined]) def test_invalid_interval_transform(): @@ -637,8 +648,8 @@ def test_exp_transform_rv(): y_rv.name = "y" y_vv = y_rv.clone() - logp = joint_logprob({y_rv: y_vv}, sum=False) - logp_fn = pytensor.function([y_vv], logp) + logprob = logp(y_rv, y_vv) + logp_fn = pytensor.function([y_vv], logprob) y_val = [-2.0, 0.1, 0.3] np.testing.assert_allclose( @@ -653,8 +664,8 @@ def test_log_transform_rv(): y_rv.name = "y" y_vv = y_rv.clone() - logp = joint_logprob({y_rv: y_vv}, sum=False) - logp_fn = pytensor.function([y_vv], logp) + logprob = logp(y_rv, y_vv) + logp_fn = pytensor.function([y_vv], logprob) y_val = [0.1, 0.3] np.testing.assert_allclose( @@ -680,9 +691,9 @@ def test_loc_transform_rv(rv_size, loc_type, addition): y_rv.name = "y" y_vv = y_rv.clone() - logp = joint_logprob({y_rv: y_vv}, sum=False) - assert_no_rvs(logp) - logp_fn = pytensor.function([loc, y_vv], logp) + logprob = logp(y_rv, y_vv) + assert_no_rvs(logprob) + logp_fn = pytensor.function([loc, y_vv], logprob) loc_test_val = np.full(rv_size, 4.0) y_test_val = np.full(rv_size, 1.0) @@ -710,9 +721,9 @@ def test_scale_transform_rv(rv_size, scale_type, product): y_rv.name = "y" y_vv = y_rv.clone() - logp = joint_logprob({y_rv: y_vv}, sum=False) - assert_no_rvs(logp) - logp_fn = pytensor.function([scale, y_vv], logp) + logprob = logp(y_rv, y_vv) + assert_no_rvs(logprob) + logp_fn = pytensor.function([scale, y_vv], logprob) scale_test_val = np.full(rv_size, 4.0) y_test_val = np.full(rv_size, 1.0) @@ -730,9 +741,10 @@ def test_transformed_rv_and_value(): transform_rewrite = TransformValuesRewrite({y_vv: LogTransform()}) - logp = joint_logprob({y_rv: y_vv}, extra_rewrites=transform_rewrite) - assert_no_rvs(logp) - logp_fn = pytensor.function([y_vv], logp) + logp = factorized_joint_logprob({y_rv: y_vv}, extra_rewrites=transform_rewrite) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) + assert_no_rvs(logp_combined) + logp_fn = pytensor.function([y_vv], logp_combined) y_test_val = -5 @@ -750,7 +762,7 @@ def test_loc_transform_multiple_rvs_fails1(): y = y_rv.clone() with pytest.raises(RuntimeError, match="could not be derived"): - joint_logprob({y_rv: y}) + factorized_joint_logprob({y_rv: y}) def test_nested_loc_transform_multiple_rvs_fails2(): @@ -761,19 +773,19 @@ def test_nested_loc_transform_multiple_rvs_fails2(): y = y_rv.clone() with pytest.raises(RuntimeError, match="could not be derived"): - joint_logprob({y_rv: y}) + factorized_joint_logprob({y_rv: y}) def test_discrete_rv_unary_transform_fails(): y_rv = pt.exp(pt.random.poisson(1)) with pytest.raises(RuntimeError, match="could not be derived"): - joint_logprob({y_rv: y_rv.clone()}) + factorized_joint_logprob({y_rv: y_rv.clone()}) def test_discrete_rv_multinary_transform_fails(): y_rv = 5 + pt.random.poisson(1) with pytest.raises(RuntimeError, match="could not be derived"): - joint_logprob({y_rv: y_rv.clone()}) + factorized_joint_logprob({y_rv: y_rv.clone()}) @pytest.mark.xfail(reason="Check not implemented yet") @@ -784,8 +796,8 @@ def test_invalid_broadcasted_transform_rv_fails(): y_vv = y_rv.clone() # This logp derivation should fail or count only once the values that are broadcasted - logp = joint_logprob({y_rv: y_vv}, sum=False) - assert logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == () + logprob = logp(y_rv, y_vv) + assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == () @pytest.mark.parametrize("numerator", (1.0, 2.0)) @@ -796,7 +808,7 @@ def test_reciprocal_rv_transform(numerator): x_rv.name = "x" x_vv = x_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) x_test_val = np.r_[-0.5, 1.5] assert np.allclose( @@ -811,7 +823,7 @@ def test_sqr_transform(): x_rv.name = "x" x_vv = x_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) x_test_val = np.r_[-0.5, 0.5, 1, 2.5] assert np.allclose( @@ -826,7 +838,7 @@ def test_sqrt_transform(): x_rv.name = "x" x_vv = x_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) x_test_val = np.r_[-2.5, 0.5, 1, 2.5] assert np.allclose( @@ -842,7 +854,7 @@ def test_negative_value_odd_power_transform(power): x_rv.name = "x" x_vv = x_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) assert np.isfinite(x_logp_fn(1)) assert np.isfinite(x_logp_fn(-1)) @@ -855,7 +867,7 @@ def test_negative_value_even_power_transform(power): x_rv.name = "x" x_vv = x_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) assert np.isfinite(x_logp_fn(1)) assert np.isneginf(x_logp_fn(-1)) @@ -868,7 +880,7 @@ def test_negative_value_frac_power_transform(power): x_rv.name = "x" x_vv = x_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) + x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) assert np.isfinite(x_logp_fn(2.5)) assert np.isneginf(x_logp_fn(-2.5)) @@ -881,8 +893,8 @@ def test_absolute_transform(test_val): x_vv = x_rv.clone() y_vv = y_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False)) - y_logp_fn = pytensor.function([y_vv], joint_logprob({y_rv: y_vv}, sum=False)) + x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) + y_logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv)) assert np.allclose(x_logp_fn(test_val), y_logp_fn(test_val)) @@ -892,7 +904,7 @@ def test_negated_rv_transform(): x_rv.name = "x" x_vv = x_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv})) + x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv))) assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5)) @@ -903,7 +915,7 @@ def test_subtracted_rv_transform(): x_rv.name = "x" x_vv = x_rv.clone() - x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv})) + x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv))) assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0)) diff --git a/tests/logprob/utils.py b/tests/logprob/utils.py index 2218d06044..368b77a5f6 100644 --- a/tests/logprob/utils.py +++ b/tests/logprob/utils.py @@ -47,33 +47,6 @@ from pymc.logprob.utils import ignore_logprob -def joint_logprob(*args, sum: bool = True, **kwargs) -> Optional[TensorVariable]: - """Create a graph representing the joint log-probability/measure of a graph. - - This function calls `factorized_joint_logprob` and returns the combined - log-probability factors as a single graph. - - Parameters - ---------- - sum: bool - If ``True`` each factor is collapsed to a scalar via ``sum`` before - being joined with the remaining factors. This may be necessary to - avoid incorrect broadcasting among independent factors. - - """ - logprob = factorized_joint_logprob(*args, **kwargs) - if not logprob: - return None - if len(logprob) == 1: - logprob = tuple(logprob.values())[0] - if sum: - return pt.sum(logprob) - return logprob - if sum: - return pt.sum([pt.sum(factor) for factor in logprob.values()]) - return pt.add(*logprob.values()) - - def simulate_poiszero_hmm( N, mu=10.0, pi_0_a=np.r_[1, 1], p_0_a=np.r_[5, 1], p_1_a=np.r_[1, 1], seed=None ): From ace98925501646f9fbdb9876692cb156f29c69da Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Tue, 4 Apr 2023 22:35:37 +0530 Subject: [PATCH 2/3] Move the joint logprob test for subtensors to test_mixture.py --- tests/logprob/test_mixture.py | 62 ++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 8767e97d66..465ecabe79 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -45,7 +45,12 @@ from pytensor.ifelse import ifelse from pytensor.tensor.random.basic import CategoricalRV from pytensor.tensor.shape import shape_tuple -from pytensor.tensor.subtensor import as_index_constant +from pytensor.tensor.subtensor import ( + AdvancedSubtensor, + AdvancedSubtensor1, + Subtensor, + as_index_constant, +) from pymc.logprob.basic import factorized_joint_logprob from pymc.logprob.mixture import MixtureRV, expand_indices @@ -1054,3 +1059,58 @@ def test_ifelse_mixture_shared_component(): ), decimal=6, ) + + +def test_joint_logprob_subtensor(): + """Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables.""" + + size = 5 + + mu_base = np.power(10, np.arange(np.prod(size))).reshape(size) + mu = np.stack([mu_base, -mu_base]) + sigma = 0.001 + rng = pytensor.shared(np.random.RandomState(232), borrow=True) + + A_rv = pt.random.normal(mu, sigma, rng=rng) + A_rv.name = "A" + + p = 0.5 + + I_rv = pt.random.bernoulli(p, size=size, rng=rng) + I_rv.name = "I" + + A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]] + + assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)) + + A_idx_value_var = A_idx.type() + A_idx_value_var.name = "A_idx_value" + + I_value_var = I_rv.type() + I_value_var.name = "I_value" + + A_idx_logp = factorized_joint_logprob({A_idx: A_idx_value_var, I_rv: I_value_var}) + A_idx_logp_comb = pt.add(*A_idx_logp.values()) + + logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp_comb) + + # The compiled graph should not contain any `RandomVariables` + assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0]) + + decimals = 6 if pytensor.config.floatX == "float64" else 4 + + test_val_rng = np.random.RandomState(3238) + + for i in range(10): + bern_sp = sp.bernoulli(p) + I_value = bern_sp.rvs(size=size, random_state=test_val_rng).astype(I_rv.dtype) + + norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma) + A_idx_value = norm_sp.rvs(random_state=test_val_rng).astype(A_idx.dtype) + + exp_obs_logps = norm_sp.logpdf(A_idx_value) + exp_obs_logps += bern_sp.logpmf(I_value) + + logp_vals = logp_vals_fn(A_idx_value, I_value) + + np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals) From 1b8bbfd31023cbb86030f8804282a282951a75e9 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Sat, 15 Apr 2023 17:41:29 +0530 Subject: [PATCH 3/3] Solve conflicts in test_basic.py --- tests/logprob/test_basic.py | 119 +++++++++++------------------------- 1 file changed, 37 insertions(+), 82 deletions(-) diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index 3e2d6bb9b8..052e6031ca 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -60,19 +60,19 @@ from pymc.logprob.utils import rvs_to_value_vars, walk_model from pymc.pytensorf import replace_rvs_by_values from pymc.testing import assert_no_rvs -from tests.logprob.utils import joint_logprob -def test_joint_logprob_basic(): - # A simple check for when `joint_logprob` is the same as `logprob` +def test_factorized_joint_logprob_basic(): + # A simple check for when `factorized_joint_logprob` is the same as `logprob` a = pt.random.uniform(0.0, 1.0) a.name = "a" a_value_var = a.clone() - a_logp = joint_logprob({a: a_value_var}, sum=False) + a_logp = factorized_joint_logprob({a: a_value_var}) + a_logp_comb = tuple(a_logp.values())[0] a_logp_exp = logp(a, a_value_var) - assert equal_computations([a_logp], [a_logp_exp]) + assert equal_computations([a_logp_comb], [a_logp_exp]) # Let's try a hierarchical model sigma = pt.random.invgamma(0.5, 0.5) @@ -81,7 +81,8 @@ def test_joint_logprob_basic(): sigma_value_var = sigma.clone() y_value_var = Y.clone() - total_ll = joint_logprob({Y: y_value_var, sigma: sigma_value_var}, sum=False) + total_ll = factorized_joint_logprob({Y: y_value_var, sigma: sigma_value_var}) + total_ll_combined = pt.add(*total_ll.values()) # We need to replace the reference to `sigma` in `Y` with its value # variable @@ -92,7 +93,7 @@ def test_joint_logprob_basic(): ) total_ll_exp = logp(sigma, sigma_value_var) + ll_Y - assert equal_computations([total_ll], [total_ll_exp]) + assert equal_computations([total_ll_combined], [total_ll_exp]) # Now, make sure we can compute a joint log-probability for a hierarchical # model with some non-`RandomVariable` nodes @@ -105,28 +106,30 @@ def test_joint_logprob_basic(): b_value_var = b.clone() c_value_var = c.clone() - b_logp = joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var}) + b_logp = factorized_joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var}) + b_logp_combined = pt.sum([pt.sum(factor) for factor in b_logp.values()]) # There shouldn't be any `RandomVariable`s in the resulting graph - assert_no_rvs(b_logp) + assert_no_rvs(b_logp_combined) - res_ancestors = list(walk_model((b_logp,), walk_past_rvs=True)) + res_ancestors = list(walk_model((b_logp_combined,), walk_past_rvs=True)) assert b_value_var in res_ancestors assert c_value_var in res_ancestors assert a_value_var in res_ancestors -def test_joint_logprob_multi_obs(): +def test_factorized_joint_logprob_multi_obs(): a = pt.random.uniform(0.0, 1.0) b = pt.random.normal(0.0, 1.0) a_val = a.clone() b_val = b.clone() - logp_res = joint_logprob({a: a_val, b: b_val}, sum=False) + logp_res = factorized_joint_logprob({a: a_val, b: b_val}) + logp_res_combined = pt.add(*logp_res.values()) logp_exp = logp(a, a_val) + logp(b, b_val) - assert equal_computations([logp_res], [logp_exp]) + assert equal_computations([logp_res_combined], [logp_exp]) x = pt.random.normal(0, 1) y = pt.random.normal(x, 1) @@ -134,13 +137,15 @@ def test_joint_logprob_multi_obs(): x_val = x.clone() y_val = y.clone() - logp_res = joint_logprob({x: x_val, y: y_val}) - exp_logp = joint_logprob({x: x_val, y: y_val}) + logp_res = factorized_joint_logprob({x: x_val, y: y_val}) + exp_logp = factorized_joint_logprob({x: x_val, y: y_val}) + logp_res_comb = pt.sum([pt.sum(factor) for factor in logp_res.values()]) + exp_logp_comb = pt.sum([pt.sum(factor) for factor in exp_logp.values()]) - assert equal_computations([logp_res], [exp_logp]) + assert equal_computations([logp_res_comb], [exp_logp_comb]) -def test_joint_logprob_diff_dims(): +def test_factorized_joint_logprob_diff_dims(): M = pt.matrix("M") x = pt.random.normal(0, 1, size=M.shape[1], name="X") y = pt.random.normal(M.dot(x), 1, name="Y") @@ -150,14 +155,15 @@ def test_joint_logprob_diff_dims(): y_vv = y.clone() y_vv.name = "y" - logp = joint_logprob({x: x_vv, y: y_vv}) + logp = factorized_joint_logprob({x: x_vv, y: y_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) M_val = np.random.normal(size=(10, 3)) x_val = np.random.normal(size=(3,)) y_val = np.random.normal(size=(10,)) point = {M: M_val, x_vv: x_val, y_vv: y_val} - logp_val = logp.eval(point) + logp_val = logp_combined.eval(point) exp_logp_val = ( sp.norm.logpdf(x_val, 0, 1).sum() + sp.norm.logpdf(y_val, M_val.dot(x_val), 1).sum() @@ -179,60 +185,6 @@ def test_incsubtensor_original_values_output_dict(): assert vv in logp_dict -def test_joint_logprob_subtensor(): - """Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables.""" - - size = 5 - - mu_base = np.power(10, np.arange(np.prod(size))).reshape(size) - mu = np.stack([mu_base, -mu_base]) - sigma = 0.001 - rng = pytensor.shared(np.random.RandomState(232), borrow=True) - - A_rv = pt.random.normal(mu, sigma, rng=rng) - A_rv.name = "A" - - p = 0.5 - - I_rv = pt.random.bernoulli(p, size=size, rng=rng) - I_rv.name = "I" - - A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]] - - assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)) - - A_idx_value_var = A_idx.type() - A_idx_value_var.name = "A_idx_value" - - I_value_var = I_rv.type() - I_value_var.name = "I_value" - - A_idx_logp = joint_logprob({A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False) - - logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp) - - # The compiled graph should not contain any `RandomVariables` - assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0]) - - decimals = 6 if pytensor.config.floatX == "float64" else 4 - - test_val_rng = np.random.RandomState(3238) - - for i in range(10): - bern_sp = sp.bernoulli(p) - I_value = bern_sp.rvs(size=size, random_state=test_val_rng).astype(I_rv.dtype) - - norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma) - A_idx_value = norm_sp.rvs(random_state=test_val_rng).astype(A_idx.dtype) - - exp_obs_logps = norm_sp.logpdf(A_idx_value) - exp_obs_logps += bern_sp.logpmf(I_value) - - logp_vals = logp_vals_fn(A_idx_value, I_value) - - np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals) - - def test_persist_inputs(): """Make sure we don't unnecessarily clone variables.""" x = pt.scalar("x") @@ -242,24 +194,27 @@ def test_persist_inputs(): beta_vv = beta_rv.type() y_vv = Y_rv.clone() - logp = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv}) + logp = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv}) + logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()]) - assert x in ancestors([logp]) + assert x in ancestors([logp_combined]) # Make sure we don't clone value variables when they're graphs. y_vv_2 = y_vv * 2 - logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2}) + logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2}) + logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()]) - assert y_vv in ancestors([logp_2]) - assert y_vv_2 in ancestors([logp_2]) + assert y_vv in ancestors([logp_2_combined]) + assert y_vv_2 in ancestors([logp_2_combined]) # Even when they are random y_vv = pt.random.normal(name="y_vv2") y_vv_2 = y_vv * 2 - logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2}) + logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2}) + logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()]) - assert y_vv in ancestors([logp_2]) - assert y_vv_2 in ancestors([logp_2]) + assert y_vv in ancestors([logp_2_combined]) + assert y_vv_2 in ancestors([logp_2_combined]) def test_warn_random_found_factorized_joint_logprob(): @@ -284,7 +239,7 @@ def test_multiple_rvs_to_same_value_raises(): msg = "More than one logprob factor was assigned to the value var x" with pytest.raises(ValueError, match=msg): - joint_logprob({x_rv1: x, x_rv2: x}) + factorized_joint_logprob({x_rv1: x, x_rv2: x}) def test_joint_logp_basic():