Skip to content

Commit 03bb3a8

Browse files
committed
Reverted logic to correct scope for math.dot
1 parent 2721c5a commit 03bb3a8

File tree

2 files changed

+2
-16
lines changed

2 files changed

+2
-16
lines changed

pytensor/link/pytorch/dispatch/nlinalg.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
@pytorch_funcify.register(Dot)
88
def pytorch_funcify_Dot(op, **kwargs):
99
def dot(x, y):
10-
# Case 1: Vector Product/Matrix Multiplication/1-D Broadcastable Vector
11-
if x.shape == 1 or y.shape == 1 or (x.shape < 3 and y.shape < 3):
12-
return torch.matmul(x, y)
13-
else:
14-
# Case 2: Stackable batch dimension
15-
return torch.tensordot(x, y, dims=([-1], [-2]))
10+
return torch.matmul(x, y)
1611

1712
return dot

tests/link/pytorch/test_nlinalg.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,11 @@
33
from pytensor.configdefaults import config
44
from pytensor.graph.fg import FunctionGraph
55
from pytensor.graph.op import get_test_value
6-
from pytensor.tensor.type import matrix, scalar, tensor3, vector
6+
from pytensor.tensor.type import matrix, scalar, vector
77
from tests.link.pytorch.test_basic import compare_pytorch_and_py
88

99

1010
def test_pytorch_dot():
11-
a = tensor3("a")
12-
a.tag.test_value = np.zeros((3, 2, 4)).astype(config.floatX)
13-
b = tensor3("b")
14-
b.tag.test_value = np.zeros((3, 4, 1)).astype(config.floatX)
1511
y = vector("y")
1612
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
1713
x = vector("x")
@@ -23,11 +19,6 @@ def test_pytorch_dot():
2319
beta = scalar("beta")
2420
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
2521

26-
# 3D * 3D
27-
out = a.dot(b * alpha) + beta * b
28-
fgraph = FunctionGraph([a, b, alpha, beta], [out])
29-
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
30-
3122
# 2D * 2D
3223
out = A.dot(A * alpha) + beta * A
3324
fgraph = FunctionGraph([A, alpha, beta], [out])

0 commit comments

Comments
 (0)