File tree 2 files changed +14
-5
lines changed
2 files changed +14
-5
lines changed Original file line number Diff line number Diff line change @@ -238,7 +238,7 @@ def _infer_shape(
238
238
raise ValueError (
239
239
f"Size length is incompatible with batched dimensions of parameter { i } { param } :\n "
240
240
f"len(size) = { size_len } , len(batched dims { param } ) = { param_batched_dims } . "
241
- f"Size length must be 0 or >= { param_batched_dims } "
241
+ f"Size must be None or have length >= { param_batched_dims } "
242
242
)
243
243
244
244
return tuple (size ) + supp_shape
@@ -454,11 +454,10 @@ def vectorize_random_variable(
454
454
455
455
original_dist_params = op .dist_params (node )
456
456
old_size = op .size_param (node )
457
- len_old_size = (
458
- None if isinstance (old_size .type , NoneTypeT ) else get_vector_length (old_size )
459
- )
460
457
461
- if len_old_size and equal_computations ([old_size ], [size ]):
458
+ if not isinstance (old_size .type , NoneTypeT ) and equal_computations (
459
+ [old_size ], [size ]
460
+ ):
462
461
# If the original RV had a size variable and a new one has not been provided,
463
462
# we need to define a new size as the concatenation of the original size dimensions
464
463
# and the novel ones implied by new broadcasted batched parameters dimensions.
Original file line number Diff line number Diff line change @@ -296,6 +296,16 @@ def test_vectorize():
296
296
assert vect_node .default_output ().type .shape == (10 , 2 , 5 )
297
297
298
298
299
+ def test_vectorize_empty_size ():
300
+ scalar_mu = pt .scalar ("scalar_mu" )
301
+ scalar_x = pt .random .normal (loc = scalar_mu , size = ())
302
+ assert scalar_x .type .shape == ()
303
+
304
+ vector_mu = pt .vector ("vector_mu" , shape = (5 ,))
305
+ vector_x = vectorize_graph (scalar_x , {scalar_mu : vector_mu })
306
+ assert vector_x .type .shape == (5 ,)
307
+
308
+
299
309
def test_size_none_vs_empty ():
300
310
rv = RandomVariable (
301
311
"normal" ,
You can’t perform that action at this time.
0 commit comments