Skip to content

Commit 623ca42

Browse files
committed
Wrap Minibatch Operation in OpFromGraph
1 parent e19cd39 commit 623ca42

File tree

9 files changed

+141
-180
lines changed

9 files changed

+141
-180
lines changed

pymc/data.py

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,19 @@
2626
import pytensor.tensor as pt
2727
import xarray as xr
2828

29+
from pytensor.compile.builders import OpFromGraph
2930
from pytensor.compile.sharedvalue import SharedVariable
31+
from pytensor.graph.basic import Variable
3032
from pytensor.raise_op import Assert
3133
from pytensor.scalar import Cast
3234
from pytensor.tensor.elemwise import Elemwise
3335
from pytensor.tensor.random.basic import IntegersRV
34-
from pytensor.tensor.subtensor import AdvancedSubtensor
3536
from pytensor.tensor.type import TensorType
3637
from pytensor.tensor.variable import TensorConstant, TensorVariable
3738

3839
import pymc as pm
3940

40-
from pymc.pytensorf import convert_data, smarttypeX
41+
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
4142
from pymc.vartypes import isgenerator
4243

4344
__all__ = [
@@ -129,46 +130,47 @@ def __hash__(self):
129130
class MinibatchIndexRV(IntegersRV):
130131
_print_name = ("minibatch_index", r"\operatorname{minibatch\_index}")
131132

132-
# Work-around for https://github.com/pymc-devs/pytensor/issues/97
133-
def make_node(self, rng, *args, **kwargs):
134-
if rng is None:
135-
rng = pytensor.shared(np.random.default_rng())
136-
return super().make_node(rng, *args, **kwargs)
137-
138133

139134
minibatch_index = MinibatchIndexRV()
140135

141136

142-
def is_minibatch(v: TensorVariable) -> bool:
143-
return (
144-
isinstance(v.owner.op, AdvancedSubtensor)
145-
and isinstance(v.owner.inputs[1].owner.op, MinibatchIndexRV)
146-
and valid_for_minibatch(v.owner.inputs[0])
147-
)
137+
class MinibatchOp(OpFromGraph):
138+
"""Encapsulate Minibatch random draws in an opaque OFG"""
139+
140+
def __init__(self, *args, **kwargs):
141+
super().__init__(*args, **kwargs, inline=True)
142+
143+
def __str__(self):
144+
return "Minibatch"
148145

149146

150-
def valid_for_minibatch(v: TensorVariable) -> bool:
147+
def is_valid_observed(v) -> bool:
148+
if not isinstance(v, Variable):
149+
# Non-symbolic constant
150+
return True
151+
152+
if v.owner is None:
153+
# Symbolic root variable (constant or not)
154+
return True
155+
151156
return (
152-
v.owner is None
153157
# The only PyTensor operation we allow on observed data is type casting
154158
# Although we could allow for any graph that does not depend on other RVs
155-
or (
159+
(
156160
isinstance(v.owner.op, Elemwise)
157-
and v.owner.inputs[0].owner is None
158161
and isinstance(v.owner.op.scalar_op, Cast)
162+
and is_valid_observed(v.owner.inputs[0])
163+
)
164+
# Or Minibatch
165+
or (
166+
isinstance(v.owner.op, MinibatchOp)
167+
and all(is_valid_observed(inp) for inp in v.owner.inputs)
159168
)
169+
# Or Generator
170+
or isinstance(v.owner.op, GeneratorOp)
160171
)
161172

162173

163-
def assert_all_scalars_equal(scalar, *scalars):
164-
if len(scalars) == 0:
165-
return scalar
166-
else:
167-
return Assert(
168-
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
169-
)(scalar, pt.all([pt.eq(scalar, s) for s in scalars]))
170-
171-
172174
def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int):
173175
"""Get random slices from variables from the leading dimension.
174176
@@ -188,18 +190,29 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
188190
if not isinstance(batch_size, int):
189191
raise TypeError("batch_size must be an integer")
190192

191-
tensor, *tensors = tuple(map(pt.as_tensor, (variable, *variables)))
192-
upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)])
193-
slc = minibatch_index(0, upper, size=batch_size)
194-
for i, v in enumerate((tensor, *tensors)):
195-
if not valid_for_minibatch(v):
193+
tensors = tuple(map(pt.as_tensor, (variable, *variables)))
194+
for i, v in enumerate(tensors):
195+
if not is_valid_observed(v):
196196
raise ValueError(
197197
f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
198198
)
199-
result = tuple([v[slc] for v in (tensor, *tensors)])
200-
for i, r in enumerate(result):
199+
200+
upper = tensors[0].shape[0]
201+
if len(tensors) > 1:
202+
upper = Assert(
203+
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
204+
)(upper, pt.all([pt.eq(upper, other_tensor.shape[0]) for other_tensor in tensors[1:]]))
205+
206+
rng = pytensor.shared(np.random.default_rng())
207+
rng_update, mb_indices = minibatch_index(0, upper, size=batch_size, rng=rng).owner.outputs
208+
mb_tensors = [tensor[mb_indices] for tensor in tensors]
209+
210+
# Wrap graph in OFG so it's easily identifiable and not rewritten accidentally
211+
*mb_tensors, _ = MinibatchOp([*tensors, rng], [*mb_tensors, rng_update])(*tensors, rng)
212+
for i, r in enumerate(mb_tensors[:-1]):
201213
r.name = f"minibatch.{i}"
202-
return result if tensors else result[0]
214+
215+
return mb_tensors if len(variables) else mb_tensors[0]
203216

204217

205218
def determine_coords(

pymc/logprob/rewriting.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141
from pytensor import config
4242
from pytensor.compile.mode import optdb
4343
from pytensor.graph.basic import (
44-
Constant,
4544
Variable,
46-
ancestors,
4745
io_toposort,
4846
truncated_graph_inputs,
4947
)
@@ -400,8 +398,8 @@ def construct_ir_fgraph(
400398
# the old nodes to the new ones; otherwise, we won't be able to use
401399
# `rv_values`.
402400
# We start the `dict` with mappings from the value variables to themselves,
403-
# to prevent them from being cloned. This also includes ancestors
404-
memo = {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)}
401+
# to prevent them from being cloned.
402+
memo = {v: v for v in rv_values.values()}
405403

406404
# We add `ShapeFeature` because it will get rid of references to the old
407405
# `RandomVariable`s that have been lifted; otherwise, it will be difficult

pymc/model/core.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,13 @@
3939
from pytensor.compile import DeepCopyOp, Function, get_mode
4040
from pytensor.compile.sharedvalue import SharedVariable
4141
from pytensor.graph.basic import Constant, Variable, graph_inputs
42-
from pytensor.scalar import Cast
43-
from pytensor.tensor.elemwise import Elemwise
4442
from pytensor.tensor.random.op import RandomVariable
4543
from pytensor.tensor.random.type import RandomType
4644
from pytensor.tensor.variable import TensorConstant, TensorVariable
4745
from typing_extensions import Self
4846

4947
from pymc.blocking import DictToArrayBijection, RaveledVars
50-
from pymc.data import GenTensorVariable, is_minibatch
48+
from pymc.data import is_valid_observed
5149
from pymc.exceptions import (
5250
BlockModelAccessError,
5351
ImputationWarning,
@@ -1294,18 +1292,7 @@ def register_rv(
12941292
self.add_named_variable(rv_var, dims)
12951293
self.set_initval(rv_var, initval)
12961294
else:
1297-
if (
1298-
isinstance(observed, Variable)
1299-
and not isinstance(observed, GenTensorVariable)
1300-
and observed.owner is not None
1301-
# The only PyTensor operation we allow on observed data is type casting
1302-
# Although we could allow for any graph that does not depend on other RVs
1303-
and not (
1304-
isinstance(observed.owner.op, Elemwise)
1305-
and isinstance(observed.owner.op.scalar_op, Cast)
1306-
)
1307-
and not is_minibatch(observed)
1308-
):
1295+
if not is_valid_observed(observed):
13091296
raise TypeError(
13101297
"Variables that depend on other nodes cannot be used for observed data."
13111298
f"The data variable was: {observed}"

pymc/pytensorf.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,25 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
156156
TypeError
157157
158158
"""
159+
# TODO: These data functions should be in data.py or model/core.py
160+
from pymc.data import MinibatchOp
161+
159162
if isinstance(x, Constant):
160163
return x.data
161164
if isinstance(x, SharedVariable):
162165
return x.get_value()
163-
if x.owner and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast):
164-
array_data = extract_obs_data(x.owner.inputs[0])
165-
return array_data.astype(x.type.dtype)
166-
if x.owner and isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1):
167-
array_data = extract_obs_data(x.owner.inputs[0])
168-
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])
169-
mask = np.zeros_like(array_data)
170-
mask[mask_idx] = 1
171-
return np.ma.MaskedArray(array_data, mask)
166+
if x.owner is not None:
167+
if isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast):
168+
array_data = extract_obs_data(x.owner.inputs[0])
169+
return array_data.astype(x.type.dtype)
170+
if isinstance(x.owner.op, MinibatchOp):
171+
return extract_obs_data(x.owner.inputs[x.owner.outputs.index(x)])
172+
if isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1):
173+
array_data = extract_obs_data(x.owner.inputs[0])
174+
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])
175+
mask = np.zeros_like(array_data)
176+
mask[mask_idx] = 1
177+
return np.ma.MaskedArray(array_data, mask)
172178

