Skip to content

Commit 93a4cb2

Browse files
committed
Adds tests for reduce_hypot and logsumexp
1 parent 725e2d1 commit 93a4cb2

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,32 @@
1818

1919
import numpy as np
2020
import pytest
21+
from numpy.testing import assert_allclose
2122

2223
import dpctl.tensor as dpt
2324
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2425

26+
_no_complex_dtypes = [
27+
"?",
28+
"i1",
29+
"u1",
30+
"i2",
31+
"u2",
32+
"i4",
33+
"u4",
34+
"i8",
35+
"u8",
36+
"f2",
37+
"f4",
38+
"f8",
39+
]
40+
41+
42+
_all_dtypes = _no_complex_dtypes + [
43+
"c8",
44+
"c16",
45+
]
46+
2547

2648
def test_max_min_axis():
2749
get_queue_or_skip()
@@ -234,3 +256,123 @@ def test_reduction_arg_validation():
234256
dpt.max(x)
235257
with pytest.raises(ValueError):
236258
dpt.argmax(x)
259+
260+
261+
@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:])
262+
def test_logsumexp_arg_dtype_default_output_dtype_matrix(arg_dtype):
263+
q = get_queue_or_skip()
264+
skip_if_dtype_not_supported(arg_dtype, q)
265+
266+
m = dpt.ones(100, dtype=arg_dtype)
267+
r = dpt.logsumexp(m)
268+
269+
assert isinstance(r, dpt.usm_ndarray)
270+
assert r.dtype.kind == "f"
271+
tol = dpt.finfo(r.dtype).resolution
272+
assert_allclose(
273+
dpt.asnumpy(r),
274+
np.logaddexp.reduce(dpt.asnumpy(m), dtype=r.dtype),
275+
rtol=tol,
276+
atol=tol,
277+
)
278+
279+
280+
def test_logsumexp_empty():
281+
get_queue_or_skip()
282+
x = dpt.empty((0,), dtype="f4")
283+
y = dpt.logsumexp(x)
284+
assert y.shape == tuple()
285+
assert y == -dpt.inf
286+
287+
288+
def test_logsumexp_axis():
289+
get_queue_or_skip()
290+
291+
m = dpt.ones((3, 4, 5, 6, 7), dtype="f4")
292+
s = dpt.logsumexp(m, axis=(1, 2, -1))
293+
294+
assert isinstance(s, dpt.usm_ndarray)
295+
assert s.shape == (3, 6)
296+
tol = dpt.finfo(s.dtype).resolution
297+
assert_allclose(
298+
dpt.asnumpy(s),
299+
np.logaddexp.reduce(dpt.asnumpy(m), axis=(1, 2, -1), dtype=s.dtype),
300+
rtol=tol,
301+
atol=tol,
302+
)
303+
304+
305+
@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:])
306+
@pytest.mark.parametrize("out_dtype", _all_dtypes[1:])
307+
def test_logsumexp_arg_out_dtype_matrix(arg_dtype, out_dtype):
308+
q = get_queue_or_skip()
309+
skip_if_dtype_not_supported(arg_dtype, q)
310+
skip_if_dtype_not_supported(out_dtype, q)
311+
312+
m = dpt.ones(100, dtype=arg_dtype)
313+
r = dpt.logsumexp(m, dtype=out_dtype)
314+
315+
assert isinstance(r, dpt.usm_ndarray)
316+
assert r.dtype == dpt.dtype(out_dtype)
317+
318+
319+
def test_logsumexp_keepdims():
320+
get_queue_or_skip()
321+
322+
m = dpt.ones((3, 4, 5, 6, 7), dtype="i4")
323+
s = dpt.logsumexp(m, axis=(1, 2, -1), keepdims=True)
324+
325+
assert isinstance(s, dpt.usm_ndarray)
326+
assert s.shape == (3, 1, 1, 6, 1)
327+
328+
329+
def test_logsumexp_scalar():
330+
get_queue_or_skip()
331+
332+
m = dpt.ones(())
333+
s = dpt.logsumexp(m)
334+
335+
assert isinstance(s, dpt.usm_ndarray)
336+
assert m.sycl_queue == s.sycl_queue
337+
assert s.shape == ()
338+
339+
340+
@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:])
341+
def test_hypot_arg_dtype_default_output_dtype_matrix(arg_dtype):
342+
q = get_queue_or_skip()
343+
skip_if_dtype_not_supported(arg_dtype, q)
344+
345+
m = dpt.ones(100, dtype=arg_dtype)
346+
r = dpt.reduce_hypot(m)
347+
348+
assert isinstance(r, dpt.usm_ndarray)
349+
assert r.dtype.kind == "f"
350+
tol = dpt.finfo(r.dtype).resolution
351+
assert_allclose(
352+
dpt.asnumpy(r),
353+
np.hypot.reduce(dpt.asnumpy(m), dtype=r.dtype),
354+
rtol=tol,
355+
atol=tol,
356+
)
357+
358+
359+
def test_hypot_empty():
360+
get_queue_or_skip()
361+
x = dpt.empty((0,), dtype="f4")
362+
y = dpt.reduce_hypot(x)
363+
assert y.shape == tuple()
364+
assert y == 0
365+
366+
367+
@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:])
368+
@pytest.mark.parametrize("out_dtype", _all_dtypes[1:])
369+
def test_hypot_arg_out_dtype_matrix(arg_dtype, out_dtype):
370+
q = get_queue_or_skip()
371+
skip_if_dtype_not_supported(arg_dtype, q)
372+
skip_if_dtype_not_supported(out_dtype, q)
373+
374+
m = dpt.ones(100, dtype=arg_dtype)
375+
r = dpt.reduce_hypot(m, dtype=out_dtype)
376+
377+
assert isinstance(r, dpt.usm_ndarray)
378+
assert r.dtype == dpt.dtype(out_dtype)

0 commit comments

Comments
 (0)