Skip to content

Commit 4e1391c

Browse files
committed
Removed math Ops Arg[Max] and Dot
1 parent f075b50 commit 4e1391c

File tree

2 files changed

+1
-118
lines changed

2 files changed

+1
-118
lines changed

pytensor/link/pytorch/dispatch/nlinalg.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import torch
22

33
from pytensor.link.pytorch.dispatch import pytorch_funcify
4-
from pytensor.tensor.blas import BatchedDot
5-
from pytensor.tensor.math import Argmax, Dot, Max
64
from pytensor.tensor.nlinalg import (
75
SVD,
86
Det,
@@ -85,14 +83,6 @@ def qr_full(x):
8583
return qr_full
8684

8785

88-
@pytorch_funcify.register(Dot)
89-
def pytorch_funcify_Dot(op, **kwargs):
90-
def dot(x, y):
91-
return torch.dot(x, y)
92-
93-
return dot
94-
95-
9686
@pytorch_funcify.register(MatrixPinv)
9787
def pytorch_funcify_Pinv(op, **kwargs):
9888
hermitian = op.hermitian
@@ -103,71 +93,9 @@ def pinv(x):
10393
return pinv
10494

10595

106-
@pytorch_funcify.register(BatchedDot)
107-
def pytorch_funcify_BatchedDot(op, **kwargs):
108-
def batched_dot(a, b):
109-
if a.shape[0] != b.shape[0]:
110-
raise TypeError("Shapes must match in the 0-th dimension")
111-
return torch.matmul(a, b)
112-
113-
return batched_dot
114-
115-
11696
@pytorch_funcify.register(KroneckerProduct)
11797
def pytorch_funcify_KroneckerProduct(op, **kwargs):
11898
def _kron(x, y):
11999
return torch.kron(x, y)
120100

121101
return _kron
122-
123-
124-
@pytorch_funcify.register(Max)
125-
def pytorch_funcify_Max(op, **kwargs):
126-
axis = op.axis
127-
128-
def max(x):
129-
if axis is None:
130-
max_res = torch.max(x.flatten())
131-
return max_res
132-
133-
# PyTorch doesn't support multiple axes for max;
134-
# this is a work-around
135-
axes = [int(ax) for ax in axis]
136-
137-
new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item()
138-
keep_axes = [i for i in range(x.ndim) if i not in axes]
139-
permute_order = keep_axes + axes
140-
permuted_x = x.permute(*permute_order)
141-
kept_shape = permuted_x.shape[: len(keep_axes)]
142-
143-
new_shape = (*kept_shape, new_dim)
144-
reshaped_x = permuted_x.reshape(new_shape)
145-
max_res, _ = torch.max(reshaped_x, dim=-1)
146-
return max_res
147-
148-
return max
149-
150-
151-
@pytorch_funcify.register(Argmax)
152-
def pytorch_funcify_Argmax(op, **kwargs):
153-
axis = op.axis
154-
155-
def argmax(x):
156-
if axis is None:
157-
return torch.argmax(x.view(-1))
158-
159-
# PyTorch doesn't support multiple axes for argmax;
160-
# this is a work-around
161-
axes = [int(ax) for ax in axis]
162-
163-
new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item()
164-
keep_axes = [i for i in range(x.ndim) if i not in axes]
165-
permute_order = keep_axes + axes
166-
permuted_x = x.permute(*permute_order)
167-
kept_shape = permuted_x.shape[: len(keep_axes)]
168-
169-
new_shape = (*kept_shape, new_dim)
170-
reshaped_x = permuted_x.reshape(new_shape)
171-
return torch.argmax(reshaped_x, dim=-1)
172-
173-
return argmax

tests/link/pytorch/test_nlinalg.py

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
from pytensor.compile.function import function
55
from pytensor.configdefaults import config
66
from pytensor.graph.fg import FunctionGraph
7-
from pytensor.graph.op import get_test_value
8-
from pytensor.tensor import blas as pt_blas
97
from pytensor.tensor import nlinalg as pt_nla
10-
from pytensor.tensor.math import argmax, dot, max
11-
from pytensor.tensor.type import matrix, tensor3, vector
8+
from pytensor.tensor.type import matrix
129
from tests.link.pytorch.test_basic import compare_pytorch_and_py
1310

1411

@@ -23,27 +20,6 @@ def matrix_test():
2320
return (x, test_value)
2421

2522

26-
def test_BatchedDot():
27-
# tensor3 . tensor3
28-
a = tensor3("a")
29-
a.tag.test_value = (
30-
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
31-
)
32-
b = tensor3("b")
33-
b.tag.test_value = (
34-
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
35-
)
36-
out = pt_blas.BatchedDot()(a, b)
37-
fgraph = FunctionGraph([a, b], [out])
38-
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
39-
40-
# A dimension mismatch should raise a TypeError for compatibility
41-
inputs = [get_test_value(a)[:-1], get_test_value(b)]
42-
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode="PYTORCH")
43-
with pytest.raises(TypeError):
44-
pytensor_jax_fn(*inputs)
45-
46-
4723
@pytest.mark.parametrize(
4824
"func",
4925
(
@@ -147,24 +123,3 @@ def test_kron():
147123
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
148124

149125
compare_pytorch_and_py(fgraph, [x_np, y_np])
150-
151-
152-
@pytest.mark.parametrize("func", (max, argmax))
153-
@pytest.mark.parametrize("axis", [None, [0], [0, 1], [0, 2], [0, 1, 2]])
154-
def test_max_and_argmax(func, axis):
155-
x = tensor3("x")
156-
np.random.seed(42)
157-
test_value = np.random.randint(0, 20, (4, 3, 2))
158-
159-
out = func(x, axis=axis)
160-
out_fg = FunctionGraph([x], [out])
161-
compare_pytorch_and_py(out_fg, [test_value])
162-
163-
164-
def test_dot():
165-
x = vector("x")
166-
test_value = np.array([1, 2, 3])
167-
168-
out = dot(x, x)
169-
out_fg = FunctionGraph([x], [out])
170-
compare_pytorch_and_py(out_fg, [test_value])

0 commit comments

Comments
 (0)