Skip to content

Commit 08152d0

Browse files
committed
Cleanup Rop tests and fix Max Rop implementation
1 parent ea1da5d commit 08152d0

File tree

5 files changed

+121
-85
lines changed

5 files changed

+121
-85
lines changed

pytensor/tensor/math.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -431,20 +431,25 @@ def L_op(self, inputs, outputs, grads):
431431
return (g_x,)
432432

433433
def R_op(self, inputs, eval_points):
434+
[x] = inputs
434435
if eval_points[0] is None:
435-
return [None, None]
436-
if len(self.axis) != 1:
437-
raise ValueError("R_op supported for max only for one axis!")
438-
if self.axis[0] > 1:
439-
raise ValueError("R_op supported for max only when axis is 0 or 1")
436+
return [None]
437+
axis = tuple(range(x.ndim) if self.axis is None else self.axis)
438+
if isinstance(axis, int):
439+
axis = [axis]
440+
if len(axis) != 1:
441+
raise NotImplementedError("R_op supported for max only for one axis!")
442+
if axis[0] > 1:
443+
raise NotImplementedError("R_op supported for max only when axis is 0 or 1")
440444
if inputs[0].ndim != 2:
441-
raise ValueError("R_op supported for max only when input is a matrix")
442-
max_pos = Argmax(self.axis).make_node(*inputs).outputs
443-
# print(eval_points[0].eval())
445+
raise NotImplementedError(
446+
"R_op supported for max only when input is a matrix"
447+
)
448+
max_pos = Argmax(self.axis)(*inputs)
444449
if self.axis[0] == 0:
445-
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None]
450+
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])]]
446451
else:
447-
return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None]
452+
return [eval_points[0][arange(eval_points[0].shape[0]), max_pos]]
448453

449454

450455
class Min(NonZeroDimsCAReduce):

tests/scan/test_basic.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,9 +1992,9 @@ def rnn_fn(_u, _y, _W):
19921992
vnu, vnh0, vnW = fn_rop(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
19931993
tnu, tnh0, tnW = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
19941994

1995-
utt.assert_allclose(vnu, tnu, atol=1e-6)
1996-
utt.assert_allclose(vnh0, tnh0, atol=1e-6)
1997-
utt.assert_allclose(vnW, tnW, atol=1e-6)
1995+
np.testing.assert_allclose(vnu, tnu, atol=1e-6)
1996+
np.testing.assert_allclose(vnh0, tnh0, atol=1e-6)
1997+
np.testing.assert_allclose(vnW, tnW, atol=1e-6)
19981998

19991999
@pytest.mark.slow
20002000
def test_R_op_2(self):
@@ -2074,9 +2074,9 @@ def rnn_fn(_u, _y, _W):
20742074
)
20752075

20762076
tnu, tnh0, tnW, tno = fn_test(v_u, v_h0, v_W, v_eu, v_eh0, v_eW)
2077-
utt.assert_allclose(vnu, tnu, atol=1e-6)
2078-
utt.assert_allclose(vnh0, tnh0, atol=1e-6)
2079-
utt.assert_allclose(vnW, tnW, atol=2e-6)
2077+
np.testing.assert_allclose(vnu, tnu, atol=1e-6)
2078+
np.testing.assert_allclose(vnh0, tnh0, atol=1e-6)
2079+
np.testing.assert_allclose(vnW, tnW, atol=2e-6)
20802080

20812081
def test_R_op_mitmot(self):
20822082
# this test is a copy paste from the script given by Justin Bayer to
@@ -2094,13 +2094,10 @@ def test_R_op_mitmot(self):
20942094
W1 = pars[:3].reshape(W1shape)
20952095
W2 = pars[3:].reshape(W2shape)
20962096

2097-
# Define recurrent model. We are using a model where each input is a
2098-
# tensor
2099-
# of shape (T, B, D) where T is the number of timesteps, B is the
2100-
# number of
2101-
# sequences iterated over in parallel and D is the dimensionality of
2102-
# each
2103-
# item at a timestep.
2097+
# Define recurrent model. We are using a model where each input
2098+
# is a tensor of shape (T, B, D) where T is the number of timesteps,
2099+
# B is the number of sequences iterated over in parallel and
2100+
# D is the dimensionality of each item at a timestep.
21042101

