Skip to content

Commit aa74cf2

Browse files
author
Maxim Kochurov
committed
implement index in a different way
1 parent cd86ed8 commit aa74cf2

File tree

4 files changed

+43
-12
lines changed

4 files changed

+43
-12
lines changed

pymc/data.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@
2626
import numpy as np
2727

2828
from aesara.compile.sharedvalue import SharedVariable
29+
from aesara.tensor.random import RandomStream
30+
from aesara.tensor.random.basic import IntegersRV
2931
from aesara.tensor.type import TensorType
3032
from aesara.tensor.var import TensorConstant, TensorVariable
3133

3234
import pymc as pm
3335

34-
from pymc.aesaraf import at_rng, convert_observed_data
36+
from pymc.aesaraf import convert_observed_data
3537

3638
__all__ = [
3739
"get_data",
@@ -123,6 +125,34 @@ def __hash__(self):
123125
return hash(id(self))
124126

125127

128+
class MinibatchIndexRV(IntegersRV):
129+
_print_name = ("minibatch_index", r"\operatorname{minibatch\_index}")
130+
131+
132+
minibatch_index = MinibatchIndexRV()
133+
134+
135+
def is_minibatch(v):
136+
from aesara.scalar import Cast
137+
from aesara.tensor.elemwise import Elemwise
138+
from aesara.tensor.subtensor import AdvancedSubtensor
139+
140+
return (
141+
isinstance(v.owner.op, AdvancedSubtensor)
142+
and isinstance(v.owner.inputs[1].owner.op, MinibatchIndexRV)
143+
and (
144+
v.owner.inputs[0].owner is None
145+
# The only Aesara operation we allow on observed data is type casting
146+
# Although we could allow for any graph that does not depend on other RVs
147+
or (
148+
isinstance(v.owner.inputs[0].owner.op, Elemwise)
149+
and v.owner.inputs[0].owner.inputs[0].owner is None
150+
and isinstance(v.owner.inputs[0].owner.op.scalar_op, Cast)
151+
)
152+
)
153+
)
154+
155+
126156
def Minibatch(
127157
variable: TensorVariable, *variables: TensorVariable, batch_size: int
128158
) -> Tuple[TensorVariable]:
@@ -143,17 +173,13 @@ def Minibatch(
143173
>>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10)
144174
"""
145175

146-
def _minibatch_name(v1, v0):
147-
base_name = getattr(v0, "name", "")
148-
v1.name = f"minibatch_{base_name}_{id(v0)}"
149-
return v1
150-
151-
slc = at_rng().uniform(0, variable.shape[0], size=batch_size).astype(np.int64)
176+
rng = RandomStream()
177+
slc = rng.gen(minibatch_index, 0, variable.shape[0], size=batch_size)
152178
if variables:
153179
variables = (variable, *variables)
154-
return tuple([_minibatch_name(at.as_tensor(v)[slc], v) for v in variables])
180+
return tuple([at.as_tensor(v)[slc] for v in variables])
155181
else:
156-
return _minibatch_name(at.as_tensor(variable)[slc], variable)
182+
return at.as_tensor(variable)[slc]
157183

158184

159185
def determine_coords(

pymc/distributions/logprob.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from aesara.tensor.var import TensorVariable
3333

3434
from pymc.aesaraf import constant_fold, floatX
35+
from pymc.data import MinibatchIndexRV
3536

3637
TOTAL_SIZE = Union[int, Sequence[int], None]
3738

@@ -116,6 +117,7 @@ def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
116117
node.owner
117118
and isinstance(node.owner.op, RandomVariable)
118119
and not isinstance(node.owner.op, SimulatorRV)
120+
and not isinstance(node.owner.op, MinibatchIndexRV)
119121
)
120122
]
121123
if unexpected_rv_nodes:

pymc/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
replace_rvs_by_values,
6161
)
6262
from pymc.blocking import DictToArrayBijection, RaveledVars
63-
from pymc.data import GenTensorVariable
63+
from pymc.data import GenTensorVariable, is_minibatch
6464
from pymc.distributions.logprob import _joint_logp
6565
from pymc.distributions.transforms import _default_transform
6666
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning
@@ -1311,7 +1311,7 @@ def register_rv(
13111311
isinstance(data.owner.op, Elemwise)
13121312
and isinstance(data.owner.op.scalar_op, Cast)
13131313
)
1314-
and not data.name.startswith("minibatch")
1314+
and not is_minibatch(data)
13151315
):
13161316
raise TypeError(
13171317
"Variables that depend on other nodes cannot be used for observed data."

pymc/tests/test_data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,10 +695,13 @@ def test_common_errors(self):
695695
assert "Double Ellipsis" in str(e.value)
696696

697697
def test_mixed1(self):
698+
from pymc.distributions.logprob import joint_logp
699+
698700
with pm.Model():
699701
data = np.random.rand(10, 20)
700702
mb = pm.Minibatch(data, batch_size=5)
701-
pm.Normal("n", observed=mb, total_size=10)
703+
v = pm.Normal("n", observed=mb, total_size=10)
704+
assert joint_logp(v) is not None, "Check index is allowed in graph"
702705

703706
def test_free_rv(self):
704707
with pm.Model() as model4:

0 commit comments

Comments
 (0)