Skip to content

Commit a0c64b5

Browse files
committed
Reduce overhead of JITLinker
1 parent d1c5ae2 commit a0c64b5

File tree

6 files changed

+52
-91
lines changed

6 files changed

+52
-91
lines changed

pytensor/link/basic.py

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -653,41 +653,36 @@ def create_jitable_thunk(
653653
)
654654

655655
thunk_inputs = self.create_thunk_inputs(storage_map)
656-
657-
thunks = []
658-
659656
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
660-
661657
fgraph_jit = self.jit_compile(converted_fgraph)
662658

663659
def thunk(
664-
fgraph=self.fgraph,
665660
fgraph_jit=fgraph_jit,
666661
thunk_inputs=thunk_inputs,
667662
thunk_outputs=thunk_outputs,
668663
):
669-
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
664+
try:
665+
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
666+
except Exception:
667+
# TODO: Should we add a fake node that combines all outputs,
668+
# since the error may come from any of them?
669+
raise_with_op(self.fgraph, output_nodes[0], thunk)
670670

671671
# strict=False because we are in a hot loop
672-
for o_var, o_storage, o_val in zip(
673-
fgraph.outputs, thunk_outputs, outputs, strict=False
674-
):
675-
compute_map[o_var][0] = True
676-
o_storage[0] = self.output_filter(o_var, o_val)
677-
return outputs
672+
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
673+
o_storage[0] = o_val
678674

679675
thunk.inputs = thunk_inputs
680676
thunk.outputs = thunk_outputs
681677
thunk.lazy = False
682678

683-
thunks.append(thunk)
679+
thunks = [thunk]
684680

685681
return thunks, output_nodes, fgraph_jit
686682

687683
def make_all(self, input_storage=None, output_storage=None, storage_map=None):
688684
fgraph = self.fgraph
689685
nodes = self.schedule(fgraph)
690-
no_recycling = self.no_recycling
691686

692687
input_storage, output_storage, storage_map = map_storage(
693688
fgraph, nodes, input_storage, output_storage, storage_map
@@ -701,34 +696,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
701696
compute_map, nodes, input_storage, output_storage, storage_map
702697
)
703698

704-
computed, last_user = gc_helper(nodes)
705-
706-
if self.allow_gc:
707-
post_thunk_old_storage = [
708-
[
709-
storage_map[input]
710-
for input in node.inputs
711-
if (input in computed)
712-
and (input not in fgraph.outputs)
713-
and (node == last_user[input])
714-
]
715-
for node in nodes
716-
]
717-
else:
718-
post_thunk_old_storage = None
719-
720-
if no_recycling is True:
721-
no_recycling = list(storage_map.values())
722-
no_recycling = difference(no_recycling, input_storage)
723-
else:
724-
no_recycling = [
725-
storage_map[r] for r in no_recycling if r not in fgraph.inputs
726-
]
727-
728-
fn = streamline(
729-
fgraph, thunks, nodes, post_thunk_old_storage, no_recycling=no_recycling
730-
)
731-
699+
[fn] = thunks
732700
fn.jit_fn = jit_fn
733701
fn.allow_gc = self.allow_gc
734702
fn.storage_map = storage_map

pytensor/link/numba/linker.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,9 @@
1-
from typing import TYPE_CHECKING, Any
2-
3-
import numpy as np
4-
5-
import pytensor
61
from pytensor.link.basic import JITLinker
72

83

9-
if TYPE_CHECKING:
10-
from pytensor.graph.basic import Variable
11-
12-
134
class NumbaLinker(JITLinker):
145
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
156

16-
def output_filter(self, var: "Variable", out: Any) -> Any:
17-
if not isinstance(var, np.ndarray) and isinstance(
18-
var.type, pytensor.tensor.TensorType
19-
):
20-
return var.type.filter(out, allow_downcast=True)
21-
22-
return out
23-
247
def fgraph_convert(self, fgraph, **kwargs):
258
from pytensor.link.numba.dispatch import numba_funcify
269

pytensor/link/pytorch/linker.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import copy
2-
from typing import Any
3-
4-
from pytensor.graph.basic import Variable
51
from pytensor.link.basic import JITLinker
62
from pytensor.link.utils import unique_name_generator
73

@@ -13,14 +9,6 @@ def __init__(self, *args, **kwargs):
139
super().__init__(*args, **kwargs)
1410
self.gen_functors = []
1511

16-
def input_filter(self, inp: Any) -> Any:
17-
from pytensor.link.pytorch.dispatch import pytorch_typify
18-
19-
return pytorch_typify(inp)
20-
21-
def output_filter(self, var: Variable, out: Any) -> Any:
22-
return out.cpu()
23-
2412
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2513
from pytensor.link.pytorch.dispatch import pytorch_funcify
2614

@@ -49,6 +37,8 @@ def conversion_func_register(*args, **kwargs):
4937
def jit_compile(self, fn):
5038
import torch
5139

