Skip to content

Commit a64055d

Browse files
committed
Fail early in RandomVariable.make_node when size is incompatible with parameters dimensionality
1 parent 237f54f commit a64055d

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

pytensor/tensor/random/op.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,19 @@ def _infer_shape(
192192
size_len = get_vector_length(size)
193193

194194
if size_len > 0:
195+
196+
# Fail early when size is incompatible with parameters
197+
for i, (param, param_ndim_supp) in enumerate(
198+
zip(dist_params, self.ndims_params)
199+
):
200+
param_batched_dims = getattr(param, "ndim", 0) - param_ndim_supp
201+
if param_batched_dims > size_len:
202+
raise ValueError(
203+
f"Size length is incompatible with batched dimensions of parameter {i} {param}:\n"
204+
f"len(size) = {size_len}, len(batched dims {param}) = {param_batched_dims}. "
205+
f"Size length must be 0 or >= {param_batched_dims}"
206+
)
207+
195208
if self.ndim_supp == 0:
196209
return size
197210
else:

tests/tensor/random/test_op.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,17 @@ def test_random_maker_ops_no_seed():
217217
z = function(inputs=[], outputs=[default_rng()])()
218218
aes_res = z[0]
219219
assert isinstance(aes_res, np.random.Generator)
220+
221+
222+
def test_RandomVariable_incompatible_size():
223+
rv_op = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
224+
with pytest.raises(
225+
ValueError, match="Size length is incompatible with batched dimensions"
226+
):
227+
rv_op(np.zeros((1, 3)), 1, size=(3,))
228+
229+
rv_op = RandomVariable("dirichlet", 0, [1], config.floatX, inplace=True)
230+
with pytest.raises(
231+
ValueError, match="Size length is incompatible with batched dimensions"
232+
):
233+
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))

0 commit comments

Comments
 (0)