1
1
"""
2
2
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
3
3
"""
4
+ import math
4
5
from collections import defaultdict
5
6
from typing import Tuple , Union , List
6
7
24
25
@given (hh .mutually_promotable_dtypes (None ))
25
26
def test_result_type (dtypes ):
26
27
out = xp .result_type (* dtypes )
27
- ph .assert_dtype (' result_type' , dtypes , out , out_name = ' out' )
28
+ ph .assert_dtype (" result_type" , dtypes , out , out_name = " out" )
28
29
29
30
31
+ # The number and size of generated arrays is arbitrarily limited to prevent
32
+ # meshgrid() running out of memory.
30
33
@given (
31
- dtypes = hh .mutually_promotable_dtypes (None , dtypes = dh .numeric_dtypes ),
34
+ dtypes = hh .mutually_promotable_dtypes (5 , dtypes = dh .numeric_dtypes ),
32
35
data = st .data (),
33
36
)
34
37
def test_meshgrid (dtypes , data ):
35
38
arrays = []
36
- shapes = data .draw (hh .mutually_broadcastable_shapes (len (dtypes )), label = 'shapes' )
39
+ shapes = data .draw (
40
+ hh .mutually_broadcastable_shapes (
41
+ len (dtypes ), min_dims = 1 , max_dims = 1 , max_side = 5
42
+ ),
43
+ label = "shapes" ,
44
+ )
37
45
for i , (dtype , shape ) in enumerate (zip (dtypes , shapes ), 1 ):
38
- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
46
+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
39
47
arrays .append (x )
48
+ assert math .prod (x .size for x in arrays ) <= hh .MAX_ARRAY_SIZE # sanity check
40
49
out = xp .meshgrid (* arrays )
41
50
for i , x in enumerate (out ):
42
- ph .assert_dtype (' meshgrid' , dtypes , x .dtype , out_name = f' out[{ i } ].dtype' )
51
+ ph .assert_dtype (" meshgrid" , dtypes , x .dtype , out_name = f" out[{ i } ].dtype" )
43
52
44
53
45
54
@given (
@@ -50,10 +59,10 @@ def test_meshgrid(dtypes, data):
50
59
def test_concat (shape , dtypes , data ):
51
60
arrays = []
52
61
for i , dtype in enumerate (dtypes , 1 ):
53
- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
62
+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
54
63
arrays .append (x )
55
64
out = xp .concat (arrays )
56
- ph .assert_dtype (' concat' , dtypes , out .dtype )
65
+ ph .assert_dtype (" concat" , dtypes , out .dtype )
57
66
58
67
59
68
@given (
@@ -64,26 +73,26 @@ def test_concat(shape, dtypes, data):
64
73
def test_stack (shape , dtypes , data ):
65
74
arrays = []
66
75
for i , dtype in enumerate (dtypes , 1 ):
67
- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f' x{ i } ' )
76
+ x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f" x{ i } " )
68
77
arrays .append (x )
69
78
out = xp .stack (arrays )
70
- ph .assert_dtype (' stack' , dtypes , out .dtype )
79
+ ph .assert_dtype (" stack" , dtypes , out .dtype )
71
80
72
81
73
82
bitwise_shift_funcs = [
74
- ' bitwise_left_shift' ,
75
- ' bitwise_right_shift' ,
76
- ' __lshift__' ,
77
- ' __rshift__' ,
78
- ' __ilshift__' ,
79
- ' __irshift__' ,
83
+ " bitwise_left_shift" ,
84
+ " bitwise_right_shift" ,
85
+ " __lshift__" ,
86
+ " __rshift__" ,
87
+ " __ilshift__" ,
88
+ " __irshift__" ,
80
89
]
81
90
82
91
83
92
# We pass kwargs to the elements strategy used by xps.arrays() so that we don't
84
93
# generate array elements that are erroneous or undefined for a function.
85
94
func_elements = defaultdict (
86
- lambda : None , {func : {' min_value' : 1 } for func in bitwise_shift_funcs }
95
+ lambda : None , {func : {" min_value" : 1 } for func in bitwise_shift_funcs }
87
96
)
88
97
89
98
@@ -94,7 +103,7 @@ def make_id(
94
103
) -> str :
95
104
f_args = dh .fmt_types (in_dtypes )
96
105
f_out_dtype = dh .dtype_to_name [out_dtype ]
97
- return f' { func_name } ({ f_args } ) -> { f_out_dtype } '
106
+ return f" { func_name } ({ f_args } ) -> { f_out_dtype } "
98
107
99
108
100
109
func_params : List [Param [str , Tuple [DataType , ...], DataType ]] = []
@@ -128,25 +137,25 @@ def make_id(
128
137
raise NotImplementedError ()
129
138
130
139
131
- @pytest .mark .parametrize (' func_name, in_dtypes, out_dtype' , func_params )
140
+ @pytest .mark .parametrize (" func_name, in_dtypes, out_dtype" , func_params )
132
141
@given (data = st .data ())
133
142
def test_func_promotion (func_name , in_dtypes , out_dtype , data ):
134
143
func = getattr (xp , func_name )
135
144
elements = func_elements [func_name ]
136
145
if len (in_dtypes ) == 1 :
137
146
x = data .draw (
138
147
xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
139
- label = 'x' ,
148
+ label = "x" ,
140
149
)
141
150
out = func (x )
142
151
else :
143
152
arrays = []
144
153
shapes = data .draw (
145
- hh .mutually_broadcastable_shapes (len (in_dtypes )), label = ' shapes'
154
+ hh .mutually_broadcastable_shapes (len (in_dtypes )), label = " shapes"
146
155
)
147
156
for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
148
157
x = data .draw (
149
- xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f' x{ i } '
158
+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f" x{ i } "
150
159
)
151
160
arrays .append (x )
152
161
try :
@@ -161,46 +170,46 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
161
170
p = pytest .param (
162
171
(dtype1 , dtype2 ),
163
172
promoted_dtype ,
164
- id = make_id ('' , (dtype1 , dtype2 ), promoted_dtype ),
173
+ id = make_id ("" , (dtype1 , dtype2 ), promoted_dtype ),
165
174
)
166
175
promotion_params .append (p )
167
176
168
177
169
- @pytest .mark .parametrize (' in_dtypes, out_dtype' , promotion_params )
178
+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , promotion_params )
170
179
@given (shapes = hh .mutually_broadcastable_shapes (3 ), data = st .data ())
171
180
def test_where (in_dtypes , out_dtype , shapes , data ):
172
- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
173
- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
174
- cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = ' condition' )
181
+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
182
+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
183
+ cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = " condition" )
175
184
out = xp .where (cond , x1 , x2 )
176
- ph .assert_dtype (' where' , in_dtypes , out .dtype , out_dtype )
185
+ ph .assert_dtype (" where" , in_dtypes , out .dtype , out_dtype )
177
186
178
187
179
188
numeric_promotion_params = promotion_params [1 :]
180
189
181
190
182
- @pytest .mark .parametrize (' in_dtypes, out_dtype' , numeric_promotion_params )
191
+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , numeric_promotion_params )
183
192
@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 2 ), data = st .data ())
184
193
def test_tensordot (in_dtypes , out_dtype , shapes , data ):
185
- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
186
- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
194
+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
195
+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
187
196
out = xp .tensordot (x1 , x2 )
188
- ph .assert_dtype (' tensordot' , in_dtypes , out .dtype , out_dtype )
197
+ ph .assert_dtype (" tensordot" , in_dtypes , out .dtype , out_dtype )
189
198
190
199
191
- @pytest .mark .parametrize (' in_dtypes, out_dtype' , numeric_promotion_params )
200
+ @pytest .mark .parametrize (" in_dtypes, out_dtype" , numeric_promotion_params )
192
201
@given (shapes = hh .mutually_broadcastable_shapes (2 , min_dims = 1 ), data = st .data ())
193
202
def test_vecdot (in_dtypes , out_dtype , shapes , data ):
194
- x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
195
- x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
203
+ x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = "x1" )
204
+ x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = "x2" )
196
205
out = xp .vecdot (x1 , x2 )
197
- ph .assert_dtype (' vecdot' , in_dtypes , out .dtype , out_dtype )
206
+ ph .assert_dtype (" vecdot" , in_dtypes , out .dtype , out_dtype )
198
207
199
208
200
209
op_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
201
210
op_to_symbol = {** dh .unary_op_to_symbol , ** dh .binary_op_to_symbol }
202
211
for op , symbol in op_to_symbol .items ():
203
- if op == ' __matmul__' :
212
+ if op == " __matmul__" :
204
213
continue
205
214
valid_in_dtypes = dh .func_in_dtypes [op ]
206
215
ndtypes = ph .nargs (op )
@@ -209,7 +218,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
209
218
out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
210
219
p = pytest .param (
211
220
op ,
212
- f' { symbol } x' ,
221
+ f" { symbol } x" ,
213
222
(in_dtype ,),
214
223
out_dtype ,
215
224
id = make_id (op , (in_dtype ,), out_dtype ),
@@ -221,42 +230,42 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
221
230
out_dtype = xp .bool if dh .func_returns_bool [op ] else promoted_dtype
222
231
p = pytest .param (
223
232
op ,
224
- f' x1 { symbol } x2' ,
233
+ f" x1 { symbol } x2" ,
225
234
(in_dtype1 , in_dtype2 ),
226
235
out_dtype ,
227
236
id = make_id (op , (in_dtype1 , in_dtype2 ), out_dtype ),
228
237
)
229
238
op_params .append (p )
230
239
# We generate params for abs seperately as it does not have an associated symbol
231
- for in_dtype in dh .func_in_dtypes [' __abs__' ]:
240
+ for in_dtype in dh .func_in_dtypes [" __abs__" ]:
232
241
p = pytest .param (
233
- ' __abs__' ,
234
- ' abs(x)' ,
242
+ " __abs__" ,
243
+ " abs(x)" ,
235
244
(in_dtype ,),
236
245
in_dtype ,
237
- id = make_id (' __abs__' , (in_dtype ,), in_dtype ),
246
+ id = make_id (" __abs__" , (in_dtype ,), in_dtype ),
238
247
)
239
248
op_params .append (p )
240
249
241
250
242
- @pytest .mark .parametrize (' op, expr, in_dtypes, out_dtype' , op_params )
251
+ @pytest .mark .parametrize (" op, expr, in_dtypes, out_dtype" , op_params )
243
252
@given (data = st .data ())
244
253
def test_op_promotion (op , expr , in_dtypes , out_dtype , data ):
245
254
elements = func_elements [func_name ]
246
255
if len (in_dtypes ) == 1 :
247
256
x = data .draw (
248
257
xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
249
- label = 'x' ,
258
+ label = "x" ,
250
259
)
251
- out = eval (expr , {'x' : x })
260
+ out = eval (expr , {"x" : x })
252
261
else :
253
262
locals_ = {}
254
263
shapes = data .draw (
255
- hh .mutually_broadcastable_shapes (len (in_dtypes )), label = ' shapes'
264
+ hh .mutually_broadcastable_shapes (len (in_dtypes )), label = " shapes"
256
265
)
257
266
for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
258
- locals_ [f' x{ i } ' ] = data .draw (
259
- xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f' x{ i } '
267
+ locals_ [f" x{ i } " ] = data .draw (
268
+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f" x{ i } "
260
269
)
261
270
try :
262
271
out = eval (expr , locals_ )
@@ -267,7 +276,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
267
276
268
277
inplace_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
269
278
for op , symbol in dh .inplace_op_to_symbol .items ():
270
- if op == ' __imatmul__' :
279
+ if op == " __imatmul__" :
271
280
continue
272
281
valid_in_dtypes = dh .func_in_dtypes [op ]
273
282
for (in_dtype1 , in_dtype2 ), promoted_dtype in dh .promotion_table .items ():
@@ -278,44 +287,44 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
278
287
):
279
288
p = pytest .param (
280
289
op ,
281
- f' x1 { symbol } x2' ,
290
+ f" x1 { symbol } x2" ,
282
291
(in_dtype1 , in_dtype2 ),
283
292
promoted_dtype ,
284
293
id = make_id (op , (in_dtype1 , in_dtype2 ), promoted_dtype ),
285
294
)
286
295
inplace_params .append (p )
287
296
288
297
289
- @pytest .mark .parametrize (' op, expr, in_dtypes, out_dtype' , inplace_params )
298
+ @pytest .mark .parametrize (" op, expr, in_dtypes, out_dtype" , inplace_params )
290
299
@given (shapes = hh .mutually_broadcastable_shapes (2 ), data = st .data ())
291
300
def test_inplace_op_promotion (op , expr , in_dtypes , out_dtype , shapes , data ):
292
301
assume (len (shapes [0 ]) >= len (shapes [1 ]))
293
302
elements = func_elements [func_name ]
294
303
x1 = data .draw (
295
- xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = 'x1'
304
+ xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = "x1"
296
305
)
297
306
x2 = data .draw (
298
- xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = 'x2'
307
+ xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = "x2"
299
308
)
300
- locals_ = {'x1' : x1 , 'x2' : x2 }
309
+ locals_ = {"x1" : x1 , "x2" : x2 }
301
310
try :
302
311
exec (expr , locals_ )
303
312
except OverflowError :
304
313
reject ()
305
- x1 = locals_ ['x1' ]
306
- ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = ' x1.dtype' )
314
+ x1 = locals_ ["x1" ]
315
+ ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = " x1.dtype" )
307
316
308
317
309
318
op_scalar_params : List [Param [str , str , DataType , ScalarType , DataType ]] = []
310
319
for op , symbol in dh .binary_op_to_symbol .items ():
311
- if op == ' __matmul__' :
320
+ if op == " __matmul__" :
312
321
continue
313
322
for in_dtype in dh .func_in_dtypes [op ]:
314
323
out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
315
324
for in_stype in dh .dtype_to_scalars [in_dtype ]:
316
325
p = pytest .param (
317
326
op ,
318
- f' x { symbol } s' ,
327
+ f" x { symbol } s" ,
319
328
in_dtype ,
320
329
in_stype ,
321
330
out_dtype ,
@@ -324,57 +333,57 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
324
333
op_scalar_params .append (p )
325
334
326
335
327
- @pytest .mark .parametrize (' op, expr, in_dtype, in_stype, out_dtype' , op_scalar_params )
336
+ @pytest .mark .parametrize (" op, expr, in_dtype, in_stype, out_dtype" , op_scalar_params )
328
337
@given (data = st .data ())
329
338
def test_op_scalar_promotion (op , expr , in_dtype , in_stype , out_dtype , data ):
330
339
elements = func_elements [func_name ]
331
- kw = {k : in_stype is float for k in (' allow_nan' , ' allow_infinity' )}
332
- s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = ' scalar' )
340
+ kw = {k : in_stype is float for k in (" allow_nan" , " allow_infinity" )}
341
+ s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = " scalar" )
333
342
x = data .draw (
334
- xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = 'x'
343
+ xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = "x"
335
344
)
336
345
try :
337
- out = eval (expr , {'x' : x , 's' : s })
346
+ out = eval (expr , {"x" : x , "s" : s })
338
347
except OverflowError :
339
348
reject ()
340
349
ph .assert_dtype (op , (in_dtype , in_stype ), out .dtype , out_dtype )
341
350
342
351
343
352
inplace_scalar_params : List [Param [str , str , DataType , ScalarType ]] = []
344
353
for op , symbol in dh .inplace_op_to_symbol .items ():
345
- if op == ' __imatmul__' :
354
+ if op == " __imatmul__" :
346
355
continue
347
356
for dtype in dh .func_in_dtypes [op ]:
348
357
for in_stype in dh .dtype_to_scalars [dtype ]:
349
358
p = pytest .param (
350
359
op ,
351
- f' x { symbol } s' ,
360
+ f" x { symbol } s" ,
352
361
dtype ,
353
362
in_stype ,
354
363
id = make_id (op , (dtype , in_stype ), dtype ),
355
364
)
356
365
inplace_scalar_params .append (p )
357
366
358
367
359
- @pytest .mark .parametrize (' op, expr, dtype, in_stype' , inplace_scalar_params )
368
+ @pytest .mark .parametrize (" op, expr, dtype, in_stype" , inplace_scalar_params )
360
369
@given (data = st .data ())
361
370
def test_inplace_op_scalar_promotion (op , expr , dtype , in_stype , data ):
362
371
elements = func_elements [func_name ]
363
- kw = {k : in_stype is float for k in (' allow_nan' , ' allow_infinity' )}
364
- s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = ' scalar' )
372
+ kw = {k : in_stype is float for k in (" allow_nan" , " allow_infinity" )}
373
+ s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = " scalar" )
365
374
x = data .draw (
366
- xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = 'x'
375
+ xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = "x"
367
376
)
368
- locals_ = {'x' : x , 's' : s }
377
+ locals_ = {"x" : x , "s" : s }
369
378
try :
370
379
exec (expr , locals_ )
371
380
except OverflowError :
372
381
reject ()
373
- x = locals_ ['x' ]
374
- assert x .dtype == dtype , f' { x .dtype = !s} , but should be { dtype } '
375
- ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = ' x.dtype' )
382
+ x = locals_ ["x" ]
383
+ assert x .dtype == dtype , f" { x .dtype = !s} , but should be { dtype } "
384
+ ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = " x.dtype" )
376
385
377
386
378
- if __name__ == ' __main__' :
387
+ if __name__ == " __main__" :
379
388
for (i , j ), p in dh .promotion_table .items ():
380
- print (f' ({ i } , { j } ) -> { p } ' )
389
+ print (f" ({ i } , { j } ) -> { p } " )
0 commit comments