From d4e5dcde2001ff39c87fd8980ecab791584e6896 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 6 Dec 2022 16:52:04 -0600 Subject: [PATCH 01/14] Initial version of llvm elemwise impl --- pytensor/link/numba/dispatch/elemwise.py | 258 ++++++++++++++---- .../link/numba/dispatch/elemwise_codegen.py | 231 ++++++++++++++++ pytensor/link/numba/dispatch/helpers.py | 43 +++ 3 files changed, 481 insertions(+), 51 deletions(-) create mode 100644 pytensor/link/numba/dispatch/elemwise_codegen.py create mode 100644 pytensor/link/numba/dispatch/helpers.py diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 0595191da0..1cc293e1b6 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,11 +1,17 @@ -import inspect from functools import singledispatch from numbers import Number +import pickle from textwrap import indent -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Literal, Optional, Union +import base64 import numba import numpy as np +from llvmlite import ir +from numba import TypingError, literal_unroll, types, literally +from numba.core import cgutils +from numba.cpython.unsafe.tuple import tuple_setitem +from numba.np import arrayobj from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor import config @@ -16,13 +22,12 @@ create_numba_signature, create_tuple_creator, numba_funcify, + numba_njit, use_optimized_cheap_pass, ) -from pytensor.link.utils import ( - compile_function_src, - get_name_for_object, - unique_name_generator, -) +from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper +from pytensor.link.numba.dispatch import elemwise_codegen +from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.scalar.basic import ( AND, OR, @@ -431,6 +436,170 @@ def axis_apply_fn(x): return axis_apply_fn +_jit_options = { + "fastmath": { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # TODO Do we want this one? + } +} + +@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) +def _vectorized( + typingctx, + scalar_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + inputs, +): + #if not isinstance(scalar_func, types.Literal): + # raise TypingError("scalar func must be literal.") + #scalar_func = scalar_func.literal_value + + arg_types = [ + scalar_func, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + inputs, + ] + + if not isinstance(input_bc_patterns, types.Literal): + raise TypingError("input_bc_patterns must be literal.") + input_bc_patterns = input_bc_patterns.literal_value + input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode())) + + if not isinstance(output_bc_patterns, types.Literal): + raise TypeError("output_bc_patterns must be literal.") + output_bc_patterns = output_bc_patterns.literal_value + output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode())) + + if not isinstance(output_dtypes, types.Literal): + raise TypeError("output_dtypes must be literal.") + output_dtypes = output_dtypes.literal_value + output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode())) + + if not isinstance(inplace_pattern, types.Literal): + raise TypeError("inplace_pattern must be literal.") + inplace_pattern = inplace_pattern.literal_value + inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) + + n_inputs = len(inputs) + n_outputs = len(output_bc_patterns) + + if not len(inputs) > 0: + raise TypingError("Empty argument list to elemwise op.") + + if not n_outputs > 0: + raise TypingError("Empty list of outputs for elemwise op.") + + if not all(isinstance(input, types.Array) for input in inputs): + raise TypingError("Inputs to elemwise must be arrays.") + ndim = inputs[0].ndim + + if not all(input.ndim == ndim for input in inputs): + raise TypingError("Inputs to elemwise must have the same rank.") + + if not all(len(pattern) == ndim for pattern in output_bc_patterns): + raise TypingError("Invalid output broadcasting pattern.") + + scalar_signature = typingctx.resolve_function_type( + scalar_func, [in_type.dtype for in_type in inputs], {} + ) + + # So we can access the constant values in codegen... + input_bc_patterns_val = input_bc_patterns + output_bc_patterns_val = output_bc_patterns + output_dtypes_val = output_dtypes + inplace_pattern_val = inplace_pattern + input_types = inputs + + #assert not inplace_pattern_val + + def codegen( + ctx, + builder, + sig, + args, + ): + + [_, _, _, _, _, inputs] = args + inputs = cgutils.unpack_tuple(builder, inputs) + inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)] + in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] + + iter_shape = elemwise_codegen.compute_itershape( + ctx, + builder, + in_shapes, + input_bc_patterns_val, + ) + + outputs, output_types = elemwise_codegen.make_outputs( + ctx, + builder, + iter_shape, + output_bc_patterns_val, + output_dtypes_val, + inplace_pattern_val, + inputs, + input_types, + ) + + def _check_input_shapes(*_): + # TODO impl + return + + _check_input_shapes( + ctx, + builder, + iter_shape, + inputs, + input_bc_patterns_val, + ) + + elemwise_codegen.make_loop_call( + typingctx, + ctx, + builder, + scalar_func, + scalar_signature, + iter_shape, + inputs, + outputs, + input_bc_patterns_val, + output_bc_patterns_val, + input_types, + output_types, + ) + + if len(outputs) == 1: + if inplace_pattern: + assert inplace_pattern[0][0] == 0 + ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue()) + return outputs[0]._getvalue() + + for inplace_idx in dict(inplace_pattern): + ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value()) + return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs]) + + # TODO check inplace_pattern + ret_type = types.Tuple([ + types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") + for dtype in output_dtypes + ]) + if len(output_dtypes) == 1: + ret_type = ret_type.types[0] + sig = ret_type(*arg_types) + + return sig, codegen + + @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): # Creating a new scalar node is more involved and unnecessary @@ -441,55 +610,42 @@ def numba_funcify_Elemwise(op, node, **kwargs): scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs] scalar_node = op.scalar_op.make_node(*scalar_inputs) + flags = { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # TODO Do we want this one? + } + scalar_op_fn = numba_funcify( - op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs + op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs ) - elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) - elemwise_fn_name = elemwise_fn.__name__ - - if op.inplace_pattern: - input_idx = op.inplace_pattern[0] - sign_obj = inspect.signature(elemwise_fn.py_scalar_func) - input_names = list(sign_obj.parameters.keys()) - - unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_") - input_names = [unique_names(i, force_unique=True) for i in input_names] - updated_input_name = input_names[input_idx] - - inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np} - - inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace" - - input_signature_str = ", ".join(input_names) - - if node.inputs[input_idx].ndim > 0: - inplace_elemwise_src = f""" -def {inplace_elemwise_fn_name}({input_signature_str}): - return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}) - """ - else: - # We can't perform in-place updates on Numba scalars, so we need to - # convert them to NumPy scalars. - # TODO: We should really prevent the rewrites from creating - # in-place updates on scalars when the Numba mode is selected (or - # in general?). - inplace_elemwise_src = f""" -def {inplace_elemwise_fn_name}({input_signature_str}): - {updated_input_name}_scalar = np.asarray({updated_input_name}) - return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item() - """ - - inplace_elemwise_fn = compile_function_src( - inplace_elemwise_src, - inplace_elemwise_fn_name, - {**globals(), **inplace_global_env}, - ) - return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)( - inplace_elemwise_fn + ndim = node.outputs[0].ndim + output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs]) + input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs]) + output_dtypes = tuple(variable.dtype for variable in node.outputs) + inplace_pattern = tuple(op.inplace_pattern.items()) + + # numba doesn't support nested literals right now... + input_bc_patterns = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode() + output_bc_patterns = base64.encodebytes(pickle.dumps(output_bc_patterns)).decode() + output_dtypes = base64.encodebytes(pickle.dumps(output_dtypes)).decode() + inplace_pattern = base64.encodebytes(pickle.dumps(inplace_pattern)).decode() + + @numba_njit + def elemwise_wrapper(*inputs): + return _vectorized( + scalar_op_fn, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + inputs, ) - return elemwise_fn + return elemwise_wrapper @numba_funcify.register(CAReduce) diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py new file mode 100644 index 0000000000..a8bc0b3629 --- /dev/null +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -0,0 +1,231 @@ +from llvmlite import ir +from numba import types +from numba.np import arrayobj +from numba.core import cgutils +import numba +import numpy as np + + +def compute_itershape( + ctx, + builder: ir.IRBuilder, + in_shapes, + broadcast_pattern, +): + one = ir.IntType(64)(1) + ndim = len(in_shapes[0]) + #shape = [ir.IntType(64)(1) for _ in range(ndim)] + shape = [None] * ndim + for i in range(ndim): + # TODO Error checking... + # What if all shapes are 0? + for bc, in_shape in zip(broadcast_pattern, in_shapes): + if bc[i]: + # TODO + # raise error if length != 1 + pass + else: + # TODO + # if shape[i] is not None: + # raise Error if != + shape[i] = in_shape[i] + for i in range(ndim): + if shape[i] is None: + shape[i] = one + return shape + + +def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types): + arrays = [] + ar_types: list[types.Array] = [] + one = ir.IntType(64)(1) + inplace = dict(inplace) + for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)): + if i in inplace: + arrays.append(inputs[inplace[i]]) + ar_types.append(input_types[inplace[i]]) + # We need to incref once we return the inplace objects + continue + dtype = numba.from_dtype(np.dtype(dtype)) + arrtype = types.Array(dtype, len(iter_shape), "C") + ar_types.append(arrtype) + # This is actually an interal numba function, I guess we could + # call `numba.nd.unsafe.ndarray` instead? + shape = [ + length if not bc_dim else one + for length, bc_dim in zip(iter_shape, bc) + ] + array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) + arrays.append(array) + + # If there is no inplace operation, we know that all output arrays + # don't alias. Informing llvm can make it easier to vectorize. + if not inplace: + # The first argument is the output pointer + arg = builder.function.args[0] + arg.add_attribute("noalias") + return arrays, ar_types + + +def make_loop_call( + typingctx, + context: numba.core.base.BaseContext, + builder: ir.IRBuilder, + scalar_func, + scalar_signature, + iter_shape, + inputs, + outputs, + input_bc, + output_bc, + input_types, + output_types, +): + safe = (False, False) + n_outputs = len(outputs) + + #context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) + + # Lower the code of the scalar function so that we can use it in the inner loop + # Caching is set to false to avoid a numba bug TODO ref? + inner_func = context.compile_subroutine( + builder, + # I don't quite understand why we need to access `dispatcher` here. + # The object does seem to be a dispatcher already? But it is missing + # attributes... + scalar_func.dispatcher, + scalar_signature, + caching=False, + ) + inner = inner_func.fndesc + + # Extract shape and stride information from the array. + # For later use in the loop body to do the indexing + def extract_array(aryty, obj): + shape = cgutils.unpack_tuple(builder, obj.shape) + strides = cgutils.unpack_tuple(builder, obj.strides) + data = obj.data + layout = aryty.layout + return (data, shape, strides, layout) + + # TODO I think this is better than the noalias attribute + # for the input, but self_ref isn't supported in a released + # llvmlite version yet + # mod = builder.module + # domain = mod.add_metadata([], self_ref=True) + # input_scope = mod.add_metadata([domain], self_ref=True) + # output_scope = mod.add_metadata([domain], self_ref=True) + # input_scope_set = mod.add_metadata([input_scope, output_scope]) + # output_scope_set = mod.add_metadata([input_scope, output_scope]) + + inputs = [ + extract_array(aryty, ary) + for aryty, ary in zip(input_types, inputs, strict=True) + ] + + outputs = [ + extract_array(aryty, ary) + for aryty, ary in zip(output_types, outputs, strict=True) + ] + + zero = ir.Constant(ir.IntType(64), 0) + + # Setup loops and initialize accumulators for outputs + # This part corresponds to opening the loops + loop_stack = [] + loops = [] + output_accumulator = [(None, None)] * n_outputs + for dim, length in enumerate(iter_shape): + # Find outputs that only have accumulations left + for output in range(n_outputs): + if output_accumulator[output][0] is not None: + continue + if all(output_bc[output][dim:]): + value = outputs[output][0].type.pointee(0) + accu = cgutils.alloca_once_value(builder, value) + output_accumulator[output] = (accu, dim) + + loop = cgutils.for_range(builder, length) + loop_stack.append(loop) + loops.append(loop.__enter__()) + + # Code in the inner most loop... + idxs = [loopval.index for loopval in loops] + + # Load values from input arrays + input_vals = [] + for array_info, bc in zip(inputs, input_bc, strict=True): + idxs_bc = [ + zero if bc else idx for idx, bc in zip(idxs, bc, strict=True) + ] + ptr = cgutils.get_item_pointer2( + context, builder, *array_info, idxs_bc, *safe + ) + val = builder.load(ptr) + # val.set_metadata("alias.scope", input_scope_set) + # val.set_metadata("noalias", output_scope_set) + input_vals.append(val) + + # Call scalar function + output_values = context.call_internal( + builder, + inner, + scalar_signature, + input_vals, + ) + if isinstance(scalar_signature.return_type, types.Tuple): + output_values = cgutils.unpack_tuple(builder, output_values) + else: + output_values = [output_values] + + # Update output value or accumulators respectively + for i, ((accu, _), value) in enumerate( + zip(output_accumulator, output_values, strict=True) + ): + if accu is not None: + load = builder.load(accu) + # load.set_metadata("alias.scope", output_scope_set) + # load.set_metadata("noalias", input_scope_set) + new_value = builder.fadd(load, value) + builder.store(new_value, accu) + # TODO belongs to noalias scope + # store.set_metadata("alias.scope", output_scope_set) + # store.set_metadata("noalias", input_scope_set) + else: + idxs_bc = [ + zero if bc else idx + for idx, bc in zip(idxs, output_bc[i], strict=True) + ] + ptr = cgutils.get_item_pointer2( + context, builder, *outputs[i], idxs_bc + ) + # store = builder.store(value, ptr) + arrayobj.store_item(context, builder, output_types[i], value, ptr) + # store.set_metadata("alias.scope", output_scope_set) + # store.set_metadata("noalias", input_scope_set) + + # Close the loops and write accumulator values to the output arrays + for depth, loop in enumerate(loop_stack[::-1]): + for output, (accu, accu_depth) in enumerate(output_accumulator): + if accu_depth == depth: + idxs_bc = [ + zero if bc else idx + for idx, bc in zip( + idxs, output_bc[output], strict=True + ) + ] + ptr = cgutils.get_item_pointer2( + context, builder, *outputs[output], idxs_bc + ) + load = builder.load(accu) + # load.set_metadata("alias.scope", output_scope_set) + # load.set_metadata("noalias", input_scope_set) + # store = builder.store(load, ptr) + arrayobj.store_item( + context, builder, output_types[output], load, ptr + ) + # store.set_metadata("alias.scope", output_scope_set) + # store.set_metadata("noalias", input_scope_set) + loop.__exit__(None, None, None) + + return diff --git a/pytensor/link/numba/dispatch/helpers.py b/pytensor/link/numba/dispatch/helpers.py new file mode 100644 index 0000000000..4cc2a96ca3 --- /dev/null +++ b/pytensor/link/numba/dispatch/helpers.py @@ -0,0 +1,43 @@ +from numba import njit, types +from numba.core import cgutils +from numba.extending import intrinsic + + +def tuple_mapper(item_map_func): + @intrinsic + def map_tuple(typingctx, *input_tuples): + signatures = [ + typingctx.resolve_function_type(item_map_func, args, {}) + for args in zip(*[in_type.types for in_type in input_tuples], strict=True) + ] + + output_type = types.Tuple([sig.return_type for sig in signatures]) + signature = output_type(types.StarArgTuple(input_tuples)) + + def codegen(context, builder, signature, args): + (input_tuples,) = args + input_values = [] + for val in cgutils.unpack_tuple(builder, input_tuples): + input_values.append(cgutils.unpack_tuple(builder, val)) + + mapped_values = [] + for values, sig in zip(zip(*input_values), signatures, strict=True): + func = context.compile_subroutine(builder, item_map_func, sig) + output = context.call_internal(builder, func.fndesc, sig, values) + mapped_values.append(output) + + return context.make_tuple(builder, output_type, mapped_values) + + return signature, codegen + + return map_tuple + + +@njit +def check_broadcasting(array, bcs, shape): + assert array.ndim == len(shape) + for bc, array_length, length in zip(bcs, array.shape, shape): + if bc: + assert array_length == 1 + else: + assert array_length == length From 61e7a67ff9f2ea3dfbfaa0e227a04edf613b1623 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 19 Dec 2022 23:19:53 -0600 Subject: [PATCH 02/14] numba reshape should always return an array --- pytensor/link/numba/dispatch/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 02936ccdb1..8a83e119dc 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -698,7 +698,7 @@ def numba_funcify_Reshape(op, **kwargs): @numba_njit def reshape(x, shape): - return x.item() + return np.asarray(x.item()) else: From 7ecbac70921433186f7de3c662b8e5d6578b1cd2 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 19 Dec 2022 23:20:12 -0600 Subject: [PATCH 03/14] numba careduce should return an array --- pytensor/link/numba/dispatch/elemwise.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 1cc293e1b6..855878c343 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -376,6 +376,7 @@ def careduce_maximum(input): careduce_def_src = f""" def {careduce_fn_name}({input_name}): {careduce_assign_lines} + #return np.asarray({var_name}) return {var_name} """ From 6936d2d61a7696114efd095bfa9f80eb5160447d Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 19 Dec 2022 23:20:33 -0600 Subject: [PATCH 04/14] Specialized numba sum impl --- pytensor/link/numba/dispatch/elemwise.py | 79 ++++++++++++++----- .../link/numba/dispatch/elemwise_codegen.py | 38 ++++----- 2 files changed, 74 insertions(+), 43 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 855878c343..4c1dcafb06 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,16 +1,14 @@ +import base64 +import pickle from functools import singledispatch from numbers import Number -import pickle from textwrap import indent -from typing import Any, Callable, Literal, Optional, Union -import base64 +from typing import Any, Callable, Optional, Union import numba import numpy as np -from llvmlite import ir -from numba import TypingError, literal_unroll, types, literally +from numba import TypingError, types from numba.core import cgutils -from numba.cpython.unsafe.tuple import tuple_setitem from numba.np import arrayobj from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple @@ -18,6 +16,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch import elemwise_codegen from pytensor.link.numba.dispatch.basic import ( create_numba_signature, create_tuple_creator, @@ -25,8 +24,6 @@ numba_njit, use_optimized_cheap_pass, ) -from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper -from pytensor.link.numba.dispatch import elemwise_codegen from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.scalar.basic import ( AND, @@ -45,7 +42,7 @@ from pytensor.scalar.basic import add as add_as from pytensor.scalar.basic import scalar_maximum from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros +from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.type import scalar @@ -376,8 +373,7 @@ def careduce_maximum(input): careduce_def_src = f""" def {careduce_fn_name}({input_name}): {careduce_assign_lines} - #return np.asarray({var_name}) - return {var_name} + return np.asarray({var_name}) """ careduce_fn = compile_function_src( @@ -447,6 +443,7 @@ def axis_apply_fn(x): } } + @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True) def _vectorized( typingctx, @@ -490,7 +487,6 @@ def _vectorized( inplace_pattern = inplace_pattern.literal_value inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) - n_inputs = len(inputs) n_outputs = len(output_bc_patterns) if not len(inputs) > 0: @@ -531,7 +527,10 @@ def codegen( [_, _, _, _, _, inputs] = args inputs = cgutils.unpack_tuple(builder, inputs) - inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)] + inputs = [ + arrayobj.make_array(ty)(ctx, builder, val) + for ty, val in zip(input_types, inputs) + ] in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs] iter_shape = elemwise_codegen.compute_itershape( @@ -586,14 +585,22 @@ def _check_input_shapes(*_): return outputs[0]._getvalue() for inplace_idx in dict(inplace_pattern): - ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value()) - return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs]) + ctx.nrt.incref( + builder, + sig.return_type.types[inplace_idx], + outputs[inplace_idx]._get_value(), + ) + return ctx.make_tuple( + builder, sig.return_type, [out._getvalue() for out in outputs] + ) # TODO check inplace_pattern - ret_type = types.Tuple([ - types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") - for dtype in output_dtypes - ]) + ret_type = types.Tuple( + [ + types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") + for dtype in output_dtypes + ] + ) if len(output_dtypes) == 1: ret_type = ret_type.types[0] sig = ret_type(*arg_types) @@ -649,6 +656,40 @@ def elemwise_wrapper(*inputs): return elemwise_wrapper +@numba_funcify.register(Sum) +def numba_funcify_Sum(op, node, **kwargs): + axes = op.axis + if axes is None: + axes = list(range(node.inputs[0].ndim)) + + axes = list(axes) + + ndim_input = node.inputs[0].ndim + + if hasattr(op, "acc_dtype") and op.acc_dtype is not None: + acc_dtype = op.acc_dtype + else: + acc_dtype = node.outputs[0].type.dtype + + np_acc_dtype = np.dtype(acc_dtype) + + if ndim_input == len(axes): + + @numba_njit(fastmath=True) + def impl_sum(array): + # TODO The accumulation itself should happen in acc_dtype... + return np.asarray(array.sum()).astype(np_acc_dtype) + + else: + + @numba_njit(fastmath=True) + def impl_sum(array): + # TODO The accumulation itself should happen in acc_dtype... + return array.sum(axes).astype(np_acc_dtype) + + return impl_sum + + @numba_funcify.register(CAReduce) def numba_funcify_CAReduce(op, node, **kwargs): axes = op.axis diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index a8bc0b3629..6c2ccd11a5 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -1,9 +1,9 @@ +import numba +import numpy as np from llvmlite import ir from numba import types -from numba.np import arrayobj from numba.core import cgutils -import numba -import numpy as np +from numba.np import arrayobj def compute_itershape( @@ -35,7 +35,9 @@ def compute_itershape( return shape -def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types): +def make_outputs( + ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types +): arrays = [] ar_types: list[types.Array] = [] one = ir.IntType(64)(1) @@ -52,8 +54,7 @@ def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace # This is actually an interal numba function, I guess we could # call `numba.nd.unsafe.ndarray` instead? shape = [ - length if not bc_dim else one - for length, bc_dim in zip(iter_shape, bc) + length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc) ] array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) arrays.append(array) @@ -84,7 +85,7 @@ def make_loop_call( safe = (False, False) n_outputs = len(outputs) - #context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) + # context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) # Lower the code of the scalar function so that we can use it in the inner loop # Caching is set to false to avoid a numba bug TODO ref? @@ -155,12 +156,8 @@ def extract_array(aryty, obj): # Load values from input arrays input_vals = [] for array_info, bc in zip(inputs, input_bc, strict=True): - idxs_bc = [ - zero if bc else idx for idx, bc in zip(idxs, bc, strict=True) - ] - ptr = cgutils.get_item_pointer2( - context, builder, *array_info, idxs_bc, *safe - ) + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe) val = builder.load(ptr) # val.set_metadata("alias.scope", input_scope_set) # val.set_metadata("noalias", output_scope_set) @@ -193,12 +190,9 @@ def extract_array(aryty, obj): # store.set_metadata("noalias", input_scope_set) else: idxs_bc = [ - zero if bc else idx - for idx, bc in zip(idxs, output_bc[i], strict=True) + zero if bc else idx for idx, bc in zip(idxs, output_bc[i], strict=True) ] - ptr = cgutils.get_item_pointer2( - context, builder, *outputs[i], idxs_bc - ) + ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc) # store = builder.store(value, ptr) arrayobj.store_item(context, builder, output_types[i], value, ptr) # store.set_metadata("alias.scope", output_scope_set) @@ -210,9 +204,7 @@ def extract_array(aryty, obj): if accu_depth == depth: idxs_bc = [ zero if bc else idx - for idx, bc in zip( - idxs, output_bc[output], strict=True - ) + for idx, bc in zip(idxs, output_bc[output], strict=True) ] ptr = cgutils.get_item_pointer2( context, builder, *outputs[output], idxs_bc @@ -221,9 +213,7 @@ def extract_array(aryty, obj): # load.set_metadata("alias.scope", output_scope_set) # load.set_metadata("noalias", input_scope_set) # store = builder.store(load, ptr) - arrayobj.store_item( - context, builder, output_types[output], load, ptr - ) + arrayobj.store_item(context, builder, output_types[output], load, ptr) # store.set_metadata("alias.scope", output_scope_set) # store.set_metadata("noalias", input_scope_set) loop.__exit__(None, None, None) From ae122aa094345d0f4ca4003f203c0b4ca54a6c94 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 20 Dec 2022 10:47:08 -0600 Subject: [PATCH 05/14] Add shape checking in numba elemwise --- pytensor/link/numba/dispatch/elemwise.py | 19 -------- .../link/numba/dispatch/elemwise_codegen.py | 47 ++++++++++++++----- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 4c1dcafb06..31fc05f741 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -454,10 +454,6 @@ def _vectorized( inplace_pattern, inputs, ): - #if not isinstance(scalar_func, types.Literal): - # raise TypingError("scalar func must be literal.") - #scalar_func = scalar_func.literal_value - arg_types = [ scalar_func, input_bc_patterns, @@ -516,8 +512,6 @@ def _vectorized( inplace_pattern_val = inplace_pattern input_types = inputs - #assert not inplace_pattern_val - def codegen( ctx, builder, @@ -551,18 +545,6 @@ def codegen( input_types, ) - def _check_input_shapes(*_): - # TODO impl - return - - _check_input_shapes( - ctx, - builder, - iter_shape, - inputs, - input_bc_patterns_val, - ) - elemwise_codegen.make_loop_call( typingctx, ctx, @@ -594,7 +576,6 @@ def _check_input_shapes(*_): builder, sig.return_type, [out._getvalue() for out in outputs] ) - # TODO check inplace_pattern ret_type = types.Tuple( [ types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index 6c2ccd11a5..6460986b3e 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -3,32 +3,55 @@ from llvmlite import ir from numba import types from numba.core import cgutils +from numba.core.base import BaseContext from numba.np import arrayobj def compute_itershape( - ctx, + ctx: BaseContext, builder: ir.IRBuilder, in_shapes, broadcast_pattern, ): one = ir.IntType(64)(1) ndim = len(in_shapes[0]) - #shape = [ir.IntType(64)(1) for _ in range(ndim)] shape = [None] * ndim for i in range(ndim): - # TODO Error checking... - # What if all shapes are 0? - for bc, in_shape in zip(broadcast_pattern, in_shapes): + for j, (bc, in_shape) in enumerate( + zip(broadcast_pattern, in_shapes, strict=True) + ): + length = in_shape[i] if bc[i]: - # TODO - # raise error if length != 1 - pass + with builder.if_then( + builder.icmp_unsigned("!=", length, one), likely=False + ): + msg = ( + f"Input {j} to elemwise is expected to have shape 1 in axis {i}" + ) + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + elif shape[i] is not None: + with builder.if_then( + builder.icmp_unsigned("!=", length, shape[i]), likely=False + ): + with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( + then, + otherwise, + ): + with then: + msg = ( + f"Incompative shapes for input {j} and axis {i} of " + f"elemwise. Input {j} has shape 1, but is not statically " + "known to have shape 1, and thus not broadcastable." + ) + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) + with otherwise: + msg = ( + f"Input {j} to elemwise has an incompatible " + f"shape in axis {i}." + ) + ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) else: - # TODO - # if shape[i] is not None: - # raise Error if != - shape[i] = in_shape[i] + shape[i] = length for i in range(ndim): if shape[i] is None: shape[i] = one From 0ed277fd61d018e4cc593af4e507a7d2de514247 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 20 Dec 2022 11:24:11 -0600 Subject: [PATCH 06/14] Add typing for some numba elemwise --- .../link/numba/dispatch/elemwise_codegen.py | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index 6460986b3e..6bb51c1451 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -1,3 +1,5 @@ +from typing import Any, List, Optional, Tuple + import numba import numpy as np from llvmlite import ir @@ -10,8 +12,8 @@ def compute_itershape( ctx: BaseContext, builder: ir.IRBuilder, - in_shapes, - broadcast_pattern, + in_shapes: Tuple[ir.Instruction, ...], + broadcast_pattern: Tuple[Tuple[bool, ...], ...], ): one = ir.IntType(64)(1) ndim = len(in_shapes[0]) @@ -59,16 +61,23 @@ def compute_itershape( def make_outputs( - ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types + ctx: numba.core.base.BaseContext, + builder: ir.IRBuilder, + iter_shape: Tuple[ir.Instruction, ...], + out_bc: Tuple[Tuple[bool, ...], ...], + dtypes: Tuple[Any, ...], + inplace: Tuple[Tuple[int, int], ...], + inputs: Tuple[Any, ...], + input_types: Tuple[Any, ...], ): arrays = [] ar_types: list[types.Array] = [] one = ir.IntType(64)(1) - inplace = dict(inplace) + inplace_dict = dict(inplace) for i, (bc, dtype) in enumerate(zip(out_bc, dtypes)): - if i in inplace: - arrays.append(inputs[inplace[i]]) - ar_types.append(input_types[inplace[i]]) + if i in inplace_dict: + arrays.append(inputs[inplace_dict[i]]) + ar_types.append(input_types[inplace_dict[i]]) # We need to incref once we return the inplace objects continue dtype = numba.from_dtype(np.dtype(dtype)) @@ -95,15 +104,15 @@ def make_loop_call( typingctx, context: numba.core.base.BaseContext, builder: ir.IRBuilder, - scalar_func, - scalar_signature, - iter_shape, - inputs, - outputs, - input_bc, - output_bc, - input_types, - output_types, + scalar_func: Any, + scalar_signature: types.FunctionType, + iter_shape: Tuple[ir.Instruction, ...], + inputs: Tuple[ir.Instruction, ...], + outputs: Tuple[ir.Instruction, ...], + input_bc: Tuple[Tuple[bool, ...], ...], + output_bc: Tuple[Tuple[bool, ...], ...], + input_types: Tuple[Any, ...], + output_types: Tuple[Any, ...], ): safe = (False, False) n_outputs = len(outputs) @@ -142,15 +151,15 @@ def extract_array(aryty, obj): # input_scope_set = mod.add_metadata([input_scope, output_scope]) # output_scope_set = mod.add_metadata([input_scope, output_scope]) - inputs = [ + inputs = tuple( extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs, strict=True) - ] + ) - outputs = [ + outputs = tuple( extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs, strict=True) - ] + ) zero = ir.Constant(ir.IntType(64), 0) @@ -158,7 +167,9 @@ def extract_array(aryty, obj): # This part corresponds to opening the loops loop_stack = [] loops = [] - output_accumulator = [(None, None)] * n_outputs + output_accumulator: List[Tuple[Optional[Any], Optional[int]]] = [ + (None, None) + ] * n_outputs for dim, length in enumerate(iter_shape): # Find outputs that only have accumulations left for output in range(n_outputs): From 79b10370d4577c6b20db6267979fc6faeb5ca68b Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 20 Dec 2022 11:36:06 -0600 Subject: [PATCH 07/14] Remove py310 only strict arg to zip --- .../link/numba/dispatch/elemwise_codegen.py | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index 6bb51c1451..871a21cfb5 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -19,9 +19,7 @@ def compute_itershape( ndim = len(in_shapes[0]) shape = [None] * ndim for i in range(ndim): - for j, (bc, in_shape) in enumerate( - zip(broadcast_pattern, in_shapes, strict=True) - ): + for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)): length = in_shape[i] if bc[i]: with builder.if_then( @@ -151,14 +149,10 @@ def extract_array(aryty, obj): # input_scope_set = mod.add_metadata([input_scope, output_scope]) # output_scope_set = mod.add_metadata([input_scope, output_scope]) - inputs = tuple( - extract_array(aryty, ary) - for aryty, ary in zip(input_types, inputs, strict=True) - ) + inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs)) outputs = tuple( - extract_array(aryty, ary) - for aryty, ary in zip(output_types, outputs, strict=True) + extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs) ) zero = ir.Constant(ir.IntType(64), 0) @@ -189,8 +183,8 @@ def extract_array(aryty, obj): # Load values from input arrays input_vals = [] - for array_info, bc in zip(inputs, input_bc, strict=True): - idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + for array_info, bc in zip(inputs, input_bc): + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe) val = builder.load(ptr) # val.set_metadata("alias.scope", input_scope_set) @@ -210,9 +204,7 @@ def extract_array(aryty, obj): output_values = [output_values] # Update output value or accumulators respectively - for i, ((accu, _), value) in enumerate( - zip(output_accumulator, output_values, strict=True) - ): + for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)): if accu is not None: load = builder.load(accu) # load.set_metadata("alias.scope", output_scope_set) @@ -223,9 +215,7 @@ def extract_array(aryty, obj): # store.set_metadata("alias.scope", output_scope_set) # store.set_metadata("noalias", input_scope_set) else: - idxs_bc = [ - zero if bc else idx for idx, bc in zip(idxs, output_bc[i], strict=True) - ] + idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])] ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc) # store = builder.store(value, ptr) arrayobj.store_item(context, builder, output_types[i], value, ptr) @@ -237,8 +227,7 @@ def extract_array(aryty, obj): for output, (accu, accu_depth) in enumerate(output_accumulator): if accu_depth == depth: idxs_bc = [ - zero if bc else idx - for idx, bc in zip(idxs, output_bc[output], strict=True) + zero if bc else idx for idx, bc in zip(idxs, output_bc[output]) ] ptr = cgutils.get_item_pointer2( context, builder, *outputs[output], idxs_bc From e9ab63c0c4f15b588b951e00a0b143c682192385 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 20 Dec 2022 21:16:17 -0600 Subject: [PATCH 08/14] Fix tests and fix scalar numba return types --- pytensor/link/numba/dispatch/basic.py | 6 +- pytensor/link/numba/dispatch/elemwise.py | 115 ++++++++++++++---- .../link/numba/dispatch/elemwise_codegen.py | 30 ++--- pytensor/link/numba/dispatch/extra_ops.py | 1 + pytensor/link/numba/dispatch/scalar.py | 3 + pytensor/link/numba/dispatch/scan.py | 24 +++- pytensor/link/numba/linker.py | 4 +- tests/link/numba/test_basic.py | 52 ++++++-- tests/link/numba/test_elemwise.py | 6 +- tests/link/numba/test_extra_ops.py | 1 + 10 files changed, 171 insertions(+), 71 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 8a83e119dc..fe66207d73 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -204,7 +204,7 @@ def in_seq_empty_tuple(x, y): def to_scalar(x): - raise NotImplementedError() + return np.asarray(x).item() @numba.extending.overload(to_scalar) @@ -543,7 +543,7 @@ def {fn_name}({", ".join(input_names)}): {index_prologue} {indices_creation_src} {index_body} - return z + return np.asarray(z) """ return subtensor_def_src @@ -665,7 +665,7 @@ def numba_funcify_Shape_i(op, **kwargs): @numba_njit def shape_i(x): - return np.shape(x)[i] + return np.asarray(np.shape(x)[i]) return shape_i diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 31fc05f741..3bab657fb0 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -9,6 +9,7 @@ import numpy as np from numba import TypingError, types from numba.core import cgutils +from numba.core.extending import overload from numba.np import arrayobj from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple @@ -174,6 +175,7 @@ def create_axis_reducer( ndim: int, dtype: numba.types.Type, keepdims: bool = False, + return_scalar=False, ) -> numba.core.dispatcher.Dispatcher: r"""Create Python function that performs a NumPy-like reduction on a given axis. @@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x): inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2) return_expr = "res" if keepdims else "res.item()" + if not return_scalar: + return_expr = f"np.asarray({return_expr})" reduce_elemwise_def_src = f""" def {reduce_elemwise_fn_name}(x): @@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x): def create_multiaxis_reducer( - scalar_op, identity, axes, ndim, dtype, input_name="input" + scalar_op, + identity, + axes, + ndim, + dtype, + input_name="input", + return_scalar=False, ): r"""Construct a function that reduces multiple axes. @@ -336,6 +346,8 @@ def careduce_maximum(input): The number of dimensions of the result. dtype: The data type of the result. + return_scalar: + If True, return a scalar, otherwise an array. Returns ======= @@ -370,10 +382,17 @@ def careduce_maximum(input): ) careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) + if not return_scalar: + pre_result = "np.asarray" + post_result = "" + else: + pre_result = "np.asarray" + post_result = ".item()" + careduce_def_src = f""" def {careduce_fn_name}({input_name}): {careduce_assign_lines} - return np.asarray({var_name}) + return {pre_result}({var_name}){post_result} """ careduce_fn = compile_function_src( @@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}): return careduce_fn -def jit_compile_reducer(node, fn, **kwds): +def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds): """Compile Python source for reduction loops using additional optimizations. Parameters @@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds): A :func:`numba.njit`-compiled function. """ - signature = create_numba_signature(node, reduce_to_scalar=True) + signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar) # Eagerly compile the function using increased optimizations. This should # help improve nested loop reductions. @@ -618,23 +637,58 @@ def numba_funcify_Elemwise(op, node, **kwargs): inplace_pattern = tuple(op.inplace_pattern.items()) # numba doesn't support nested literals right now... - input_bc_patterns = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode() - output_bc_patterns = base64.encodebytes(pickle.dumps(output_bc_patterns)).decode() - output_dtypes = base64.encodebytes(pickle.dumps(output_dtypes)).decode() - inplace_pattern = base64.encodebytes(pickle.dumps(inplace_pattern)).decode() + input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode() + output_bc_patterns_enc = base64.encodebytes( + pickle.dumps(output_bc_patterns) + ).decode() + output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode() + inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode() - @numba_njit def elemwise_wrapper(*inputs): return _vectorized( scalar_op_fn, - input_bc_patterns, - output_bc_patterns, - output_dtypes, - inplace_pattern, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, inputs, ) - return elemwise_wrapper + # Pure python implementation, that will be used in tests + def elemwise(*inputs): + inputs = [np.asarray(input) for input in inputs] + inputs_bc = np.broadcast_arrays(*inputs) + shape = inputs[0].shape + for input, bc in zip(inputs, input_bc_patterns): + for length, allow_bc, iter_length in zip(input.shape, bc, shape): + if length == 1 and shape and iter_length != 1 and not allow_bc: + raise ValueError("Broadcast not allowed.") + + outputs = [] + for dtype in output_dtypes: + outputs.append(np.empty(shape, dtype=dtype)) + + for idx in np.ndindex(shape): + vals = [input[idx] for input in inputs_bc] + outs = scalar_op_fn(*vals) + if not isinstance(outs, tuple): + outs = (outs,) + for out, out_val in zip(outputs, outs): + out[idx] = out_val + + outputs_summed = [] + for output, bc in zip(outputs, output_bc_patterns): + axes = tuple(np.nonzero(bc)[0]) + outputs_summed.append(output.sum(axes, keepdims=True)) + if len(outputs_summed) != 1: + return tuple(outputs_summed) + return outputs_summed[0] + + @overload(elemwise) + def ov_elemwise(*inputs): + return elemwise_wrapper + + return elemwise @numba_funcify.register(Sum) @@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs): if axes is None: axes = list(range(node.inputs[0].ndim)) - axes = list(axes) + axes = tuple(axes) ndim_input = node.inputs[0].ndim @@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs): @numba_njit(fastmath=True) def impl_sum(array): - # TODO The accumulation itself should happen in acc_dtype... - return np.asarray(array.sum()).astype(np_acc_dtype) + return np.asarray(array.sum(), dtype=np_acc_dtype) - else: + elif len(axes) == 0: @numba_njit(fastmath=True) def impl_sum(array): - # TODO The accumulation itself should happen in acc_dtype... - return array.sum(axes).astype(np_acc_dtype) + return array + + else: + impl_sum = numba_funcify_CAReduce(op, node, **kwargs) return impl_sum @@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs): input_name=input_name, ) - careduce_fn = jit_compile_reducer(node, careduce_py_fn) + careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) return careduce_fn @@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) reduce_max_py = create_axis_reducer( - scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True + scalar_maximum, + -np.inf, + axis, + x_at.ndim, + x_dtype, + keepdims=True, ) reduce_sum_py = create_axis_reducer( add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True @@ -935,10 +995,17 @@ def maxandargmax(x): keep_axes = tuple(i for i in range(x_ndim) if i not in axes) reduce_max_py_fn = create_multiaxis_reducer( - scalar_maximum, -np.inf, axes, x_ndim, x_dtype + scalar_maximum, + -np.inf, + axes, + x_ndim, + x_dtype, + return_scalar=False, ) reduce_max = jit_compile_reducer( - Apply(node.op, node.inputs, [node.outputs[0].clone()]), reduce_max_py_fn + Apply(node.op, node.inputs, [node.outputs[0].clone()]), + reduce_max_py_fn, + reduce_to_scalar=False, ) reduced_x_ndim = x_ndim - len(axes) + 1 diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index 871a21cfb5..a8a808785c 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -117,19 +117,6 @@ def make_loop_call( # context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) - # Lower the code of the scalar function so that we can use it in the inner loop - # Caching is set to false to avoid a numba bug TODO ref? - inner_func = context.compile_subroutine( - builder, - # I don't quite understand why we need to access `dispatcher` here. - # The object does seem to be a dispatcher already? But it is missing - # attributes... - scalar_func.dispatcher, - scalar_signature, - caching=False, - ) - inner = inner_func.fndesc - # Extract shape and stride information from the array. # For later use in the loop body to do the indexing def extract_array(aryty, obj): @@ -191,14 +178,15 @@ def extract_array(aryty, obj): # val.set_metadata("noalias", output_scope_set) input_vals.append(val) - # Call scalar function - output_values = context.call_internal( - builder, - inner, - scalar_signature, - input_vals, - ) - if isinstance(scalar_signature.return_type, types.Tuple): + inner_codegen = context.get_function(scalar_func, scalar_signature) + + if isinstance( + scalar_signature.args[0], (types.StarArgTuple, types.StarArgUniTuple) + ): + input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)] + output_values = inner_codegen(builder, input_vals) + + if isinstance(scalar_signature.return_type, (types.Tuple, types.UniTuple)): output_values = cgutils.unpack_tuple(builder, output_values) else: output_values = [output_values] diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 9871584454..33fac601a5 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -364,6 +364,7 @@ def numba_funcify_BroadcastTo(op, node, **kwargs): lambda _: 0, len(node.inputs) - 1 ) + # TODO broadcastable checks @numba_basic.numba_njit def broadcast_to(x, *shape): scalars_shape = create_zeros_tuple() diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index d72277b2f5..8cd57c6765 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -38,6 +38,9 @@ def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? + if not hasattr(op, "nfunc_spec"): + return generate_fallback_impl(op, node, **kwargs) + scalar_func_path = op.nfunc_spec[0] scalar_func_numba = None diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c26cd9aa6c..2da07174c9 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -17,7 +17,11 @@ def idx_to_str( - array_name: str, offset: int, size: Optional[str] = None, idx_symbol: str = "i" + array_name: str, + offset: int, + size: Optional[str] = None, + idx_symbol: str = "i", + allow_scalar=False, ) -> str: if offset < 0: indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}" @@ -32,7 +36,10 @@ def idx_to_str( # compensate for this poor `Op`/rewrite design and implementation. indices = f"({indices}) % {size}" - return f"{array_name}[{indices}]" + if allow_scalar: + return f"{array_name}[{indices}]" + else: + return f"np.asarray({array_name}[{indices}])" @overload(range) @@ -115,7 +122,9 @@ def add_inner_in_expr( indexed_inner_in_str = ( storage_name if tap_offset is None - else idx_to_str(storage_name, tap_offset, size=storage_size_var) + else idx_to_str( + storage_name, tap_offset, size=storage_size_var, allow_scalar=False + ) ) inner_in_exprs.append(indexed_inner_in_str) @@ -232,7 +241,12 @@ def add_output_storage_post_proc_stmt( ) for out_tap in output_taps: inner_out_to_outer_in_stmts.append( - idx_to_str(storage_name, out_tap, size=storage_size_name) + idx_to_str( + storage_name, + out_tap, + size=storage_size_name, + allow_scalar=True, + ) ) add_output_storage_post_proc_stmt( @@ -269,7 +283,7 @@ def add_output_storage_post_proc_stmt( storage_size_name = f"{outer_in_name}_len" inner_out_to_outer_in_stmts.append( - idx_to_str(storage_name, 0, size=storage_size_name) + idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True) ) add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name) diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 7cddedbc58..3f0e35543f 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -27,9 +27,9 @@ def fgraph_convert(self, fgraph, **kwargs): return numba_funcify(fgraph, **kwargs) def jit_compile(self, fn): - import numba + from pytensor.link.numba.dispatch.basic import numba_njit - jitted_fn = numba.njit(fn) + jitted_fn = numba_njit(fn) return jitted_fn def create_thunk_inputs(self, storage_map): diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 1dbf416f24..5686951c1b 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -27,6 +27,7 @@ from pytensor.link.numba.dispatch import numba_typify from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op +from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.tensor import blas from pytensor.tensor import subtensor as at_subtensor from pytensor.tensor.elemwise import Elemwise @@ -63,6 +64,33 @@ def perform(self, node, inputs, outputs): outputs[0][0] = res +class ScalarMyMultiOut(ScalarOp): + nin = 2 + nout = 2 + + @staticmethod + def impl(a, b): + res1 = 2 * a + res2 = 2 * b + return [res1, res2] + + def make_node(self, a, b): + a = as_scalar(a) + b = as_scalar(b) + return Apply(self, [a, b], [a.type(), b.type()]) + + def perform(self, node, inputs, outputs): + res1, res2 = self.impl(inputs[0], inputs[1]) + outputs[0][0] = res1 + outputs[1][0] = res2 + + +scalar_my_multi_out = Elemwise(ScalarMyMultiOut()) +scalar_my_multi_out.ufunc = ScalarMyMultiOut.impl +scalar_my_multi_out.ufunc.nin = 2 +scalar_my_multi_out.ufunc.nout = 2 + + class MyMultiOut(Op): nin = 2 nout = 2 @@ -86,7 +114,6 @@ def perform(self, node, inputs, outputs): my_multi_out.ufunc = MyMultiOut.impl my_multi_out.ufunc.nin = 2 my_multi_out.ufunc.nout = 2 - opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) numba_mode = Mode(NumbaLinker(), opts) py_mode = Mode("py", opts) @@ -988,8 +1015,8 @@ def test_config_options_parallel(): x = at.dvector() with config.change_flags(numba__vectorize_target="parallel"): - pytensor_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode) + numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] assert numba_mul_fn.targetoptions["parallel"] is True @@ -997,8 +1024,9 @@ def test_config_options_fastmath(): x = at.dvector() with config.change_flags(numba__fastmath=True): - pytensor_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode) + print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__.keys())) + numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] assert numba_mul_fn.targetoptions["fastmath"] is True @@ -1006,16 +1034,14 @@ def test_config_options_cached(): x = at.dvector() with config.change_flags(numba__cache=True): - pytensor_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert not isinstance( - numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache - ) + pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode) + numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] + assert not isinstance(numba_mul_fn._cache, numba.core.caching.NullCache) with config.change_flags(numba__cache=False): - pytensor_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache) + pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode) + numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] + assert isinstance(numba_mul_fn._cache, numba.core.caching.NullCache) def test_scalar_return_value_conversion(): diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 4d4186e4b6..2f3846416c 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -16,7 +16,7 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from tests.link.numba.test_basic import ( compare_numba_and_py, - my_multi_out, + scalar_my_multi_out, set_test_value, ) @@ -99,8 +99,8 @@ rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX), ], - lambda x, y: my_multi_out(x, y), - NotImplementedError, + lambda x, y: scalar_my_multi_out(x, y), + None, ), ], ) diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index 8cf1fdc6bd..0570ef2996 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -32,6 +32,7 @@ def test_Bartlett(val): for i in g_fg.inputs if not isinstance(i, (SharedVariable, Constant)) ], + assert_fn=lambda x, y: np.testing.assert_allclose(x, y, atol=1e-15), ) From 05f6ab20ca9c2a900d97a35bd26c550a5df7ae15 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 20 Dec 2022 21:56:24 -0600 Subject: [PATCH 09/14] Add cast in numba elemwise between func type and output type --- pytensor/link/numba/dispatch/elemwise_codegen.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index a8a808785c..c13b756e52 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -188,8 +188,10 @@ def extract_array(aryty, obj): if isinstance(scalar_signature.return_type, (types.Tuple, types.UniTuple)): output_values = cgutils.unpack_tuple(builder, output_values) + func_output_types = scalar_signature.return_type.types else: output_values = [output_values] + func_output_types = [scalar_signature.return_type] # Update output value or accumulators respectively for i, ((accu, _), value) in enumerate(zip(output_accumulator, output_values)): @@ -206,6 +208,9 @@ def extract_array(aryty, obj): idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, output_bc[i])] ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc) # store = builder.store(value, ptr) + value = context.cast( + builder, value, func_output_types[i], output_types[i].dtype + ) arrayobj.store_item(context, builder, output_types[i], value, ptr) # store.set_metadata("alias.scope", output_scope_set) # store.set_metadata("noalias", input_scope_set) @@ -224,6 +229,9 @@ def extract_array(aryty, obj): # load.set_metadata("alias.scope", output_scope_set) # load.set_metadata("noalias", input_scope_set) # store = builder.store(load, ptr) + load = context.cast( + builder, load, func_output_types[output], output_types[output].dtype + ) arrayobj.store_item(context, builder, output_types[output], load, ptr) # store.set_metadata("alias.scope", output_scope_set) # store.set_metadata("noalias", input_scope_set) From b93a4f799f728050db8b0838b3460a401b073bfe Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 21 Dec 2022 11:54:55 -0600 Subject: [PATCH 10/14] numba while condition must be tensor --- pytensor/link/numba/dispatch/scan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 2da07174c9..a307d29c5e 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -351,8 +351,8 @@ def scan({", ".join(outer_in_names)}): {indent(input_storage_block, " " * 4)} i = 0 - cond = False - while i < n_steps and not cond: + cond = np.array(False) + while i < n_steps and not cond.item(): {inner_outputs} = scan_inner_func({inner_in_args}) {indent(inner_out_post_processing_block, " " * 8)} {indent(inner_out_to_outer_out_stmts, " " * 8)} From bfb91f9e1b3a9387d8d6b0aeec2129b93744f298 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 21 Dec 2022 16:27:13 -0600 Subject: [PATCH 11/14] Fix some floatX issues --- pytensor/link/numba/dispatch/elemwise.py | 6 ++++-- tests/link/numba/test_scalar.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 3bab657fb0..aad741c67d 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -708,17 +708,19 @@ def numba_funcify_Sum(op, node, **kwargs): np_acc_dtype = np.dtype(acc_dtype) + out_dtype = np.dtype(node.outputs[0].dtype) + if ndim_input == len(axes): @numba_njit(fastmath=True) def impl_sum(array): - return np.asarray(array.sum(), dtype=np_acc_dtype) + return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) elif len(axes) == 0: @numba_njit(fastmath=True) def impl_sum(array): - return array + return np.asarray(array, dtype=out_dtype) else: impl_sum = numba_funcify_CAReduce(op, node, **kwargs) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index dde20f5f19..7676b1bf40 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -97,7 +97,7 @@ def test_Clip(v, min, max): ], ) def test_Composite(inputs, input_values, scalar_fn): - composite_inputs = [aes.float64(i.name) for i in inputs] + composite_inputs = [aes.ScalarType(config.floatX)(name=i.name) for i in inputs] comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)])) out_fg = FunctionGraph(inputs, [comp_op(*inputs)]) compare_numba_and_py(out_fg, input_values) From 84d6fd5e4c7af546f17224aeeff7b96f6577de73 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 22 Dec 2022 12:52:45 -0600 Subject: [PATCH 12/14] Add benchmark for numba elemwise --- tests/link/numba/test_elemwise.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 2f3846416c..0958e90034 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -6,7 +6,7 @@ import pytensor.tensor as at import pytensor.tensor.inplace as ati import pytensor.tensor.math as aem -from pytensor import config +from pytensor import config, function from pytensor.compile.ops import deep_copy_op from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant @@ -117,6 +117,25 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): compare_numba_and_py(out_fg, input_vals) +def test_elemwise_speed(benchmark): + x = at.dmatrix("y") + y = at.dvector("z") + + out = np.exp(2 * x * y + y) + + rng = np.random.default_rng(42) + + x_val = rng.normal(size=(200, 500)) + y_val = rng.normal(size=500) + + func = function([x, y], out, mode="NUMBA") + func = func.vm.jit_fn + (out,) = func(x_val, y_val) + np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out) + + benchmark(func, x_val, y_val) + + @pytest.mark.parametrize( "v, new_order", [ From 4fedcbef712dff0754a57e4883a980efb40bb65f Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 3 Jan 2023 14:27:36 -0600 Subject: [PATCH 13/14] Fix typo in error message Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/link/numba/dispatch/elemwise_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index c13b756e52..d3d1ff1df1 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -39,7 +39,7 @@ def compute_itershape( ): with then: msg = ( - f"Incompative shapes for input {j} and axis {i} of " + f"Incompatible shapes for input {j} and axis {i} of " f"elemwise. Input {j} has shape 1, but is not statically " "known to have shape 1, and thus not broadcastable." ) From 462b957f3082ca991a69fc55a030b53d4eac9bce Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 3 Jan 2023 14:31:56 -0600 Subject: [PATCH 14/14] Remove unused helper file --- pytensor/link/numba/dispatch/helpers.py | 43 ------------------------- 1 file changed, 43 deletions(-) delete mode 100644 pytensor/link/numba/dispatch/helpers.py diff --git a/pytensor/link/numba/dispatch/helpers.py b/pytensor/link/numba/dispatch/helpers.py deleted file mode 100644 index 4cc2a96ca3..0000000000 --- a/pytensor/link/numba/dispatch/helpers.py +++ /dev/null @@ -1,43 +0,0 @@ -from numba import njit, types -from numba.core import cgutils -from numba.extending import intrinsic - - -def tuple_mapper(item_map_func): - @intrinsic - def map_tuple(typingctx, *input_tuples): - signatures = [ - typingctx.resolve_function_type(item_map_func, args, {}) - for args in zip(*[in_type.types for in_type in input_tuples], strict=True) - ] - - output_type = types.Tuple([sig.return_type for sig in signatures]) - signature = output_type(types.StarArgTuple(input_tuples)) - - def codegen(context, builder, signature, args): - (input_tuples,) = args - input_values = [] - for val in cgutils.unpack_tuple(builder, input_tuples): - input_values.append(cgutils.unpack_tuple(builder, val)) - - mapped_values = [] - for values, sig in zip(zip(*input_values), signatures, strict=True): - func = context.compile_subroutine(builder, item_map_func, sig) - output = context.call_internal(builder, func.fndesc, sig, values) - mapped_values.append(output) - - return context.make_tuple(builder, output_type, mapped_values) - - return signature, codegen - - return map_tuple - - -@njit -def check_broadcasting(array, bcs, shape): - assert array.ndim == len(shape) - for bc, array_length, length in zip(bcs, array.shape, shape): - if bc: - assert array_length == 1 - else: - assert array_length == length