Skip to content

Commit 5d2582d

Browse files
committed
Add shape checking in numba elemwise
1 parent 9fdaeca commit 5d2582d

File tree

2 files changed

+35
-31
lines changed

2 files changed

+35
-31
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -454,10 +454,6 @@ def _vectorized(
454454
inplace_pattern,
455455
inputs,
456456
):
457-
#if not isinstance(scalar_func, types.Literal):
458-
# raise TypingError("scalar func must be literal.")
459-
#scalar_func = scalar_func.literal_value
460-
461457
arg_types = [
462458
scalar_func,
463459
input_bc_patterns,
@@ -516,8 +512,6 @@ def _vectorized(
516512
inplace_pattern_val = inplace_pattern
517513
input_types = inputs
518514

519-
#assert not inplace_pattern_val
520-
521515
def codegen(
522516
ctx,
523517
builder,
@@ -551,18 +545,6 @@ def codegen(
551545
input_types,
552546
)
553547

554-
def _check_input_shapes(*_):
555-
# TODO impl
556-
return
557-
558-
_check_input_shapes(
559-
ctx,
560-
builder,
561-
iter_shape,
562-
inputs,
563-
input_bc_patterns_val,
564-
)
565-
566548
elemwise_codegen.make_loop_call(
567549
typingctx,
568550
ctx,
@@ -594,7 +576,6 @@ def _check_input_shapes(*_):
594576
builder, sig.return_type, [out._getvalue() for out in outputs]
595577
)
596578

597-
# TODO check inplace_pattern
598579
ret_type = types.Tuple(
599580
[
600581
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")

pytensor/link/numba/dispatch/elemwise_codegen.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,55 @@
33
from llvmlite import ir
44
from numba import types
55
from numba.core import cgutils
6+
from numba.core.base import BaseContext
67
from numba.np import arrayobj
78

89

910
def compute_itershape(
10-
ctx,
11+
ctx: BaseContext,
1112
builder: ir.IRBuilder,
1213
in_shapes,
1314
broadcast_pattern,
1415
):
1516
one = ir.IntType(64)(1)
1617
ndim = len(in_shapes[0])
17-
#shape = [ir.IntType(64)(1) for _ in range(ndim)]
1818
shape = [None] * ndim
1919
for i in range(ndim):
20-
# TODO Error checking...
21-
# What if all shapes are 0?
22-
for bc, in_shape in zip(broadcast_pattern, in_shapes):
20+
for j, (bc, in_shape) in enumerate(
21+
zip(broadcast_pattern, in_shapes, strict=True)
22+
):
23+
length = in_shape[i]
2324
if bc[i]:
24-
# TODO
25-
# raise error if length != 1
26-
pass
25+
with builder.if_then(
26+
builder.icmp_unsigned("!=", length, one), likely=False
27+
):
28+
msg = (
29+
f"Input {j} to elemwise is expected to have shape 1 in axis {i}"
30+
)
31+
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
32+
elif shape[i] is not None:
33+
with builder.if_then(
34+
builder.icmp_unsigned("!=", length, shape[i]), likely=False
35+
):
36+
with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
37+
then,
38+
otherwise,
39+
):
40+
with then:
41+
msg = (
42+
f"Incompative shapes for input {j} and axis {i} of "
43+
f"elemwise. Input {j} has shape 1, but is not statically "
44+
"known to have shape 1, and thus not broadcastable."
45+
)
46+
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
47+
with otherwise:
48+
msg = (
49+
f"Input {j} to elemwise has an incompatible "
50+
f"shape in axis {i}."
51+
)
52+
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
2753
else:
28-
# TODO
29-
# if shape[i] is not None:
30-
# raise Error if !=
31-
shape[i] = in_shape[i]
54+
shape[i] = length
3255
for i in range(ndim):
3356
if shape[i] is None:
3457
shape[i] = one

0 commit comments

Comments
 (0)