Skip to content

Commit 8c93bb5

Browse files
committed
Fix Minibatch for multiple variables
1 parent db15ae4 commit 8c93bb5

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

pymc/data.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,11 @@ def assert_all_scalars_equal(scalar, *scalars):
170170
else:
171171
return Assert(
172172
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
173-
)(scalar, pt.all([scalar == s for s in scalars]))
173+
)(scalar, pt.all([pt.eq(scalar, s) for s in scalars]))
174174

175175

176176
def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int):
177-
"""
178-
Get random slices from variables from the leading dimension.
179-
177+
"""Get random slices from variables from the leading dimension.
180178
181179
Parameters
182180
----------
@@ -191,6 +189,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
191189
>>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10)
192190
"""
193191

192+
if not isinstance(batch_size, int):
193+
raise TypeError("batch_size must be an integer")
194+
194195
tensor, *tensors = tuple(map(pt.as_tensor, (variable, *variables)))
195196
upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)])
196197
slc = minibatch_index(0, upper, size=batch_size)

tests/test_data.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import io
1616
import itertools as it
17+
import re
1718

1819
import cloudpickle
1920
import numpy as np
@@ -614,3 +615,23 @@ def test_assert(self):
614615
):
615616
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
616617
d1.eval()
618+
619+
def test_multiple_vars(self):
620+
A = np.arange(1000)
621+
B = np.arange(1000)
622+
mA, mB = pm.Minibatch(A, B, batch_size=10)
623+
624+
[draw_mA, draw_mB] = pm.draw([mA, mB])
625+
assert draw_mA.shape == (10,)
626+
np.testing.assert_allclose(draw_mA, draw_mB)
627+
628+
# Check invalid dims
629+
A = np.arange(1000)
630+
C = np.arange(999)
631+
mA, mC = pm.Minibatch(A, C, batch_size=10)
632+
633+
with pytest.raises(
634+
AssertionError,
635+
match=re.escape("All variables shape[0] in Minibatch should be equal"),
636+
):
637+
pm.draw([mA, mC])

0 commit comments

Comments
 (0)