Skip to content

Commit e5e7404

Browse files
committed
Adds new tests for reductions
1 parent 908f358 commit e5e7404

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ def test_largish_reduction(arg_dtype, n):
173173
assert dpt.all(dpt.equal(y1, n * m))
174174

175175

176+
@pytest.mark.parametrize("n", [1023, 1024, 1025])
177+
def test_largish_reduction_axis1_axis0(n):
178+
get_queue_or_skip()
179+
180+
m = 25
181+
x1 = dpt.ones((m, n), dtype="f4")
182+
x2 = dpt.ones((n, m), dtype="f4")
183+
184+
y1 = dpt.sum(x1, axis=1)
185+
y2 = dpt.sum(x2, axis=0)
186+
187+
assert dpt.all(y1 == n)
188+
assert dpt.all(y2 == n)
189+
190+
176191
def test_axis0_bug():
177192
"gh-1391"
178193
get_queue_or_skip()

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,28 @@ def test_logsumexp_keepdims():
326326
assert s.shape == (3, 1, 1, 6, 1)
327327

328328

329+
def test_logsumexp_keepdims_zero_size():
330+
get_queue_or_skip()
331+
n = 10
332+
a = dpt.ones((n, 0, n))
333+
334+
s1 = dpt.logsumexp(a, keepdims=True)
335+
assert s1.shape == (1, 1, 1)
336+
337+
s2 = dpt.logsumexp(a, axis=(0, 1), keepdims=True)
338+
assert s2.shape == (1, 1, n)
339+
340+
s3 = dpt.logsumexp(a, axis=(1, 2), keepdims=True)
341+
assert s3.shape == (n, 1, 1)
342+
343+
s4 = dpt.logsumexp(a, axis=(0, 2), keepdims=True)
344+
assert s4.shape == (1, 0, 1)
345+
346+
a0 = a[0]
347+
s5 = dpt.logsumexp(a0, keepdims=True)
348+
assert s5.shape == (1, 1)
349+
350+
329351
def test_logsumexp_scalar():
330352
get_queue_or_skip()
331353

@@ -337,6 +359,29 @@ def test_logsumexp_scalar():
337359
assert s.shape == ()
338360

339361

362+
def test_logsumexp_complex():
363+
get_queue_or_skip()
364+
365+
x = dpt.zeros(1, dtype="c8")
366+
with pytest.raises(TypeError):
367+
dpt.logsumexp(x)
368+
369+
370+
def test_logsumexp_int_axis():
371+
get_queue_or_skip()
372+
373+
x = dpt.zeros((8, 10), dtype="f4")
374+
res = dpt.logsumexp(x, axis=0)
375+
assert res.ndim == 1
376+
assert res.shape[0] == 10
377+
378+
379+
def test_logsumexp_invalid_arr():
380+
x = dict()
381+
with pytest.raises(TypeError):
382+
dpt.logsumexp(x)
383+
384+
340385
@pytest.mark.parametrize("arg_dtype", _no_complex_dtypes[1:])
341386
def test_hypot_arg_dtype_default_output_dtype_matrix(arg_dtype):
342387
q = get_queue_or_skip()
@@ -376,3 +421,11 @@ def test_hypot_arg_out_dtype_matrix(arg_dtype, out_dtype):
376421

377422
assert isinstance(r, dpt.usm_ndarray)
378423
assert r.dtype == dpt.dtype(out_dtype)
424+
425+
426+
def test_hypot_complex():
427+
get_queue_or_skip()
428+
429+
x = dpt.zeros(1, dtype="c8")
430+
with pytest.raises(TypeError):
431+
dpt.reduce_hypot(x)

0 commit comments

Comments
 (0)