Skip to content

Remove joint_logprob function from tests.logprob.utils #6650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 37 additions & 82 deletions tests/logprob/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@
from pymc.logprob.utils import rvs_to_value_vars, walk_model
from pymc.pytensorf import replace_rvs_by_values
from pymc.testing import assert_no_rvs
from tests.logprob.utils import joint_logprob


def test_joint_logprob_basic():
# A simple check for when `joint_logprob` is the same as `logprob`
def test_factorized_joint_logprob_basic():
# A simple check for when `factorized_joint_logprob` is the same as `logprob`
a = pt.random.uniform(0.0, 1.0)
a.name = "a"
a_value_var = a.clone()

a_logp = joint_logprob({a: a_value_var}, sum=False)
a_logp = factorized_joint_logprob({a: a_value_var})
a_logp_comb = tuple(a_logp.values())[0]
a_logp_exp = logp(a, a_value_var)

assert equal_computations([a_logp], [a_logp_exp])
assert equal_computations([a_logp_comb], [a_logp_exp])

# Let's try a hierarchical model
sigma = pt.random.invgamma(0.5, 0.5)
Expand All @@ -81,7 +81,8 @@ def test_joint_logprob_basic():
sigma_value_var = sigma.clone()
y_value_var = Y.clone()

total_ll = joint_logprob({Y: y_value_var, sigma: sigma_value_var}, sum=False)
total_ll = factorized_joint_logprob({Y: y_value_var, sigma: sigma_value_var})
total_ll_combined = pt.add(*total_ll.values())

# We need to replace the reference to `sigma` in `Y` with its value
# variable
Expand All @@ -92,7 +93,7 @@ def test_joint_logprob_basic():
)
total_ll_exp = logp(sigma, sigma_value_var) + ll_Y

assert equal_computations([total_ll], [total_ll_exp])
assert equal_computations([total_ll_combined], [total_ll_exp])

# Now, make sure we can compute a joint log-probability for a hierarchical
# model with some non-`RandomVariable` nodes
Expand All @@ -105,42 +106,46 @@ def test_joint_logprob_basic():
b_value_var = b.clone()
c_value_var = c.clone()

b_logp = joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var})
b_logp = factorized_joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var})
b_logp_combined = pt.sum([pt.sum(factor) for factor in b_logp.values()])

# There shouldn't be any `RandomVariable`s in the resulting graph
assert_no_rvs(b_logp)
assert_no_rvs(b_logp_combined)

res_ancestors = list(walk_model((b_logp,), walk_past_rvs=True))
res_ancestors = list(walk_model((b_logp_combined,), walk_past_rvs=True))
assert b_value_var in res_ancestors
assert c_value_var in res_ancestors
assert a_value_var in res_ancestors


def test_joint_logprob_multi_obs():
def test_factorized_joint_logprob_multi_obs():
a = pt.random.uniform(0.0, 1.0)
b = pt.random.normal(0.0, 1.0)

a_val = a.clone()
b_val = b.clone()

logp_res = joint_logprob({a: a_val, b: b_val}, sum=False)
logp_res = factorized_joint_logprob({a: a_val, b: b_val})
logp_res_combined = pt.add(*logp_res.values())
logp_exp = logp(a, a_val) + logp(b, b_val)

assert equal_computations([logp_res], [logp_exp])
assert equal_computations([logp_res_combined], [logp_exp])

x = pt.random.normal(0, 1)
y = pt.random.normal(x, 1)

x_val = x.clone()
y_val = y.clone()

logp_res = joint_logprob({x: x_val, y: y_val})
exp_logp = joint_logprob({x: x_val, y: y_val})
logp_res = factorized_joint_logprob({x: x_val, y: y_val})
exp_logp = factorized_joint_logprob({x: x_val, y: y_val})
logp_res_comb = pt.sum([pt.sum(factor) for factor in logp_res.values()])
exp_logp_comb = pt.sum([pt.sum(factor) for factor in exp_logp.values()])

assert equal_computations([logp_res], [exp_logp])
assert equal_computations([logp_res_comb], [exp_logp_comb])


def test_joint_logprob_diff_dims():
def test_factorized_joint_logprob_diff_dims():
M = pt.matrix("M")
x = pt.random.normal(0, 1, size=M.shape[1], name="X")
y = pt.random.normal(M.dot(x), 1, name="Y")
Expand All @@ -150,14 +155,15 @@ def test_joint_logprob_diff_dims():
y_vv = y.clone()
y_vv.name = "y"

