Skip to content

Commit f852075

Browse files
committed
Add tests
1 parent f1d7852 commit f852075

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66

7+
import pytensor.tensor as pt
78
import pytensor.tensor.basic as ptb
89
from pytensor.compile.builders import OpFromGraph
910
from pytensor.compile.function import function
@@ -343,3 +344,17 @@ def test_pytorch_OpFromGraph():
343344

344345
f = FunctionGraph([x, y, z], [out])
345346
compare_pytorch_and_py(f, [xv, yv, zv])
347+
348+
349+
def test_pytorch_scipy():
350+
x = vector("a", shape=(3,))
351+
out = pt.expit(x)
352+
f = FunctionGraph([x], [out])
353+
compare_pytorch_and_py(f, [np.random.rand(3)])
354+
355+
356+
def test_pytorch_softplus():
357+
x = vector("a", shape=(3,))
358+
out = pt.softplus(x)
359+
f = FunctionGraph([x], [out])
360+
compare_pytorch_and_py(f, [np.random.rand(3)])

0 commit comments

Comments
 (0)