Skip to content

Commit 02b616c

Browse files
committed
Fix tests and fix scalar numba return types
1 parent 0f6dbc8 commit 02b616c

File tree

10 files changed

+171
-71
lines changed

10 files changed

+171
-71
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def in_seq_empty_tuple(x, y):
195195

196196

197197
def to_scalar(x):
198-
raise NotImplementedError()
198+
return np.asarray(x).item()
199199

200200

201201
@numba.extending.overload(to_scalar)
@@ -534,7 +534,7 @@ def {fn_name}({", ".join(input_names)}):
534534
{index_prologue}
535535
{indices_creation_src}
536536
{index_body}
537-
return z
537+
return np.asarray(z)
538538
"""
539539

540540
return subtensor_def_src
@@ -652,7 +652,7 @@ def numba_funcify_Shape_i(op, **kwargs):
652652

653653
@numba_njit(inline="always")
654654
def shape_i(x):
655-
return np.shape(x)[i]
655+
return np.asarray(np.shape(x)[i])
656656

657657
return shape_i
658658

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from numba import TypingError, types
1111
from numba.core import cgutils
12+
from numba.core.extending import overload
1213
from numba.np import arrayobj
1314
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
1415

@@ -174,6 +175,7 @@ def create_axis_reducer(
174175
ndim: int,
175176
dtype: numba.types.Type,
176177
keepdims: bool = False,
178+
return_scalar=False,
177179
) -> numba.core.dispatcher.Dispatcher:
178180
r"""Create Python function that performs a NumPy-like reduction on a given axis.
179181
@@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x):
284286
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2)
285287

286288
return_expr = "res" if keepdims else "res.item()"
289+
if not return_scalar:
290+
return_expr = f"np.asarray({return_expr})"
287291
reduce_elemwise_def_src = f"""
288292
def {reduce_elemwise_fn_name}(x):
289293
@@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x):
305309

306310

307311
def create_multiaxis_reducer(
308-
scalar_op, identity, axes, ndim, dtype, input_name="input"
312+
scalar_op,
313+
identity,
314+
axes,
315+
ndim,
316+
dtype,
317+
input_name="input",
318+
return_scalar=False,
309319
):
310320
r"""Construct a function that reduces multiple axes.
311321
@@ -336,6 +346,8 @@ def careduce_maximum(input):
336346
The number of dimensions of the result.
337347
dtype:
338348
The data type of the result.
349+
return_scalar:
350+
If True, return a scalar, otherwise an array.
339351
340352
Returns
341353
=======
@@ -370,10 +382,17 @@ def careduce_maximum(input):
370382
)
371383

372384
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
385+
if not return_scalar:
386+
pre_result = "np.asarray"
387+
post_result = ""
388+
else:
389+
pre_result = "np.asarray"
390+
post_result = ".item()"
391+
373392
careduce_def_src = f"""
374393
def {careduce_fn_name}({input_name}):
375394
{careduce_assign_lines}
376-
return np.asarray({var_name})
395+
return {pre_result}({var_name}){post_result}
377396
"""
378397

379398
careduce_fn = compile_function_src(
@@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}):
383402
return careduce_fn
384403

385404

386-
def jit_compile_reducer(node, fn, **kwds):
405+
def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
387406
"""Compile Python source for reduction loops using additional optimizations.
388407
389408
Parameters
@@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds):
400419
A :func:`numba.njit`-compiled function.
401420
402421
"""
403-
signature = create_numba_signature(node, reduce_to_scalar=True)
422+
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
404423

405424
# Eagerly compile the function using increased optimizations. This should
406425
# help improve nested loop reductions.
@@ -618,23 +637,58 @@ def numba_funcify_Elemwise(op, node, **kwargs):
618637
inplace_pattern = tuple(op.inplace_pattern.items())
619638

620639
# numba doesn't support nested literals right now...
621-
input_bc_patterns = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
622-
output_bc_patterns = base64.encodebytes(pickle.dumps(output_bc_patterns)).decode()
623-
output_dtypes = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
624-
inplace_pattern = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
640+
input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
641+
output_bc_patterns_enc = base64.encodebytes(
642+
pickle.dumps(output_bc_patterns)
643+
).decode()
644+
output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
645+
inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
625646

626-
@numba_njit
627647
def elemwise_wrapper(*inputs):
628648
return _vectorized(
629649
scalar_op_fn,
630-
input_bc_patterns,
631-
output_bc_patterns,
632-
output_dtypes,
633-
inplace_pattern,
650+
input_bc_patterns_enc,
651+
output_bc_patterns_enc,
652+
output_dtypes_enc,
653+
inplace_pattern_enc,
634654
inputs,
635655
)
636656