21052102
inpt = tensor3("inpt")
21062103
target = tensor3("target")
@@ -2128,6 +2125,7 @@ def test_R_op_mitmot(self):
21282125
d_cost_wrt_pars = grad(cost, pars)
21292126

21302127
p = dvector()
2128+
# TODO: We should test something about the Rop!
21312129
Rop(d_cost_wrt_pars, pars, p)
21322130

21332131

tests/tensor/rewriting/test_linalg.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.tensor import swapaxes
1515
from pytensor.tensor.blockwise import Blockwise
1616
from pytensor.tensor.elemwise import DimShuffle
17-
from pytensor.tensor.math import _allclose, dot, matmul
17+
from pytensor.tensor.math import dot, matmul
1818
from pytensor.tensor.nlinalg import (
1919
SVD,
2020
Det,
@@ -42,7 +42,8 @@
4242
from tests.test_rop import break_op
4343

4444

45-
def test_rop_lop():
45+
def test_matrix_inverse_rop_lop():
46+
rtol = 1e-7 if config.floatX == "float64" else 1e-5
4647
mx = matrix("mx")
4748
mv = matrix("mv")
4849
v = vector("v")
@@ -62,23 +63,13 @@ def test_rop_lop():
6263
vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX)
6364
vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX)
6465

65-
v1 = rop_f(vx, vv)
66-
v2 = scan_f(vx, vv)
66+
v_ref = scan_f(vx, vv)
67+
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
6768

68-
assert _allclose(v1, v2), f"ROP mismatch: {v1} {v2}"
69-
70-
raised = False
71-
try:
69+
with pytest.raises(ValueError):
7270
pytensor.gradient.Rop(
7371
pytensor.clone_replace(y, replace={mx: break_op(mx)}), mx, mv
7472
)
75-
except ValueError:
76-
raised = True
77-
if not raised:
78-
raise Exception(
79-
"Op did not raised an error even though the function"
80-
" is not differentiable"
81-
)
8273

8374
vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX)
8475
yv = pytensor.gradient.Lop(y, mx, v)
@@ -87,9 +78,9 @@ def test_rop_lop():
8778
sy = pytensor.gradient.grad((v * y).sum(), mx)
8879
scan_f = function([mx, v], sy)
8980

90-
v1 = lop_f(vx, vv)
91-
v2 = scan_f(vx, vv)
92-
assert _allclose(v1, v2), f"LOP mismatch: {v1} {v2}"
81+
v_ref = scan_f(vx, vv)
82+
v = lop_f(vx, vv)
83+
np.testing.assert_allclose(v, v_ref, rtol=rtol)
9384

9485

9586
def test_transinv_to_invtrans():

tests/tensor/test_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def test_validation(self):
603603

604604
class TestRopLop(RopLopChecker):
605605
def test_shape(self):
606-
self.check_nondiff_rop(self.x.shape[0])
606+
self.check_nondiff_rop(self.x.shape[0], self.x, self.v)
607607

608608
def test_specifyshape(self):
609609
self.check_rop_lop(specify_shape(self.x, self.in_shape), self.in_shape)

tests/test_rop.py

Lines changed: 85 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@
1616

