Skip to content

Commit 1322dfa

Browse files
author
Maxim Kochurov
committed
more strict approach for minibatch validation
1 parent 0c9ae26 commit 1322dfa

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

pymc/data.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,22 +133,27 @@ class MinibatchIndexRV(IntegersRV):
133133

134134

135135
def is_minibatch(v):
136-
from aesara.scalar import Cast
137-
from aesara.tensor.elemwise import Elemwise
138136
from aesara.tensor.subtensor import AdvancedSubtensor
139137

140138
return (
141139
isinstance(v.owner.op, AdvancedSubtensor)
142140
and isinstance(v.owner.inputs[1].owner.op, MinibatchIndexRV)
143-
and (
144-
v.owner.inputs[0].owner is None
145-
# The only Aesara operation we allow on observed data is type casting
146-
# Although we could allow for any graph that does not depend on other RVs
147-
or (
148-
isinstance(v.owner.inputs[0].owner.op, Elemwise)
149-
and v.owner.inputs[0].owner.inputs[0].owner is None
150-
and isinstance(v.owner.inputs[0].owner.op.scalar_op, Cast)
151-
)
141+
and valid_for_minibatch(v.owner.inputs[0])
142+
)
143+
144+
145+
def valid_for_minibatch(v):
146+
from aesara.scalar import Cast
147+
from aesara.tensor.elemwise import Elemwise
148+
149+
return (
150+
v.owner is None
151+
# The only Aesara operation we allow on observed data is type casting
152+
# Although we could allow for any graph that does not depend on other RVs
153+
or (
154+
isinstance(v.owner.op, Elemwise)
155+
and v.owner.inputs[0].owner is None
156+
and isinstance(v.owner.op.scalar_op, Cast)
152157
)
153158
)
154159

@@ -176,10 +181,20 @@ def Minibatch(
176181
rng = RandomStream()
177182
slc = rng.gen(minibatch_index, 0, variable.shape[0], size=batch_size)
178183
if variables:
179-
variables = (variable, *variables)
180-
return tuple([at.as_tensor(v)[slc] for v in variables])
184+
variables = list(map(at.as_tensor, (variable, *variables)))
185+
for i, v in enumerate(variables):
186+
if not valid_for_minibatch(v):
187+
raise ValueError(
188+
f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
189+
)
190+
return tuple([v[slc] for v in variables])
181191
else:
182-
return at.as_tensor(variable)[slc]
192+
variable = at.as_tensor(variable)
193+
if not valid_for_minibatch(variable):
194+
raise ValueError(
195+
f"{variable} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
196+
)
197+
return variable[slc]
183198

184199

185200
def determine_coords(

pymc/tests/test_data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,5 +720,22 @@ class TestMinibatch:
720720
data = np.random.rand(30, 10)
721721

722722
def test_1d(self):
723+
from pymc.data import is_minibatch
724+
723725
mb = pm.Minibatch(self.data, batch_size=20)
726+
assert is_minibatch(mb)
724727
assert mb.eval().shape == (20, 10)
728+
729+
def test_allowed(self):
730+
from pymc.data import is_minibatch
731+
732+
mb = pm.Minibatch(at.as_tensor(self.data).astype(int), batch_size=20)
733+
assert is_minibatch(mb)
734+
735+
def test_not_allowed(self):
736+
with pytest.raises(ValueError):
737+
mb = pm.Minibatch(at.as_tensor(self.data) * 2, batch_size=20)
738+
739+
def test_not_allowed2(self):
740+
with pytest.raises(ValueError):
741+
mb = pm.Minibatch(self.data, at.as_tensor(self.data) * 2, batch_size=20)

0 commit comments

Comments
 (0)