@@ -191,6 +191,8 @@ def _infer_shape(
191
191
192
192
"""
193
193
194
+ from pytensor .tensor .extra_ops import broadcast_shape_iter
195
+
194
196
size_len = get_vector_length (size )
195
197
196
198
if size_len > 0 :
@@ -216,57 +218,52 @@ def _infer_shape(
216
218
217
219
# Broadcast the parameters
218
220
param_shapes = params_broadcast_shapes (
219
- param_shapes or [shape_tuple (p ) for p in dist_params ], self .ndims_params
221
+ param_shapes or [shape_tuple (p ) for p in dist_params ],
222
+ self .ndims_params ,
220
223
)
221
224
222
- def slice_ind_dims (p , ps , n ):
225
+ def extract_batch_shape (p , ps , n ):
223
226
shape = tuple (ps )
224
227
225
228
if n == 0 :
226
- return ( p , shape )
229
+ return shape
227
230
228
- ind_slice = (slice (None ),) * (p .ndim - n ) + (0 ,) * n
229
- ind_shape = [
231
+ batch_shape = [
230
232
s if b is False else constant (1 , "int64" )
231
- for s , b in zip (shape [:- n ], p .broadcastable [:- n ])
233
+ for s , b in zip (shape [:- n ], p .type . broadcastable [:- n ])
232
234
]
233
- return (
234
- p [ind_slice ],
235
- ind_shape ,
236
- )
235
+ return batch_shape
237
236
238
237
# These are versions of our actual parameters with the anticipated
239
238
# dimensions (i.e. support dimensions) removed so that only the
240
239
# independent variate dimensions are left.
241
- params_ind_slice = tuple (
242
- slice_ind_dims (p , ps , n )
240
+ params_batch_shape = tuple (
241
+ extract_batch_shape (p , ps , n )
243
242
for p , ps , n in zip (dist_params , param_shapes , self .ndims_params )
244
243
)
245
244
246
- if len (params_ind_slice ) == 1 :
247
- _ , shape_ind = params_ind_slice [ 0 ]
248
- elif len (params_ind_slice ) > 1 :
245
+ if len (params_batch_shape ) == 1 :
246
+ [ batch_shape ] = params_batch_shape
247
+ elif len (params_batch_shape ) > 1 :
249
248
# If there are multiple parameters, the dimensions of their
250
249
# independent variates should broadcast together.
251
- p_slices , p_shapes = zip (* params_ind_slice )
252
-
253
- shape_ind = pytensor .tensor .extra_ops .broadcast_shape_iter (
254
- p_shapes , arrays_are_shapes = True
250
+ batch_shape = broadcast_shape_iter (
251
+ params_batch_shape ,
252
+ arrays_are_shapes = True ,
255
253
)
256
-
257
254
else :
258
255
# Distribution has no parameters
259
- shape_ind = ()
256
+ batch_shape = ()
260
257
261
258
if self .ndim_supp == 0 :
262
- shape_supp = ()
259
+ supp_shape = ()
263
260
else :
264
- shape_supp = self ._supp_shape_from_params (
261
+ supp_shape = self ._supp_shape_from_params (
265
262
dist_params ,
266
263
param_shapes = param_shapes ,
267
264
)
268
265
269
- shape = tuple (shape_ind ) + tuple (shape_supp )
266
+ shape = tuple (batch_shape ) + tuple (supp_shape )
270
267
if not shape :
271
268
shape = constant ([], dtype = "int64" )
272
269
0 commit comments