From 090d0c839a4339ded3982d69891c4356faec35f6 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Sun, 7 Jul 2024 21:34:12 +0200 Subject: [PATCH 1/3] Implemented Sort/Argsort Ops in PyTorch --- pytensor/link/pytorch/dispatch/__init__.py | 1 + pytensor/link/pytorch/dispatch/sort.py | 25 ++++++++++++++++++++++ tests/link/pytorch/test_sort.py | 23 ++++++++++++++++++++ 3 files changed, 49 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/sort.py create mode 100644 tests/link/pytorch/test_sort.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 7e476aba04..143d6b1bcb 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -5,4 +5,5 @@ import pytensor.link.pytorch.dispatch.scalar import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.extra_ops +import pytensor.link.pytorch.dispatch.sort # isort: on diff --git a/pytensor/link/pytorch/dispatch/sort.py b/pytensor/link/pytorch/dispatch/sort.py new file mode 100644 index 0000000000..95e24c4fe3 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/sort.py @@ -0,0 +1,25 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.sort import ArgSortOp, SortOp + + +@pytorch_funcify.register(SortOp) +def pytorch_funcify_Sort(op, **kwargs): + stable = op.kind == "stable" + + def sort(arr, axis): + sorted, _ = torch.sort(arr, dim=axis, stable=stable) + return sorted + + return sort + + +@pytorch_funcify.register(ArgSortOp) +def pytorch_funcify_ArgSort(op, **kwargs): + stable = op.kind == "stable" + + def argsort(arr, axis): + return torch.argsort(arr, dim=axis, stable=stable) + + return argsort diff --git a/tests/link/pytorch/test_sort.py b/tests/link/pytorch/test_sort.py new file mode 100644 index 0000000000..5d266b603b --- /dev/null +++ b/tests/link/pytorch/test_sort.py @@ -0,0 +1,23 @@ +import numpy as np +import pytest + +from pytensor.graph import FunctionGraph +from pytensor.tensor import matrix +from pytensor.tensor.sort import argsort, sort +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.mark.parametrize("axis", [0, 1, None]) +@pytest.mark.parametrize("func", (sort, argsort)) +def test_sort(func, axis): + x = matrix("x", shape=(2, 2), dtype="float64") + out = func(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + arr = np.array([[1.0, 4.0], [5.0, 2.0]]) + + # TODO: remove condition once Reshape is implemented + if axis is None: + with pytest.raises(NotImplementedError): + compare_pytorch_and_py(fgraph, [arr]) + else: + compare_pytorch_and_py(fgraph, [arr]) From 4a33301612fee74777d9565292a6d3bf68951bd2 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Mon, 8 Jul 2024 00:08:55 +0200 Subject: [PATCH 2/3] Marked test for [Arg]Sort Op in PyTorch as xfail --- tests/link/pytorch/test_sort.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/link/pytorch/test_sort.py b/tests/link/pytorch/test_sort.py index 5d266b603b..8595c43303 100644 --- a/tests/link/pytorch/test_sort.py +++ b/tests/link/pytorch/test_sort.py @@ -7,6 +7,7 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py +@pytest.mark.xfail(reason="Reshape not implemented") @pytest.mark.parametrize("axis", [0, 1, None]) @pytest.mark.parametrize("func", (sort, argsort)) def test_sort(func, axis): @@ -14,10 +15,4 @@ def test_sort(func, axis): out = func(x, axis=axis) fgraph = FunctionGraph([x], [out]) arr = np.array([[1.0, 4.0], [5.0, 2.0]]) - - # TODO: remove condition once Reshape is implemented - if axis is None: - with pytest.raises(NotImplementedError): - compare_pytorch_and_py(fgraph, [arr]) - else: - compare_pytorch_and_py(fgraph, [arr]) + compare_pytorch_and_py(fgraph, [arr]) From cff4cf4fd4b81b2fa5fa0a5b8f2cf9007203e786 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Mon, 8 Jul 2024 18:24:43 +0200 Subject: [PATCH 3/3] Marked test for [Arg]Sort Op in PyTorch as xfail when axis is None --- tests/link/pytorch/test_sort.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/link/pytorch/test_sort.py b/tests/link/pytorch/test_sort.py index 8595c43303..386a974cf4 100644 --- a/tests/link/pytorch/test_sort.py +++ b/tests/link/pytorch/test_sort.py @@ -7,9 +7,17 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py -@pytest.mark.xfail(reason="Reshape not implemented") -@pytest.mark.parametrize("axis", [0, 1, None]) @pytest.mark.parametrize("func", (sort, argsort)) +@pytest.mark.parametrize( + "axis", + [ + pytest.param(0), + pytest.param(1), + pytest.param( + None, marks=pytest.mark.xfail(reason="Reshape Op not implemented") + ), + ], +) def test_sort(func, axis): x = matrix("x", shape=(2, 2), dtype="float64") out = func(x, axis=axis)