Skip to content

Commit 75cf84d

Browse files
committed
Derive logprob of SetSubtensor operations
1 parent 4781957 commit 75cf84d

File tree

4 files changed

+375
-0
lines changed

4 files changed

+375
-0
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ jobs:
121121
tests/logprob/test_order.py
122122
tests/logprob/test_rewriting.py
123123
tests/logprob/test_scan.py
124+
tests/logprob/test_set_subtensor.py
124125
tests/logprob/test_tensor.py
125126
tests/logprob/test_transform_value.py
126127
tests/logprob/test_transforms.py

pymc/logprob/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import pymc.logprob.mixture
5454
import pymc.logprob.order
5555
import pymc.logprob.scan
56+
import pymc.logprob.set_subtensor
5657
import pymc.logprob.tensor
5758
import pymc.logprob.transforms
5859

pymc/logprob/set_subtensor.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytensor.graph.basic import Variable
15+
from pytensor.graph.rewriting.basic import node_rewriter
16+
from pytensor.tensor import eq
17+
from pytensor.tensor.subtensor import (
18+
AdvancedIncSubtensor,
19+
AdvancedIncSubtensor1,
20+
IncSubtensor,
21+
indices_from_subtensor,
22+
)
23+
from pytensor.tensor.type import TensorType
24+
from pytensor.tensor.type_other import NoneTypeT
25+
26+
from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
27+
from pymc.logprob.checks import MeasurableCheckAndRaise
28+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
29+
from pymc.logprob.utils import (
30+
check_potential_measurability,
31+
dirac_delta,
32+
filter_measurable_variables,
33+
)
34+
35+
36+
class MeasurableSetSubtensor(IncSubtensor, MeasurableOp):
37+
"""Measurable SetSubtensor Op."""
38+
39+
def __str__(self):
40+
return f"Measurable{super().__str__()}"
41+
42+
43+
class MeasurableAdvancedSetSubtensor(AdvancedIncSubtensor, MeasurableOp):
44+
"""Measurable AdvancedSetSubtensor Op."""
45+
46+
def __str__(self):
47+
return f"Measurable{super().__str__()}"
48+
49+
50+
set_subtensor_does_not_broadcast = MeasurableCheckAndRaise(
51+
exc_type=NotImplementedError,
52+
msg="Measurable SetSubtensor not supported when set value is broadcasted.",
53+
)
54+
55+
56+
@node_rewriter(tracks=[IncSubtensor, AdvancedIncSubtensor1, AdvancedIncSubtensor])
57+
def find_measurable_set_subtensor(fgraph, node) -> list | None:
58+
"""Find `SetSubtensor` for which a `logprob` can be computed."""
59+
if isinstance(node.op, MeasurableOp):
60+
return None
61+
62+
if not node.op.set_instead_of_inc:
63+
return None
64+
65+
x, y, *idx_elements = node.inputs
66+
67+
measurable_inputs = filter_measurable_variables([x, y])
68+
69+
if y not in measurable_inputs:
70+
return None
71+
72+
if x not in measurable_inputs:
73+
# x is potentially measurable, wait for it's logprob IR to be inferred
74+
if check_potential_measurability([x]):
75+
return None
76+
# x has no link to measurable variables, so it's value should be constant
77+
else:
78+
x = dirac_delta(x, rtol=0, atol=0)
79+
80+
if check_potential_measurability(idx_elements):
81+
return None
82+
83+
measurable_class: type[MeasurableSetSubtensor | MeasurableAdvancedSetSubtensor]
84+
if isinstance(node.op, IncSubtensor):
85+
measurable_class = MeasurableSetSubtensor
86+
idx = indices_from_subtensor(idx_elements, node.op.idx_list)
87+
else:
88+
measurable_class = MeasurableAdvancedSetSubtensor
89+
idx = tuple(idx_elements)
90+
91+
# Check that y is not certainly broadcasted.
92+
indexed_block = x[idx]
93+
missing_y_dims = indexed_block.type.ndim - y.type.ndim
94+
y_bcast = [True] * missing_y_dims + list(y.type.broadcastable)
95+
if any(
96+
y_dim_bcast and indexed_block_dim_len not in (None, 1)
97+
for y_dim_bcast, indexed_block_dim_len in zip(
98+
y_bcast, indexed_block.type.shape, strict=True
99+
)
100+
):
101+
return None
102+
103+
measurable_set_subtensor = measurable_class(**node.op._props_dict())(x, y, *idx_elements)
104+
105+
# Often with indexing we don't know the static shape of the indexed block.
106+
# And, what's more, the indexing operations actually support runtime broadcasting.
107+
# As the logp is not valid under broadcasting, we have to add a runtime check.
108+
# This will hopefully be removed during shape inference when not violated.
109+
potential_broadcasted_dims = [
110+
i
111+
for i, (y_bcast_dim, indexed_block_dim_len) in enumerate(
112+
zip(y_bcast, indexed_block.type.shape)
113+
)
114+
if y_bcast_dim and indexed_block_dim_len is None
115+
]
116+
if potential_broadcasted_dims:
117+
indexed_block_shape = tuple(indexed_block.shape)
118+
measurable_set_subtensor = set_subtensor_does_not_broadcast(
119+
measurable_set_subtensor,
120+
*(eq(indexed_block_shape[i], 1) for i in potential_broadcasted_dims),
121+
)
122+
123+
return [measurable_set_subtensor]
124+
125+
126+
measurable_ir_rewrites_db.register(
127+
find_measurable_set_subtensor.__name__,
128+
find_measurable_set_subtensor,
129+
"basic",
130+
"set_subtensor",
131+
)
132+
133+
134+
def indexed_dims(idx) -> list[int | None]:
135+
"""Return the indices of the dimensions of the indexed tensor that are being indexed."""
136+
dims: list[int | None] = []
137+
idx_counter = 0
138+
for idx_elem in idx:
139+
if isinstance(idx_elem, Variable) and isinstance(idx_elem.type, NoneTypeT):
140+
# None in indexes correspond to newaxis, and don't map to any existing dimension
141+
dims.append(None)
142+
143+
elif (
144+
isinstance(idx_elem, Variable)
145+
and isinstance(idx_elem.type, TensorType)
146+
and idx_elem.type.dtype == "bool"
147+
):
148+
# Boolean indexes map to as many dimensions as the mask has
149+
for i in range(idx_elem.type.ndim):
150+
dims.append(idx_counter)
151+
idx_counter += 1
152+
else:
153+
dims.append(idx_counter)
154+
idx_counter += 1
155+
156+
return dims
157+
158+
159+
@_logprob.register(MeasurableSetSubtensor)
160+
@_logprob.register(MeasurableAdvancedSetSubtensor)
161+
def logprob_setsubtensor(op, values, x, y, *idx_elements, **kwargs):
162+
"""Compute the log-likelihood graph for a `SetSubtensor`.
163+
164+
For a generative graph like:
165+
o = zeros(2)
166+
x = o[0].set(X)
167+
y = x[1].set(Y)
168+
169+
The log-likelihood graph is:
170+
logp(y, value) = (
171+
logp(x, value)
172+
[1].set(logp(y, value[1]))
173+
)
174+
175+
Unrolling the logp(x, value) gives:
176+
logp(y, value) = (
177+
DiracDelta(zeros(2), value) # Irrelevant if all entries are set
178+
[0].set(logp(x, value[0]))
179+
[1].set(logp(y, value[1]))
180+
)
181+
"""
182+
[value] = values
183+
if isinstance(op, MeasurableSetSubtensor):
184+
# For basic indexing we have to recreate the index from the input list
185+
idx = indices_from_subtensor(idx_elements, op.idx_list)
186+
else:
187+
# For advanced indexing we can use the idx_elements directly
188+
idx = tuple(idx_elements)
189+
190+
x_logp = _logprob_helper(x, value)
191+
y_logp = _logprob_helper(y, value[idx])
192+
193+
y_ndim_supp = x[idx].type.ndim - y_logp.type.ndim
194+
x_ndim_supp = x.type.ndim - x_logp.type.ndim
195+
ndim_supp = max(y_ndim_supp, x_ndim_supp)
196+
if ndim_supp > 0:
197+
# Multivariate logp only valid if we are not doing indexing along the reduced dimensions
198+
# Otherwise we don't know if successive writings are overlapping or not
199+
core_dims = set(range(x.type.ndim)[-ndim_supp:])
200+
if set(indexed_dims(idx)) & core_dims:
201+
# When we have IR meta-info about support_ndim, we can fail at the rewriting stage
202+
raise NotImplementedError(
203+
"Indexing along core dimensions of multivariate SetSubtensor not supported"
204+
)
205+
206+
ndim_supp_diff = y_ndim_supp - x_ndim_supp
207+
if ndim_supp_diff > 0:
208+
# In this case y_logp will have fewer dimensions than x_logp after indexing, so we need to reduce x before indexing.
209+
x_logp = x_logp.sum(axis=tuple(range(-ndim_supp_diff, 0)))
210+
elif ndim_supp_diff < 0:
211+
# In this case x_logp will have fewer dimensions than y_logp after indexing, so we need to reduce y before indexing.
212+
y_logp = y_logp.sum(axis=tuple(range(ndim_supp_diff, 0)))
213+
214+
out_logp = x_logp[idx].set(y_logp)
215+
return out_logp

