Skip to content

Commit 5c73991

Browse files
Fix Subtensor and AdvSubtensor Ops
1 parent 218a5fe commit 5c73991

File tree

4 files changed

+102
-35
lines changed

4 files changed

+102
-35
lines changed

pytensor/compile/mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
471471
"BlasOpt",
472472
"fusion",
473473
"inplace",
474+
"local_uint_constant_indices",
474475
],
475476
),
476477
)

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,40 @@
11
from functools import singledispatch
22
from types import NoneType
33

4+
import numpy as np
45
import torch
56

67
from pytensor.compile.ops import DeepCopyOp
78
from pytensor.graph.fg import FunctionGraph
89
from pytensor.link.utils import fgraph_to_python
910
from pytensor.raise_op import CheckAndRaise
10-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
11+
from pytensor.tensor.basic import (
12+
Alloc,
13+
AllocEmpty,
14+
ARange,
15+
Eye,
16+
Join,
17+
MakeVector,
18+
TensorFromScalar,
19+
)
1120

1221

1322
@singledispatch
14-
def pytorch_typify(data, dtype=None, **kwargs):
15-
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
16-
if isinstance(data, NoneType):
17-
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
18-
if isinstance(data, slice):
19-
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
20-
else:
21-
return torch.as_tensor(data, dtype=dtype)
23+
def pytorch_typify(data, **kwargs):
24+
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
2225

2326

27+
@pytorch_typify.register(np.ndarray)
28+
@pytorch_typify.register(torch.Tensor)
29+
def pytorch_typify_tensor(data, dtype=None, **kwargs):
30+
return torch.as_tensor(data, dtype=dtype)
31+
32+
33+
@pytorch_typify.register(slice)
2434
@pytorch_typify.register(NoneType)
25-
def pytorch_typify_None(data, **kwargs):
26-
return None
35+
@pytorch_typify.register(np.number)
36+
def pytorch_typify_scalar(data, **kwargs):
37+
return data
2738

2839

2940
@singledispatch
@@ -137,3 +148,11 @@ def makevector(*x):
137148
return torch.tensor(x, dtype=torch_dtype)
138149

139150
return makevector
151+
152+
153+
@pytorch_funcify.register(TensorFromScalar)
154+
def pytorch_funcify_TensorFromScalar(op, **kwargs):
155+
def tensorfromscalar(x):
156+
return x
157+
158+
return tensorfromscalar

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,48 @@
77
Subtensor,
88
indices_from_subtensor,
99
)
10+
from pytensor.tensor.type_other import MakeSlice
1011

1112

1213
@pytorch_funcify.register(Subtensor)
13-
@pytorch_funcify.register(AdvancedSubtensor)
14-
@pytorch_funcify.register(AdvancedSubtensor1)
1514
def pytorch_funcify_Subtensor(op, node, **kwargs):
1615
idx_list = getattr(op, "idx_list", None)
1716

1817
def subtensor(x, *ilists):
1918
indices = indices_from_subtensor(ilists, idx_list)
20-
new_indices = []
19+
for i in indices:
20+
if isinstance(i, slice):
21+
if i.step and i.step < 0:
22+
raise NotImplementedError(
23+
"Negative step sizes are not supported in Pytorch"
24+
)
25+
26+
return x[indices]
2127

28+
return subtensor
29+
30+
31+
@pytorch_funcify.register(AdvancedSubtensor1)
32+
@pytorch_funcify.register(AdvancedSubtensor)
33+
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
34+
def advsubtensor(x, *indices):
2235
for i in indices:
2336
if isinstance(i, slice):
2437
if i.step and i.step < 0:
2538
raise NotImplementedError(
2639
"Negative step sizes are not supported in Pytorch"
2740
)
28-
new_indices.append(i)
29-
else:
30-
new_indices.append(i.tolist())
41+
return x[indices]
3142

32-
if len(indices) == 1:
33-
indices = indices[0]
43+
return advsubtensor
3444

35-
return x[tuple(new_indices)]
3645

37-
return subtensor
46+
@pytorch_funcify.register(MakeSlice)
47+
def pytorch_funcify_makeslice(op, **kwargs):
48+
def makeslice(*x):
49+
return slice(x)
50+
51+
return makeslice
3852