logp = joint_logprob({x: x_vv, y: y_vv})
logp = factorized_joint_logprob({x: x_vv, y: y_vv})
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])

M_val = np.random.normal(size=(10, 3))
x_val = np.random.normal(size=(3,))
y_val = np.random.normal(size=(10,))

point = {M: M_val, x_vv: x_val, y_vv: y_val}
logp_val = logp.eval(point)
logp_val = logp_combined.eval(point)

exp_logp_val = (
sp.norm.logpdf(x_val, 0, 1).sum() + sp.norm.logpdf(y_val, M_val.dot(x_val), 1).sum()
Expand All @@ -179,60 +185,6 @@ def test_incsubtensor_original_values_output_dict():
assert vv in logp_dict


def test_joint_logprob_subtensor():
"""Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables."""

size = 5

mu_base = np.power(10, np.arange(np.prod(size))).reshape(size)
mu = np.stack([mu_base, -mu_base])
sigma = 0.001
rng = pytensor.shared(np.random.RandomState(232), borrow=True)

A_rv = pt.random.normal(mu, sigma, rng=rng)
A_rv.name = "A"

p = 0.5

I_rv = pt.random.bernoulli(p, size=size, rng=rng)
I_rv.name = "I"

A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]]

assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1))

A_idx_value_var = A_idx.type()
A_idx_value_var.name = "A_idx_value"

I_value_var = I_rv.type()
I_value_var.name = "I_value"

A_idx_logp = joint_logprob({A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False)

logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp)

# The compiled graph should not contain any `RandomVariables`
assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0])

decimals = 6 if pytensor.config.floatX == "float64" else 4

test_val_rng = np.random.RandomState(3238)

for i in range(10):
bern_sp = sp.bernoulli(p)
I_value = bern_sp.rvs(size=size, random_state=test_val_rng).astype(I_rv.dtype)

norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma)
A_idx_value = norm_sp.rvs(random_state=test_val_rng).astype(A_idx.dtype)

exp_obs_logps = norm_sp.logpdf(A_idx_value)
exp_obs_logps += bern_sp.logpmf(I_value)

logp_vals = logp_vals_fn(A_idx_value, I_value)

np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)


def test_persist_inputs():
"""Make sure we don't unnecessarily clone variables."""
x = pt.scalar("x")
Expand All @@ -242,24 +194,27 @@ def test_persist_inputs():
beta_vv = beta_rv.type()
y_vv = Y_rv.clone()

logp = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv})
logp = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv})
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])

assert x in ancestors([logp])
assert x in ancestors([logp_combined])

# Make sure we don't clone value variables when they're graphs.
y_vv_2 = y_vv * 2
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()])

assert y_vv in ancestors([logp_2])
assert y_vv_2 in ancestors([logp_2])
assert y_vv in ancestors([logp_2_combined])
assert y_vv_2 in ancestors([logp_2_combined])

# Even when they are random
y_vv = pt.random.normal(name="y_vv2")
y_vv_2 = y_vv * 2
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()])

assert y_vv in ancestors([logp_2])
assert y_vv_2 in ancestors([logp_2])
assert y_vv in ancestors([logp_2_combined])
assert y_vv_2 in ancestors([logp_2_combined])


def test_warn_random_found_factorized_joint_logprob():
Expand All @@ -284,7 +239,7 @@ def test_multiple_rvs_to_same_value_raises():

msg = "More than one logprob factor was assigned to the value var x"
with pytest.raises(ValueError, match=msg):
joint_logprob({x_rv1: x, x_rv2: x})
factorized_joint_logprob({x_rv1: x, x_rv2: x})


def test_joint_logp_basic():
Expand Down
62 changes: 35 additions & 27 deletions tests/logprob/test_censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
import scipy as sp
import scipy.stats as st

from pymc import logp
from pymc.logprob import factorized_joint_logprob
from pymc.logprob.transforms import LogTransform, TransformValuesRewrite
from pymc.testing import assert_no_rvs
from tests.logprob.utils import joint_logprob


@pytensor.config.change_flags(compute_test_value="raise")
Expand All @@ -55,10 +55,10 @@ def test_continuous_rv_clip():
cens_x_vv = cens_x_rv.clone()
cens_x_vv.tag.test_value = 0

logp = joint_logprob({cens_x_rv: cens_x_vv})
assert_no_rvs(logp)
logprob = pt.sum(logp(cens_x_rv, cens_x_vv))
assert_no_rvs(logprob)

