2
2
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
3
3
"""
4
4
from collections import defaultdict
5
- from functools import lru_cache
6
- from typing import Tuple , Type , Union , List
5
+ from typing import Tuple , Union , List
7
6
8
7
import pytest
9
8
from hypothesis import assume , given , reject
12
11
from . import _array_module as xp
13
12
from . import dtype_helpers as dh
14
13
from . import hypothesis_helpers as hh
14
+ from . import pytest_helpers as ph
15
15
from . import xps
16
+ from .typing import DataType , ScalarType , Param
16
17
from .function_stubs import elementwise_functions
17
- from .pytest_helpers import nargs
18
-
19
-
20
- DT = Type
21
- ScalarType = Union [Type [bool ], Type [int ], Type [float ]]
22
- Param = Tuple
23
-
24
-
25
- @lru_cache
26
- def fmt_types (types : Tuple [Union [DT , ScalarType ], ...]) -> str :
27
- f_types = []
28
- for type_ in types :
29
- try :
30
- f_types .append (dh .dtype_to_name [type_ ])
31
- except KeyError :
32
- # i.e. dtype is bool, int, or float
33
- f_types .append (type_ .__name__ )
34
- return ', ' .join (f_types )
35
-
36
-
37
- def assert_dtype (test_case : str , result_name : str , dtype : DT , expected : DT ):
38
- msg = (
39
- f'{ result_name } ={ dh .dtype_to_name [dtype ]} , '
40
- f'but should be { dh .dtype_to_name [expected ]} [{ test_case } ]'
41
- )
42
- assert dtype == expected , msg
43
18
44
19
45
20
@given (hh .mutually_promotable_dtypes (None ))
46
21
def test_result_type (dtypes ):
47
22
out = xp .result_type (* dtypes )
48
- assert_dtype (
49
- f'result_type({ fmt_types (dtypes )} )' , 'out' , out , dh .result_type (* dtypes )
23
+ ph . assert_dtype (
24
+ f'result_type({ dh . fmt_types (dtypes )} )' , 'out' , out , dh .result_type (* dtypes )
50
25
)
51
26
52
27
@@ -62,9 +37,9 @@ def test_meshgrid(dtypes, data):
62
37
arrays .append (x )
63
38
out = xp .meshgrid (* arrays )
64
39
expected = dh .result_type (* dtypes )
65
- test_case = f'meshgrid({ fmt_types (dtypes )} )'
40
+ test_case = f'meshgrid({ dh . fmt_types (dtypes )} )'
66
41
for i , x in enumerate (out ):
67
- assert_dtype (test_case , f'out[{ i } ].dtype' , x .dtype , expected )
42
+ ph . assert_dtype (test_case , f'out[{ i } ].dtype' , x .dtype , expected )
68
43
69
44
70
45
@given (
@@ -78,8 +53,11 @@ def test_concat(shape, dtypes, data):
78
53
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f'x{ i } ' )
79
54
arrays .append (x )
80
55
out = xp .concat (arrays )
81
- assert_dtype (
82
- f'concat({ fmt_types (dtypes )} )' , 'out.dtype' , out .dtype , dh .result_type (* dtypes )
56
+ ph .assert_dtype (
57
+ f'concat({ dh .fmt_types (dtypes )} )' ,
58
+ 'out.dtype' ,
59
+ out .dtype ,
60
+ dh .result_type (* dtypes ),
83
61
)
84
62
85
63
@@ -94,8 +72,11 @@ def test_stack(shape, dtypes, data):
94
72
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f'x{ i } ' )
95
73
arrays .append (x )
96
74
out = xp .stack (arrays )
97
- assert_dtype (
98
- f'stack({ fmt_types (dtypes )} )' , 'out.dtype' , out .dtype , dh .result_type (* dtypes )
75
+ ph .assert_dtype (
76
+ f'stack({ dh .fmt_types (dtypes )} )' ,
77
+ 'out.dtype' ,
78
+ out .dtype ,
79
+ dh .result_type (* dtypes ),
99
80
)
100
81
101
82
@@ -117,17 +98,19 @@ def test_stack(shape, dtypes, data):
117
98
118
99
119
100
def make_id (
120
- func_name : str , in_dtypes : Tuple [Union [DT , ScalarType ], ...], out_dtype : DT
101
+ func_name : str ,
102
+ in_dtypes : Tuple [Union [DataType , ScalarType ], ...],
103
+ out_dtype : DataType ,
121
104
) -> str :
122
- f_args = fmt_types (in_dtypes )
105
+ f_args = dh . fmt_types (in_dtypes )
123
106
f_out_dtype = dh .dtype_to_name [out_dtype ]
124
107
return f'{ func_name } ({ f_args } ) -> { f_out_dtype } '
125
108
126
109
127
- func_params : List [Param [str , Tuple [DT , ...], DT ]] = []
110
+ func_params : List [Param [str , Tuple [DataType , ...], DataType ]] = []
128
111
for func_name in elementwise_functions .__all__ :
129
112
valid_in_dtypes = dh .func_in_dtypes [func_name ]
130
- ndtypes = nargs (func_name )
113
+ ndtypes = ph . nargs (func_name )
131
114
if ndtypes == 1 :
132
115
for in_dtype in valid_in_dtypes :
133
116
out_dtype = xp .bool if dh .func_returns_bool [func_name ] else in_dtype
@@ -180,12 +163,12 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
180
163
out = func (* arrays )
181
164
except OverflowError :
182
165
reject ()
183
- assert_dtype (
184
- f'{ func_name } ({ fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
166
+ ph . assert_dtype (
167
+ f'{ func_name } ({ dh . fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
185
168
)
186
169
187
170
188
- promotion_params : List [Param [Tuple [DT , DT ], DT ]] = []
171
+ promotion_params : List [Param [Tuple [DataType , DataType ], DataType ]] = []
189
172
for (dtype1 , dtype2 ), promoted_dtype in dh .promotion_table .items ():
190
173
p = pytest .param (
191
174
(dtype1 , dtype2 ),
@@ -202,7 +185,9 @@ def test_where(in_dtypes, out_dtype, shapes, data):
202
185
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
203
186
cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = 'condition' )
204
187
out = xp .where (cond , x1 , x2 )
205
- assert_dtype (f'where({ fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype )
188
+ ph .assert_dtype (
189
+ f'where({ dh .fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
190
+ )
206
191
207
192
208
193
numeric_promotion_params = promotion_params [1 :]
@@ -214,8 +199,8 @@ def test_tensordot(in_dtypes, out_dtype, shapes, data):
214
199
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
215
200
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
216
201
out = xp .tensordot (x1 , x2 )
217
- assert_dtype (
218
- f'tensordot({ fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
202
+ ph . assert_dtype (
203
+ f'tensordot({ dh . fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
219
204
)
220
205
221
206
@@ -225,16 +210,18 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
225
210
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
226
211
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
227
212
out = xp .vecdot (x1 , x2 )
228
- assert_dtype (f'vecdot({ fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype )
213
+ ph .assert_dtype (
214
+ f'vecdot({ dh .fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
215
+ )
229
216
230
217
231
- op_params : List [Param [str , str , Tuple [DT , ...], DT ]] = []
218
+ op_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
232
219
op_to_symbol = {** dh .unary_op_to_symbol , ** dh .binary_op_to_symbol }
233
220
for op , symbol in op_to_symbol .items ():
234
221
if op == '__matmul__' :
235
222
continue
236
223
valid_in_dtypes = dh .func_in_dtypes [op ]
237
- ndtypes = nargs (op )
224
+ ndtypes = ph . nargs (op )
238
225
if ndtypes == 1 :
239
226
for in_dtype in valid_in_dtypes :
240
227
out_dtype = xp .bool if dh .func_returns_bool [op ] else in_dtype
@@ -293,10 +280,12 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
293
280
out = eval (expr , locals_ )
294
281
except OverflowError :
295
282
reject ()
296
- assert_dtype (f'{ op } ({ fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype )
283
+ ph .assert_dtype (
284
+ f'{ op } ({ dh .fmt_types (in_dtypes )} )' , 'out.dtype' , out .dtype , out_dtype
285
+ )
297
286
298
287
299
- inplace_params : List [Param [str , str , Tuple [DT , ...], DT ]] = []
288
+ inplace_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
300
289
for op , symbol in dh .inplace_op_to_symbol .items ():
301
290
if op == '__imatmul__' :
302
291
continue
@@ -334,10 +323,10 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
334
323
except OverflowError :
335
324
reject ()
336
325
x1 = locals_ ['x1' ]
337
- assert_dtype (f'{ op } ({ fmt_types (in_dtypes )} )' , 'x1.dtype' , x1 .dtype , out_dtype )
326
+ ph . assert_dtype (f'{ op } ({ dh . fmt_types (in_dtypes )} )' , 'x1.dtype' , x1 .dtype , out_dtype )
338
327
339
328
340
- op_scalar_params : List [Param [str , str , DT , ScalarType , DT ]] = []
329
+ op_scalar_params : List [Param [str , str , DataType , ScalarType , DataType ]] = []
341
330
for op , symbol in dh .binary_op_to_symbol .items ():
342
331
if op == '__matmul__' :
343
332
continue
@@ -368,12 +357,12 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
368
357
out = eval (expr , {'x' : x , 's' : s })
369
358
except OverflowError :
370
359
reject ()
371
- assert_dtype (
372
- f'{ op } ({ fmt_types ((in_dtype , in_stype ))} )' , 'out.dtype' , out .dtype , out_dtype
360
+ ph . assert_dtype (
361
+ f'{ op } ({ dh . fmt_types ((in_dtype , in_stype ))} )' , 'out.dtype' , out .dtype , out_dtype
373
362
)
374
363
375
364
376
- inplace_scalar_params : List [Param [str , str , DT , ScalarType ]] = []
365
+ inplace_scalar_params : List [Param [str , str , DataType , ScalarType ]] = []
377
366
for op , symbol in dh .inplace_op_to_symbol .items ():
378
367
if op == '__imatmul__' :
379
368
continue
@@ -405,7 +394,9 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
405
394
reject ()
406
395
x = locals_ ['x' ]
407
396
assert x .dtype == dtype , f'{ x .dtype = !s} , but should be { dtype } '
408
- assert_dtype (f'{ op } ({ fmt_types ((dtype , in_stype ))} )' , 'x.dtype' , x .dtype , dtype )
397
+ ph .assert_dtype (
398
+ f'{ op } ({ dh .fmt_types ((dtype , in_stype ))} )' , 'x.dtype' , x .dtype , dtype
399
+ )
409
400
410
401
411
402
if __name__ == '__main__' :
0 commit comments