@@ -158,9 +158,7 @@ def valid_for_minibatch(v):
158
158
)
159
159
160
160
161
- def Minibatch (
162
- variable : TensorVariable , * variables : TensorVariable , batch_size : int
163
- ) -> Tuple [TensorVariable ]:
161
+ def Minibatch (variable : TensorVariable , * variables : TensorVariable , batch_size : int ):
164
162
"""
165
163
Get random slices from variables from the leading dimension.
166
164
@@ -181,20 +179,20 @@ def Minibatch(
181
179
rng = RandomStream ()
182
180
slc = rng .gen (minibatch_index , 0 , variable .shape [0 ], size = batch_size )
183
181
if variables :
184
- variables = list (map (at .as_tensor , (variable , * variables )))
185
- for i , v in enumerate (variables ):
182
+ tensors = tuple (map (at .as_tensor , (variable , * variables )))
183
+ for i , v in enumerate (tensors ):
186
184
if not valid_for_minibatch (v ):
187
185
raise ValueError (
188
186
f"{ i } : { v } is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
189
187
)
190
- return tuple ([v [slc ] for v in variables ])
188
+ return tuple ([v [slc ] for v in tensors ])
191
189
else :
192
- variable = at .as_tensor (variable )
193
- if not valid_for_minibatch (variable ):
190
+ tensor = at .as_tensor (variable )
191
+ if not valid_for_minibatch (tensor ):
194
192
raise ValueError (
195
193
f"{ variable } is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
196
194
)
197
- return variable [slc ]
195
+ return tensor [slc ]
198
196
199
197
200
198
def determine_coords (
0 commit comments