Skip to content

Commit 9e597fb

Browse files
committed
Align sum and prod tests
1 parent 0ca4e10 commit 9e597fb

File tree

3 files changed

+26
-52
lines changed

3 files changed

+26
-52
lines changed

tests/test_arithmetic.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,6 @@
66
from tests.third_party.cupy import testing
77

88

9-
# Note: numpy.sum() always upcast integers to (u)int64 and float32 to
10-
# float64 for dtype=None. `np.sum` does that too for integers, but not for
11-
# float32, so we need to special-case it for these tests
12-
def _get_dtype_kwargs(xp, dtype):
13-
if xp is numpy and dtype == numpy.float32 and has_support_aspect64():
14-
return {"dtype": numpy.float64}
15-
return {}
16-
17-
189
class TestArithmetic(unittest.TestCase):
1910
@testing.for_float_dtypes()
2011
@testing.numpy_cupy_allclose()
@@ -42,7 +33,7 @@ def test_nanprod(self, xp, dtype):
4233
@testing.numpy_cupy_allclose()
4334
def test_nansum(self, xp, dtype):
4435
a = xp.array([-2.5, -1.5, xp.nan, 10.5, 1.5, xp.nan], dtype=dtype)
45-
return xp.nansum(a, **_get_dtype_kwargs(xp, a.dtype))
36+
return xp.nansum(a)
4637

4738
@testing.for_float_dtypes()
4839
@testing.numpy_cupy_allclose()

tests/test_sum.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,7 @@ def test_sum_axis():
6161
ia = dpnp.array(a)
6262

6363
result = dpnp.sum(ia, axis=1)
64-
if has_support_aspect64():
65-
expected = numpy.sum(a, axis=1, dtype=numpy.float64)
66-
else:
67-
expected = numpy.sum(a, axis=1)
64+
expected = numpy.sum(a, axis=1)
6865
assert_array_equal(expected, result)
6966

7067

tests/third_party/cupy/math_tests/test_sumprod.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,6 @@
66
from tests.third_party.cupy import testing
77

88

9-
# Note: numpy.sum() always upcast integers to (u)int64 and float32 to
10-
# float64 for dtype=None. `np.sum` does that too for integers, but not for
11-
# float32, so we need to special-case it for these tests
12-
def _get_dtype_kwargs(xp, dtype):
13-
if xp is numpy and dtype == numpy.float32 and has_support_aspect64():
14-
return {"dtype": numpy.float64}
15-
return {}
16-
17-
189
class TestSumprod:
1910
def tearDown(self):
2011
# Free huge memory for slow test
@@ -26,43 +17,43 @@ def tearDown(self):
2617
@testing.numpy_cupy_allclose()
2718
def test_sum_all(self, xp, dtype):
2819
a = testing.shaped_arange((2, 3, 4), xp, dtype)
29-
return a.sum(**_get_dtype_kwargs(xp, dtype))
20+
return a.sum()
3021

3122
@testing.for_all_dtypes()
3223
@testing.numpy_cupy_allclose()
3324
def test_sum_all_keepdims(self, xp, dtype):
3425
a = testing.shaped_arange((2, 3, 4), xp, dtype)
35-
return a.sum(**_get_dtype_kwargs(xp, dtype), keepdims=True)
26+
return a.sum(keepdims=True)
3627

3728
@testing.for_all_dtypes()
3829
@testing.numpy_cupy_allclose()
3930
def test_external_sum_all(self, xp, dtype):
4031
a = testing.shaped_arange((2, 3, 4), xp, dtype)
41-
return xp.sum(a, **_get_dtype_kwargs(xp, dtype))
32+
return xp.sum(a)
4233

4334
@testing.for_all_dtypes()
4435
@testing.numpy_cupy_allclose(rtol=1e-06)
4536
def test_sum_all2(self, xp, dtype):
4637
a = testing.shaped_arange((20, 30, 40), xp, dtype)
47-
return a.sum(**_get_dtype_kwargs(xp, dtype))
38+
return a.sum()
4839

