Skip to content

Commit 71a5071

Browse files
Change rng.__getstate__ to rng.bit_generator.state
numpy.random.Generator.__getstate__() now returns none; to see the state of the bit generator, you need to use Generator.bit_generator.state. This change affects `RandomGeneratorType`, and several of the random tests (including some for Jax.)
1 parent 0f1b286 commit 71a5071

File tree

5 files changed

+20
-12
lines changed

5 files changed

+20
-12
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def assert_size_argument_jax_compatible(node):
5656

5757
@jax_typify.register(Generator)
5858
def jax_typify_Generator(rng, **kwargs):
59-
state = rng.__getstate__()
59+
state = rng.bit_generator.state
6060
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
6161

6262
# XXX: Is this a reasonable approach?

pytensor/tensor/random/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def filter(self, data, strict=False, allow_downcast=None):
8787

8888
@staticmethod
8989
def values_eq(a, b):
90-
sa = a if isinstance(a, dict) else a.__getstate__()
91-
sb = b if isinstance(b, dict) else b.__getstate__()
90+
sa = a if isinstance(a, dict) else a.bit_generator.state
91+
sb = b if isinstance(b, dict) else b.bit_generator.state
9292

9393
def _eq(sa, sb):
9494
for key in sa:

tests/link/jax/test_random.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def test_random_updates(rng_ctor):
6363
assert all(
6464
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
6565
for a, b in zip(
66-
rng.get_value().__getstate__(), original_value.__getstate__(), strict=True
66+
rng.get_value().bit_generator.state,
67+
original_value.bit_generator.state,
68+
strict=True,
6769
)
6870
)
6971

tests/tensor/random/test_type.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_filter(self):
5252
with pytest.raises(TypeError):
5353
rng_type.filter(1)
5454

55-
rng_dict = rng.__getstate__()
55+
rng_dict = rng.bit_generator.state
5656

5757
assert rng_type.is_valid_value(rng_dict) is False
5858
assert rng_type.is_valid_value(rng_dict, strict=False)
@@ -88,13 +88,13 @@ def test_values_eq(self):
8888
assert rng_type.values_eq(bitgen_g, bitgen_h)
8989

9090
assert rng_type.is_valid_value(bitgen_a, strict=True)
91-
assert rng_type.is_valid_value(bitgen_b.__getstate__(), strict=False)
91+
assert rng_type.is_valid_value(bitgen_b.bit_generator.state, strict=False)
9292
assert rng_type.is_valid_value(bitgen_c, strict=True)
93-
assert rng_type.is_valid_value(bitgen_d.__getstate__(), strict=False)
93+
assert rng_type.is_valid_value(bitgen_d.bit_generator.state, strict=False)
9494
assert rng_type.is_valid_value(bitgen_e, strict=True)
95-
assert rng_type.is_valid_value(bitgen_f.__getstate__(), strict=False)
95+
assert rng_type.is_valid_value(bitgen_f.bit_generator.state, strict=False)
9696
assert rng_type.is_valid_value(bitgen_g, strict=True)
97-
assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False)
97+
assert rng_type.is_valid_value(bitgen_h.bit_generator.state, strict=False)
9898

9999
def test_may_share_memory(self):
100100
bg_a = np.random.PCG64()

tests/tensor/random/test_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,20 @@ def test_seed(self, rng_ctor):
165165
state_rng = random.state_updates[0][0].get_value(borrow=True)
166166

167167
if hasattr(state_rng, "get_state"):
168-
ref_state = ref_rng.get_state()
169168
random_state = state_rng.get_state()
169+
170+
# hack to try to get something reasonable for ref_rng
171+
try:
172+
ref_state = ref_rng.get_state()
173+
except AttributeError:
174+
ref_state = list(ref_rng.bit_generator.state.values())
175+
170176
assert np.array_equal(random_state[1], ref_state[1])
171177
assert random_state[0] == ref_state[0]
172178
assert random_state[2:] == ref_state[2:]
173179
else:
174-
ref_state = ref_rng.__getstate__()
175-
random_state = state_rng.__getstate__()
180+
ref_state = ref_rng.bit_generator.state
181+
random_state = state_rng.bit_generator.state
176182
assert random_state["bit_generator"] == ref_state["bit_generator"]
177183
assert random_state["state"] == ref_state["state"]
178184

0 commit comments

Comments
 (0)