We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5121a85 commit 2721c5aCopy full SHA for 2721c5a
pytensor/link/pytorch/dispatch/nlinalg.py
@@ -8,7 +8,7 @@
8
def pytorch_funcify_Dot(op, **kwargs):
9
def dot(x, y):
10
# Case 1: Vector Product/Matrix Multiplication/1-D Broadcastable Vector
11
- if x.shape < 3 and y.shape < 3:
+ if x.shape == 1 or y.shape == 1 or (x.shape < 3 and y.shape < 3):
12
return torch.matmul(x, y)
13
else:
14
# Case 2: Stackable batch dimension
0 commit comments