Skip to content

Commit 5121a85

Browse files
committed
Changed implementation of dot. Renamed tests
1 parent ffad937 commit 5121a85

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

pytensor/link/pytorch/dispatch/nlinalg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
@pytorch_funcify.register(Dot)
88
def pytorch_funcify_Dot(op, **kwargs):
99
def dot(x, y):
10-
return torch.matmul(x, y)
10+
# Case 1: Vector Product/Matrix Multiplication/1-D Broadcastable Vector
11+
if x.shape < 3 and y.shape < 3:
12+
return torch.matmul(x, y)
13+
else:
14+
# Case 2: Stackable batch dimension
15+
return torch.tensordot(x, y, dims=([-1], [-2]))
1116

1217
return dot

tests/link/pytorch/test_nlinalg.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from pytensor.configdefaults import config
44
from pytensor.graph.fg import FunctionGraph
55
from pytensor.graph.op import get_test_value
6-
from pytensor.tensor.type import matrix, scalar, vector
6+
from pytensor.tensor.type import matrix, scalar, tensor3, vector
77
from tests.link.pytorch.test_basic import compare_pytorch_and_py
88

99

10-
def test_tensor_basics():
10+
def test_pytorch_dot():
11+
a = tensor3("a")
12+
a.tag.test_value = np.zeros((3, 2, 4)).astype(config.floatX)
13+
b = tensor3("b")
14+
b.tag.test_value = np.zeros((3, 4, 1)).astype(config.floatX)
1115
y = vector("y")
1216
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
1317
x = vector("x")
@@ -19,12 +23,17 @@ def test_tensor_basics():
1923
beta = scalar("beta")
2024
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
2125

22-
# 1D * 2D * 1D
23-
out = y.dot(alpha * A).dot(x) + beta * y
24-
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
26+
# 3D * 3D
27+
out = a.dot(b * alpha) + beta * b
28+
fgraph = FunctionGraph([a, b, alpha, beta], [out])
2529
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
2630

2731
# 2D * 2D
2832
out = A.dot(A * alpha) + beta * A
2933
fgraph = FunctionGraph([A, alpha, beta], [out])
3034
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
35+
36+
# 1D * 2D and 1D * 1D
37+
out = y.dot(alpha * A).dot(x) + beta * y
38+
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
39+
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

0 commit comments

Comments
 (0)