3953

4054
@pytorch_funcify.register(IncSubtensor)
@@ -45,12 +59,14 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
4559
if getattr(op, "set_instead_of_inc", False):
4660

4761
def torch_fn(x, indices, y):
48-
return x.at[indices].set(y)
62+
x[indices] = y
63+
return x
4964

5065
else:
5166

5267
def torch_fn(x, indices, y):
53-
return x.at[indices].add(y)
68+
x[indices] += y
69+
return x
5470

5571
def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list):
5672
indices = indices_from_subtensor(ilist, idx_list)

tests/link/pytorch/test_subtensor.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
import pytensor.scalar as ps
45
import pytensor.tensor as pt
56
from pytensor.configdefaults import config
67
from pytensor.graph.fg import FunctionGraph
@@ -15,6 +16,7 @@ def test_pytorch_Subtensor():
1516

1617
# Basic indices
1718
out_pt = x_pt[1, 2, 0]
19+
1820
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
1921
out_fg = FunctionGraph([x_pt], [out_pt])
2022
compare_pytorch_and_py(out_fg, [x_np])
@@ -34,6 +36,27 @@ def test_pytorch_Subtensor():
3436
out_fg = FunctionGraph([x_pt], [out_pt])
3537
compare_pytorch_and_py(out_fg, [x_np])
3638

39+
a_pt = ps.int64("a")
40+
a_np = 1
41+
out_pt = x_pt[a_pt, 2, a_pt:2]
42+
43+
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
44+
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
45+
compare_pytorch_and_py(out_fg, [x_np, a_np])
46+
47+
# Negative step
48+
with pytest.raises(
49+
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
50+
):
51+
out_pt = x_pt[::-1]
52+
out_fg = FunctionGraph([x_pt], [out_pt])
53+
compare_pytorch_and_py(out_fg, [x_np])
54+
55+
56+
def test_pytorch_AdvSubtensor():
57+
shape = (3, 4, 5)
58+
x_pt = pt.tensor("x", shape=shape, dtype="int")
59+
x_np = np.arange(np.prod(shape)).reshape(shape)
3760
# Advanced indexing
3861
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
3962
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
@@ -51,24 +74,32 @@ def test_pytorch_Subtensor():
5174
out_fg = FunctionGraph([x_pt], [out_pt])
5275
compare_pytorch_and_py(out_fg, [x_np])
5376

54-
with pytest.raises(NotImplementedError):
55-
out_pt = x_pt[[1, 2], :, [3, 4]]
56-
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
57-
out_fg = FunctionGraph([x_pt], [out_pt])
58-
compare_pytorch_and_py(out_fg, [x_np])
77+
out_pt = x_pt[[1, 2], :, [3, 4]]
78+
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
79+
out_fg = FunctionGraph([x_pt], [out_pt])
80+
compare_pytorch_and_py(out_fg, [x_np])
5981

60-
# Flipping
61-
with pytest.raises(
62-
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
63-
):
64-
out_pt = x_pt[::-1]
65-
out_fg = FunctionGraph([x_pt], [out_pt])
66-
compare_pytorch_and_py(out_fg, [x_np])
82+
out_pt = x_pt[[1, 2], None]
83+
out_fg = FunctionGraph([x_pt], [out_pt])
84+
compare_pytorch_and_py(out_fg, [x_np])
85+
86+
a_pt = ps.int64("a")
87+
a_np = 2
88+
89+
out_pt = x_pt[[1, a_pt], a_pt]
90+
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
91+
compare_pytorch_and_py(out_fg, [x_np, a_np])
6792

6893
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)]
6994
out_fg = FunctionGraph([x_pt], [out_pt])
7095
compare_pytorch_and_py(out_fg, [x_np])
7196

97+
a_pt = pt.tensor3("a", dtype="bool")
98+
a_np = np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)
99+
out_pt = x_pt[a_pt]
100+
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
101+
compare_pytorch_and_py(out_fg, [x_np, a_np])
102+
72103

73104
def test_pytorch_IncSubtensor():
74105
rng = np.random.default_rng(42)

0 commit comments

Comments
 (0)