1717
import pytensor
1818
import pytensor.tensor as pt
19-
from pytensor import function
20-
from pytensor.gradient import Lop, Rop, grad, grad_undefined
19+
from pytensor import config, function
20+
from pytensor.gradient import (
21+
Lop,
22+
NullTypeGradError,
23+
Rop,
24+
grad,
25+
grad_undefined,
26+
)
2127
from pytensor.graph.basic import Apply
2228
from pytensor.graph.op import Op
2329
from pytensor.tensor.math import argmax, dot
@@ -61,6 +67,10 @@ class RopLopChecker:
6167
Rop to class that inherit from it.
6268
"""
6369

70+
@staticmethod
71+
def rtol():
72+
return 1e-7 if config.floatX == "float64" else 1e-5
73+
6474
def setup_method(self):
6575
# Using vectors make things a lot simpler for generating the same
6676
# computations using scan
@@ -72,13 +82,13 @@ def setup_method(self):
7282
self.mv = matrix("mv")
7383
self.mat_in_shape = (5 + self.rng.integers(3), 5 + self.rng.integers(3))
7484

75-
def check_nondiff_rop(self, y):
85+
def check_nondiff_rop(self, y, x, v):
7686
"""
7787
If your op is not differentiable(so you can't define Rop)
7888
test that an error is raised.
7989
"""
8090
with pytest.raises(ValueError):
81-
Rop(y, self.x, self.v)
91+
Rop(y, x, v)
8292

8393
def check_mat_rop_lop(self, y, out_shape):
8494
"""
@@ -115,13 +125,13 @@ def check_mat_rop_lop(self, y, out_shape):
115125
)
116126
scan_f = function([self.mx, self.mv], sy, on_unused_input="ignore")
117127

118-
v1 = rop_f(vx, vv)
119-
v2 = scan_f(vx, vv)
120-
121-
assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}"
128+
v_ref = scan_f(vx, vv)
129+
np.testing.assert_allclose(rop_f(vx, vv), v_ref)
122130

123131
self.check_nondiff_rop(
124-
pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)})
132+
pytensor.clone_replace(y, replace={self.mx: break_op(self.mx)}),
133+
self.mx,
134+
self.mv,
125135
)
126136

127137
vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX)
@@ -131,15 +141,17 @@ def check_mat_rop_lop(self, y, out_shape):
131141
sy = grad((self.v * y).sum(), self.mx)
132142
scan_f = function([self.mx, self.v], sy)
133143

134-
v1 = lop_f(vx, vv)
135-
v2 = scan_f(vx, vv)
136-
assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}"
144+
v = lop_f(vx, vv)
145+
v_ref = scan_f(vx, vv)
146+
np.testing.assert_allclose(v, v_ref)
137147

138-
def check_rop_lop(self, y, out_shape):
148+
def check_rop_lop(self, y, out_shape, check_nondiff_rop: bool = True):
139149
"""
140150
As check_mat_rop_lop, except the input is self.x which is a
141151
vector. The output is still a vector.
142152
"""
153+
rtol = self.rtol()
154+
143155
# TEST ROP
144156
vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
145157
vv = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
@@ -152,24 +164,17 @@ def check_rop_lop(self, y, out_shape):
152164
non_sequences=[y, self.x],
153165
)
154166
sy = dot(J, self.v)
155-
156167
scan_f = function([self.x, self.v], sy, on_unused_input="ignore")
157168

158-
v1 = rop_f(vx, vv)
159-
v2 = scan_f(vx, vv)
160-
assert np.allclose(v1, v2), f"ROP mismatch: {v1} {v2}"
169+
v_ref = scan_f(vx, vv)
170+
np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol)
161171

162-
try:
163-
Rop(
172+
if check_nondiff_rop:
173+
self.check_nondiff_rop(
164174
pytensor.clone_replace(y, replace={self.x: break_op(self.x)}),
165175
self.x,
166176
self.v,
167177
)
168-
except ValueError:
169-
pytest.skip(
170-
"Rop does not handle non-differentiable inputs "
171-
"correctly. Bug exposed by fixing Add.grad method."
172-
)
173178

174179
vx = np.asarray(self.rng.uniform(size=self.in_shape), pytensor.config.floatX)
175180
vv = np.asarray(self.rng.uniform(size=out_shape), pytensor.config.floatX)
@@ -182,22 +187,20 @@ def check_rop_lop(self, y, out_shape):
182187
non_sequences=[y, self.x],
183188
)
184189
sy = dot(self.v, J)
185-
186190
scan_f = function([self.x, self.v], sy)
187191

188-
v1 = lop_f(vx, vv)
189-
v2 = scan_f(vx, vv)
190-
assert np.allclose(v1, v2), f"LOP mismatch: {v1} {v2}"
192+
v = lop_f(vx, vv)
193+
v_ref = scan_f(vx, vv)
194+
np.testing.assert_allclose(v, v_ref, rtol=rtol)
191195

192196

