Skip to content

Commit 426931b

Browse files
authored
Implements shape Ops and MakeVector in PyTorch (#926)
* Implements shape and MakeVector Ops in PyTorch - Shape - Shape_i - Reshape - SpecifyShape - Unbroadcast - MakeVector
1 parent 6b8df2c commit 426931b

File tree

7 files changed

+140
-23
lines changed

7 files changed

+140
-23
lines changed

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
import pytensor.link.pytorch.dispatch.elemwise
77
import pytensor.link.pytorch.dispatch.extra_ops
88
import pytensor.link.pytorch.dispatch.sort
9+
import pytensor.link.pytorch.dispatch.shape
910
# isort: on

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from functools import singledispatch
2+
from types import NoneType
23

34
import torch
45

56
from pytensor.compile.ops import DeepCopyOp
67
from pytensor.graph.fg import FunctionGraph
78
from pytensor.link.utils import fgraph_to_python
89
from pytensor.raise_op import CheckAndRaise
9-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
10+
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
1011

1112

1213
@singledispatch
@@ -15,6 +16,11 @@ def pytorch_typify(data, dtype=None, **kwargs):
1516
return torch.as_tensor(data, dtype=dtype)
1617

1718

19+
@pytorch_typify.register(NoneType)
20+
def pytorch_typify_None(data, **kwargs):
21+
return None
22+
23+
1824
@singledispatch
1925
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
2026
"""Create a PyTorch compatible function from an PyTensor `Op`."""
@@ -116,3 +122,13 @@ def eye(N, M, k):
116122
return zeros
117123

118124
return eye
125+
126+
127+
@pytorch_funcify.register(MakeVector)
128+
def pytorch_funcify_MakeVector(op, **kwargs):
129+
torch_dtype = getattr(torch, op.dtype)
130+
131+
def makevector(*x):
132+
return torch.tensor(x, dtype=torch_dtype)
133+
134+
return makevector
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4+
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
5+
6+
7+
@pytorch_funcify.register(Reshape)
8+
def pytorch_funcify_Reshape(op, node, **kwargs):
9+
def reshape(x, shape):
10+
return torch.reshape(x, tuple(shape))
11+
12+
return reshape
13+
14+
15+
@pytorch_funcify.register(Shape)
16+
def pytorch_funcify_Shape(op, **kwargs):
17+
def shape(x):
18+
return x.shape
19+
20+
return shape
21+
22+
23+
@pytorch_funcify.register(Shape_i)
24+
def pytorch_funcify_Shape_i(op, **kwargs):
25+
i = op.i
26+
27+
def shape_i(x):
28+
return torch.tensor(x.shape[i])
29+
30+
return shape_i
31+
32+
33+
@pytorch_funcify.register(SpecifyShape)
34+
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
35+
def specifyshape(x, *shape):
36+
assert x.ndim == len(shape)
37+
for actual, expected in zip(x.shape, shape):
38+
if expected is None:
39+
continue
40+
if actual != expected:
41+
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
42+
return x
43+
44+
return specifyshape
45+
46+
47+
@pytorch_funcify.register(Unbroadcast)
48+
def pytorch_funcify_Unbroadcast(op, **kwargs):
49+
def unbroadcast(x):
50+
return x
51+
52+
return unbroadcast

tests/link/pytorch/test_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,10 @@ def test_eye(dtype):
294294
for _M in range(1, 6):
295295
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
296296
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))
297+
298+
299+
def test_pytorch_MakeVector():
300+
x = ptb.make_vector(1, 2, 3)
301+
x_fg = FunctionGraph([], [x])
302+
303+
compare_pytorch_and_py(x_fg, [])

tests/link/pytorch/test_extra_ops.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,7 @@ def test_pytorch_CumOp(axis, dtype):
4343
compare_pytorch_and_py(fgraph, [test_value])
4444

4545

46-
@pytest.mark.parametrize(
47-
"axis, repeats",
48-
[
49-
(0, (1, 2, 3)),
50-
(1, (3, 3)),
51-
pytest.param(
52-
None,
53-
3,
54-
marks=pytest.mark.xfail(reason="Reshape not implemented"),
55-
),
56-
],
57-
)
46+
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])
5847
def test_pytorch_Repeat(axis, repeats):
5948
a = pt.matrix("a", dtype="float64")
6049

tests/link/pytorch/test_shape.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import numpy as np
2+
3+
import pytensor.tensor as pt
4+
from pytensor.configdefaults import config
5+
from pytensor.graph.fg import FunctionGraph
6+
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
7+
from pytensor.tensor.type import iscalar, vector
8+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
9+
10+
11+
def test_pytorch_shape_ops():
12+
x_np = np.zeros((20, 3))
13+
x = Shape()(pt.as_tensor_variable(x_np))
14+
x_fg = FunctionGraph([], [x])
15+
16+
compare_pytorch_and_py(x_fg, [], must_be_device_array=False)
17+
18+
x = Shape_i(1)(pt.as_tensor_variable(x_np))
19+
x_fg = FunctionGraph([], [x])
20+
21+
compare_pytorch_and_py(x_fg, [], must_be_device_array=False)
22+
23+
24+
def test_pytorch_specify_shape():
25+
in_pt = pt.matrix("in")
26+
x = pt.specify_shape(in_pt, (4, None))
27+
x_fg = FunctionGraph([in_pt], [x])
28+
compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
29+
30+
# When used to assert two arrays have similar shapes
31+
in_pt = pt.matrix("in")
32+
shape_pt = pt.matrix("shape")
33+
x = pt.specify_shape(in_pt, shape_pt.shape)
34+
x_fg = FunctionGraph([in_pt, shape_pt], [x])
35+
compare_pytorch_and_py(
36+
x_fg,
37+
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
38+
)
39+
40+
41+
def test_pytorch_Reshape_constant():
42+
a = vector("a")
43+
x = reshape(a, (2, 2))
44+
x_fg = FunctionGraph([a], [x])
45+
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
46+
47+
48+
def test_pytorch_Reshape_dynamic():
49+
a = vector("a")
50+
shape_pt = iscalar("b")
51+
x = reshape(a, (shape_pt, shape_pt))
52+
x_fg = FunctionGraph([a, shape_pt], [x])
53+
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
54+
55+
56+
def test_pytorch_unbroadcast():
57+
x_np = np.zeros((20, 1, 1))
58+
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
59+
x_fg = FunctionGraph([], [x])
60+
61+
compare_pytorch_and_py(x_fg, [])

tests/link/pytorch/test_sort.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,7 @@
88

99

1010
@pytest.mark.parametrize("func", (sort, argsort))
11-
@pytest.mark.parametrize(
12-
"axis",
13-
[
14-
pytest.param(0),
15-
pytest.param(1),
16-
pytest.param(
17-
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
18-
),
19-
],
20-
)
11+
@pytest.mark.parametrize("axis", [0, 1, None])
2112
def test_sort(func, axis):
2213
x = matrix("x", shape=(2, 2), dtype="float64")
2314
out = func(x, axis=axis)

0 commit comments

Comments
 (0)