Skip to content

Commit 8fd37d0

Browse files
author
Ian Schweer
committed
Use testop in test
1 parent a10e6f4 commit 8fd37d0

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

tests/link/pytorch/test_blockwise.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,49 @@
33

44
import pytensor
55
import pytensor.tensor as pt
6+
from pytensor.graph.basic import Apply
7+
from pytensor.graph.op import Op
8+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
9+
from pytensor.tensor.blockwise import Blockwise
610

711

812
torch = pytest.importorskip("torch")
913

1014

15+
class TestOp(Op):
16+
gufunc_signature = "(m,n),(n,p)->(m,p)"
17+
18+
def __init__(self, final_shape):
19+
self.final_shape = final_shape
20+
self.call_shapes = []
21+
22+
def make_node(self, *args):
23+
return Apply(self, list(args), [pt.matrix("_", shape=self.final_shape)])
24+
25+
def perform(self, *_):
26+
raise RuntimeError("In perform")
27+
28+
29+
@pytorch_funcify.register(TestOp)
30+
def evaluate_test_op(op, **_):
31+
def func(a, b):
32+
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
33+
return a @ b
34+
35+
return func
36+
37+
1138
def test_blockwise_broadcast():
1239
_x = np.random.rand(5, 1, 2, 3)
1340
_y = np.random.rand(3, 3, 2)
1441

1542
x = pt.tensor4("x", shape=(5, 1, 2, 3))
1643
y = pt.tensor3("y", shape=(3, 3, 2))
44+
op = TestOp((2, 2))
45+
z = Blockwise(op)(x, y)
1746

18-
f = pytensor.function([x, y], x @ y, mode="PYTORCH")
47+
f = pytensor.function([x, y], z, mode="PYTORCH")
1948
res = f(_x, _y)
2049
assert tuple(res.shape) == (5, 3, 2, 2)
2150
np.testing.assert_allclose(res, _x @ _y)
51+
assert op.call_shapes == [(2, 3), (3, 2)]

0 commit comments

Comments
 (0)