From 4d4abd7350d8219890f905ca7d1932a2ee0e5b16 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Fri, 12 Jul 2024 23:25:35 +0200 Subject: [PATCH 1/3] Implements shape and MakeVector Ops in PyTorch - Shape - Shape_i - Reshape - SpecifyShape - Unbroadcast - MakeVector --- pytensor/link/pytorch/dispatch/__init__.py | 1 + pytensor/link/pytorch/dispatch/basic.py | 16 ++++- pytensor/link/pytorch/dispatch/shape.py | 54 ++++++++++++++++ tests/link/pytorch/test_basic.py | 7 +++ tests/link/pytorch/test_shape.py | 72 ++++++++++++++++++++++ tests/link/pytorch/test_sort.py | 11 +--- 6 files changed, 149 insertions(+), 12 deletions(-) create mode 100644 pytensor/link/pytorch/dispatch/shape.py create mode 100644 tests/link/pytorch/test_shape.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 143d6b1bcb..017e57df64 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -6,4 +6,5 @@ import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.sort +import pytensor.link.pytorch.dispatch.shape # isort: on diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 37622a8294..8e5b3ee4bf 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -6,13 +6,15 @@ from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise -from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join +from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector @singledispatch def pytorch_typify(data, dtype=None, **kwargs): r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" - return torch.as_tensor(data, dtype=dtype) + if data is not None: + return torch.as_tensor(data, dtype=dtype) + return None @singledispatch @@ -116,3 +118,13 @@ def eye(N, M, k): return zeros return eye + + +@pytorch_funcify.register(MakeVector) +def pytorch_funcify_MakeVector(op, **kwargs): + torch_dtype = getattr(torch, op.dtype) + + def makevector(*x): + return torch.tensor(x, dtype=torch_dtype) + + return makevector diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py new file mode 100644 index 0000000000..c6395b6713 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -0,0 +1,54 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast + + +@pytorch_funcify.register(Reshape) +def pytorch_funcify_Reshape(op, node, **kwargs): + shape = node.inputs[1] + + def reshape(x, shape=shape): + return torch.reshape(x, tuple(shape)) + + return reshape + + +@pytorch_funcify.register(Shape) +def pytorch_funcify_Shape(op, **kwargs): + def shape(x): + return x.shape + + return shape + + +@pytorch_funcify.register(Shape_i) +def pytorch_funcify_Shape_i(op, **kwargs): + i = op.i + + def shape_i(x): + return x.shape[i] + + return shape_i + + +@pytorch_funcify.register(SpecifyShape) +def pytorch_funcify_SpecifyShape(op, node, **kwargs): + def specifyshape(x, *shape): + assert x.ndim == len(shape) + for actual, expected in zip(x.shape, shape): + if expected is None: + continue + if actual != expected: + raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") + return x + + return specifyshape + + +@pytorch_funcify.register(Unbroadcast) +def pytorch_funcify_Unbroadcast(op, **kwargs): + def unbroadcast(x): + return x + + return unbroadcast diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 0ccb1c454f..27c1b1bd6a 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -294,3 +294,10 @@ def test_eye(dtype): for _M in range(1, 6): for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k)) + + +def test_pytorch_MakeVector(): + x = ptb.make_vector(1, 2, 3) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) diff --git a/tests/link/pytorch/test_shape.py b/tests/link/pytorch/test_shape.py new file mode 100644 index 0000000000..4f72763cb5 --- /dev/null +++ b/tests/link/pytorch/test_shape.py @@ -0,0 +1,72 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.compile.ops import DeepCopyOp, ViewOp +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape +from pytensor.tensor.type import iscalar, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_shape_ops(): + x_np = np.zeros((20, 3)) + x = Shape()(pt.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, [], must_be_device_array=False) + + x = Shape_i(1)(pt.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, [], must_be_device_array=False) + + +def test_pytorch_specify_shape(): + in_pt = pt.matrix("in") + x = pt.specify_shape(in_pt, (4, None)) + x_fg = FunctionGraph([in_pt], [x]) + compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)]) + + # When used to assert two arrays have similar shapes + in_pt = pt.matrix("in") + shape_pt = pt.matrix("shape") + x = pt.specify_shape(in_pt, shape_pt.shape) + x_fg = FunctionGraph([in_pt, shape_pt], [x]) + compare_pytorch_and_py( + x_fg, + [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], + ) + + +def test_pytorch_Reshape_constant(): + a = vector("a") + x = reshape(a, (2, 2)) + x_fg = FunctionGraph([a], [x]) + compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + +def test_pytorch_Reshape_shape_graph_input(): + a = vector("a") + shape_pt = iscalar("b") + x = reshape(a, (shape_pt, shape_pt)) + x_fg = FunctionGraph([a, shape_pt], [x]) + compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) + + +def test_pytorch_compile_ops(): + x = DeepCopyOp()(pt.as_tensor_variable(1.1)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) + + x_np = np.zeros((20, 1, 1)) + x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) + + x = ViewOp()(pt.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_pytorch_and_py(x_fg, []) diff --git a/tests/link/pytorch/test_sort.py b/tests/link/pytorch/test_sort.py index 386a974cf4..7912dd4a03 100644 --- a/tests/link/pytorch/test_sort.py +++ b/tests/link/pytorch/test_sort.py @@ -8,16 +8,7 @@ @pytest.mark.parametrize("func", (sort, argsort)) -@pytest.mark.parametrize( - "axis", - [ - pytest.param(0), - pytest.param(1), - pytest.param( - None, marks=pytest.mark.xfail(reason="Reshape Op not implemented") - ), - ], -) +@pytest.mark.parametrize("axis", [0, 1, None]) def test_sort(func, axis): x = matrix("x", shape=(2, 2), dtype="float64") out = func(x, axis=axis) From 0e455fd10258dc710723a511fefe79c709c70a30 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Tue, 16 Jul 2024 20:21:41 +0200 Subject: [PATCH 2/3] Reworked tests in implementation of Shape in PyTorch --- pytensor/link/pytorch/dispatch/shape.py | 4 +--- tests/link/pytorch/test_extra_ops.py | 3 ++- tests/link/pytorch/test_shape.py | 15 ++------------- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index c6395b6713..666379452e 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -6,9 +6,7 @@ @pytorch_funcify.register(Reshape) def pytorch_funcify_Reshape(op, node, **kwargs): - shape = node.inputs[1] - - def reshape(x, shape=shape): + def reshape(x, shape): return torch.reshape(x, tuple(shape)) return reshape diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index 221855864a..38dc5cd0e8 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -51,7 +51,8 @@ def test_pytorch_CumOp(axis, dtype): pytest.param( None, 3, - marks=pytest.mark.xfail(reason="Reshape not implemented"), + marks=pytest.mark.xfail(reason="Issue in Elemwise"), + # TODO: add reference to issue ), ], ) diff --git a/tests/link/pytorch/test_shape.py b/tests/link/pytorch/test_shape.py index 4f72763cb5..152aa8ddf3 100644 --- a/tests/link/pytorch/test_shape.py +++ b/tests/link/pytorch/test_shape.py @@ -1,7 +1,6 @@ import numpy as np import pytensor.tensor as pt -from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape @@ -46,7 +45,7 @@ def test_pytorch_Reshape_constant(): compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) -def test_pytorch_Reshape_shape_graph_input(): +def test_pytorch_Reshape_dynamic(): a = vector("a") shape_pt = iscalar("b") x = reshape(a, (shape_pt, shape_pt)) @@ -54,19 +53,9 @@ def test_pytorch_Reshape_shape_graph_input(): compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) -def test_pytorch_compile_ops(): - x = DeepCopyOp()(pt.as_tensor_variable(1.1)) - x_fg = FunctionGraph([], [x]) - - compare_pytorch_and_py(x_fg, []) - +def test_pytorch_unbroadcast(): x_np = np.zeros((20, 1, 1)) x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) x_fg = FunctionGraph([], [x]) compare_pytorch_and_py(x_fg, []) - - x = ViewOp()(pt.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - - compare_pytorch_and_py(x_fg, []) From 0786f2c9e7ba978d3cc5080169cd17abe8f5eeef Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Wed, 17 Jul 2024 19:35:43 +0200 Subject: [PATCH 3/3] Fixed implementation of Shape Op in PyTorch - Fixed Shape_i - Typified Python NoneType --- pytensor/link/pytorch/dispatch/basic.py | 8 ++++++-- pytensor/link/pytorch/dispatch/shape.py | 2 +- tests/link/pytorch/test_extra_ops.py | 14 +------------- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 8e5b3ee4bf..c71e1606bf 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -1,4 +1,5 @@ from functools import singledispatch +from types import NoneType import torch @@ -12,8 +13,11 @@ @singledispatch def pytorch_typify(data, dtype=None, **kwargs): r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" - if data is not None: - return torch.as_tensor(data, dtype=dtype) + return torch.as_tensor(data, dtype=dtype) + + +@pytorch_typify.register(NoneType) +def pytorch_typify_None(data, **kwargs): return None diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index 666379452e..7633e28e01 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -25,7 +25,7 @@ def pytorch_funcify_Shape_i(op, **kwargs): i = op.i def shape_i(x): - return x.shape[i] + return torch.tensor(x.shape[i]) return shape_i diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index 38dc5cd0e8..c615176a45 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -43,19 +43,7 @@ def test_pytorch_CumOp(axis, dtype): compare_pytorch_and_py(fgraph, [test_value]) -@pytest.mark.parametrize( - "axis, repeats", - [ - (0, (1, 2, 3)), - (1, (3, 3)), - pytest.param( - None, - 3, - marks=pytest.mark.xfail(reason="Issue in Elemwise"), - # TODO: add reference to issue - ), - ], -) +@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)]) def test_pytorch_Repeat(axis, repeats): a = pt.matrix("a", dtype="float64")