Skip to content

Commit 6a767e7

Browse files
Changed copy to deepcopy for rng
This was done for the python linker and numba linker. deepcopy seems to be the recommended method for copying a numpy Generator. After this numpy PR: numpy/numpy@44ba7ca `copy` didn't seem to actually make an independent copy of the `np.random.Generator` objects spawned by `RandomStream`. This was causing the "test values" computed by e.g. `RandomStream.uniform` to increment the RNG state, which was causing tests that rely on `RandomStream` to fail. Here is some related discussion: numpy/numpy#24086 I didn't see any official documentation about a change in numpy that would make copy stop working.
1 parent 2d6efd6 commit 6a767e7

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable
2-
from copy import copy
2+
from copy import copy, deepcopy
33
from functools import singledispatch
44
from textwrap import dedent
55

@@ -34,7 +34,7 @@ def copy_NumPyRandomGenerator(rng):
3434
def impl(rng):
3535
# TODO: Open issue on Numba?
3636
with numba.objmode(new_rng=types.npy_rng):
37-
new_rng = copy(rng)
37+
new_rng = deepcopy(rng)
3838

3939
return new_rng
4040

pytensor/tensor/random/op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22
from collections.abc import Sequence
3-
from copy import copy
3+
from copy import deepcopy
44
from typing import Any, cast
55

66
import numpy as np
@@ -395,7 +395,7 @@ def perform(self, node, inputs, outputs):
395395

396396
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
397397
if not self.inplace:
398-
rng = copy(rng)
398+
rng = deepcopy(rng)
399399

400400
outputs[0][0] = rng
401401
outputs[1][0] = np.asarray(

tests/tensor/random/test_basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pickle
22
import re
3-
from copy import copy
3+
from copy import deepcopy
44

55
import numpy as np
66
import pytest
@@ -113,7 +113,9 @@ def test_fn(*args, random_state=None, **kwargs):
113113

114114
pt_rng = shared(rng, borrow=True)
115115

116-
numpy_res = np.asarray(test_fn(*param_vals, random_state=copy(rng), **kwargs_vals))
116+
numpy_res = np.asarray(
117+
test_fn(*param_vals, random_state=deepcopy(rng), **kwargs_vals)
118+
)
117119

118120
pytensor_res = rv(*params, rng=pt_rng, **kwargs)
119121

0 commit comments

Comments
 (0)