Skip to content

Commit c643494

Browse files
committed
Added PyTorch links for Max and Argmax. Tests are failing
1 parent 398383a commit c643494

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed
Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import numpy as np
12
import torch
23

34
from pytensor.link.pytorch.dispatch import pytorch_funcify
4-
from pytensor.tensor.math import Dot
5+
from pytensor.tensor.math import Argmax, Dot, Max
56

67

78
@pytorch_funcify.register(Dot)
@@ -10,3 +11,49 @@ def dot(x, y):
1011
return torch.matmul(x, y)
1112

1213
return dot
14+
15+
16+
@pytorch_funcify.register(Max)
17+
def pytorch_funcify_Max(op, **kwargs):
18+
axis = op.axis
19+
keepdims = op.keepdims
20+
21+
def max(x):
22+
return torch.max(x, dim=axis, keepdim=keepdims)
23+
24+
return max
25+
26+
27+
@pytorch_funcify.register(Argmax)
28+
def pytorch_funcify_Argmax(op, **kwargs):
29+
axis = op.axis
30+
31+
def argmax(x):
32+
if axis is None:
33+
axes = tuple(range(x.ndim))
34+
else:
35+
axes = tuple(int(ax) for ax in axis)
36+
37+
# NumPy does not support multiple axes for argmax; this is a
38+
# work-around
39+
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
40+
# Not-reduced axes in front
41+
transposed_x = np.transpose(
42+
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
43+
)
44+
kept_shape = transposed_x.shape[: len(keep_axes)]
45+
reduced_shape = transposed_x.shape[len(keep_axes) :]
46+
47+
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
48+
# Otherwise reshape would complain citing float arg
49+
new_shape = (
50+
*kept_shape,
51+
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
52+
)
53+
reshaped_x = torch.tensor(transposed_x.reshape(tuple(new_shape)))
54+
55+
max_idx_res = torch.argmax(reshaped_x, dim=-1).to(torch.long)
56+
57+
return max_idx_res
58+
59+
return argmax

tests/link/pytorch/test_math.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
2+
import pytest
23

34
from pytensor.configdefaults import config
45
from pytensor.graph.fg import FunctionGraph
56
from pytensor.graph.op import get_test_value
7+
from pytensor.tensor.math import argmax, max
68
from pytensor.tensor.type import matrix, scalar, vector
79
from tests.link.pytorch.test_basic import compare_pytorch_and_py
810

@@ -28,3 +30,35 @@ def test_pytorch_dot():
2830
out = y.dot(alpha * A).dot(x) + beta * y
2931
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
3032
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
33+
34+
35+
@pytest.mark.parametrize(
36+
"keepdims",
37+
[True, False],
38+
)
39+
@pytest.mark.parametrize(
40+
"axis",
41+
[None, 1, 0],
42+
)
43+
def test_pytorch_max(axis, keepdims):
44+
a = matrix("a", dtype=config.floatX)
45+
a.tag.test_value = np.random.randn(4, 4).astype(config.floatX)
46+
amx = max(a, axis=axis, keepdims=keepdims)
47+
fgraph = FunctionGraph([a], amx, clone=False)
48+
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
49+
50+
51+
@pytest.mark.parametrize(
52+
"keepdims",
53+
[True, False],
54+
)
55+
@pytest.mark.parametrize(
56+
"axis",
57+
[None, 1, (0,)],
58+
)
59+
def test_pytorch_argmax(axis, keepdims):
60+
a = matrix("a", dtype=config.floatX)
61+
a.tag.test_value = np.random.randn(4, 4).astype(config.floatX)
62+
amx = argmax(a, axis=axis, keepdims=keepdims)
63+
fgraph = FunctionGraph([a], amx)
64+
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

0 commit comments

Comments
 (0)