4940
@testing.for_all_dtypes()
5041
@testing.numpy_cupy_allclose()
5142
def test_sum_all_transposed(self, xp, dtype):
5243
a = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(2, 0, 1)
53-
return a.sum(**_get_dtype_kwargs(xp, dtype))
44+
return a.sum()
5445

5546
@testing.for_all_dtypes()
5647
@testing.numpy_cupy_allclose(rtol=1e-06)
5748
def test_sum_all_transposed2(self, xp, dtype):
5849
a = testing.shaped_arange((20, 30, 40), xp, dtype).transpose(2, 0, 1)
59-
return a.sum(**_get_dtype_kwargs(xp, dtype))
50+
return a.sum()
6051

6152
@testing.for_all_dtypes()
6253
@testing.numpy_cupy_allclose()
6354
def test_sum_axis(self, xp, dtype):
6455
a = testing.shaped_arange((2, 3, 4), xp, dtype)
65-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=1)
56+
return a.sum(axis=1)
6657

6758
@testing.slow
6859
@testing.numpy_cupy_allclose()
@@ -74,57 +65,57 @@ def test_sum_axis_huge(self, xp):
7465
@testing.numpy_cupy_allclose()
7566
def test_external_sum_axis(self, xp, dtype):
7667
a = testing.shaped_arange((2, 3, 4), xp, dtype)
77-
return xp.sum(a, **_get_dtype_kwargs(xp, dtype), axis=1)
68+
return xp.sum(a, axis=1)
7869

7970
# float16 is omitted, since NumPy's sum on float16 arrays has more error
8071
# than CuPy's.
8172
@testing.for_all_dtypes(no_float16=True)
8273
@testing.numpy_cupy_allclose()
8374
def test_sum_axis2(self, xp, dtype):
8475
a = testing.shaped_arange((20, 30, 40), xp, dtype)
85-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=1)
76+
return a.sum(axis=1)
8677

8778
@testing.for_all_dtypes()
8879
@testing.numpy_cupy_allclose(contiguous_check=False)
8980
def test_sum_axis_transposed(self, xp, dtype):
9081
a = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(2, 0, 1)
91-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=1)
82+
return a.sum(axis=1)
9283

9384
@testing.for_all_dtypes()
9485
@testing.numpy_cupy_allclose(contiguous_check=False)
9586
def test_sum_axis_transposed2(self, xp, dtype):
9687
a = testing.shaped_arange((20, 30, 40), xp, dtype).transpose(2, 0, 1)
97-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=1)
88+
return a.sum(axis=1)
9889

9990
@testing.for_all_dtypes()
10091
@testing.numpy_cupy_allclose()
10192
def test_sum_axes(self, xp, dtype):
10293
a = testing.shaped_arange((2, 3, 4, 5), xp, dtype)
103-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=(1, 3))
94+
return a.sum(axis=(1, 3))
10495

10596
@testing.for_all_dtypes()
10697
@testing.numpy_cupy_allclose(rtol=1e-4)
10798
def test_sum_axes2(self, xp, dtype):
10899
a = testing.shaped_arange((20, 30, 40, 50), xp, dtype)
109-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=(1, 3))
100+
return a.sum(axis=(1, 3))
110101

111102
@testing.for_all_dtypes()
112103
@testing.numpy_cupy_allclose(rtol=1e-6)
113104
def test_sum_axes3(self, xp, dtype):
114105
a = testing.shaped_arange((2, 3, 4, 5), xp, dtype)
115-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=(0, 2, 3))
106+
return a.sum(axis=(0, 2, 3))
116107

117108
@testing.for_all_dtypes()
118109
@testing.numpy_cupy_allclose(rtol=1e-6)
119110
def test_sum_axes4(self, xp, dtype):
120111
a = testing.shaped_arange((20, 30, 40, 50), xp, dtype)
121-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=(0, 2, 3))
112+
return a.sum(axis=(0, 2, 3))
122113

123114
@testing.for_all_dtypes()
124115
@testing.numpy_cupy_allclose()
125116
def test_sum_empty_axis(self, xp, dtype):
126117
a = testing.shaped_arange((2, 3, 4, 5), xp, dtype)
127-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=())
118+
return a.sum(axis=())
128119

