Skip to content

Commit dc5e943

Browse files
brandonwillardricardoV94
authored andcommitted
Separate interface and dispatch of numba_funcify
1 parent b1b97ee commit dc5e943

File tree

11 files changed

+169
-128
lines changed

11 files changed

+169
-128
lines changed

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,15 @@ Here's an example for :class:`IfElse`:
8383
return res if n_outs > 1 else res[0]
8484
8585
86-
Step 3: Register the function with the `jax_funcify` dispatcher
86+
Step 3: Register the function with the `_jax_funcify` dispatcher
8787
---------------------------------------------------------------
8888

8989
With the PyTensor `Op` replicated in JAX, we’ll need to register the
9090
function with the PyTensor JAX `Linker`. This is done through the use of
9191
`singledispatch`. If you don't know how `singledispatch` works, see the
9292
`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch>`_.
9393

94-
The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and
94+
The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.basic._numba_funcify` and
9595
:func:`pytensor.link.jax.dispatch.jax_funcify`.
9696

9797
Here’s an example for the `Eye`\ `Op`:

pytensor/link/numba/dispatch/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# isort: off
2-
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_const_convert
2+
from pytensor.link.numba.dispatch.basic import (
3+
numba_funcify,
4+
numba_const_convert,
5+
numba_njit,
6+
)
37

48
# Load dispatch specializations
59
import pytensor.link.numba.dispatch.scalar

pytensor/link/numba/dispatch/basic.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from contextlib import contextmanager
55
from functools import singledispatch
66
from textwrap import dedent
7-
from typing import Union
7+
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
88

99
import numba
1010
import numba.np.unsafe.ndarray as numba_ndarray
@@ -22,6 +22,7 @@
2222
from pytensor.compile.ops import DeepCopyOp
2323
from pytensor.graph.basic import Apply, NoParams
2424
from pytensor.graph.fg import FunctionGraph
25+
from pytensor.graph.op import Op
2526
from pytensor.graph.type import Type
2627
from pytensor.ifelse import IfElse
2728
from pytensor.link.utils import (
@@ -48,6 +49,10 @@
4849
from pytensor.tensor.type_other import MakeSlice, NoneConst
4950

5051

52+
if TYPE_CHECKING:
53+
from pytensor.graph.op import StorageMapType
54+
55+
5156
def numba_njit(*args, **kwargs):
5257

5358
kwargs = kwargs.copy()
@@ -353,8 +358,43 @@ def numba_const_convert(data, dtype=None, **kwargs):
353358
return data
354359

355360

356-
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
357-
"""Create a Numba compatible function from an Aesara `Op`."""
361+
def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable:
362+
"""Convert `obj` to a Numba-JITable object."""
363+
return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs)
364+
365+
366+
@singledispatch
367+
def _numba_funcify(
368+
obj,
369+
node: Optional[Apply] = None,
370+
storage_map: Optional["StorageMapType"] = None,
371+
**kwargs,
372+
) -> Callable:
373+
r"""Dispatch on PyTensor object types to perform Numba conversions.
374+
375+
Arguments
376+
---------
377+
obj
378+
The object used to determine the appropriate conversion function based
379+
on its type. This is generally an `Op` instance, but `FunctionGraph`\s
380+
are also supported.
381+
node
382+
When `obj` is an `Op`, this value should be the corresponding `Apply` node.
383+
storage_map
384+
A storage map with, for example, the constant and `SharedVariable` values
385+
of the graph being converted.
386+
387+
Returns
388+
-------
389+
A `Callable` that can be JIT-compiled in Numba using `numba.jit`.
390+
391+
"""
392+
raise NotImplementedError(f"Numba funcify for obj {obj} not implemented")
393+
394+
395+
@_numba_funcify.register(Op)
396+
def numba_funcify_perform(op, node, storage_map=None, **kwargs) -> Callable:
397+
"""Create a Numba compatible function from an PyTensor `Op.perform`."""
358398

359399
warnings.warn(
360400
f"Numba will use object mode to run {op}'s perform method",
@@ -405,16 +445,10 @@ def perform(*inputs):
405445
ret = py_perform_return(inputs)
406446
return ret
407447

408-
return perform
409-
410-
411-
@singledispatch
412-
def numba_funcify(op, node=None, storage_map=None, **kwargs):
413-
"""Generate a numba function for a given op and apply node."""
414-
return generate_fallback_impl(op, node, storage_map, **kwargs)
448+
return cast(Callable, perform)
415449

416450

417-
@numba_funcify.register(OpFromGraph)
451+
@_numba_funcify.register(OpFromGraph)
418452
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
419453

420454
_ = kwargs.pop("storage_map", None)
@@ -436,7 +470,7 @@ def opfromgraph(*inputs):
436470
return opfromgraph
437471

438472

439-
@numba_funcify.register(FunctionGraph)
473+
@_numba_funcify.register(FunctionGraph)
440474
def numba_funcify_FunctionGraph(
441475
fgraph,
442476
node=None,
@@ -544,8 +578,8 @@ def {fn_name}({", ".join(input_names)}):
544578
return subtensor_def_src
545579

546580

547-
@numba_funcify.register(Subtensor)
548-
@numba_funcify.register(AdvancedSubtensor1)
581+
@_numba_funcify.register(Subtensor)
582+
@_numba_funcify.register(AdvancedSubtensor1)
549583
def numba_funcify_Subtensor(op, node, **kwargs):
550584

551585
subtensor_def_src = create_index_func(
@@ -561,7 +595,7 @@ def numba_funcify_Subtensor(op, node, **kwargs):
561595
return numba_njit(subtensor_fn, boundscheck=True)
562596

563597

564-
@numba_funcify.register(IncSubtensor)
598+
@_numba_funcify.register(IncSubtensor)
565599
def numba_funcify_IncSubtensor(op, node, **kwargs):
566600

567601
incsubtensor_def_src = create_index_func(
@@ -577,7 +611,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
577611
return numba_njit(incsubtensor_fn, boundscheck=True)
578612

579613

580-
@numba_funcify.register(AdvancedIncSubtensor1)
614+
@_numba_funcify.register(AdvancedIncSubtensor1)
581615
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
582616
inplace = op.inplace
583617
set_instead_of_inc = op.set_instead_of_inc
@@ -610,7 +644,7 @@ def advancedincsubtensor1(x, vals, idxs):
610644
return advancedincsubtensor1
611645

612646

613-
@numba_funcify.register(DeepCopyOp)
647+
@_numba_funcify.register(DeepCopyOp)
614648
def numba_funcify_DeepCopyOp(op, node, **kwargs):
615649

616650
# Scalars are apparently returned as actual Python scalar types and not
@@ -632,26 +666,26 @@ def deepcopyop(x):
632666
return deepcopyop
633667

634668

635-
@numba_funcify.register(MakeSlice)
636-
def numba_funcify_MakeSlice(op, **kwargs):
669+
@_numba_funcify.register(MakeSlice)
670+
def numba_funcify_MakeSlice(op, node, **kwargs):
637671
@numba_njit
638672
def makeslice(*x):
639673
return slice(*x)
640674

641675
return makeslice
642676

643677

644-
@numba_funcify.register(Shape)
645-
def numba_funcify_Shape(op, **kwargs):
678+
@_numba_funcify.register(Shape)
679+
def numba_funcify_Shape(op, node, **kwargs):
646680
@numba_njit(inline="always")
647681
def shape(x):
648682
return np.asarray(np.shape(x))
649683

650684
return shape
651685

652686

653-
@numba_funcify.register(Shape_i)
654-
def numba_funcify_Shape_i(op, **kwargs):
687+
@_numba_funcify.register(Shape_i)
688+
def numba_funcify_Shape_i(op, node, **kwargs):
655689
i = op.i
656690

657691
@numba_njit(inline="always")
@@ -681,8 +715,8 @@ def codegen(context, builder, signature, args):
681715
return sig, codegen
682716

683717

684-
@numba_funcify.register(Reshape)
685-
def numba_funcify_Reshape(op, **kwargs):
718+
@_numba_funcify.register(Reshape)
719+
def numba_funcify_Reshape(op, node, **kwargs):
686720
ndim = op.ndim
687721

688722
if ndim == 0:
@@ -704,7 +738,7 @@ def reshape(x, shape):
704738
return reshape
705739

706740

707-
@numba_funcify.register(SpecifyShape)
741+
@_numba_funcify.register(SpecifyShape)
708742
def numba_funcify_SpecifyShape(op, node, **kwargs):
709743
shape_inputs = node.inputs[1:]
710744
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
@@ -751,7 +785,7 @@ def inputs_cast(x):
751785
return inputs_cast
752786

753787

754-
@numba_funcify.register(Dot)
788+
@_numba_funcify.register(Dot)
755789
def numba_funcify_Dot(op, node, **kwargs):
756790
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
757791
# float.
@@ -766,7 +800,7 @@ def dot(x, y):
766800
return dot
767801

768802

769-
@numba_funcify.register(Softplus)
803+
@_numba_funcify.register(Softplus)
770804
def numba_funcify_Softplus(op, node, **kwargs):
771805

772806
x_dtype = np.dtype(node.inputs[0].dtype)
@@ -785,7 +819,7 @@ def softplus(x):
785819
return softplus
786820

787821

788-
@numba_funcify.register(Cholesky)
822+
@_numba_funcify.register(Cholesky)
789823
def numba_funcify_Cholesky(op, node, **kwargs):
790824
lower = op.lower
791825

@@ -821,7 +855,7 @@ def cholesky(a):
821855
return cholesky
822856

823857

824-
@numba_funcify.register(Solve)
858+
@_numba_funcify.register(Solve)
825859
def numba_funcify_Solve(op, node, **kwargs):
826860

827861
assume_a = op.assume_a
@@ -868,7 +902,7 @@ def solve(a, b):
868902
return solve
869903

870904

871-
@numba_funcify.register(BatchedDot)
905+
@_numba_funcify.register(BatchedDot)
872906
def numba_funcify_BatchedDot(op, node, **kwargs):
873907
dtype = node.outputs[0].type.numpy_dtype
874908

@@ -889,7 +923,7 @@ def batched_dot(x, y):
889923
# optimizations are apparently already performed by Numba
890924

891925

892-
@numba_funcify.register(IfElse)
926+
@_numba_funcify.register(IfElse)
893927
def numba_funcify_IfElse(op, **kwargs):
894928
n_outs = op.n_outs
895929

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.graph.op import Op
1414
from pytensor.link.numba.dispatch import basic as numba_basic
1515
from pytensor.link.numba.dispatch.basic import (
16+
_numba_funcify,
1617
create_numba_signature,
1718
create_tuple_creator,
1819
numba_funcify,
@@ -431,7 +432,7 @@ def axis_apply_fn(x):
431432
return axis_apply_fn
432433

433434

434-
@numba_funcify.register(Elemwise)
435+
@_numba_funcify.register(Elemwise)
435436
def numba_funcify_Elemwise(op, node, **kwargs):
436437
# Creating a new scalar node is more involved and unnecessary
437438
# if the scalar_op is composite, as the fgraph already contains
@@ -492,7 +493,7 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
492493
return elemwise_fn
493494

494495

495-
@numba_funcify.register(CAReduce)
496+
@_numba_funcify.register(CAReduce)
496497
def numba_funcify_CAReduce(op, node, **kwargs):
497498
axes = op.axis
498499
if axes is None:
@@ -530,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
530531
return careduce_fn
531532

532533

533-
@numba_funcify.register(DimShuffle)
534+
@_numba_funcify.register(DimShuffle)
534535
def numba_funcify_DimShuffle(op, node, **kwargs):
535536
shuffle = tuple(op.shuffle)
536537
transposition = tuple(op.transposition)
@@ -628,7 +629,7 @@ def dimshuffle(x):
628629
return dimshuffle
629630

630631

631-
@numba_funcify.register(Softmax)
632+
@_numba_funcify.register(Softmax)
632633
def numba_funcify_Softmax(op, node, **kwargs):
633634

634635
x_at = node.inputs[0]
@@ -666,7 +667,7 @@ def softmax_py_fn(x):
666667
return softmax
667668

668669

669-
@numba_funcify.register(SoftmaxGrad)
670+
@_numba_funcify.register(SoftmaxGrad)
670671
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
671672

672673
sm_at = node.inputs[1]
@@ -698,7 +699,7 @@ def softmax_grad_py_fn(dy, sm):
698699
return softmax_grad
699700

700701

701-
@numba_funcify.register(LogSoftmax)
702+
@_numba_funcify.register(LogSoftmax)
702703
def numba_funcify_LogSoftmax(op, node, **kwargs):
703704

704705
x_at = node.inputs[0]
@@ -733,7 +734,7 @@ def log_softmax_py_fn(x):
733734
return log_softmax
734735

735736

736-
@numba_funcify.register(MaxAndArgmax)
737+
@_numba_funcify.register(MaxAndArgmax)
737738
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
738739
axis = op.axis
739740
x_at = node.inputs[0]

0 commit comments

Comments
 (0)