16
16
17
17
import pytensor
18
18
import pytensor .tensor as pt
19
- from pytensor import function
20
- from pytensor .gradient import Lop , Rop , grad , grad_undefined
19
+ from pytensor import config , function
20
+ from pytensor .gradient import (
21
+ Lop ,
22
+ NullTypeGradError ,
23
+ Rop ,
24
+ grad ,
25
+ grad_undefined ,
26
+ )
21
27
from pytensor .graph .basic import Apply
22
28
from pytensor .graph .op import Op
23
29
from pytensor .tensor .math import argmax , dot
@@ -61,6 +67,10 @@ class RopLopChecker:
61
67
Rop to class that inherit from it.
62
68
"""
63
69
70
+ @staticmethod
71
+ def rtol ():
72
+ return 1e-7 if config .floatX == "float64" else 1e-5
73
+
64
74
def setup_method (self ):
65
75
# Using vectors make things a lot simpler for generating the same
66
76
# computations using scan
@@ -72,13 +82,13 @@ def setup_method(self):
72
82
self .mv = matrix ("mv" )
73
83
self .mat_in_shape = (5 + self .rng .integers (3 ), 5 + self .rng .integers (3 ))
74
84
75
- def check_nondiff_rop (self , y ):
85
+ def check_nondiff_rop (self , y , x , v ):
76
86
"""
77
87
If your op is not differentiable(so you can't define Rop)
78
88
test that an error is raised.
79
89
"""
80
90
with pytest .raises (ValueError ):
81
- Rop (y , self . x , self . v )
91
+ Rop (y , x , v )
82
92
83
93
def check_mat_rop_lop (self , y , out_shape ):
84
94
"""
@@ -115,13 +125,13 @@ def check_mat_rop_lop(self, y, out_shape):
115
125
)
116
126
scan_f = function ([self .mx , self .mv ], sy , on_unused_input = "ignore" )
117
127
118
- v1 = rop_f (vx , vv )
119
- v2 = scan_f (vx , vv )
120
-
121
- assert np .allclose (v1 , v2 ), f"ROP mismatch: { v1 } { v2 } "
128
+ v_ref = scan_f (vx , vv )
129
+ np .testing .assert_allclose (rop_f (vx , vv ), v_ref )
122
130
123
131
self .check_nondiff_rop (
124
- pytensor .clone_replace (y , replace = {self .mx : break_op (self .mx )})
132
+ pytensor .clone_replace (y , replace = {self .mx : break_op (self .mx )}),
133
+ self .mx ,
134
+ self .mv ,
125
135
)
126
136
127
137
vv = np .asarray (self .rng .uniform (size = out_shape ), pytensor .config .floatX )
@@ -131,15 +141,17 @@ def check_mat_rop_lop(self, y, out_shape):
131
141
sy = grad ((self .v * y ).sum (), self .mx )
132
142
scan_f = function ([self .mx , self .v ], sy )
133
143
134
- v1 = lop_f (vx , vv )
135
- v2 = scan_f (vx , vv )
136
- assert np .allclose ( v1 , v2 ), f"LOP mismatch: { v1 } { v2 } "
144
+ v = lop_f (vx , vv )
145
+ v_ref = scan_f (vx , vv )
146
+ np .testing . assert_allclose ( v , v_ref )
137
147
138
- def check_rop_lop (self , y , out_shape ):
148
+ def check_rop_lop (self , y , out_shape , check_nondiff_rop : bool = True ):
139
149
"""
140
150
As check_mat_rop_lop, except the input is self.x which is a
141
151
vector. The output is still a vector.
142
152
"""
153
+ rtol = self .rtol ()
154
+
143
155
# TEST ROP
144
156
vx = np .asarray (self .rng .uniform (size = self .in_shape ), pytensor .config .floatX )
145
157
vv = np .asarray (self .rng .uniform (size = self .in_shape ), pytensor .config .floatX )
@@ -152,24 +164,17 @@ def check_rop_lop(self, y, out_shape):
152
164
non_sequences = [y , self .x ],
153
165
)
154
166
sy = dot (J , self .v )
155
-
156
167
scan_f = function ([self .x , self .v ], sy , on_unused_input = "ignore" )
157
168
158
- v1 = rop_f (vx , vv )
159
- v2 = scan_f (vx , vv )
160
- assert np .allclose (v1 , v2 ), f"ROP mismatch: { v1 } { v2 } "
169
+ v_ref = scan_f (vx , vv )
170
+ np .testing .assert_allclose (rop_f (vx , vv ), v_ref , rtol = rtol )
161
171
162
- try :
163
- Rop (
172
+ if check_nondiff_rop :
173
+ self . check_nondiff_rop (
164
174
pytensor .clone_replace (y , replace = {self .x : break_op (self .x )}),
165
175
self .x ,
166
176
self .v ,
167
177
)
168
- except ValueError :
169
- pytest .skip (
170
- "Rop does not handle non-differentiable inputs "
171
- "correctly. Bug exposed by fixing Add.grad method."
172
- )
173
178
174
179
vx = np .asarray (self .rng .uniform (size = self .in_shape ), pytensor .config .floatX )
175
180
vv = np .asarray (self .rng .uniform (size = out_shape ), pytensor .config .floatX )
@@ -182,22 +187,20 @@ def check_rop_lop(self, y, out_shape):
182
187
non_sequences = [y , self .x ],
183
188
)
184
189
sy = dot (self .v , J )
185
-
186
190
scan_f = function ([self .x , self .v ], sy )
187
191
188
- v1 = lop_f (vx , vv )
189
- v2 = scan_f (vx , vv )
190
- assert np .allclose ( v1 , v2 ), f"LOP mismatch: { v1 } { v2 } "
192
+ v = lop_f (vx , vv )
193
+ v_ref = scan_f (vx , vv )
194
+ np .testing . assert_allclose ( v , v_ref , rtol = rtol )
191
195
192
196
193
197
class TestRopLop (RopLopChecker ):
194
198
def test_max (self ):
195
- # self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ())
196
199
self .check_mat_rop_lop (pt_max (self .mx , axis = 0 ), (self .mat_in_shape [1 ],))
197
200
self .check_mat_rop_lop (pt_max (self .mx , axis = 1 ), (self .mat_in_shape [0 ],))
198
201
199
202
def test_argmax (self ):
200
- self .check_nondiff_rop (argmax (self .mx , axis = 1 ))
203
+ self .check_nondiff_rop (argmax (self .mx , axis = 1 ), self . mx , self . mv )
201
204
202
205
def test_subtensor (self ):
203
206
self .check_rop_lop (self .x [:4 ], (4 ,))
@@ -252,10 +255,14 @@ def test_dot(self):
252
255
insh = self .in_shape [0 ]
253
256
vW = np .asarray (self .rng .uniform (size = (insh , insh )), pytensor .config .floatX )
254
257
W = pytensor .shared (vW )
255
- self .check_rop_lop (dot (self .x , W ), self .in_shape )
258
+ # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
259
+ # See: test_Rop_partially_differentiable_paths
260
+ self .check_rop_lop (dot (self .x , W ), self .in_shape , check_nondiff_rop = False )
256
261
257
262
def test_elemwise0 (self ):
258
- self .check_rop_lop ((self .x + 1 ) ** 2 , self .in_shape )
263
+ # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
264
+ # See: test_Rop_partially_differentiable_paths
265
+ self .check_rop_lop ((self .x + 1 ) ** 2 , self .in_shape , check_nondiff_rop = False )
259
266
260
267
def test_elemwise1 (self ):
261
268
self .check_rop_lop (self .x + pt .cast (self .x , "int32" ), self .in_shape )
@@ -288,15 +295,8 @@ def test_alloc(self):
288
295
)
289
296
290
297
def test_invalid_input (self ):
291
- success = False
292
-
293
- try :
298
+ with pytest .raises (ValueError ):
294
299
Rop (0.0 , [matrix ()], [vector ()])
295
- success = True
296
- except ValueError :
297
- pass
298
-
299
- assert not success
300
300
301
301
def test_multiple_outputs (self ):
302
302
m = matrix ("m" )
@@ -322,12 +322,54 @@ def test_multiple_outputs(self):
322
322
f = pytensor .function ([m , v , m_ , v_ ], all_outs )
323
323
f (mval , vval , m_val , v_val )
324
324
325
- def test_Rop_dot_bug_18Oct2013_Jeremiah (self ):
325
+ @pytest .mark .xfail ()
326
+ def test_Rop_partially_differentiable_paths (self ):
326
327
# This test refers to a bug reported by Jeremiah Lowin on 18th Oct
327
328
# 2013. The bug consists when through a dot operation there is only
328
329
# one differentiable path (i.e. there is no gradient wrt to one of
329
330
# the inputs).
330
331
x = pt .arange (20.0 ).reshape ([1 , 20 ])
331
- v = pytensor .shared (np .ones ([20 ]))
332
+ v = pytensor .shared (np .ones ([20 ]), name = "v" )
332
333
d = dot (x , v ).sum ()
333
- Rop (grad (d , v ), v , v )
334
+
335
+ Rop (
336
+ grad (d , v ),
337
+ v ,
338
+ v ,
339
+ disconnected_outputs = "raise" ,
340
+ )
341
+
342
+ # 2025: Here is an unambiguous test for the original commented issue:
343
+ x = pt .matrix ("x" )
344
+ y = pt .matrix ("y" )
345
+ out = dot (x , break_op (y )).sum ()
346
+ # Should not raise an error
347
+ Rop (
348
+ out ,
349
+ [x ],
350
+ [x .type ()],
351
+ disconnected_outputs = "raise" ,
352
+ )
353
+
354
+ # More extensive testing shows that the legacy Rop implementation FAILS to raise when
355
+ # the cost is linked through strictly non-differentiable paths.
356
+ # This is not Dot specific, we would observe the same with any operation where the gradient
357
+ # with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...)
358
+ out = dot (break_op (x ), y ).sum ()
359
+ with pytest .raises ((ValueError , NullTypeGradError )):
360
+ Rop (
361
+ out ,
362
+ [x ],
363
+ [x .type ()],
364
+ disconnected_outputs = "raise" ,
365
+ )
366
+
367
+ # Only when both paths are non-differentiable is an error correctly raised again.
368
+ out = dot (break_op (x ), break_op (y )).sum ()
369
+ with pytest .raises ((ValueError , NullTypeGradError )):
370
+ Rop (
371
+ out ,
372
+ [x ],
373
+ [x .type ()],
374
+ disconnected_outputs = "raise" ,
375
+ )
0 commit comments