Skip to content

Commit be799d8

Browse files
committed
Add test for R_Op of OpFromGrah with multiple outputs
1 parent 53ec8e3 commit be799d8

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/compile/test_builders.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,41 @@ def test_rop(self, cls_ofg):
310310
dvval2 = fn(xval, Wval, duval)
311311
np.testing.assert_array_almost_equal(dvval2, dvval, 4)
312312

313+
def test_rop_multiple_outputs(self):
314+
a = vector()
315+
M = matrix()
316+
b = dot(a, M)
317+
op_matmul = OpFromGraph([a, M], [b, -b])
318+
319+
x = vector()
320+
W = matrix()
321+
du = vector()
322+
323+
xval = np.random.random((16,)).astype(config.floatX)
324+
Wval = np.random.random((16, 16)).astype(config.floatX)
325+
duval = np.random.random((16,)).astype(config.floatX)
326+
327+
y = op_matmul(x, W)[0]
328+
dv = Rop(y, x, du)
329+
fn = function([x, W, du], dv)
330+
result_dvval = fn(xval, Wval, duval)
331+
expected_dvval = np.dot(duval, Wval)
332+
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
333+
334+
y = op_matmul(x, W)[1]
335+
dv = Rop(y, x, du)
336+
fn = function([x, W, du], dv)
337+
result_dvval = fn(xval, Wval, duval)
338+
expected_dvval = -np.dot(duval, Wval)
339+
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
340+
341+
y = pt.add(*op_matmul(x, W))
342+
dv = Rop(y, x, du)
343+
fn = function([x, W, du], dv)
344+
result_dvval = fn(xval, Wval, duval)
345+
expected_dvval = np.zeros_like(np.dot(duval, Wval))
346+
np.testing.assert_array_almost_equal(result_dvval, expected_dvval, 4)
347+
313348
@pytest.mark.parametrize(
314349
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
315350
)

0 commit comments

Comments
 (0)