Skip to content

Commit 45c8656

Browse files
committed
StandandardNormalRV is now just a helper function
1 parent 6af4abc commit 45c8656

File tree

4 files changed

+25
-42
lines changed

4 files changed

+25
-42
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def sample_fn(rng, size, dtype, *parameters):
145145
@jax_sample_fn.register(ptr.LaplaceRV)
146146
@jax_sample_fn.register(ptr.LogisticRV)
147147
@jax_sample_fn.register(ptr.NormalRV)
148-
@jax_sample_fn.register(ptr.StandardNormalRV)
149148
def jax_sample_fn_loc_scale(op):
150149
"""JAX implementation of random variables in the loc-scale families.
151150

pytensor/tensor/random/basic.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -278,38 +278,24 @@ def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs):
278278
normal = NormalRV()
279279

280280

281-
class StandardNormalRV(NormalRV):
282-
r"""A standard normal continuous random variable.
281+
def standard_normal(*, size=None, rng=None, dtype=None):
282+
"""Draw samples from a standard normal distribution.
283283
284-
The probability density function for `standard_normal` is:
284+
Signature
285+
---------
285286
286-
.. math::
287+
`nil -> ()`
287288
288-
f(x) = \frac{1}{\sqrt{2 \pi}} e^{-\frac{x^2}{2}}
289+
Parameters
290+
----------
291+
size
292+
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
293+
independent, identically distributed random variables are
294+
returned. Default is `None` in which case a single random variable
295+
is returned.
289296
290297
"""
291-
292-
def __call__(self, size=None, **kwargs):
293-
"""Draw samples from a standard normal distribution.
294-
295-
Signature
296-
---------
297-
298-
`nil -> ()`
299-
300-
Parameters
301-
----------
302-
size
303-
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
304-
independent, identically distributed random variables are
305-
returned. Default is `None` in which case a single random variable
306-
is returned.
307-
308-
"""
309-
return super().__call__(loc=0.0, scale=1.0, size=size, **kwargs)
310-
311-
312-
standard_normal = StandardNormalRV()
298+
return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype)
313299

314300

315301
class HalfNormalRV(ScipyRandomVariable):

pytensor/tensor/random/utils.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,9 @@ def __init__(
218218
if namespace is None:
219219
from pytensor.tensor.random import basic # pylint: disable=import-self
220220

221-
self.namespaces = [basic]
221+
self.namespaces = [(basic, set(basic.__all__))]
222222
else:
223-
self.namespaces = [namespace]
223+
self.namespaces = [(namespace, set(namespace.__all__))]
224224

225225
self.default_instance_seed = seed
226226
self.state_updates = []
@@ -235,22 +235,20 @@ def rng_ctor(seed):
235235

236236
def __getattr__(self, obj):
237237
ns_obj = next(
238-
(getattr(ns, obj) for ns in self.namespaces if hasattr(ns, obj)), None
238+
(
239+
getattr(ns, obj)
240+
for ns, all_ in self.namespaces
241+
if obj in all_ and hasattr(ns, obj)
242+
),
243+
None,
239244
)
240245

241246
if ns_obj is None:
242247
raise AttributeError(f"No attribute {obj}.")
243248

244-
from pytensor.tensor.random.op import RandomVariable
245-
246-
if isinstance(ns_obj, RandomVariable):
247-
248-
@wraps(ns_obj)
249-
def meta_obj(*args, **kwargs):
250-
return self.gen(ns_obj, *args, **kwargs)
251-
252-
else:
253-
raise AttributeError(f"No attribute {obj}.")
249+
@wraps(ns_obj)
250+
def meta_obj(*args, **kwargs):
251+
return self.gen(ns_obj, *args, **kwargs)
254252

255253
setattr(self, obj, meta_obj)
256254
return getattr(self, obj)

tests/tensor/random/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_basics(self, rng_ctor):
114114
assert hasattr(random, "standard_normal")
115115

116116
with pytest.raises(AttributeError):
117-
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
117+
np_random = RandomStream(namespace=np.random, rng_ctor=rng_ctor)
118118
np_random.ndarray
119119

120120
fn = function([], random.uniform(0, 1, size=(2, 2)), updates=random.updates())

0 commit comments

Comments
 (0)