|
3 | 3 |
|
4 | 4 | import pytensor
|
5 | 5 | import pytensor.tensor as pt
|
| 6 | +from pytensor.graph.basic import Apply |
| 7 | +from pytensor.graph.op import Op |
| 8 | +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify |
| 9 | +from pytensor.tensor.blockwise import Blockwise |
6 | 10 |
|
7 | 11 |
|
8 | 12 | torch = pytest.importorskip("torch")
|
9 | 13 |
|
10 | 14 |
|
| 15 | +class TestOp(Op): |
| 16 | + gufunc_signature = "(m,n),(n,p)->(m,p)" |
| 17 | + |
| 18 | + def __init__(self, final_shape): |
| 19 | + self.final_shape = final_shape |
| 20 | + self.call_shapes = [] |
| 21 | + |
| 22 | + def make_node(self, *args): |
| 23 | + return Apply(self, list(args), [pt.matrix("_", shape=self.final_shape)]) |
| 24 | + |
| 25 | + def perform(self, *_): |
| 26 | + raise RuntimeError("In perform") |
| 27 | + |
| 28 | + |
| 29 | +@pytorch_funcify.register(TestOp) |
| 30 | +def evaluate_test_op(op, **_): |
| 31 | + def func(a, b): |
| 32 | + op.call_shapes.extend(map(torch.Tensor.size, [a, b])) |
| 33 | + return a @ b |
| 34 | + |
| 35 | + return func |
| 36 | + |
| 37 | + |
11 | 38 | def test_blockwise_broadcast():
|
12 | 39 | _x = np.random.rand(5, 1, 2, 3)
|
13 | 40 | _y = np.random.rand(3, 3, 2)
|
14 | 41 |
|
15 | 42 | x = pt.tensor4("x", shape=(5, 1, 2, 3))
|
16 | 43 | y = pt.tensor3("y", shape=(3, 3, 2))
|
| 44 | + op = TestOp((2, 2)) |
| 45 | + z = Blockwise(op)(x, y) |
17 | 46 |
|
18 |
| - f = pytensor.function([x, y], x @ y, mode="PYTORCH") |
| 47 | + f = pytensor.function([x, y], z, mode="PYTORCH") |
19 | 48 | res = f(_x, _y)
|
20 | 49 | assert tuple(res.shape) == (5, 3, 2, 2)
|
21 | 50 | np.testing.assert_allclose(res, _x @ _y)
|
| 51 | + assert op.call_shapes == [(2, 3), (3, 2)] |
0 commit comments