10
10
from pytensor .graph .basic import Apply , Variable , equal_computations
11
11
from pytensor .graph .op import Op
12
12
from pytensor .graph .replace import _vectorize_node
13
- from pytensor .misc .safe_asarray import _asarray
14
13
from pytensor .scalar import ScalarVariable
15
14
from pytensor .tensor .basic import (
16
15
as_tensor_variable ,
@@ -389,7 +388,6 @@ def dist_params(self, node) -> Sequence[Variable]:
389
388
390
389
def perform (self , node , inputs , outputs ):
391
390
rng_var_out , smpl_out = outputs
392
-
393
391
rng , size , * args = inputs
394
392
395
393
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
@@ -400,12 +398,11 @@ def perform(self, node, inputs, outputs):
400
398
401
399
if size is not None :
402
400
size = tuple (size )
403
- smpl_val = self .rng_fn (rng , * ([* args , size ]))
404
-
405
- if not isinstance (smpl_val , np .ndarray ) or str (smpl_val .dtype ) != self .dtype :
406
- smpl_val = _asarray (smpl_val , dtype = self .dtype )
407
401
408
- smpl_out [0 ] = smpl_val
402
+ smpl_out [0 ] = np .asarray (
403
+ self .rng_fn (rng , * args , size ),
404
+ dtype = self .dtype ,
405
+ )
409
406
410
407
def grad (self , inputs , outputs ):
411
408
return [
0 commit comments