40+
from pytensor.link.pytorch.dispatch import pytorch_typify
41+
5242
class wrapper:
5343
"""
5444
Pytorch would fail compiling our method when trying
@@ -62,7 +52,7 @@ class wrapper:
6252

6353
def __init__(self, fn, gen_functors):
6454
self.fn = torch.compile(fn)
65-
self.gen_functors = copy.copy(gen_functors)
55+
self.gen_functors = gen_functors.copy()
6656

6757
def __call__(self, *args, **kwargs):
6858
import pytensor.link.utils
@@ -83,9 +73,15 @@ def __call__(self, *args, **kwargs):
8373
def __del__(self):
8474
del self.gen_functors
8575

86-
res = wrapper(fn, self.gen_functors)
76+
inner_fn = wrapper(fn, self.gen_functors)
8777
self.gen_functors = []
88-
return res
78+
79+
# Torch does not accept numpy inputs and may return GPU objects
80+
def fn(*inputs, inner_fn=inner_fn):
81+
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
82+
return tuple(out.cpu().numpy() for out in outs)
83+
84+
return fn
8985

9086
def create_thunk_inputs(self, storage_map):
9187
thunk_inputs = []

tests/link/numba/test_basic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,3 +889,20 @@ def test_cache_warning_suppressed():
889889

890890
x_test = np.random.uniform(size=5)
891891
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)
892+
893+
894+
@pytest.mark.parametrize("mode", ("default", "trust_input", "direct"))
895+
def test_function_overhead(mode, benchmark):
896+
x = pt.vector("x")
897+
out = pt.exp(x)
898+
899+
fn = function([x], out, mode="NUMBA")
900+
if mode == "trust_input":
901+
fn.trust_input = True
902+
elif mode == "direct":
903+
fn = fn.vm.jit_fn
904+
905+
test_x = np.zeros(1000)
906+
assert np.sum(fn(test_x)) == 1000
907+
908+
benchmark(fn, test_x)

tests/link/pytorch/test_basic.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ def compare_pytorch_and_py(
5353
assert_fn: func, opt
5454
Assert function used to check for equality between python and pytorch. If not
5555
provided uses np.testing.assert_allclose
56-
must_be_device_array: Bool
57-
Checks if torch.device.type is cuda
5856
5957
6058
"""
@@ -66,20 +64,19 @@ def compare_pytorch_and_py(
6664
pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode)
6765
pytorch_res = pytensor_torch_fn(*test_inputs)
6866

69-
if must_be_device_array:
70-
if isinstance(pytorch_res, list):
71-
assert all(isinstance(res, torch.Tensor) for res in pytorch_res)
72-
else:
73-
assert pytorch_res.device.type == "cuda"
67+
if isinstance(pytorch_res, list):
68+
assert all(isinstance(res, np.ndarray) for res in pytorch_res)
69+
else:
70+
assert isinstance(pytorch_res, np.ndarray)
7471

7572
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
7673
py_res = pytensor_py_fn(*test_inputs)
7774

7875
if len(fgraph.outputs) > 1:
7976
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True):
80-
assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i)
77+
assert_fn(pytorch_res_i, py_res_i)
8178
else:
82-
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])
79+
assert_fn(pytorch_res[0], py_res[0])
8380

8481
return pytensor_torch_fn, pytorch_res
8582

@@ -162,23 +159,23 @@ def test_shared(device):
162159
pytensor_torch_fn = function([], a, mode="PYTORCH")
163160
pytorch_res = pytensor_torch_fn()
164161

165-
assert isinstance(pytorch_res, torch.Tensor)
162+
assert isinstance(pytorch_res, np.ndarray)
166163
assert isinstance(a.get_value(), np.ndarray)
167-
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value())
164+
np.testing.assert_allclose(pytorch_res, a.get_value())
168165

169166
pytensor_torch_fn = function([], a * 2, mode="PYTORCH")
170167
pytorch_res = pytensor_torch_fn()
171168

172-
assert isinstance(pytorch_res, torch.Tensor)
169+
assert isinstance(pytorch_res, np.ndarray)
173170
assert isinstance(a.get_value(), np.ndarray)
174-
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2)
171+
np.testing.assert_allclose(pytorch_res, a.get_value() * 2)
175172

176173
new_a_value = np.array([3, 4, 5], dtype=config.floatX)
177174
a.set_value(new_a_value)
178175

179176
pytorch_res = pytensor_torch_fn()
180-
assert isinstance(pytorch_res, torch.Tensor)
181-
np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2)
177+
assert isinstance(pytorch_res, np.ndarray)
178+
np.testing.assert_allclose(pytorch_res, new_a_value * 2)
182179

183180

184181
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@@ -225,7 +222,7 @@ def test_alloc_and_empty():
225222
fn = function([dim1], out, mode=pytorch_mode)
226223
res = fn(7)
227224
assert res.shape == (5, 7, 3)
228-
assert res.dtype == torch.float32
225+
assert res.dtype == np.float32
229226

230227
v = vector("v", shape=(3,), dtype="float64")
231228
out = alloc(v, dim0, dim1, 3)

tests/link/pytorch/test_elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_cast():
152152
_, [res] = compare_pytorch_and_py(
153153
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
154154
)
155-
assert res.dtype == torch.int32
155+
assert res.dtype == np.int32
156156

157157

158158
def test_vmap_elemwise():

0 commit comments

Comments
 (0)