diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 143d6b1bcb..017e57df64 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -6,4 +6,5 @@ import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.sort +import pytensor.link.pytorch.dispatch.shape # isort: on diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 37622a8294..c71e1606bf 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -1,4 +1,5 @@ from functools import singledispatch +from types import NoneType import torch @@ -6,7 +7,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise -from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join +from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector @singledispatch @@ -15,6 +16,11 @@ def pytorch_typify(data, dtype=None, **kwargs): return torch.as_tensor(data, dtype=dtype) +@pytorch_typify.register(NoneType) +def pytorch_typify_None(data, **kwargs): + return None + + @singledispatch def pytorch_funcify(op, node=None, storage_map=None, **kwargs): """Create a PyTorch compatible function from an PyTensor `Op`.""" @@ -116,3 +122,13 @@ def eye(N, M, k): return zeros return eye + + +@pytorch_funcify.register(MakeVector) +def pytorch_funcify_MakeVector(op, **kwargs): + torch_dtype = getattr(torch, op.dtype) + + def makevector(*x): + return torch.tensor(x, dtype=torch_dtype) + + return makevector diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py new file mode 100644 index 0000000000..7633e28e01 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -0,0 +1,52 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast + + +@pytorch_funcify.register(Reshape) +def pytorch_funcify_Reshape(op, node, **kwargs): + def reshape(x, shape): + return torch.reshape(x, tuple(shape)) + + return reshape + + +@pytorch_funcify.register(Shape) +def pytorch_funcify_Shape(op, **kwargs): + def shape(x): + return x.shape + + return shape + + +@pytorch_funcify.register(Shape_i) +def pytorch_funcify_Shape_i(op, **kwargs): + i = op.i + + def shape_i(x): + return torch.tensor(x.shape[i]) + + return shape_i + + +@pytorch_funcify.register(SpecifyShape) +def pytorch_funcify_SpecifyShape(op, node, **kwargs): + def specifyshape(x, *shape): + assert x.ndim == len(shape) + for actual, expected in zip(x.shape, shape): + if expected is None: + continue + if actual != expected: + raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") + return x + + return specifyshape + + +@pytorch_funcify.register(Unbroadcast) +def pytorch_funcify_Unbroadcast(op, **kwargs): + def unbroadcast(x): + return x + + return unbroadcast diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 0ccb1c454f..27c1b1bd6a 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -294,3 +294,10 @@ def test_eye(dtype): for _M in range(1, 6): for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k)) + + +def test_pytorch_MakeVector(): + x = ptb.make_vector(1, 2, 3) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index 221855864a..c615176a45 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -43,18 +43,7 @@ def test_pytorch_CumOp(axis, dtype): compare_pytorch_and_py(fgraph, [test_value]) -@pytest.mark.parametrize( - "axis, repeats", - [ - (0, (1, 2, 3)), - (1, (3, 3)), - pytest.param( - None, - 3, - marks=pytest.mark.xfail(reason="Reshape not implemented"), - ), - ], -) +@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)]) def test_pytorch_Repeat(axis, repeats): a = pt.matrix("a", dtype="float64") diff --git a/tests/link/pytorch/test_shape.py b/tests/link/pytorch/test_shape.py new file mode 100644 index 0000000000..152aa8ddf3 --- /dev/null +++ b/tests/link/pytorch/test_shape.py @@ -0,0 +1,61 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape +from pytensor.tensor.type import iscalar, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_shape_ops(): + x_np = np.zeros((20, 3)) + x = Shape()(pt.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, [], must_be_device_array=False) + + x = Shape_i(1)(pt.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, [], must_be_device_array=False) + + +def test_pytorch_specify_shape(): + in_pt = pt.matrix("in") + x = pt.specify_shape(in_pt, (4, None)) + x_fg = FunctionGraph([in_pt], [x]) + compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)]) + + # When used to assert two arrays have similar shapes + in_pt = pt.matrix("in") + shape_pt = pt.matrix("shape") + x = pt.specify_shape(in_pt, shape_pt.shape) + x_fg = FunctionGraph([in_pt, shape_pt], [x]) + compare_pytorch_and_py( + x_fg, + [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], + ) + + +def test_pytorch_Reshape_constant(): + a = vector("a") + x = reshape(a, (2, 2)) + x_fg = FunctionGraph([a], [x]) + compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +def test_pytorch_Reshape_dynamic(): + a = vector("a") + shape_pt = iscalar("b") + x = reshape(a, (shape_pt, shape_pt)) + x_fg = FunctionGraph([a, shape_pt], [x]) + compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) + + +def test_pytorch_unbroadcast(): + x_np = np.zeros((20, 1, 1)) + x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) diff --git a/tests/link/pytorch/test_sort.py b/tests/link/pytorch/test_sort.py index 386a974cf4..7912dd4a03 100644 --- a/tests/link/pytorch/test_sort.py +++ b/tests/link/pytorch/test_sort.py @@ -8,16 +8,7 @@ @pytest.mark.parametrize("func", (sort, argsort)) -@pytest.mark.parametrize( - "axis", - [ - pytest.param(0), - pytest.param(1), - pytest.param( - None, marks=pytest.mark.xfail(reason="Reshape Op not implemented") - ), - ], -) +@pytest.mark.parametrize("axis", [0, 1, None]) def test_sort(func, axis): x = matrix("x", shape=(2, 2), dtype="float64") out = func(x, axis=axis)