We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 601c2ab commit a9e2343Copy full SHA for a9e2343
tests/link/pytorch/test_sort.py
@@ -7,17 +7,12 @@
7
from tests.link.pytorch.test_basic import compare_pytorch_and_py
8
9
10
+@pytest.mark.xfail(reason="Reshape not implemented")
11
@pytest.mark.parametrize("axis", [0, 1, None])
12
@pytest.mark.parametrize("func", (sort, argsort))
13
def test_sort(func, axis):
14
x = matrix("x", shape=(2, 2), dtype="float64")
15
out = func(x, axis=axis)
16
fgraph = FunctionGraph([x], [out])
17
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
-
18
- # TODO: remove condition once Reshape is implemented
19
- if axis is None:
20
- with pytest.raises(NotImplementedError):
21
- compare_pytorch_and_py(fgraph, [arr])
22
- else:
23
+ compare_pytorch_and_py(fgraph, [arr])
0 commit comments