173179
raise TypeError(f"Data cannot be extracted from {x}")
174180

pymc/variational/opvi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,13 @@ class GroupError(VariationalInferenceError, TypeError):
116116

117117
def _known_scan_ignored_inputs(terms):
118118
# TODO: remove when scan issue with grads is fixed
119-
from pymc.data import MinibatchIndexRV
119+
from pymc.data import MinibatchOp
120120
from pymc.distributions.simulator import SimulatorRV
121121

122122
return [
123123
n.owner.inputs[0]
124124
for n in pytensor.graph.ancestors(terms)
125-
if n.owner is not None and isinstance(n.owner.op, MinibatchIndexRV | SimulatorRV)
125+
if n.owner is not None and isinstance(n.owner.op, MinibatchOp | SimulatorRV)
126126
]
127127

128128

tests/test_data.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import io
1616
import itertools as it
17-
import re
1817

1918
from os import path
2019

@@ -29,7 +28,7 @@
2928

3029
import pymc as pm
3130

32-
from pymc.data import is_minibatch
31+
from pymc.data import MinibatchOp
3332
from pymc.pytensorf import GeneratorOp, floatX
3433

3534

@@ -593,44 +592,34 @@ class TestMinibatch:
593592

594593
def test_1d(self):
595594
mb = pm.Minibatch(self.data, batch_size=20)
596-
assert is_minibatch(mb)
597-
assert mb.eval().shape == (20, 10)
595+
assert isinstance(mb.owner.op, MinibatchOp)
596+
draw1, draw2 = pm.draw(mb, draws=2)
597+
assert draw1.shape == (20, 10)
598+
assert draw2.shape == (20, 10)
599+
assert not np.all(draw1 == draw2)
598600

