Skip to content

Commit e0b6f0e

Browse files
Update test_sum_float in test_sum.py
1 parent 42b199d commit e0b6f0e

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tests/test_sum.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@ def test_sum_float(dtype):
2727
)
2828
ia = dpnp.array(a)
2929

30+
# Flag for type check in special cases
31+
# Skip dtype checks when dpnp handles float32 arrays
32+
# as `dpnp.sum()` and `numpy.sum()` return different dtypes
33+
check_dtype = dtype != dpnp.float32
3034
for axis in range(len(a)):
3135
result = dpnp.sum(ia, axis=axis)
3236
expected = numpy.sum(a, axis=axis)
33-
assert_dtype_allclose(result, expected)
37+
assert_dtype_allclose(result, expected, check_type=check_dtype)
38+
if not check_dtype:
39+
# Ensure dtype kind matches when check_dtype is False
40+
assert result.dtype.kind == expected.dtype.kind
3441

3542

3643
def test_sum_int():

0 commit comments

Comments
 (0)