1
1
"""
2
2
https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
3
3
"""
4
- import math
5
4
from collections import defaultdict
6
- from typing import Tuple , Union , List
5
+ from typing import List , Tuple , Union
7
6
8
7
import pytest
9
8
from hypothesis import assume , given , reject
14
13
from . import hypothesis_helpers as hh
15
14
from . import pytest_helpers as ph
16
15
from . import xps
17
- from .typing import DataType , ScalarType , Param
18
16
from .function_stubs import elementwise_functions
19
-
17
+ from . typing import DataType , Param , ScalarType
20
18
21
19
# TODO: move tests not covering elementwise funcs/ops into standalone tests
22
20
# result_type, meshgrid, tensordor, vecdot
@@ -28,29 +26,6 @@ def test_result_type(dtypes):
28
26
ph .assert_dtype ("result_type" , dtypes , out , repr_name = "out" )
29
27
30
28
31
- # The number and size of generated arrays is arbitrarily limited to prevent
32
- # meshgrid() running out of memory.
33
- @given (
34
- dtypes = hh .mutually_promotable_dtypes (5 , dtypes = dh .numeric_dtypes ),
35
- data = st .data (),
36
- )
37
- def test_meshgrid (dtypes , data ):
38
- arrays = []
39
- shapes = data .draw (
40
- hh .mutually_broadcastable_shapes (
41
- len (dtypes ), min_dims = 1 , max_dims = 1 , max_side = 5
42
- ),
43
- label = "shapes" ,
44
- )
45
- for i , (dtype , shape ) in enumerate (zip (dtypes , shapes ), 1 ):
46
- x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f"x{ i } " )
47
- arrays .append (x )
48
- assert math .prod (x .size for x in arrays ) <= hh .MAX_ARRAY_SIZE # sanity check
49
- out = xp .meshgrid (* arrays )
50
- for i , x in enumerate (out ):
51
- ph .assert_dtype ("meshgrid" , dtypes , x .dtype , repr_name = f"out[{ i } ].dtype" )
52
-
53
-
54
29
bitwise_shift_funcs = [
55
30
"bitwise_left_shift" ,
56
31
"bitwise_right_shift" ,
@@ -78,6 +53,14 @@ def make_id(
78
53
return f"{ func_name } ({ f_args } ) -> { f_out_dtype } "
79
54
80
55
56
+ def mark_stubbed_dtypes (* dtypes ):
57
+ for dtype in dtypes :
58
+ if isinstance (dtype , xp ._UndefinedStub ):
59
+ return pytest .mark .skip (reason = f"xp.{ dtype .name } not defined" )
60
+ else :
61
+ return ()
62
+
63
+
81
64
func_params : List [Param [str , Tuple [DataType , ...], DataType ]] = []
82
65
for func_name in elementwise_functions .__all__ :
83
66
valid_in_dtypes = dh .func_in_dtypes [func_name ]
@@ -90,6 +73,7 @@ def make_id(
90
73
(in_dtype ,),
91
74
out_dtype ,
92
75
id = make_id (func_name , (in_dtype ,), out_dtype ),
76
+ marks = mark_stubbed_dtypes (in_dtype , out_dtype ),
93
77
)
94
78
func_params .append (p )
95
79
elif ndtypes == 2 :
@@ -103,6 +87,7 @@ def make_id(
103
87
(in_dtype1 , in_dtype2 ),
104
88
out_dtype ,
105
89
id = make_id (func_name , (in_dtype1 , in_dtype2 ), out_dtype ),
90
+ marks = mark_stubbed_dtypes (in_dtype1 , in_dtype2 , out_dtype ),
106
91
)
107
92
func_params .append (p )
108
93
else :
@@ -143,6 +128,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
143
128
(dtype1 , dtype2 ),
144
129
promoted_dtype ,
145
130
id = make_id ("" , (dtype1 , dtype2 ), promoted_dtype ),
131
+ marks = mark_stubbed_dtypes (dtype1 , dtype2 , promoted_dtype ),
146
132
)
147
133
promotion_params .append (p )
148
134
@@ -194,6 +180,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
194
180
(in_dtype ,),
195
181
out_dtype ,
196
182
id = make_id (op , (in_dtype ,), out_dtype ),
183
+ marks = mark_stubbed_dtypes (in_dtype , out_dtype ),
197
184
)
198
185
op_params .append (p )
199
186
else :
@@ -206,6 +193,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
206
193
(in_dtype1 , in_dtype2 ),
207
194
out_dtype ,
208
195
id = make_id (op , (in_dtype1 , in_dtype2 ), out_dtype ),
196
+ marks = mark_stubbed_dtypes (in_dtype1 , in_dtype2 , out_dtype ),
209
197
)
210
198
op_params .append (p )
211
199
# We generate params for abs seperately as it does not have an associated symbol
@@ -216,6 +204,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
216
204
(in_dtype ,),
217
205
in_dtype ,
218
206
id = make_id ("__abs__" , (in_dtype ,), in_dtype ),
207
+ marks = mark_stubbed_dtypes (in_dtype ),
219
208
)
220
209
op_params .append (p )
221
210
@@ -263,6 +252,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
263
252
(in_dtype1 , in_dtype2 ),
264
253
promoted_dtype ,
265
254
id = make_id (op , (in_dtype1 , in_dtype2 ), promoted_dtype ),
255
+ marks = mark_stubbed_dtypes (in_dtype1 , in_dtype2 , promoted_dtype ),
266
256
)
267
257
inplace_params .append (p )
268
258
@@ -301,6 +291,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
301
291
in_stype ,
302
292
out_dtype ,
303
293
id = make_id (op , (in_dtype , in_stype ), out_dtype ),
294
+ marks = mark_stubbed_dtypes (in_dtype , out_dtype ),
304
295
)
305
296
op_scalar_params .append (p )
306
297
@@ -333,6 +324,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
333
324
dtype ,
334
325
in_stype ,
335
326
id = make_id (op , (dtype , in_stype ), dtype ),
327
+ marks = mark_stubbed_dtypes (dtype ),
336
328
)
337
329
inplace_scalar_params .append (p )
338
330
0 commit comments