Skip to content

Commit ffad937

Browse files
committed
Added PyTorch link and unit tests for normal dot
1 parent a8d7638 commit ffad937

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.math import Dot
5+
6+
7+
@pytorch_funcify.register(Dot)
8+
def pytorch_funcify_Dot(op, **kwargs):
9+
def dot(x, y):
10+
return torch.matmul(x, y)
11+
12+
return dot

tests/link/pytorch/test_nlinalg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
3+
from pytensor.configdefaults import config
4+
from pytensor.graph.fg import FunctionGraph
5+
from pytensor.graph.op import get_test_value
6+
from pytensor.tensor.type import matrix, scalar, vector
7+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
8+
9+
10+
def test_tensor_basics():
11+
y = vector("y")
12+
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
13+
x = vector("x")
14+
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
15+
A = matrix("A")
16+
A.tag.test_value = np.array([[6, 3], [3, 0]], dtype=config.floatX)
17+
alpha = scalar("alpha")
18+
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
19+
beta = scalar("beta")
20+
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
21+
22+
# 1D * 2D * 1D
23+
out = y.dot(alpha * A).dot(x) + beta * y
24+
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
25+
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
26+
27+
# 2D * 2D
28+
out = A.dot(A * alpha) + beta * A
29+
fgraph = FunctionGraph([A, alpha, beta], [out])
30+
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

0 commit comments

Comments
 (0)