Skip to content

Commit ea30077

Browse files
committed
Align logsumexp and reduce_hypot tests
1 parent 9e597fb commit ea30077

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

tests/test_mathematical.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,9 @@ class TestLogSumExp:
18841884
def test_logsumexp(self, dtype, axis, keepdims):
18851885
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
18861886
res = dpnp.logsumexp(a, axis=axis, keepdims=keepdims)
1887-
exp_dtype = dpnp.default_float_type(a.device)
1887+
exp_dtype = (
1888+
dpnp.default_float_type(a.device) if dtype == dpnp.bool else None
1889+
)
18881890
exp = numpy.logaddexp.reduce(
18891891
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
18901892
)
@@ -1896,11 +1898,17 @@ def test_logsumexp(self, dtype, axis, keepdims):
18961898
@pytest.mark.parametrize("keepdims", [True, False])
18971899
def test_logsumexp_out(self, dtype, axis, keepdims):
18981900
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
1899-
exp_dtype = dpnp.default_float_type(a.device)
1901+
exp_dtype = (
1902+
dpnp.default_float_type(a.device) if dtype == dpnp.bool else None
1903+
)
19001904
exp = numpy.logaddexp.reduce(
19011905
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
19021906
)
1903-
dpnp_out = dpnp.empty(exp.shape, dtype=exp_dtype)
1907+
1908+
exp_dtype = exp.dtype
1909+
if exp_dtype == numpy.float64 and not has_support_aspect64():
1910+
exp_dtype = numpy.float32
1911+
dpnp_out = dpnp.empty_like(a, shape=exp.shape, dtype=exp_dtype)
19041912
res = dpnp.logsumexp(a, axis=axis, out=dpnp_out, keepdims=keepdims)
19051913

19061914
assert res is dpnp_out
@@ -1926,7 +1934,9 @@ class TestReduceHypot:
19261934
def test_reduce_hypot(self, dtype, axis, keepdims):
19271935
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
19281936
res = dpnp.reduce_hypot(a, axis=axis, keepdims=keepdims)
1929-
exp_dtype = dpnp.default_float_type(a.device)
1937+
exp_dtype = (
1938+
dpnp.default_float_type(a.device) if dtype == dpnp.bool else None
1939+
)
19301940
exp = numpy.hypot.reduce(
19311941
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
19321942
)
@@ -1938,11 +1948,17 @@ def test_reduce_hypot(self, dtype, axis, keepdims):
19381948
@pytest.mark.parametrize("keepdims", [True, False])
19391949
def test_reduce_hypot_out(self, dtype, axis, keepdims):
19401950
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
1941-
exp_dtype = dpnp.default_float_type(a.device)
1951+
exp_dtype = (
1952+
dpnp.default_float_type(a.device) if dtype == dpnp.bool else None
1953+
)
19421954
exp = numpy.hypot.reduce(
19431955
dpnp.asnumpy(a), axis=axis, keepdims=keepdims, dtype=exp_dtype
19441956
)
1945-
dpnp_out = dpnp.empty(exp.shape, dtype=exp_dtype)
1957+
1958+
exp_dtype = exp.dtype
1959+
if exp_dtype == numpy.float64 and not has_support_aspect64():
1960+
exp_dtype = numpy.float32
1961+
dpnp_out = dpnp.empty_like(a, shape=exp.shape, dtype=exp_dtype)
19461962
res = dpnp.reduce_hypot(a, axis=axis, out=dpnp_out, keepdims=keepdims)
19471963

19481964
assert res is dpnp_out

0 commit comments

Comments
 (0)