Skip to content

Commit 822d23d

Browse files
author
Ian Schweer
committed
Fix imports
1 parent f852075 commit 822d23d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import pytest
66

7-
import pytensor.tensor as pt
87
import pytensor.tensor.basic as ptb
98
from pytensor.compile.builders import OpFromGraph
109
from pytensor.compile.function import function
@@ -18,7 +17,7 @@
1817
from pytensor.ifelse import ifelse
1918
from pytensor.link.pytorch.linker import PytorchLinker
2019
from pytensor.raise_op import CheckAndRaise
21-
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
20+
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
2221
from pytensor.tensor.type import matrices, matrix, scalar, vector
2322

2423

@@ -348,13 +347,13 @@ def test_pytorch_OpFromGraph():
348347

349348
def test_pytorch_scipy():
350349
x = vector("a", shape=(3,))
351-
out = pt.expit(x)
350+
out = expit(x)
352351
f = FunctionGraph([x], [out])
353352
compare_pytorch_and_py(f, [np.random.rand(3)])
354353

355354

356355
def test_pytorch_softplus():
357356
x = vector("a", shape=(3,))
358-
out = pt.softplus(x)
357+
out = softplus(x)
359358
f = FunctionGraph([x], [out])
360359
compare_pytorch_and_py(f, [np.random.rand(3)])

0 commit comments

Comments
 (0)