4
4
from contextlib import contextmanager
5
5
from functools import singledispatch
6
6
from textwrap import dedent
7
- from typing import Union
7
+ from typing import TYPE_CHECKING , Callable , Optional , Union , cast
8
8
9
9
import numba
10
10
import numba .np .unsafe .ndarray as numba_ndarray
22
22
from pytensor .compile .ops import DeepCopyOp
23
23
from pytensor .graph .basic import Apply , NoParams
24
24
from pytensor .graph .fg import FunctionGraph
25
+ from pytensor .graph .op import Op
25
26
from pytensor .graph .type import Type
26
27
from pytensor .ifelse import IfElse
27
28
from pytensor .link .utils import (
48
49
from pytensor .tensor .type_other import MakeSlice , NoneConst
49
50
50
51
52
+ if TYPE_CHECKING :
53
+ from pytensor .graph .op import StorageMapType
54
+
55
+
51
56
def numba_njit (* args , ** kwargs ):
52
57
53
58
kwargs = kwargs .copy ()
@@ -353,8 +358,43 @@ def numba_const_convert(data, dtype=None, **kwargs):
353
358
return data
354
359
355
360
356
- def generate_fallback_impl (op , node = None , storage_map = None , ** kwargs ):
357
- """Create a Numba compatible function from an Aesara `Op`."""
361
+ def numba_funcify (obj , node = None , storage_map = None , ** kwargs ) -> Callable :
362
+ """Convert `obj` to a Numba-JITable object."""
363
+ return _numba_funcify (obj , node = node , storage_map = storage_map , ** kwargs )
364
+
365
+
366
+ @singledispatch
367
+ def _numba_funcify (
368
+ obj ,
369
+ node : Optional [Apply ] = None ,
370
+ storage_map : Optional ["StorageMapType" ] = None ,
371
+ ** kwargs ,
372
+ ) -> Callable :
373
+ r"""Dispatch on PyTensor object types to perform Numba conversions.
374
+
375
+ Arguments
376
+ ---------
377
+ obj
378
+ The object used to determine the appropriate conversion function based
379
+ on its type. This is generally an `Op` instance, but `FunctionGraph`\s
380
+ are also supported.
381
+ node
382
+ When `obj` is an `Op`, this value should be the corresponding `Apply` node.
383
+ storage_map
384
+ A storage map with, for example, the constant and `SharedVariable` values
385
+ of the graph being converted.
386
+
387
+ Returns
388
+ -------
389
+ A `Callable` that can be JIT-compiled in Numba using `numba.jit`.
390
+
391
+ """
392
+ raise NotImplementedError (f"Numba funcify for obj { obj } not implemented" )
393
+
394
+
395
+ @_numba_funcify .register (Op )
396
+ def numba_funcify_perform (op , node , storage_map = None , ** kwargs ) -> Callable :
397
+ """Create a Numba compatible function from an PyTensor `Op.perform`."""
358
398
359
399
warnings .warn (
360
400
f"Numba will use object mode to run { op } 's perform method" ,
@@ -405,16 +445,10 @@ def perform(*inputs):
405
445
ret = py_perform_return (inputs )
406
446
return ret
407
447
408
- return perform
409
-
410
-
411
- @singledispatch
412
- def numba_funcify (op , node = None , storage_map = None , ** kwargs ):
413
- """Generate a numba function for a given op and apply node."""
414
- return generate_fallback_impl (op , node , storage_map , ** kwargs )
448
+ return cast (Callable , perform )
415
449
416
450
417
- @numba_funcify .register (OpFromGraph )
451
+ @_numba_funcify .register (OpFromGraph )
418
452
def numba_funcify_OpFromGraph (op , node = None , ** kwargs ):
419
453
420
454
_ = kwargs .pop ("storage_map" , None )
@@ -436,7 +470,7 @@ def opfromgraph(*inputs):
436
470
return opfromgraph
437
471
438
472
439
- @numba_funcify .register (FunctionGraph )
473
+ @_numba_funcify .register (FunctionGraph )
440
474
def numba_funcify_FunctionGraph (
441
475
fgraph ,
442
476
node = None ,
@@ -544,8 +578,8 @@ def {fn_name}({", ".join(input_names)}):
544
578
return subtensor_def_src
545
579
546
580
547
- @numba_funcify .register (Subtensor )
548
- @numba_funcify .register (AdvancedSubtensor1 )
581
+ @_numba_funcify .register (Subtensor )
582
+ @_numba_funcify .register (AdvancedSubtensor1 )
549
583
def numba_funcify_Subtensor (op , node , ** kwargs ):
550
584
551
585
subtensor_def_src = create_index_func (
@@ -561,7 +595,7 @@ def numba_funcify_Subtensor(op, node, **kwargs):
561
595
return numba_njit (subtensor_fn , boundscheck = True )
562
596
563
597
564
- @numba_funcify .register (IncSubtensor )
598
+ @_numba_funcify .register (IncSubtensor )
565
599
def numba_funcify_IncSubtensor (op , node , ** kwargs ):
566
600
567
601
incsubtensor_def_src = create_index_func (
@@ -577,7 +611,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
577
611
return numba_njit (incsubtensor_fn , boundscheck = True )
578
612
579
613
580
- @numba_funcify .register (AdvancedIncSubtensor1 )
614
+ @_numba_funcify .register (AdvancedIncSubtensor1 )
581
615
def numba_funcify_AdvancedIncSubtensor1 (op , node , ** kwargs ):
582
616
inplace = op .inplace
583
617
set_instead_of_inc = op .set_instead_of_inc
@@ -610,7 +644,7 @@ def advancedincsubtensor1(x, vals, idxs):
610
644
return advancedincsubtensor1
611
645
612
646
613
- @numba_funcify .register (DeepCopyOp )
647
+ @_numba_funcify .register (DeepCopyOp )
614
648
def numba_funcify_DeepCopyOp (op , node , ** kwargs ):
615
649
616
650
# Scalars are apparently returned as actual Python scalar types and not
@@ -632,26 +666,26 @@ def deepcopyop(x):
632
666
return deepcopyop
633
667
634
668
635
- @numba_funcify .register (MakeSlice )
636
- def numba_funcify_MakeSlice (op , ** kwargs ):
669
+ @_numba_funcify .register (MakeSlice )
670
+ def numba_funcify_MakeSlice (op , node , ** kwargs ):
637
671
@numba_njit
638
672
def makeslice (* x ):
639
673
return slice (* x )
640
674
641
675
return makeslice
642
676
643
677
644
- @numba_funcify .register (Shape )
645
- def numba_funcify_Shape (op , ** kwargs ):
678
+ @_numba_funcify .register (Shape )
679
+ def numba_funcify_Shape (op , node , ** kwargs ):
646
680
@numba_njit (inline = "always" )
647
681
def shape (x ):
648
682
return np .asarray (np .shape (x ))
649
683
650
684
return shape
651
685
652
686
653
- @numba_funcify .register (Shape_i )
654
- def numba_funcify_Shape_i (op , ** kwargs ):
687
+ @_numba_funcify .register (Shape_i )
688
+ def numba_funcify_Shape_i (op , node , ** kwargs ):
655
689
i = op .i
656
690
657
691
@numba_njit (inline = "always" )
@@ -681,8 +715,8 @@ def codegen(context, builder, signature, args):
681
715
return sig , codegen
682
716
683
717
684
- @numba_funcify .register (Reshape )
685
- def numba_funcify_Reshape (op , ** kwargs ):
718
+ @_numba_funcify .register (Reshape )
719
+ def numba_funcify_Reshape (op , node , ** kwargs ):
686
720
ndim = op .ndim
687
721
688
722
if ndim == 0 :
@@ -704,7 +738,7 @@ def reshape(x, shape):
704
738
return reshape
705
739
706
740
707
- @numba_funcify .register (SpecifyShape )
741
+ @_numba_funcify .register (SpecifyShape )
708
742
def numba_funcify_SpecifyShape (op , node , ** kwargs ):
709
743
shape_inputs = node .inputs [1 :]
710
744
shape_input_names = ["shape_" + str (i ) for i in range (len (shape_inputs ))]
@@ -751,7 +785,7 @@ def inputs_cast(x):
751
785
return inputs_cast
752
786
753
787
754
- @numba_funcify .register (Dot )
788
+ @_numba_funcify .register (Dot )
755
789
def numba_funcify_Dot (op , node , ** kwargs ):
756
790
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
757
791
# float.
@@ -766,7 +800,7 @@ def dot(x, y):
766
800
return dot
767
801
768
802
769
- @numba_funcify .register (Softplus )
803
+ @_numba_funcify .register (Softplus )
770
804
def numba_funcify_Softplus (op , node , ** kwargs ):
771
805
772
806
x_dtype = np .dtype (node .inputs [0 ].dtype )
@@ -785,7 +819,7 @@ def softplus(x):
785
819
return softplus
786
820
787
821
788
- @numba_funcify .register (Cholesky )
822
+ @_numba_funcify .register (Cholesky )
789
823
def numba_funcify_Cholesky (op , node , ** kwargs ):
790
824
lower = op .lower
791
825
@@ -821,7 +855,7 @@ def cholesky(a):
821
855
return cholesky
822
856
823
857
824
- @numba_funcify .register (Solve )
858
+ @_numba_funcify .register (Solve )
825
859
def numba_funcify_Solve (op , node , ** kwargs ):
826
860
827
861
assume_a = op .assume_a
@@ -868,7 +902,7 @@ def solve(a, b):
868
902
return solve
869
903
870
904
871
- @numba_funcify .register (BatchedDot )
905
+ @_numba_funcify .register (BatchedDot )
872
906
def numba_funcify_BatchedDot (op , node , ** kwargs ):
873
907
dtype = node .outputs [0 ].type .numpy_dtype
874
908
@@ -889,7 +923,7 @@ def batched_dot(x, y):
889
923
# optimizations are apparently already performed by Numba
890
924
891
925
892
- @numba_funcify .register (IfElse )
926
+ @_numba_funcify .register (IfElse )
893
927
def numba_funcify_IfElse (op , ** kwargs ):
894
928
n_outs = op .n_outs
895
929
0 commit comments