@@ -133,22 +133,27 @@ class MinibatchIndexRV(IntegersRV):
133
133
134
134
135
135
def is_minibatch (v ):
136
- from aesara .scalar import Cast
137
- from aesara .tensor .elemwise import Elemwise
138
136
from aesara .tensor .subtensor import AdvancedSubtensor
139
137
140
138
return (
141
139
isinstance (v .owner .op , AdvancedSubtensor )
142
140
and isinstance (v .owner .inputs [1 ].owner .op , MinibatchIndexRV )
143
- and (
144
- v .owner .inputs [0 ].owner is None
145
- # The only Aesara operation we allow on observed data is type casting
146
- # Although we could allow for any graph that does not depend on other RVs
147
- or (
148
- isinstance (v .owner .inputs [0 ].owner .op , Elemwise )
149
- and v .owner .inputs [0 ].owner .inputs [0 ].owner is None
150
- and isinstance (v .owner .inputs [0 ].owner .op .scalar_op , Cast )
151
- )
141
+ and valid_for_minibatch (v .owner .inputs [0 ])
142
+ )
143
+
144
+
145
+ def valid_for_minibatch (v ):
146
+ from aesara .scalar import Cast
147
+ from aesara .tensor .elemwise import Elemwise
148
+
149
+ return (
150
+ v .owner is None
151
+ # The only Aesara operation we allow on observed data is type casting
152
+ # Although we could allow for any graph that does not depend on other RVs
153
+ or (
154
+ isinstance (v .owner .op , Elemwise )
155
+ and v .owner .inputs [0 ].owner is None
156
+ and isinstance (v .owner .op .scalar_op , Cast )
152
157
)
153
158
)
154
159
@@ -176,10 +181,20 @@ def Minibatch(
176
181
rng = RandomStream ()
177
182
slc = rng .gen (minibatch_index , 0 , variable .shape [0 ], size = batch_size )
178
183
if variables :
179
- variables = (variable , * variables )
180
- return tuple ([at .as_tensor (v )[slc ] for v in variables ])
184
+ variables = list (map (at .as_tensor , (variable , * variables )))
185
+ for i , v in enumerate (variables ):
186
+ if not valid_for_minibatch (v ):
187
+ raise ValueError (
188
+ f"{ i } : { v } is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
189
+ )
190
+ return tuple ([v [slc ] for v in variables ])
181
191
else :
182
- return at .as_tensor (variable )[slc ]
192
+ variable = at .as_tensor (variable )
193
+ if not valid_for_minibatch (variable ):
194
+ raise ValueError (
195
+ f"{ variable } is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
196
+ )
197
+ return variable [slc ]
183
198
184
199
185
200
def determine_coords (
0 commit comments