193197
class TestRopLop(RopLopChecker):
194198
def test_max(self):
195-
# self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ())
196199
self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],))
197200
self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],))
198201

199202
def test_argmax(self):
200-
self.check_nondiff_rop(argmax(self.mx, axis=1))
203+
self.check_nondiff_rop(argmax(self.mx, axis=1), self.mx, self.mv)
201204

202205
def test_subtensor(self):
203206
self.check_rop_lop(self.x[:4], (4,))
@@ -252,10 +255,14 @@ def test_dot(self):
252255
insh = self.in_shape[0]
253256
vW = np.asarray(self.rng.uniform(size=(insh, insh)), pytensor.config.floatX)
254257
W = pytensor.shared(vW)
255-
self.check_rop_lop(dot(self.x, W), self.in_shape)
258+
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
259+
# See: test_Rop_partially_differentiable_paths
260+
self.check_rop_lop(dot(self.x, W), self.in_shape, check_nondiff_rop=False)
256261

257262
def test_elemwise0(self):
258-
self.check_rop_lop((self.x + 1) ** 2, self.in_shape)
263+
# check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
264+
# See: test_Rop_partially_differentiable_paths
265+
self.check_rop_lop((self.x + 1) ** 2, self.in_shape, check_nondiff_rop=False)
259266

260267
def test_elemwise1(self):
261268
self.check_rop_lop(self.x + pt.cast(self.x, "int32"), self.in_shape)
@@ -288,15 +295,8 @@ def test_alloc(self):
288295
)
289296

290297
def test_invalid_input(self):
291-
success = False
292-
293-
try:
298+
with pytest.raises(ValueError):
294299
Rop(0.0, [matrix()], [vector()])
295-
success = True
296-
except ValueError:
297-
pass
298-
299-
assert not success
300300

301301
def test_multiple_outputs(self):
302302
m = matrix("m")
@@ -322,12 +322,54 @@ def test_multiple_outputs(self):
322322
f = pytensor.function([m, v, m_, v_], all_outs)
323323
f(mval, vval, m_val, v_val)
324324

325-
def test_Rop_dot_bug_18Oct2013_Jeremiah(self):
325+
@pytest.mark.xfail()
326+
def test_Rop_partially_differentiable_paths(self):
326327
# This test refers to a bug reported by Jeremiah Lowin on 18th Oct
327328
# 2013. The bug consists when through a dot operation there is only
328329
# one differentiable path (i.e. there is no gradient wrt to one of
329330
# the inputs).
330331
x = pt.arange(20.0).reshape([1, 20])
331-
v = pytensor.shared(np.ones([20]))
332+
v = pytensor.shared(np.ones([20]), name="v")
332333
d = dot(x, v).sum()
333-
Rop(grad(d, v), v, v)
334+
335+
Rop(
336+
grad(d, v),
337+
v,
338+
v,
339+
disconnected_outputs="raise",
340+
)
341+
342+
# 2025: Here is an unambiguous test for the original commented issue:
343+
x = pt.matrix("x")
344+
y = pt.matrix("y")
345+
out = dot(x, break_op(y)).sum()
346+
# Should not raise an error
347+
Rop(
348+
out,
349+
[x],
350+
[x.type()],
351+
disconnected_outputs="raise",
352+
)
353+
354+
# More extensive testing shows that the legacy Rop implementation FAILS to raise when
355+
# the cost is linked through strictly non-differentiable paths.
356+
# This is not Dot specific, we would observe the same with any operation where the gradient
357+
# with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...)
358+
out = dot(break_op(x), y).sum()
359+
with pytest.raises((ValueError, NullTypeGradError)):
360+
Rop(
361+
out,
362+
[x],
363+
[x.type()],
364+
disconnected_outputs="raise",
365+
)
366+
367+
# Only when both paths are non-differentiable is an error correctly raised again.
368+
out = dot(break_op(x), break_op(y)).sum()
369+
with pytest.raises((ValueError, NullTypeGradError)):
370+
Rop(
371+
out,
372+
[x],
373+
[x.type()],
374+
disconnected_outputs="raise",
375+
)

0 commit comments

Comments
 (0)