129120
@testing.for_all_dtypes_combination(names=["src_dtype", "dst_dtype"])
130121
@testing.numpy_cupy_allclose()
@@ -142,7 +133,7 @@ def test_sum_keepdims_and_dtype(self, xp, src_dtype, dst_dtype):
142133
@testing.numpy_cupy_allclose()
143134
def test_sum_keepdims_multiple_axes(self, xp, dtype):
144135
a = testing.shaped_arange((2, 3, 4), xp, dtype)
145-
return a.sum(**_get_dtype_kwargs(xp, dtype), axis=(1, 2), keepdims=True)
136+
return a.sum(axis=(1, 2), keepdims=True)
146137

147138
@testing.for_all_dtypes()
148139
@testing.numpy_cupy_allclose()
@@ -162,25 +153,25 @@ def test_sum_out_wrong_shape(self):
162153
@testing.numpy_cupy_allclose()
163154
def test_prod_all(self, xp, dtype):
164155
a = testing.shaped_arange((2, 3), xp, dtype)
165-
return a.prod(**_get_dtype_kwargs(xp, dtype))
156+
return a.prod()
166157

167158
@testing.for_all_dtypes()
168159
@testing.numpy_cupy_allclose()
169160
def test_external_prod_all(self, xp, dtype):
170161
a = testing.shaped_arange((2, 3), xp, dtype)
171-
return xp.prod(a, **_get_dtype_kwargs(xp, dtype))
162+
return xp.prod(a)
172163

173164
@testing.for_all_dtypes()
174165
@testing.numpy_cupy_allclose()
175166
def test_prod_axis(self, xp, dtype):
176167
a = testing.shaped_arange((2, 3, 4), xp, dtype)
177-
return a.prod(axis=1, **_get_dtype_kwargs(xp, dtype))
168+
return a.prod(axis=1)
178169

179170
@testing.for_all_dtypes()
180171
@testing.numpy_cupy_allclose()
181172
def test_external_prod_axis(self, xp, dtype):
182173
a = testing.shaped_arange((2, 3, 4), xp, dtype)
183-
return xp.prod(a, axis=1, **_get_dtype_kwargs(xp, dtype))
174+
return xp.prod(a, axis=1)
184175

185176
@testing.for_all_dtypes_combination(names=["src_dtype", "dst_dtype"])
186177
@testing.numpy_cupy_allclose()
@@ -228,12 +219,7 @@ def _test(self, xp, dtype):
228219
if not issubclass(dtype, xp.integer):
229220
a[:, 1] = xp.nan
230221
func = getattr(xp, self.func)
231-
return func(
232-
a,
233-
**_get_dtype_kwargs(xp, dtype),
234-
axis=self.axis,
235-
keepdims=self.keepdims,
236-
)
222+
return func(a, axis=self.axis, keepdims=self.keepdims)
237223

238224
@testing.for_all_dtypes(no_bool=True, no_float16=True)
239225
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
@@ -299,15 +285,15 @@ def test_nansum_axes(self, xp, dtype):
299285
a = testing.shaped_arange(self.shape, xp, dtype)
300286
if not issubclass(dtype, xp.integer):
301287
a[:, 1] = xp.nan
302-
return xp.nansum(a, **_get_dtype_kwargs(xp, dtype), axis=self.axis)
288+
return xp.nansum(a, axis=self.axis)
303289

304290

305291
class TestNansumNanprodHuge:
306292
def _test(self, xp, nan_slice):
307293
a = testing.shaped_random((2048, 1, 1024), xp, "f")
308294
a[nan_slice] = xp.nan
309295
a = xp.broadcast_to(a, (2048, 256, 1024))
310-
return xp.nansum(a, **_get_dtype_kwargs(xp, a.dtype), axis=2)
296+
return xp.nansum(a, axis=2)
311297

312298
@testing.slow
313299
@testing.numpy_cupy_allclose(atol=1e-1)

0 commit comments

Comments
 (0)