Skip to content

Commit d98e68e

Browse files
committed
Update tests
1 parent fdd5d5c commit d98e68e

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,20 +307,14 @@ def test_pytorch_MakeVector():
307307
def test_pytorch_OpFromGraph():
308308
x, y, z = matrices("xyz")
309309
ofg_1 = OpFromGraph([x, y], [x + y])
310-
OpFromGraph([x, y], [x * y, x - y])
310+
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
311311

312-
# o1, o2 = ofg_2(y, z)
313-
# out = ofg_1(x, o1) + o2
314-
315-
out = ofg_1(y, z)
312+
o1, o2 = ofg_2(y, z)
313+
out = ofg_1(x, o1) + o2
316314

317315
xv = np.ones((2, 2), dtype=config.floatX)
318-
np.ones((2, 2), dtype=config.floatX) * 3
316+
yv = np.ones((2, 2), dtype=config.floatX) * 3
319317
zv = np.ones((2, 2), dtype=config.floatX) * 5
320318

321-
f = FunctionGraph([y, z], [out])
322-
import pytensor.printing
323-
324-
pytensor.printing.debugprint(f)
325-
326-
compare_pytorch_and_py(f, [xv, zv])
319+
f = FunctionGraph([x, y, z], [out])
320+
compare_pytorch_and_py(f, [xv, yv, zv])

0 commit comments

Comments
 (0)