637-
return elemwise_wrapper
657+
# Pure python implementation, that will be used in tests
658+
def elemwise(*inputs):
659+
inputs = [np.asarray(input) for input in inputs]
660+
inputs_bc = np.broadcast_arrays(*inputs)
661+
shape = inputs[0].shape
662+
for input, bc in zip(inputs, input_bc_patterns):
663+
for length, allow_bc, iter_length in zip(input.shape, bc, shape):
664+
if length == 1 and shape and iter_length != 1 and not allow_bc:
665+
raise ValueError("Broadcast not allowed.")
666+
667+
outputs = []
668+
for dtype in output_dtypes:
669+
outputs.append(np.empty(shape, dtype=dtype))
670+
671+
for idx in np.ndindex(shape):
672+
vals = [input[idx] for input in inputs_bc]
673+
outs = scalar_op_fn(*vals)
674+
if not isinstance(outs, tuple):
675+
outs = (outs,)
676+
for out, out_val in zip(outputs, outs):
677+
out[idx] = out_val
678+
679+
outputs_summed = []
680+
for output, bc in zip(outputs, output_bc_patterns):
681+
axes = tuple(np.nonzero(bc)[0])
682+
outputs_summed.append(output.sum(axes, keepdims=True))
683+
if len(outputs_summed) != 1:
684+
return tuple(outputs_summed)
685+
return outputs_summed[0]
686+
687+
@overload(elemwise)
688+
def ov_elemwise(*inputs):
689+
return elemwise_wrapper
690+
691+
return elemwise
638692

639693

640694
@numba_funcify.register(Sum)
@@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs):
643697
if axes is None:
644698
axes = list(range(node.inputs[0].ndim))
645699

646-
axes = list(axes)
700+
axes = tuple(axes)
647701

648702
ndim_input = node.inputs[0].ndim
649703

@@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs):
658712

659713
@numba_njit(fastmath=True)
660714
def impl_sum(array):
661-
# TODO The accumulation itself should happen in acc_dtype...
662-
return np.asarray(array.sum()).astype(np_acc_dtype)
715+
return np.asarray(array.sum(), dtype=np_acc_dtype)
663716

664-
else:
717+
elif len(axes) == 0:
665718

666719
@numba_njit(fastmath=True)
667720
def impl_sum(array):
668-
# TODO The accumulation itself should happen in acc_dtype...
669-
return array.sum(axes).astype(np_acc_dtype)
721+
return array
722+
723+
else:
724+
impl_sum = numba_funcify_CAReduce(op, node, **kwargs)
670725

671726
return impl_sum
672727

@@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
705760
input_name=input_name,
706761
)
707762

708-
careduce_fn = jit_compile_reducer(node, careduce_py_fn)
763+
careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
709764
return careduce_fn
710765

711766

@@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
888943
if axis is not None:
889944
axis = normalize_axis_index(axis, x_at.ndim)
890945
reduce_max_py = create_axis_reducer(
891-
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
946+
scalar_maximum,
947+
-np.inf,
948+
axis,
949+
x_at.ndim,
950+
x_dtype,
951+
keepdims=True,
892952
)
893953
reduce_sum_py = create_axis_reducer(
894954
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
@@ -935,10 +995,17 @@ def maxandargmax(x):
935995
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
936996

937997
reduce_max_py_fn = create_multiaxis_reducer(
938-
scalar_maximum, -np.inf, axes, x_ndim, x_dtype
998+
scalar_maximum,
999+
-np.inf,
1000+
axes,
1001+
x_ndim,
1002+
x_dtype,
1003+
return_scalar=False,
9391004
)
9401005
reduce_max = jit_compile_reducer(
941-
Apply(node.op, node.inputs, [node.outputs[0].clone()]), reduce_max_py_fn
1006+
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
1007+
reduce_max_py_fn,
1008+
reduce_to_scalar=False,
9421009
)
9431010

9441011
reduced_x_ndim = x_ndim - len(axes) + 1

pytensor/link/numba/dispatch/elemwise_codegen.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,6 @@ def make_loop_call(
117117

118118
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
119119

120-
# Lower the code of the scalar function so that we can use it in the inner loop
121-
# Caching is set to false to avoid a numba bug TODO ref?
122-
inner_func = context.compile_subroutine(
123-
builder,
124-
# I don't quite understand why we need to access `dispatcher` here.
125-
# The object does seem to be a dispatcher already? But it is missing
126-
# attributes...
127-
scalar_func.dispatcher,
128-
scalar_signature,
129-
caching=False,
130-
)
131-
inner = inner_func.fndesc
132-
133120
# Extract shape and stride information from the array.
134121
# For later use in the loop body to do the indexing
135122
def extract_array(aryty, obj):
@@ -191,14 +178,15 @@ def extract_array(aryty, obj):
191178
# val.set_metadata("noalias", output_scope_set)
192179
input_vals.append(val)
193180

194-
# Call scalar function
195-
output_values = context.call_internal(
196-
builder,
197-
inner,
198-
scalar_signature,
199-
input_vals,
200-
)
201-
if isinstance(scalar_signature.return_type, types.Tuple):
181+
inner_codegen = context.get_function(scalar_func, scalar_signature)
182+
183+
if isinstance(
184+
scalar_signature.args[0], (types.StarArgTuple, types.StarArgUniTuple)
185+
):
186+
input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)]
187+
output_values = inner_codegen(builder, input_vals)
188+
189+
if isinstance(scalar_signature.return_type, (types.Tuple, types.UniTuple)):
202190
output_values = cgutils.unpack_tuple(builder, output_values)
203191
else:
204192
output_values = [output_values]

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def numba_funcify_BroadcastTo(op, node, **kwargs):
364364
lambda _: 0, len(node.inputs) - 1
365365
)
366366

