From d00bb64decfec9712950a339c716ef6ce8392f79 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Fri, 1 Nov 2024 22:03:21 -0700 Subject: [PATCH 1/2] Increase tolerance of flaky test --- tests/tensor/test_blockwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index bd69d809a3..51b381861a 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -538,7 +538,7 @@ def core_scipy_fn(A, b): A_val_copy, b_val_copy ) np.testing.assert_allclose( - out, expected_out, atol=1e-6 if config.floatX == "float32" else 0 + out, expected_out, atol=1e-5 if config.floatX == "float32" else 0 ) # Confirm input was destroyed From 9b6555f21d559ed9e1a14ff923d5a34a5507cfa5 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Tue, 5 Nov 2024 11:49:51 +0100 Subject: [PATCH 2/2] Implement Blockwise in PyTorch backend --- pytensor/link/pytorch/dispatch/__init__.py | 1 + pytensor/link/pytorch/dispatch/blockwise.py | 32 +++++++++++++ tests/link/pytorch/test_blockwise.py | 53 +++++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/blockwise.py create mode 100644 tests/link/pytorch/test_blockwise.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index fddded525a..4caabf3e03 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -11,4 +11,5 @@ import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.subtensor +import pytensor.link.pytorch.dispatch.blockwise # isort: on diff --git a/pytensor/link/pytorch/dispatch/blockwise.py b/pytensor/link/pytorch/dispatch/blockwise.py new file mode 100644 index 0000000000..524e706633 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/blockwise.py @@ -0,0 +1,32 @@ +import torch +import torch.compiler + +from pytensor.graph import FunctionGraph +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.blockwise import Blockwise + + +@pytorch_funcify.register(Blockwise) +def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): + batched_dims = op.batch_ndim(node) + core_node = op._create_dummy_core_node(node.inputs) + core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) + inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1) + + for _ in range(batched_dims): + inner_func = torch.vmap(inner_func) + + @torch.compiler.disable(recursive=False) + def batcher(*inputs): + op._check_runtime_broadcast(node, inputs) + # broadcast on batched_dims + all_batched_dims = tuple(t.shape[:batched_dims] for t in inputs) + batched_shape = torch.broadcast_shapes(*all_batched_dims) + broadcast_inputs = [ + torch.broadcast_to(i, batched_shape + i.shape[batched_dims:]) + for i in inputs + ] + res = inner_func(*broadcast_inputs) + return res + + return batcher diff --git a/tests/link/pytorch/test_blockwise.py b/tests/link/pytorch/test_blockwise.py new file mode 100644 index 0000000000..75f207e544 --- /dev/null +++ b/tests/link/pytorch/test_blockwise.py @@ -0,0 +1,53 @@ +import numpy as np +import pytest + +import pytensor +import pytensor.tensor as pt +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op +from pytensor.tensor.blockwise import Blockwise + + +torch = pytest.importorskip("torch") +basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic") + + +class TestOp(Op): + gufunc_signature = "(m,n),(n,p)->(m,p)" + + def __init__(self, final_shape): + super().__init__() + self.final_shape = final_shape + self.call_shapes = [] + + def make_node(self, *args): + return Apply(self, list(args), [pt.matrix("_", shape=self.final_shape)]) + + def perform(self, *_): + raise RuntimeError("In perform") + + +@basic.pytorch_funcify.register(TestOp) +def evaluate_test_op(op, **_): + @torch.compiler.disable(recursive=False) + def func(a, b): + op.call_shapes.extend(map(torch.Tensor.size, [a, b])) + return a @ b + + return func + + +def test_blockwise_broadcast(): + _x = np.random.rand(5, 1, 2, 3) + _y = np.random.rand(3, 3, 2) + + x = pt.tensor4("x", shape=(5, 1, 2, 3)) + y = pt.tensor3("y", shape=(3, 3, 2)) + op = TestOp((2, 2)) + z = Blockwise(op)(x, y) + + f = pytensor.function([x, y], z, mode="PYTORCH") + res = f(_x, _y) + assert tuple(res.shape) == (5, 3, 2, 2) + np.testing.assert_allclose(res, _x @ _y) + assert op.call_shapes == [(2, 3), (3, 2)]