Skip to content

Commit f65dc90

Browse files
author
Maxim Kochurov
committed
mypy does not understand tensor slicing
1 parent 1322dfa commit f65dc90

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

pymc/data.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,7 @@ def valid_for_minibatch(v):
158158
)
159159

160160

161-
def Minibatch(
162-
variable: TensorVariable, *variables: TensorVariable, batch_size: int
163-
) -> Tuple[TensorVariable]:
161+
def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int):
164162
"""
165163
Get random slices from variables from the leading dimension.
166164
@@ -181,20 +179,20 @@ def Minibatch(
181179
rng = RandomStream()
182180
slc = rng.gen(minibatch_index, 0, variable.shape[0], size=batch_size)
183181
if variables:
184-
variables = list(map(at.as_tensor, (variable, *variables)))
185-
for i, v in enumerate(variables):
182+
tensors = tuple(map(at.as_tensor, (variable, *variables)))
183+
for i, v in enumerate(tensors):
186184
if not valid_for_minibatch(v):
187185
raise ValueError(
188186
f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
189187
)
190-
return tuple([v[slc] for v in variables])
188+
return tuple([v[slc] for v in tensors])
191189
else:
192-
variable = at.as_tensor(variable)
193-
if not valid_for_minibatch(variable):
190+
tensor = at.as_tensor(variable)
191+
if not valid_for_minibatch(tensor):
194192
raise ValueError(
195193
f"{variable} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
196194
)
197-
return variable[slc]
195+
return tensor[slc]
198196

199197

200198
def determine_coords(

0 commit comments

Comments
 (0)