Skip to content

Commit b1925ba

Browse files
Make if_else accesible from root
1 parent 30b760f commit b1925ba

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

pytensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def get_underlying_scalar_constant(v):
159159
# isort: off
160160
import pytensor.tensor.random.var
161161
import pytensor.sparse
162+
from pytensor.ifelse import ifelse
162163
from pytensor.scan import checkpoints
163164
from pytensor.scan.basic import scan
164165
from pytensor.scan.views import foldl, foldr, map, reduce

pytensor/tensor/random/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ def chisquare(df, size=None, **kwargs):
526526
return gamma(shape=df / 2.0, scale=2.0, size=size, **kwargs)
527527

528528

529+
def rayleigh(scale=1.0, *, size=None, **kwargs):
530+
return chisquare(df=2, size=size, **kwargs) * as_tensor_variable(scale)
531+
532+
529533
class ParetoRV(ScipyRandomVariable):
530534
r"""A pareto continuous random variable.
531535

tests/tensor/random/test_basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
permutation,
5151
poisson,
5252
randint,
53+
rayleigh,
5354
standard_normal,
5455
t,
5556
triangular,
@@ -390,6 +391,18 @@ def test_chisquare_samples(df, size):
390391
compare_sample_values(chisquare, df, size=size, test_fn=fixed_scipy_rvs("chi2"))
391392

392393

394+
@pytest.mark.parametrize(
395+
"size",
396+
[
397+
(None),
398+
([]),
399+
],
400+
)
401+
def test_rayleigh_samples(size):
402+
compare_sample_values(rayleigh, size=size, test_fn=fixed_scipy_rvs("rayleigh"))
403+
compare_sample_values(rayleigh)
404+
405+
393406
@pytest.mark.parametrize(
394407
"mu, beta, size",
395408
[

0 commit comments

Comments
 (0)