Skip to content

Commit ace9892

Browse files
committed
Move the joint logprob test for subtensors to test_mixture.py
1 parent ccb9cfb commit ace9892

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

tests/logprob/test_mixture.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@
4545
from pytensor.ifelse import ifelse
4646
from pytensor.tensor.random.basic import CategoricalRV
4747
from pytensor.tensor.shape import shape_tuple
48-
from pytensor.tensor.subtensor import as_index_constant
48+
from pytensor.tensor.subtensor import (
49+
AdvancedSubtensor,
50+
AdvancedSubtensor1,
51+
Subtensor,
52+
as_index_constant,
53+
)
4954

5055
from pymc.logprob.basic import factorized_joint_logprob
5156
from pymc.logprob.mixture import MixtureRV, expand_indices
@@ -1054,3 +1059,58 @@ def test_ifelse_mixture_shared_component():
10541059
),
10551060
decimal=6,
10561061
)
1062+
1063+
1064+
def test_joint_logprob_subtensor():
1065+
"""Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables."""
1066+
1067+
size = 5
1068+
1069+
mu_base = np.power(10, np.arange(np.prod(size))).reshape(size)
1070+
mu = np.stack([mu_base, -mu_base])
1071+
sigma = 0.001
1072+
rng = pytensor.shared(np.random.RandomState(232), borrow=True)
1073+
1074+
A_rv = pt.random.normal(mu, sigma, rng=rng)
1075+
A_rv.name = "A"
1076+
1077+
p = 0.5
1078+
1079+
I_rv = pt.random.bernoulli(p, size=size, rng=rng)
1080+
I_rv.name = "I"
1081+
1082+
A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]]
1083+
1084+
assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1))
1085+
1086+
A_idx_value_var = A_idx.type()
1087+
A_idx_value_var.name = "A_idx_value"
1088+
1089+
I_value_var = I_rv.type()
1090+
I_value_var.name = "I_value"
1091+
1092+
A_idx_logp = factorized_joint_logprob({A_idx: A_idx_value_var, I_rv: I_value_var})
1093+
A_idx_logp_comb = pt.add(*A_idx_logp.values())
1094+
1095+
logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp_comb)
1096+
1097+
# The compiled graph should not contain any `RandomVariables`
1098+
assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0])
1099+
1100+
decimals = 6 if pytensor.config.floatX == "float64" else 4
1101+
1102+
test_val_rng = np.random.RandomState(3238)
1103+
1104+
for i in range(10):
1105+
bern_sp = sp.bernoulli(p)
1106+
I_value = bern_sp.rvs(size=size, random_state=test_val_rng).astype(I_rv.dtype)
1107+
1108+
norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma)
1109+
A_idx_value = norm_sp.rvs(random_state=test_val_rng).astype(A_idx.dtype)
1110+
1111+
exp_obs_logps = norm_sp.logpdf(A_idx_value)
1112+
exp_obs_logps += bern_sp.logpmf(I_value)
1113+
1114+
logp_vals = logp_vals_fn(A_idx_value, I_value)
1115+
1116+
np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)

0 commit comments

Comments
 (0)