Skip to content

Commit 05d376f

Browse files
committed
Fix bug in vectorize_random_variable when size is empty but not None
1 parent 8550622 commit 05d376f

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

pytensor/tensor/random/op.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _infer_shape(
238238
raise ValueError(
239239
f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n"
240240
f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. "
241-
f"Size length must be 0 or >= {param_batched_dims}"
241+
f"Size must be None or have length >= {param_batched_dims}"
242242
)
243243

244244
return tuple(size) + supp_shape
@@ -454,11 +454,10 @@ def vectorize_random_variable(
454454

455455
original_dist_params = op.dist_params(node)
456456
old_size = op.size_param(node)
457-
len_old_size = (
458-
None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size)
459-
)
460457

461-
if len_old_size and equal_computations([old_size], [size]):
458+
if not isinstance(old_size.type, NoneTypeT) and equal_computations(
459+
[old_size], [size]
460+
):
462461
# If the original RV had a size variable and a new one has not been provided,
463462
# we need to define a new size as the concatenation of the original size dimensions
464463
# and the novel ones implied by new broadcasted batched parameters dimensions.

tests/tensor/random/test_op.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,16 @@ def test_vectorize():
296296
assert vect_node.default_output().type.shape == (10, 2, 5)
297297

298298

299+
def test_vectorize_empty_size():
300+
scalar_mu = pt.scalar("scalar_mu")
301+
scalar_x = pt.random.normal(loc=scalar_mu, size=())
302+
assert scalar_x.type.shape == ()
303+
304+
vector_mu = pt.vector("vector_mu", shape=(5,))
305+
vector_x = vectorize_graph(scalar_x, {scalar_mu: vector_mu})
306+
assert vector_x.type.shape == (5,)
307+
308+
299309
def test_size_none_vs_empty():
300310
rv = RandomVariable(
301311
"normal",

0 commit comments

Comments
 (0)