599601
def test_allowed(self):
600602
mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20)
601-
assert is_minibatch(mb)
603+
assert isinstance(mb.owner.op, MinibatchOp)
602604

603-
def test_not_allowed(self):
604605
with pytest.raises(ValueError, match="not valid for Minibatch"):
605-
mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
606+
pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
606607

607-
def test_not_allowed2(self):
608608
with pytest.raises(ValueError, match="not valid for Minibatch"):
609-
mb = pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)
609+
pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)
610610

611611
def test_assert(self):
612+
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
612613
with pytest.raises(
613614
AssertionError, match=r"All variables shape\[0\] in Minibatch should be equal"
614615
):
615-
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
616616
d1.eval()
617617

618618
def test_multiple_vars(self):
619619
A = np.arange(1000)
620-
B = np.arange(1000)
620+
B = -np.arange(1000)
621621
mA, mB = pm.Minibatch(A, B, batch_size=10)
622622

623623
[draw_mA, draw_mB] = pm.draw([mA, mB])
624624
assert draw_mA.shape == (10,)
625-
np.testing.assert_allclose(draw_mA, draw_mB)
626-
627-
# Check invalid dims
628-
A = np.arange(1000)
629-
C = np.arange(999)
630-
mA, mC = pm.Minibatch(A, C, batch_size=10)
631-
632-
with pytest.raises(
633-
AssertionError,
634-
match=re.escape("All variables shape[0] in Minibatch should be equal"),
635-
):
636-
pm.draw([mA, mC])
625+
np.testing.assert_allclose(draw_mA, -draw_mB)

0 commit comments

Comments
 (0)