@@ -310,6 +310,41 @@ def test_rop(self, cls_ofg):
310
310
dvval2 = fn (xval , Wval , duval )
311
311
np .testing .assert_array_almost_equal (dvval2 , dvval , 4 )
312
312
313
+ def test_rop_multiple_outputs (self ):
314
+ a = vector ()
315
+ M = matrix ()
316
+ b = dot (a , M )
317
+ op_matmul = OpFromGraph ([a , M ], [b , - b ])
318
+
319
+ x = vector ()
320
+ W = matrix ()
321
+ du = vector ()
322
+
323
+ xval = np .random .random ((16 ,)).astype (config .floatX )
324
+ Wval = np .random .random ((16 , 16 )).astype (config .floatX )
325
+ duval = np .random .random ((16 ,)).astype (config .floatX )
326
+
327
+ y = op_matmul (x , W )[0 ]
328
+ dv = Rop (y , x , du )
329
+ fn = function ([x , W , du ], dv )
330
+ result_dvval = fn (xval , Wval , duval )
331
+ expected_dvval = np .dot (duval , Wval )
332
+ np .testing .assert_array_almost_equal (result_dvval , expected_dvval , 4 )
333
+
334
+ y = op_matmul (x , W )[1 ]
335
+ dv = Rop (y , x , du )
336
+ fn = function ([x , W , du ], dv )
337
+ result_dvval = fn (xval , Wval , duval )
338
+ expected_dvval = - np .dot (duval , Wval )
339
+ np .testing .assert_array_almost_equal (result_dvval , expected_dvval , 4 )
340
+
341
+ y = pt .add (* op_matmul (x , W ))
342
+ dv = Rop (y , x , du )
343
+ fn = function ([x , W , du ], dv )
344
+ result_dvval = fn (xval , Wval , duval )
345
+ expected_dvval = np .zeros_like (np .dot (duval , Wval ))
346
+ np .testing .assert_array_almost_equal (result_dvval , expected_dvval , 4 )
347
+
313
348
@pytest .mark .parametrize (
314
349
"cls_ofg" , [OpFromGraph , partial (OpFromGraph , inline = True )]
315
350
)
0 commit comments