|
1 | 1 | import torch
|
| 2 | +import torch.compiler |
2 | 3 |
|
3 | 4 | from pytensor.graph import FunctionGraph
|
4 | 5 | from pytensor.link.pytorch.dispatch import pytorch_funcify
|
5 | 6 | from pytensor.tensor.blockwise import Blockwise
|
6 |
| -from pytensor.tensor.random.utils import params_broadcast_shapes |
7 | 7 |
|
8 | 8 |
|
9 | 9 | @pytorch_funcify.register(Blockwise)
|
10 | 10 | def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
|
11 | 11 | batched_dims = op.batch_ndim(node)
|
12 | 12 | core_node = op._create_dummy_core_node(node.inputs)
|
13 | 13 | core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
|
14 |
| - core_func = pytorch_funcify(core_fgraph) |
15 |
| - if len(node.outputs) == 1: |
16 |
| - |
17 |
| - def inner_func(*inputs): |
18 |
| - return core_func(*inputs)[0] |
19 |
| - else: |
20 |
| - inner_func = core_func |
| 14 | + inner_func = pytorch_funcify(core_fgraph) |
21 | 15 |
|
22 | 16 | for _ in range(batched_dims):
|
23 | 17 | inner_func = torch.vmap(inner_func)
|
24 | 18 |
|
| 19 | + @torch.compiler.disable(recursive=False) |
25 | 20 | def batcher(*inputs):
|
26 | 21 | op._check_runtime_broadcast(node, inputs)
|
27 | 22 | # broadcast on batched_dims
|
28 |
| - all_batched_dims = tuple(tuple(t.shape) for t in inputs) |
29 |
| - new_shapes = params_broadcast_shapes( |
30 |
| - all_batched_dims, |
31 |
| - ndims_params=[batched_dims] * len(inputs), |
32 |
| - use_pytensor=False, |
33 |
| - ) |
| 23 | + all_batched_dims = tuple(t.shape[:batched_dims] for t in inputs) |
| 24 | + batched_shape = torch.broadcast_shapes(*all_batched_dims) |
34 | 25 | broadcast_inputs = [
|
35 |
| - torch.broadcast_to(i, s) for i, s in zip(inputs, new_shapes) |
| 26 | + torch.broadcast_to(i, batched_shape + i.shape[batched_dims:]) |
| 27 | + for i in inputs |
36 | 28 | ]
|
37 |
| - return inner_func(*broadcast_inputs) |
| 29 | + res = inner_func(*broadcast_inputs) |
| 30 | + if len(node.outputs) == 1: |
| 31 | + return res[0] |
| 32 | + else: |
| 33 | + return res |
38 | 34 |
|
39 | 35 | return batcher
|
0 commit comments