20
20
@given (hh .mutually_promotable_dtypes (None ))
21
21
def test_result_type (dtypes ):
22
22
out = xp .result_type (* dtypes )
23
- ph .assert_dtype ('result_type' , dtypes , ' out' , out , dh . result_type ( * dtypes ) )
23
+ ph .assert_dtype ('result_type' , dtypes , out , out_name = ' out' )
24
24
25
25
26
26
@given (
@@ -34,9 +34,8 @@ def test_meshgrid(dtypes, data):
34
34
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f'x{ i } ' )
35
35
arrays .append (x )
36
36
out = xp .meshgrid (* arrays )
37
- expected = dh .result_type (* dtypes )
38
37
for i , x in enumerate (out ):
39
- ph .assert_dtype ('meshgrid' , dtypes , f'out[{ i } ].dtype' , x . dtype , expected )
38
+ ph .assert_dtype ('meshgrid' , dtypes , x . dtype , out_name = f'out[{ i } ].dtype' )
40
39
41
40
42
41
@given (
@@ -50,7 +49,7 @@ def test_concat(shape, dtypes, data):
50
49
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f'x{ i } ' )
51
50
arrays .append (x )
52
51
out = xp .concat (arrays )
53
- ph .assert_dtype ('concat' , dtypes , ' out.dtype' , out . dtype , dh . result_type ( * dtypes ) )
52
+ ph .assert_dtype ('concat' , dtypes , out .dtype )
54
53
55
54
56
55
@given (
@@ -64,7 +63,7 @@ def test_stack(shape, dtypes, data):
64
63
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f'x{ i } ' )
65
64
arrays .append (x )
66
65
out = xp .stack (arrays )
67
- ph .assert_dtype ('stack' , dtypes , ' out.dtype' , out . dtype , dh . result_type ( * dtypes ) )
66
+ ph .assert_dtype ('stack' , dtypes , out .dtype )
68
67
69
68
70
69
bitwise_shift_funcs = [
@@ -150,7 +149,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
150
149
out = func (* arrays )
151
150
except OverflowError :
152
151
reject ()
153
- ph .assert_dtype (func_name , in_dtypes , 'out.dtype' , out .dtype , out_dtype )
152
+ ph .assert_dtype (func_name , in_dtypes , out .dtype , out_dtype )
154
153
155
154
156
155
promotion_params : List [Param [Tuple [DataType , DataType ], DataType ]] = []
@@ -170,7 +169,7 @@ def test_where(in_dtypes, out_dtype, shapes, data):
170
169
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
171
170
cond = data .draw (xps .arrays (dtype = xp .bool , shape = shapes [2 ]), label = 'condition' )
172
171
out = xp .where (cond , x1 , x2 )
173
- ph .assert_dtype ('where' , in_dtypes , 'out.dtype' , out .dtype , out_dtype )
172
+ ph .assert_dtype ('where' , in_dtypes , out .dtype , out_dtype )
174
173
175
174
176
175
numeric_promotion_params = promotion_params [1 :]
@@ -182,7 +181,7 @@ def test_tensordot(in_dtypes, out_dtype, shapes, data):
182
181
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
183
182
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
184
183
out = xp .tensordot (x1 , x2 )
185
- ph .assert_dtype ('tensordot' , in_dtypes , 'out.dtype' , out .dtype , out_dtype )
184
+ ph .assert_dtype ('tensordot' , in_dtypes , out .dtype , out_dtype )
186
185
187
186
188
187
@pytest .mark .parametrize ('in_dtypes, out_dtype' , numeric_promotion_params )
@@ -191,7 +190,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
191
190
x1 = data .draw (xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]), label = 'x1' )
192
191
x2 = data .draw (xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]), label = 'x2' )
193
192
out = xp .vecdot (x1 , x2 )
194
- ph .assert_dtype ('vecdot' , in_dtypes , 'out.dtype' , out .dtype , out_dtype )
193
+ ph .assert_dtype ('vecdot' , in_dtypes , out .dtype , out_dtype )
195
194
196
195
197
196
op_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
@@ -259,7 +258,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
259
258
out = eval (expr , locals_ )
260
259
except OverflowError :
261
260
reject ()
262
- ph .assert_dtype (op , in_dtypes , 'out.dtype' , out .dtype , out_dtype )
261
+ ph .assert_dtype (op , in_dtypes , out .dtype , out_dtype )
263
262
264
263
265
264
inplace_params : List [Param [str , str , Tuple [DataType , ...], DataType ]] = []
@@ -300,7 +299,7 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
300
299
except OverflowError :
301
300
reject ()
302
301
x1 = locals_ ['x1' ]
303
- ph .assert_dtype (op , in_dtypes , ' x1.dtype' , x1 .dtype , out_dtype )
302
+ ph .assert_dtype (op , in_dtypes , x1 .dtype , out_dtype , out_name = ' x1.dtype' )
304
303
305
304
306
305
op_scalar_params : List [Param [str , str , DataType , ScalarType , DataType ]] = []
@@ -334,7 +333,7 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
334
333
out = eval (expr , {'x' : x , 's' : s })
335
334
except OverflowError :
336
335
reject ()
337
- ph .assert_dtype (op , (in_dtype , in_stype ), 'out.dtype' , out .dtype , out_dtype )
336
+ ph .assert_dtype (op , (in_dtype , in_stype ), out .dtype , out_dtype )
338
337
339
338
340
339
inplace_scalar_params : List [Param [str , str , DataType , ScalarType ]] = []
@@ -369,7 +368,7 @@ def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
369
368
reject ()
370
369
x = locals_ ['x' ]
371
370
assert x .dtype == dtype , f'{ x .dtype = !s} , but should be { dtype } '
372
- ph .assert_dtype (op , (dtype , in_stype ), ' x.dtype' , x . dtype , dtype )
371
+ ph .assert_dtype (op , (dtype , in_stype ), x .dtype , dtype , out_name = 'x. dtype' )
373
372
374
373
375
374
if __name__ == '__main__' :
0 commit comments