@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
66
66
if index != "x" :
67
67
# Initialize the variables associated to the jth loop
68
68
# jump = stride - adjust
69
- # If the variable has size 1 in that dim, we set the stride to zero to
70
- # emulate broadcasting
71
69
jump = f"({ var } _stride{ index } ) - ({ adjust } )"
72
70
init += f"""
73
71
{ var } _n{ index } = PyArray_DIMS({ var } )[{ index } ];
74
- { var } _stride{ index } = ( { var } _n { index } == 1)? 0 : PyArray_STRIDES({ var } )[{ index } ] / sizeof({ dtype } );
72
+ { var } _stride{ index } = PyArray_STRIDES({ var } )[{ index } ] / sizeof({ dtype } );
75
73
{ var } _jump{ index } _{ j } = { jump } ;
76
74
"""
77
75
adjust = f"{ var } _n{ index } *{ var } _stride{ index } "
@@ -86,88 +84,73 @@ def make_checks(loop_orders, dtypes, sub):
86
84
# This loop builds multiple if conditions to verify that the
87
85
# dimensions of the inputs match, and the first one that is true
88
86
# raises an informative error message
87
+
88
+ runtime_broadcast_error_msg = (
89
+ "Runtime broadcasting not allowed. "
90
+ "One input had a distinct dimension length of 1, but was not marked as broadcastable: "
91
+ "(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
92
+ "If broadcasting was intended, use `specify_broadcastable` on the relevant input."
93
+ )
94
+
89
95
for matches in zip (* loop_orders ):
90
96
to_compare = [(j , x ) for j , x in enumerate (matches ) if x != "x" ]
91
97
92
98
# elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx )
93
99
if len (to_compare ) < 2 :
94
100
continue
95
101
96
- # Find first dimension size that is != 1
97
- jl , xl = to_compare [- 1 ]
98
- non1size_dim_check = f"""
99
- npy_intp non1size_dim{ xl } ;
100
- non1size_dim{ xl } = """
101
- for j , x in to_compare [:- 1 ]:
102
- non1size_dim_check += f"(%(lv{ j } )s_n{ x } != 1) ? %(lv{ j } )s_n{ x } : "
103
- non1size_dim_check += f"%(lv{ jl } )s_n{ xl } ;"
104
- check += non1size_dim_check
105
-
106
- # Check the nonsize1 dims match
107
- # TODO: This is a bit inefficient because we are comparing one dimension against itself
108
- check += f"""
109
- if (non1size_dim{ xl } != 1)
110
- {{
111
- """
112
- for j , x in to_compare :
102
+ j0 , x0 = to_compare [0 ]
103
+ for j , x in to_compare [1 :]:
113
104
check += f"""
114
- if ((%(lv{ j } )s_n{ x } != non1size_dim{ x } ) && (%(lv{ j } )s_n{ x } != 1))
105
+ if (%(lv{ j0 } )s_n{ x0 } != %(lv{ j } )s_n{ x } )
106
+ {{
107
+ if (%(lv{ j0 } )s_n{ x0 } == 1 || %(lv{ j } )s_n{ x } == 1)
115
108
{{
116
- PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.",
117
- { x } ,
118
- (long long int) non1size_dim{ x } ,
109
+ PyErr_Format(PyExc_ValueError, "{ runtime_broadcast_error_msg } ",
110
+ { j0 } ,
111
+ { x0 } ,
112
+ (long long int) %(lv{ j0 } )s_n{ x0 } ,
113
+ { j } ,
114
+ { x } ,
115
+ (long long int) %(lv{ j } )s_n{ x }
116
+ );
117
+ }} else {{
118
+ PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
119
+ { j0 } ,
120
+ { x0 } ,
121
+ (long long int) %(lv{ j0 } )s_n{ x0 } ,
119
122
{ j } ,
120
123
{ x } ,
121
124
(long long int) %(lv{ j } )s_n{ x }
122
125
);
123
- %(fail)s
124
126
}}
125
- """
126
- check += """
127
- }
127
+ %(fail)s
128
+ }}
128
129
"""
129
130
130
131
return init % sub + check % sub
131
132
132
133
133
- def compute_broadcast_dimensions (array_name : str , loop_orders , sub ) -> str :
134
- """Create c_code to compute broadcasted dimensions of multiple arrays, arising from
135
- Elemwise operations.
134
+ def compute_outputs_dims (array_name : str , loop_orders , sub ) -> str :
135
+ """Create c_code to compute the output dimensions of an Elemwise operation.
136
136
137
137
The code returned by this function populates the array `array_name`, but does not
138
138
initialize it.
139
139
140
- TODO: We can decide to either specialize C code even further given the input types
141
- or make it general, regardless of whether static broadcastable information is given
140
+ Note: We could specialize C code even further with the known static output shapes
142
141
"""
143
142
dims_c_code = ""
144
143
for i , candidates in enumerate (zip (* loop_orders )):
145
- # TODO: Are candidates always either "x" or "i"? If that's the case we can
146
- # simplify some logic here (e.g., we don't need to track the `idx`).
147
- nonx_candidates = tuple (
148
- ( idx , c ) for idx , c in enumerate ( candidates ) if c != "x"
149
- )
150
-
151
- # All inputs are known to be broadcastable
152
- if not nonx_candidates :
144
+ # Borrow the length of the first non-broadcastable input dimension
145
+ for j , candidate in enumerate ( candidates ):
146
+ if candidate != "x" :
147
+ var = sub [ f"lv { int ( j ) } " ]
148
+ dims_c_code += f" { array_name } [ { i } ] = { var } _n { candidate } ; \n "
149
+ break
150
+ # If none is non-broadcastable, the output dimension has a length of 1
151
+ else : # no-break
153
152
dims_c_code += f"{ array_name } [{ i } ] = 1;\n "
154
- continue
155
-
156
- # There is only one informative source of size
157
- if len (nonx_candidates ) == 1 :
158
- idx , candidate = nonx_candidates [0 ]
159
- var = sub [f"lv{ int (idx )} " ]
160
- dims_c_code += f"{ array_name } [{ i } ] = { var } _n{ candidate } ;\n "
161
- continue
162
153
163
- # In this case any non-size 1 variable will define the right size
164
- dims_c_code += f"{ array_name } [{ i } ] = "
165
- for idx , candidate in nonx_candidates [:- 1 ]:
166
- var = sub [f"lv{ int (idx )} " ]
167
- dims_c_code += f"({ var } _n{ candidate } != 1)? { var } _n{ candidate } : "
168
- idx , candidate = nonx_candidates [- 1 ]
169
- var = sub [f"lv{ idx } " ]
170
- dims_c_code += f"{ var } _n{ candidate } ;\n "
171
154
return dims_c_code
172
155
173
156
@@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
186
169
if type .startswith ("PYTENSOR_COMPLEX" ):
187
170
type = type .replace ("PYTENSOR_COMPLEX" , "NPY_COMPLEX" )
188
171
nd = len (loop_orders [0 ])
189
- init_dims = compute_broadcast_dimensions ("dims" , loop_orders , sub )
172
+ init_dims = compute_outputs_dims ("dims" , loop_orders , sub )
190
173
191
174
# TODO: it would be interesting to allocate the output in such a
192
175
# way that its contiguous dimensions match one of the input's
@@ -359,7 +342,7 @@ def make_reordered_loop(
359
342
360
343
# Get the (sorted) total number of iterations of each loop
361
344
declare_totals = f"int init_totals[{ nnested } ];\n "
362
- declare_totals += compute_broadcast_dimensions ("init_totals" , init_loop_orders , sub )
345
+ declare_totals += compute_outputs_dims ("init_totals" , init_loop_orders , sub )
363
346
364
347
# Sort totals to match the new order that was computed by sorting
365
348
# the loop vector. One integer variable per loop is declared.
0 commit comments