|
45 | 45 | from pytensor.ifelse import ifelse
|
46 | 46 | from pytensor.tensor.random.basic import CategoricalRV
|
47 | 47 | from pytensor.tensor.shape import shape_tuple
|
48 |
| -from pytensor.tensor.subtensor import as_index_constant |
| 48 | +from pytensor.tensor.subtensor import ( |
| 49 | + AdvancedSubtensor, |
| 50 | + AdvancedSubtensor1, |
| 51 | + Subtensor, |
| 52 | + as_index_constant, |
| 53 | +) |
49 | 54 |
|
50 | 55 | from pymc.logprob.basic import factorized_joint_logprob
|
51 | 56 | from pymc.logprob.mixture import MixtureRV, expand_indices
|
@@ -1054,3 +1059,58 @@ def test_ifelse_mixture_shared_component():
|
1054 | 1059 | ),
|
1055 | 1060 | decimal=6,
|
1056 | 1061 | )
|
| 1062 | + |
| 1063 | + |
| 1064 | +def test_joint_logprob_subtensor(): |
| 1065 | + """Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables.""" |
| 1066 | + |
| 1067 | + size = 5 |
| 1068 | + |
| 1069 | + mu_base = np.power(10, np.arange(np.prod(size))).reshape(size) |
| 1070 | + mu = np.stack([mu_base, -mu_base]) |
| 1071 | + sigma = 0.001 |
| 1072 | + rng = pytensor.shared(np.random.RandomState(232), borrow=True) |
| 1073 | + |
| 1074 | + A_rv = pt.random.normal(mu, sigma, rng=rng) |
| 1075 | + A_rv.name = "A" |
| 1076 | + |
| 1077 | + p = 0.5 |
| 1078 | + |
| 1079 | + I_rv = pt.random.bernoulli(p, size=size, rng=rng) |
| 1080 | + I_rv.name = "I" |
| 1081 | + |
| 1082 | + A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]] |
| 1083 | + |
| 1084 | + assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)) |
| 1085 | + |
| 1086 | + A_idx_value_var = A_idx.type() |
| 1087 | + A_idx_value_var.name = "A_idx_value" |
| 1088 | + |
| 1089 | + I_value_var = I_rv.type() |
| 1090 | + I_value_var.name = "I_value" |
| 1091 | + |
| 1092 | + A_idx_logp = factorized_joint_logprob({A_idx: A_idx_value_var, I_rv: I_value_var}) |
| 1093 | + A_idx_logp_comb = pt.add(*A_idx_logp.values()) |
| 1094 | + |
| 1095 | + logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp_comb) |
| 1096 | + |
| 1097 | + # The compiled graph should not contain any `RandomVariables` |
| 1098 | + assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0]) |
| 1099 | + |
| 1100 | + decimals = 6 if pytensor.config.floatX == "float64" else 4 |
| 1101 | + |
| 1102 | + test_val_rng = np.random.RandomState(3238) |
| 1103 | + |
| 1104 | + for i in range(10): |
| 1105 | + bern_sp = sp.bernoulli(p) |
| 1106 | + I_value = bern_sp.rvs(size=size, random_state=test_val_rng).astype(I_rv.dtype) |
| 1107 | + |
| 1108 | + norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma) |
| 1109 | + A_idx_value = norm_sp.rvs(random_state=test_val_rng).astype(A_idx.dtype) |
| 1110 | + |
| 1111 | + exp_obs_logps = norm_sp.logpdf(A_idx_value) |
| 1112 | + exp_obs_logps += bern_sp.logpmf(I_value) |
| 1113 | + |
| 1114 | + logp_vals = logp_vals_fn(A_idx_value, I_value) |
| 1115 | + |
| 1116 | + np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals) |
0 commit comments