Skip to content

Commit 164f24f

Browse files
committed
Move vectorize wrapper to vectorize_codegen
1 parent d0b12ed commit 164f24f

File tree

2 files changed

+163
-164
lines changed

2 files changed

+163
-164
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,21 @@
88

99
import numba
1010
import numpy as np
11-
from numba import TypingError, types
12-
from numba.core import cgutils
1311
from numba.core.extending import overload
14-
from numba.np import arrayobj
1512
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
1613

1714
from pytensor import config
1815
from pytensor.graph.basic import Apply
1916
from pytensor.graph.op import Op
2017
from pytensor.link.numba.dispatch import basic as numba_basic
21-
from pytensor.link.numba.dispatch import elemwise_codegen
2218
from pytensor.link.numba.dispatch.basic import (
2319
create_numba_signature,
2420
create_tuple_creator,
2521
numba_funcify,
2622
numba_njit,
2723
use_optimized_cheap_pass,
2824
)
25+
from pytensor.link.numba.dispatch.vectorize_codegen import _vectorized
2926
from pytensor.link.utils import compile_function_src, get_name_for_object
3027
from pytensor.scalar.basic import (
3128
AND,
@@ -463,165 +460,6 @@ def axis_apply_fn(x):
463460
return axis_apply_fn
464461

465462

466-
_jit_options = {
467-
"fastmath": {
468-
"arcp", # Allow Reciprocal
469-
"contract", # Allow floating-point contraction
470-
"afn", # Approximate functions
471-
"reassoc",
472-
"nsz", # TODO Do we want this one?
473-
}
474-
}
475-
476-
477-
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
478-
def _vectorized(
479-
typingctx,
480-
scalar_func,
481-
input_bc_patterns,
482-
output_bc_patterns,
483-
output_dtypes,
484-
inplace_pattern,
485-
inputs,
486-
):
487-
arg_types = [
488-
scalar_func,
489-
input_bc_patterns,
490-
output_bc_patterns,
491-
output_dtypes,
492-
inplace_pattern,
493-
inputs,
494-
]
495-
496-
if not isinstance(input_bc_patterns, types.Literal):
497-
raise TypingError("input_bc_patterns must be literal.")
498-
input_bc_patterns = input_bc_patterns.literal_value
499-
input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode()))
500-
501-
if not isinstance(output_bc_patterns, types.Literal):
502-
raise TypeError("output_bc_patterns must be literal.")
503-
output_bc_patterns = output_bc_patterns.literal_value
504-
output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode()))
505-
506-
if not isinstance(output_dtypes, types.Literal):
507-
raise TypeError("output_dtypes must be literal.")
508-
output_dtypes = output_dtypes.literal_value
509-
output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode()))
510-
511-
if not isinstance(inplace_pattern, types.Literal):
512-
raise TypeError("inplace_pattern must be literal.")
513-
inplace_pattern = inplace_pattern.literal_value
514-
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
515-
516-
n_outputs = len(output_bc_patterns)
517-
518-
if not len(inputs) > 0:
519-
raise TypingError("Empty argument list to elemwise op.")
520-
521-
if not n_outputs > 0:
522-
raise TypingError("Empty list of outputs for elemwise op.")
523-
524-
if not all(isinstance(input, types.Array) for input in inputs):
525-
raise TypingError("Inputs to elemwise must be arrays.")
526-
ndim = inputs[0].ndim
527-
528-
if not all(input.ndim == ndim for input in inputs):
529-
raise TypingError("Inputs to elemwise must have the same rank.")
530-
531-
if not all(len(pattern) == ndim for pattern in output_bc_patterns):
532-
raise TypingError("Invalid output broadcasting pattern.")
533-
534-
scalar_signature = typingctx.resolve_function_type(
535-
scalar_func, [in_type.dtype for in_type in inputs], {}
536-
)
537-
538-
# So we can access the constant values in codegen...
539-
input_bc_patterns_val = input_bc_patterns
540-
output_bc_patterns_val = output_bc_patterns
541-
output_dtypes_val = output_dtypes
542-
inplace_pattern_val = inplace_pattern
543-
input_types = inputs
544-
545-
def codegen(
546-
ctx,
547-
builder,
548-
sig,
549-
args,
550-
):
551-
[_, _, _, _, _, inputs] = args
552-
inputs = cgutils.unpack_tuple(builder, inputs)
553-
inputs = [
554-
arrayobj.make_array(ty)(ctx, builder, val)
555-
for ty, val in zip(input_types, inputs)
556-
]
557-
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
558-
559-
iter_shape = elemwise_codegen.compute_itershape(
560-
ctx,
561-
builder,
562-
in_shapes,
563-
input_bc_patterns_val,
564-
)
565-
566-
outputs, output_types = elemwise_codegen.make_outputs(
567-
ctx,
568-
builder,
569-
iter_shape,
570-
output_bc_patterns_val,
571-
output_dtypes_val,
572-
inplace_pattern_val,
573-
inputs,
574-
input_types,
575-
)
576-
577-
elemwise_codegen.make_loop_call(
578-
typingctx,
579-
ctx,
580-
builder,
581-
scalar_func,
582-
scalar_signature,
583-
iter_shape,
584-
inputs,
585-
outputs,
586-
input_bc_patterns_val,
587-
output_bc_patterns_val,
588-
input_types,
589-
output_types,
590-
)
591-
592-
if len(outputs) == 1:
593-
if inplace_pattern:
594-
assert inplace_pattern[0][0] == 0
595-
ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue())
596-
return outputs[0]._getvalue()
597-
598-
for inplace_idx in dict(inplace_pattern):
599-
ctx.nrt.incref(
600-
builder,
601-
sig.return_type.types[inplace_idx],
602-
outputs[inplace_idx]._get_value(),
603-
)
604-
return ctx.make_tuple(
605-
builder, sig.return_type, [out._getvalue() for out in outputs]
606-
)
607-
608-
ret_types = [
609-
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
610-
for dtype in output_dtypes
611-
]
612-
613-
for output_idx, input_idx in inplace_pattern:
614-
ret_types[output_idx] = input_types[input_idx]
615-
616-
ret_type = types.Tuple(ret_types)
617-
618-
if len(output_dtypes) == 1:
619-
ret_type = ret_type.types[0]
620-
sig = ret_type(*arg_types)
621-
622-
return sig, codegen
623-
624-
625463
@numba_funcify.register(Elemwise)
626464
def numba_funcify_Elemwise(op, node, **kwargs):
627465
# Creating a new scalar node is more involved and unnecessary

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,177 @@
11
from __future__ import annotations
22

