Skip to content

Commit c29963b

Browse files
committed
Speedup random perform
1 parent e382828 commit c29963b

File tree

1 file changed

+4
-7
lines changed
  • pytensor/tensor/random

1 file changed

+4
-7
lines changed

pytensor/tensor/random/op.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from pytensor.graph.basic import Apply, Variable, equal_computations
1111
from pytensor.graph.op import Op
1212
from pytensor.graph.replace import _vectorize_node
13-
from pytensor.misc.safe_asarray import _asarray
1413
from pytensor.scalar import ScalarVariable
1514
from pytensor.tensor.basic import (
1615
as_tensor_variable,
@@ -389,7 +388,6 @@ def dist_params(self, node) -> Sequence[Variable]:
389388

390389
def perform(self, node, inputs, outputs):
391390
rng_var_out, smpl_out = outputs
392-
393391
rng, size, *args = inputs
394392

395393
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
@@ -400,12 +398,11 @@ def perform(self, node, inputs, outputs):
400398

401399
if size is not None:
402400
size = tuple(size)
403-
smpl_val = self.rng_fn(rng, *([*args, size]))
404-
405-
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
406-
smpl_val = _asarray(smpl_val, dtype=self.dtype)
407401

408-
smpl_out[0] = smpl_val
402+
smpl_out[0] = np.asarray(
403+
self.rng_fn(rng, *args, size),
404+
dtype=self.dtype,
405+
)
409406

410407
def grad(self, inputs, outputs):
411408
return [

0 commit comments

Comments
 (0)