Skip to content

Commit 2cf0ed2

Browse files
committed
Reverted folder structure and added BatchedDot
1 parent 03bb3a8 commit 2cf0ed2

File tree

5 files changed

+53
-0
lines changed

5 files changed

+53
-0
lines changed

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
33

44
# # Load dispatch specializations
5+
import pytensor.link.pytorch.dispatch.blas
56
import pytensor.link.pytorch.dispatch.scalar
67
import pytensor.link.pytorch.dispatch.elemwise
8+
import pytensor.link.pytorch.dispatch.math
9+
710
# isort: on
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.blas import BatchedDot
5+
6+
7+
@pytorch_funcify.register(BatchedDot)
8+
def pytorch_funcify_BatchedDot(op, **kwargs):
9+
def batched_dot(a, b):
10+
if a.shape[0] != b.shape[0]:
11+
raise TypeError("Shapes must match in the 0-th dimension")
12+
return torch.bmm(a, b)
13+
14+
return batched_dot

tests/link/pytorch/test_blas.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.compile.function import function
5+
from pytensor.compile.mode import Mode
6+
from pytensor.configdefaults import config
7+
from pytensor.graph.fg import FunctionGraph
8+
from pytensor.graph.op import get_test_value
9+
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
10+
from pytensor.link.pytorch import PytorchLinker
11+
from pytensor.tensor import blas as pt_blas
12+
from pytensor.tensor.type import tensor3
13+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
14+
15+
16+
def test_pytorch_BatchedDot():
17+
# tensor3 . tensor3
18+
a = tensor3("a")
19+
a.tag.test_value = (
20+
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
21+
)
22+
b = tensor3("b")
23+
b.tag.test_value = (
24+
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
25+
)
26+
out = pt_blas.BatchedDot()(a, b)
27+
fgraph = FunctionGraph([a, b], [out])
28+
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
29+
30+
# A dimension mismatch should raise a TypeError for compatibility
31+
inputs = [get_test_value(a)[:-1], get_test_value(b)]
32+
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
33+
pytorch_mode = Mode(PytorchLinker(), opts)
34+
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=pytorch_mode)
35+
with pytest.raises(TypeError):
36+
pytensor_jax_fn(*inputs)
File renamed without changes.

0 commit comments

Comments
 (0)