Skip to content

Commit bde1bbc

Browse files
committed
Simplify RandomVariable._infer_shape
1 parent 9be43d0 commit bde1bbc

File tree

1 file changed

+21
-24
lines changed
  • pytensor/tensor/random

1 file changed

+21
-24
lines changed

pytensor/tensor/random/op.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def _infer_shape(
191191
192192
"""
193193

194+
from pytensor.tensor.extra_ops import broadcast_shape_iter
195+
194196
size_len = get_vector_length(size)
195197

196198
if size_len > 0:
@@ -216,57 +218,52 @@ def _infer_shape(
216218

217219
# Broadcast the parameters
218220
param_shapes = params_broadcast_shapes(
219-
param_shapes or [shape_tuple(p) for p in dist_params], self.ndims_params
221+
param_shapes or [shape_tuple(p) for p in dist_params],
222+
self.ndims_params,
220223
)
221224

222-
def slice_ind_dims(p, ps, n):
225+
def extract_batch_shape(p, ps, n):
223226
shape = tuple(ps)
224227

225228
if n == 0:
226-
return (p, shape)
229+
return shape
227230

228-
ind_slice = (slice(None),) * (p.ndim - n) + (0,) * n
229-
ind_shape = [
231+
batch_shape = [
230232
s if b is False else constant(1, "int64")
231-
for s, b in zip(shape[:-n], p.broadcastable[:-n])
233+
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
232234
]
233-
return (
234-
p[ind_slice],
235-
ind_shape,
236-
)
235+
return batch_shape
237236

238237
# These are versions of our actual parameters with the anticipated
239238
# dimensions (i.e. support dimensions) removed so that only the
240239
# independent variate dimensions are left.
241-
params_ind_slice = tuple(
242-
slice_ind_dims(p, ps, n)
240+
params_batch_shape = tuple(
241+
extract_batch_shape(p, ps, n)
243242
for p, ps, n in zip(dist_params, param_shapes, self.ndims_params)
244243
)
245244

246-
if len(params_ind_slice) == 1:
247-
_, shape_ind = params_ind_slice[0]
248-
elif len(params_ind_slice) > 1:
245+
if len(params_batch_shape) == 1:
246+
[batch_shape] = params_batch_shape
247+
elif len(params_batch_shape) > 1:
249248
# If there are multiple parameters, the dimensions of their
250249
# independent variates should broadcast together.
251-
p_slices, p_shapes = zip(*params_ind_slice)
252-
253-
shape_ind = pytensor.tensor.extra_ops.broadcast_shape_iter(
254-
p_shapes, arrays_are_shapes=True
250+
batch_shape = broadcast_shape_iter(
251+
params_batch_shape,
252+
arrays_are_shapes=True,
255253
)
256-
257254
else:
258255
# Distribution has no parameters
259-
shape_ind = ()
256+
batch_shape = ()
260257

261258
if self.ndim_supp == 0:
262-
shape_supp = ()
259+
supp_shape = ()
263260
else:
264-
shape_supp = self._supp_shape_from_params(
261+
supp_shape = self._supp_shape_from_params(
265262
dist_params,
266263
param_shapes=param_shapes,
267264
)
268265

269-
shape = tuple(shape_ind) + tuple(shape_supp)
266+
shape = tuple(batch_shape) + tuple(supp_shape)
270267
if not shape:
271268
shape = constant([], dtype="int64")
272269

0 commit comments

Comments
 (0)