3+
import base64
4+
import pickle
35
from typing import Any
46

57
import numba
68
import numpy as np
79
from llvmlite import ir
8-
from numba import types
10+
from numba import TypingError, types
911
from numba.core import cgutils
1012
from numba.core.base import BaseContext
1113
from numba.np import arrayobj
1214

1315

16+
_jit_options = {
17+
"fastmath": {
18+
"arcp", # Allow Reciprocal
19+
"contract", # Allow floating-point contraction
20+
"afn", # Approximate functions
21+
"reassoc",
22+
"nsz", # TODO Do we want this one?
23+
}
24+
}
25+
26+
27+
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
28+
def _vectorized(
29+
typingctx,
30+
scalar_func,
31+
input_bc_patterns,
32+
output_bc_patterns,
33+
output_dtypes,
34+
inplace_pattern,
35+
inputs,
36+
):
37+
arg_types = [
38+
scalar_func,
39+
input_bc_patterns,
40+
output_bc_patterns,
41+
output_dtypes,
42+
inplace_pattern,
43+
inputs,
44+
]
45+
46+
if not isinstance(input_bc_patterns, types.Literal):
47+
raise TypingError("input_bc_patterns must be literal.")
48+
input_bc_patterns = input_bc_patterns.literal_value
49+
input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode()))
50+
51+
if not isinstance(output_bc_patterns, types.Literal):
52+
raise TypeError("output_bc_patterns must be literal.")
53+
output_bc_patterns = output_bc_patterns.literal_value
54+
output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode()))
55+
56+
if not isinstance(output_dtypes, types.Literal):
57+
raise TypeError("output_dtypes must be literal.")
58+
output_dtypes = output_dtypes.literal_value
59+
output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode()))
60+
61+
if not isinstance(inplace_pattern, types.Literal):
62+
raise TypeError("inplace_pattern must be literal.")
63+
inplace_pattern = inplace_pattern.literal_value
64+
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
65+
66+
n_outputs = len(output_bc_patterns)
67+
68+
if not len(inputs) > 0:
69+
raise TypingError("Empty argument list to elemwise op.")
70+
71+
if not n_outputs > 0:
72+
raise TypingError("Empty list of outputs for elemwise op.")
73+
74+
if not all(isinstance(input, types.Array) for input in inputs):
75+
raise TypingError("Inputs to elemwise must be arrays.")
76+
ndim = inputs[0].ndim
77+
78+
if not all(input.ndim == ndim for input in inputs):
79+
raise TypingError("Inputs to elemwise must have the same rank.")
80+
81+
if not all(len(pattern) == ndim for pattern in output_bc_patterns):
82+
raise TypingError("Invalid output broadcasting pattern.")
83+
84+
scalar_signature = typingctx.resolve_function_type(
85+
scalar_func, [in_type.dtype for in_type in inputs], {}
86+
)
87+
88+
# So we can access the constant values in codegen...
89+
input_bc_patterns_val = input_bc_patterns
90+
output_bc_patterns_val = output_bc_patterns
91+
output_dtypes_val = output_dtypes
92+
inplace_pattern_val = inplace_pattern
93+
input_types = inputs
94+
95+
def codegen(
96+
ctx,
97+
builder,
98+
sig,
99+
args,
100+
):
101+
[_, _, _, _, _, inputs] = args
102+
inputs = cgutils.unpack_tuple(builder, inputs)
103+
inputs = [
104+
arrayobj.make_array(ty)(ctx, builder, val)
105+
for ty, val in zip(input_types, inputs)
106+
]
107+
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
108+
109+
iter_shape = compute_itershape(
110+
ctx,
111+
builder,
112+
in_shapes,
113+
input_bc_patterns_val,
114+
)
115+
116+
outputs, output_types = make_outputs(
117+
ctx,
118+
builder,
119+
iter_shape,
120+
output_bc_patterns_val,
121+
output_dtypes_val,
122+
inplace_pattern_val,
123+
inputs,
124+
input_types,
125+
)
126+
127+
make_loop_call(
128+
typingctx,
129+
ctx,
130+
builder,
131+
scalar_func,
132+
scalar_signature,
133+
iter_shape,
134+
inputs,
135+
outputs,
136+
input_bc_patterns_val,
137+
output_bc_patterns_val,
138+
input_types,
139+
output_types,
140+
)
141+
142+
if len(outputs) == 1:
143+
if inplace_pattern:
144+
assert inplace_pattern[0][0] == 0
145+
ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue())
146+
return outputs[0]._getvalue()
147+
148+
for inplace_idx in dict(inplace_pattern):
149+
ctx.nrt.incref(
150+
builder,
151+
sig.return_type.types[inplace_idx],
152+
outputs[inplace_idx]._get_value(),
153+
)
154+
return ctx.make_tuple(
155+
builder, sig.return_type, [out._getvalue() for out in outputs]
156+
)
157+
158+
ret_types = [
159+
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
160+
for dtype in output_dtypes
161+
]
162+
163+
for output_idx, input_idx in inplace_pattern:
164+
ret_types[output_idx] = input_types[input_idx]
165+
166+
ret_type = types.Tuple(ret_types)
167+
168+
if len(output_dtypes) == 1:
169+
ret_type = ret_type.types[0]
170+
sig = ret_type(*arg_types)
171+
172+
return sig, codegen
173+
174+
14175
def compute_itershape(
15176
ctx: BaseContext,
16177
builder: ir.IRBuilder,

0 commit comments

Comments
 (0)