File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change 4
4
import numpy as np
5
5
import pytest
6
6
7
+ import pytensor .tensor as pt
7
8
import pytensor .tensor .basic as ptb
8
9
from pytensor .compile .builders import OpFromGraph
9
10
from pytensor .compile .function import function
@@ -343,3 +344,17 @@ def test_pytorch_OpFromGraph():
343
344
344
345
f = FunctionGraph ([x , y , z ], [out ])
345
346
compare_pytorch_and_py (f , [xv , yv , zv ])
347
+
348
+
349
+ def test_pytorch_scipy ():
350
+ x = vector ("a" , shape = (3 ,))
351
+ out = pt .expit (x )
352
+ f = FunctionGraph ([x ], [out ])
353
+ compare_pytorch_and_py (f , [np .random .rand (3 )])
354
+
355
+
356
+ def test_pytorch_softplus ():
357
+ x = vector ("a" , shape = (3 ,))
358
+ out = pt .softplus (x )
359
+ f = FunctionGraph ([x ], [out ])
360
+ compare_pytorch_and_py (f , [np .random .rand (3 )])
You can’t perform that action at this time.
0 commit comments