Skip to content

Commit 84d15c8

Browse files
authored
Broadcast shapes of alpha and beta in Weibull rng (#7288)
1 parent 29fd732 commit 84d15c8

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2499,6 +2499,8 @@ def __call__(self, alpha, beta, size=None, **kwargs):
24992499

25002500
@classmethod
25012501
def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray:
2502+
if size is None:
2503+
size = np.broadcast_shapes(alpha.shape, beta.shape)
25022504
return np.asarray(beta * rng.weibull(alpha, size=size))
25032505

25042506

tests/distributions/test_continuous.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,6 +2369,14 @@ def seeded_weibul_rng_fn(self):
23692369
"check_rv_size",
23702370
]
23712371

2372+
def test_rng_different_shapes(self):
2373+
# See issue #7220
2374+
rng = np.random.default_rng(123)
2375+
alpha = np.abs(rng.normal(size=5))
2376+
beta = np.abs(rng.normal(size=(3, 1)))
2377+
draws = pm.draw(pm.Weibull.dist(alpha, beta), random_seed=rng)
2378+
assert len(np.unique(draws)) == draws.size
2379+
23722380

23732381
@pytest.mark.skipif(
23742382
condition=_polyagamma_not_installed,

0 commit comments

Comments
 (0)