logp_fn = pytensor.function([cens_x_vv], logp)
logp_fn = pytensor.function([cens_x_vv], logprob)
ref_scipy = st.norm(0.5, 1)

assert logp_fn(-3) == -np.inf
Expand All @@ -75,10 +75,10 @@ def test_discrete_rv_clip():

cens_x_vv = cens_x_rv.clone()

logp = joint_logprob({cens_x_rv: cens_x_vv})
assert_no_rvs(logp)
logprob = pt.sum(logp(cens_x_rv, cens_x_vv))
assert_no_rvs(logprob)

logp_fn = pytensor.function([cens_x_vv], logp)
logp_fn = pytensor.function([cens_x_vv], logprob)
ref_scipy = st.poisson(2)

assert logp_fn(0) == -np.inf
Expand All @@ -97,8 +97,8 @@ def test_one_sided_clip():
lb_cens_x_vv = lb_cens_x_rv.clone()
ub_cens_x_vv = ub_cens_x_rv.clone()

lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv})
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv})
lb_logp = pt.sum(logp(lb_cens_x_rv, lb_cens_x_vv))
ub_logp = pt.sum(logp(ub_cens_x_rv, ub_cens_x_vv))
assert_no_rvs(lb_logp)
assert_no_rvs(ub_logp)

Expand All @@ -117,10 +117,10 @@ def test_useless_clip():

cens_x_vv = cens_x_rv.clone()

logp = joint_logprob({cens_x_rv: cens_x_vv}, sum=False)
assert_no_rvs(logp)
logprob = logp(cens_x_rv, cens_x_vv)
assert_no_rvs(logprob)

logp_fn = pytensor.function([cens_x_vv], logp)
logp_fn = pytensor.function([cens_x_vv], logprob)
ref_scipy = st.norm(0.5, 1)

np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))
Expand All @@ -133,10 +133,12 @@ def test_random_clip():

lb_vv = lb_rv.clone()
cens_x_vv = cens_x_rv.clone()
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}, sum=False)
assert_no_rvs(logp)
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
logp_combined = pt.add(*logp.values())

logp_fn = pytensor.function([lb_vv, cens_x_vv], logp)
assert_no_rvs(logp_combined)

logp_fn = pytensor.function([lb_vv, cens_x_vv], logp_combined)
res = logp_fn([0, -1], [-1, -1])
assert res[0] == -np.inf
assert res[1] != -np.inf
Expand All @@ -150,8 +152,10 @@ def test_broadcasted_clip_constant():
lb_vv = lb_rv.clone()
cens_x_vv = cens_x_rv.clone()

logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
assert_no_rvs(logp)
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])

assert_no_rvs(logp_combined)


def test_broadcasted_clip_random():
Expand All @@ -162,8 +166,10 @@ def test_broadcasted_clip_random():
lb_vv = lb_rv.clone()
cens_x_vv = cens_x_rv.clone()

logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
assert_no_rvs(logp)
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])

assert_no_rvs(logp_combined)


def test_fail_base_and_clip_have_values():
Expand Down Expand Up @@ -199,10 +205,11 @@ def test_deterministic_clipping():

x_vv = x_rv.clone()
y_vv = y_rv.clone()
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv})
assert_no_rvs(logp)
logp = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
assert_no_rvs(logp_combined)

logp_fn = pytensor.function([x_vv, y_vv], logp)
logp_fn = pytensor.function([x_vv, y_vv], logp_combined)
assert np.isclose(
logp_fn(-1, 1),
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
Expand All @@ -216,10 +223,11 @@ def test_clip_transform():
cens_x_vv = cens_x_rv.clone()

transform = TransformValuesRewrite({cens_x_vv: LogTransform()})
logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])

cens_x_vv_testval = -1
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval})
obs_logp = logp_combined.eval({cens_x_vv: cens_x_vv_testval})
exp_logp = sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_vv_testval)) + cens_x_vv_testval

assert np.isclose(obs_logp, exp_logp)
Expand All @@ -236,8 +244,8 @@ def test_rounding(rounding_op):
xr.name = "xr"

xr_vv = xr.clone()
logp = joint_logprob({xr: xr_vv}, sum=False)
assert logp is not None
logprob = logp(xr, xr_vv)
assert logprob is not None

x_sp = st.norm(loc, scale)
if rounding_op == pt.round:
Expand All @@ -250,6 +258,6 @@ def test_rounding(rounding_op):
raise NotImplementedError()

assert np.allclose(
logp.eval({xr_vv: test_value}),
logprob.eval({xr_vv: test_value}),
expected_logp,
)
Loading