tests/logprob/test_set_subtensor.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor
16+
import pytensor.tensor as pt
17+
import pytest
18+
19+
from pymc.distributions import Beta, Dirichlet, MvNormal, MvStudentT, Normal, StudentT
20+
from pymc.logprob.basic import logp
21+
22+
23+
@pytest.mark.parametrize("univariate", [True, False])
24+
def test_complete_set_subtensor(univariate):
25+
if univariate:
26+
rv0 = Normal.dist(mu=-10)
27+
rv1 = StudentT.dist(nu=3, mu=0)
28+
rv2 = Normal.dist(mu=10, sigma=3)
29+
rv34 = Beta.dist(alpha=[np.pi, np.e], beta=[1, 1])
30+
base = pt.empty((5,))
31+
test_val = [2, 0, -2, 0.25, 0.5]
32+
else:
33+
rv0 = MvNormal.dist(mu=[-11, -9], cov=pt.eye(2))
34+
rv1 = MvStudentT.dist(nu=3, mu=[-1, 1], cov=pt.eye(2))
35+
rv2 = MvNormal.dist(mu=[9, 11], cov=pt.eye(2) * 3)
36+
rv34 = Dirichlet.dist(a=[[np.pi, 1], [np.e, 1]])
37+
base = pt.empty((3, 2))
38+
test_val = [[2, 0], [0, -2], [-2, 2], [0.25, 0.75], [0.5, 0.5]]
39+
40+
# fmt: off
41+
rv = (
42+
# Boolean indexing
43+
base[np.array([True, False, False, False, False])].set(rv0)
44+
# Slice indexing
45+
[1:2].set(rv1)
46+
# Integer indexing
47+
[2].set(rv2)
48+
# Vector indexing
49+
[[3, 4]].set(rv34)
50+
)
51+
# fmt: on
52+
ref_rv = pt.join(0, [rv0], [rv1], [rv2], rv34)
53+
54+
np.testing.assert_allclose(
55+
logp(rv, test_val).eval(),
56+
logp(ref_rv, test_val).eval(),
57+
)
58+
59+
60+
def test_partial_set_subtensor():
61+
rv123 = Normal.dist(mu=[-10, 0, 10])
62+
63+
# When base is empty, it doesn't matter what the missing values are
64+
base = pt.empty((5,))
65+
rv = base[:3].set(rv123)
66+
67+
np.testing.assert_allclose(
68+
logp(rv, [0, 0, 0, 1, np.pi]).eval(),
69+
[*logp(rv123, [0, 0, 0]).eval(), 0, 0],
70+
)
71+
72+
# Otherwise they should match
73+
base = pt.ones((5,))
74+
rv = base[:3].set(rv123)
75+
76+
np.testing.assert_allclose(
77+
logp(rv, [0, 0, 0, 1, np.pi]).eval(),
78+
[*logp(rv123, [0, 0, 0]).eval(), 0, -np.inf],
79+
)
80+
81+
82+
def test_overwrite_set_subtensor():
83+
"""Test that order of overwriting in the generative graph is respected."""
84+
x = Normal.dist(mu=[0, 1, 2])
85+
y = x[1:].set(Normal.dist([10, 20]))
86+
z = y[2:].set(Normal.dist([300]))
87+
88+
np.testing.assert_allclose(
89+
logp(z, [0, 0, 0]).eval(),
90+
logp(Normal.dist([0, 10, 300]), [0, 0, 0]).eval(),
91+
)
92+
93+
94+
def test_mixed_dimensionality_set_subtensor():
95+
x = Normal.dist(mu=0, size=(3, 2))
96+
y = x[1].set(MvNormal.dist(mu=[1, 1], cov=np.eye(2)))
97+
z = y[2].set(Normal.dist(mu=2, size=(2,)))
98+
99+
# Because `y` is multivariate the last dimension of `z` must be summed over
100+
test_val = np.zeros((3, 2))
101+
logp_eval = logp(z, test_val).eval()
102+
assert logp_eval.shape == (3,)
103+
np.testing.assert_allclose(
104+
logp_eval,
105+
logp(Normal.dist(mu=[[0, 0], [1, 1], [2, 2]]), test_val).sum(-1).eval(),
106+
)
107+
108+
109+
def test_invalid_indexing_core_dims():
110+
x = pt.empty((2, 2))
111+
rv = MvNormal.dist(cov=np.eye(2))
112+
vv = x.type()
113+
114+
match_msg = "Indexing along core dimensions of multivariate SetSubtensor not supported"
115+
116+
y = x[[0, 1], [1, 0]].set(rv)
117+
with pytest.raises(NotImplementedError, match=match_msg):
118+
logp(y, vv)
119+
120+
y = x[np.array([[False, True], [True, False]])].set(rv)
121+
with pytest.raises(NotImplementedError, match=match_msg):
122+
logp(y, vv)
123+
124+
# Univariate indexing above multivariate core dims also not supported
125+
z = y[0].set(rv)[0, 1].set(Normal.dist())
126+
with pytest.raises(NotImplementedError, match=match_msg):
127+
logp(z, vv)
128+
129+
130+
def test_invalid_broadcasted_set_subtensor():
131+
rv_bcast = Normal.dist(mu=0)
132+
base = pt.empty((5,))
133+
134+
rv = base[:3].set(rv_bcast)
135+
vv = rv.type()
136+
137+
# Broadcasting is known at write time, and PyMC does not attempt to make SetSubtensor measurable
138+
with pytest.raises(NotImplementedError):
139+
logp(rv, vv)
140+
141+
mask = pt.tensor(shape=(5,), dtype=bool)
142+
rv = base[mask].set(rv_bcast)
143+
144+
# Broadcasting is only known at runtime, and PyMC raises an error when it happens
145+
logp_rv = logp(rv, vv)
146+
fn = pytensor.function([mask, vv], logp_rv)
147+
test_vv = np.zeros(5)
148+
149+
np.testing.assert_allclose(
150+
fn([False, False, True, False, False], test_vv),
151+
[0, 0, -0.91893853, 0, 0],
152+
)
153+
154+
with pytest.raises(
155+
NotImplementedError,
156+
match="Measurable SetSubtensor not supported when set value is broadcasted.",
157+
):
158+
fn([False, False, True, False, True], test_vv)

0 commit comments

Comments
 (0)