367+
# TODO broadcastable checks
367368
@numba_basic.numba_njit
368369
def broadcast_to(x, *shape):
369370
scalars_shape = create_zeros_tuple()

pytensor/link/numba/dispatch/scalar.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
3838
# TODO: Do we need to cache these functions so that we don't end up
3939
# compiling the same Numba function over and over again?
4040

41+
if not hasattr(op, "nfunc_spec"):
42+
return generate_fallback_impl(op, node, **kwargs)
43+
4144
scalar_func_path = op.nfunc_spec[0]
4245
scalar_func_numba = None
4346

pytensor/link/numba/dispatch/scan.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818

1919
def idx_to_str(
20-
array_name: str, offset: int, size: Optional[str] = None, idx_symbol: str = "i"
20+
array_name: str,
21+
offset: int,
22+
size: Optional[str] = None,
23+
idx_symbol: str = "i",
24+
allow_scalar=False,
2125
) -> str:
2226
if offset < 0:
2327
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
@@ -32,7 +36,10 @@ def idx_to_str(
3236
# compensate for this poor `Op`/rewrite design and implementation.
3337
indices = f"({indices}) % {size}"
3438

35-
return f"{array_name}[{indices}]"
39+
if allow_scalar:
40+
return f"{array_name}[{indices}]"
41+
else:
42+
return f"np.asarray({array_name}[{indices}])"
3643

3744

3845
@overload(range)
@@ -115,7 +122,9 @@ def add_inner_in_expr(
115122
indexed_inner_in_str = (
116123
storage_name
117124
if tap_offset is None
118-
else idx_to_str(storage_name, tap_offset, size=storage_size_var)
125+
else idx_to_str(
126+
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
127+
)
119128
)
120129
inner_in_exprs.append(indexed_inner_in_str)
121130

@@ -232,7 +241,12 @@ def add_output_storage_post_proc_stmt(
232241
)
233242
for out_tap in output_taps:
234243
inner_out_to_outer_in_stmts.append(
235-
idx_to_str(storage_name, out_tap, size=storage_size_name)
244+
idx_to_str(
245+
storage_name,
246+
out_tap,
247+
size=storage_size_name,
248+
allow_scalar=True,
249+
)
236250
)
237251

238252
add_output_storage_post_proc_stmt(
@@ -269,7 +283,7 @@ def add_output_storage_post_proc_stmt(
269283
storage_size_name = f"{outer_in_name}_len"
270284

271285
inner_out_to_outer_in_stmts.append(
272-
idx_to_str(storage_name, 0, size=storage_size_name)
286+
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
273287
)
274288
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
275289

pytensor/link/numba/linker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def fgraph_convert(self, fgraph, **kwargs):
2727
return numba_funcify(fgraph, **kwargs)
2828

2929
def jit_compile(self, fn):
30-
import numba
30+
from pytensor.link.numba.dispatch.basic import numba_njit
3131

32-
jitted_fn = numba.njit(fn)
32+
jitted_fn = numba_njit(fn)
3333
return jitted_fn
3434

3535
def create_thunk_inputs(self, storage_map):

0 commit comments

Comments
 (0)