Skip to content

Commit 7b77e16

Browse files
committed
Adds tests for cumulative_prod and cumulative_logsumexp
Also fixes incorrect TypeError in _accumulation.py
1 parent 07013c3 commit 7b77e16

File tree

2 files changed

+102
-15
lines changed

2 files changed

+102
-15
lines changed

dpctl/tensor/_accumulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
6262
elif inp_kind in "f":
6363
res_dt = inp_dt
6464
elif inp_kind in "c":
65-
raise TypeError("reduction not defined for complex types")
65+
raise ValueError("function not defined for complex types")
6666

6767
return res_dt
6868

dpctl/tests/test_tensor_accumulation.py

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from random import randrange
18+
1719
import pytest
1820
from helper import get_queue_or_skip, skip_if_dtype_not_supported
1921

2022
import dpctl.tensor as dpt
21-
from dpctl.tensor._tensor_impl import default_device_int_type
2223
from dpctl.utils import ExecutionPlacementError
2324

2425
sint_types = [
@@ -156,7 +157,7 @@ def test_cumulative_logsumexp_identity():
156157
assert r[0] == -dpt.inf
157158

158159

159-
def test_accumulate_empty_array():
160+
def test_accumulate_zero_size_dims():
160161
get_queue_or_skip()
161162

162163
n0, n1, n2 = 3, 0, 5
@@ -224,33 +225,32 @@ def test_accumulator_out_kwarg():
224225
q = get_queue_or_skip()
225226

226227
n = 100
227-
default_int = default_device_int_type(q)
228228

229-
expected = dpt.arange(1, n + 1, dtype=default_int, sycl_queue=q)
229+
expected = dpt.arange(1, n + 1, dtype="i4", sycl_queue=q)
230230
x = dpt.ones(n, dtype="i4", sycl_queue=q)
231-
out = dpt.empty_like(x, dtype=default_int)
232-
dpt.cumulative_sum(x, out=out)
231+
out = dpt.empty_like(x, dtype="i4")
232+
dpt.cumulative_sum(x, dtype="i4", out=out)
233233
assert dpt.all(expected == out)
234234

235235
# overlap
236-
x = dpt.ones(n, dtype=default_int, sycl_queue=q)
237-
dpt.cumulative_sum(x, out=x)
236+
x = dpt.ones(n, dtype="i4", sycl_queue=q)
237+
dpt.cumulative_sum(x, dtype="i4", out=x)
238238
assert dpt.all(x == expected)
239239

240240
# axis before final axis
241241
expected = dpt.broadcast_to(
242-
dpt.arange(1, n + 1, dtype=default_int, sycl_queue=q), (n, n)
242+
dpt.arange(1, n + 1, dtype="i4", sycl_queue=q), (n, n)
243243
).mT
244244
x = dpt.ones((n, n), dtype="i4", sycl_queue=q)
245-
out = dpt.empty_like(x, dtype=default_int)
246-
dpt.cumulative_sum(x, axis=0, out=out)
245+
out = dpt.empty_like(x, dtype="i4")
246+
dpt.cumulative_sum(x, axis=0, dtype="i4", out=out)
247247
assert dpt.all(expected == out)
248248

249249
# scalar
250250
x = dpt.asarray(3, dtype="i4")
251-
out = dpt.empty((), dtype=default_int)
252-
expected = dpt.asarray(3, dtype=default_int)
253-
dpt.cumulative_sum(x, out=out)
251+
out = dpt.empty((), dtype="i4")
252+
expected = 3
253+
dpt.cumulative_sum(x, dtype="i4", out=out)
254254
assert expected == out
255255

256256

@@ -294,3 +294,90 @@ def test_accumulator_arg_validation():
294294
out_wrong_queue = dpt.empty_like(x2, sycl_queue=q2)
295295
with pytest.raises(ExecutionPlacementError):
296296
dpt.cumulative_sum(x2, out=out_wrong_queue)
297+
298+
299+
def test_cumsum_nan_propagation():
300+
get_queue_or_skip()
301+
302+
n = 100
303+
x = dpt.ones(n, dtype="f4")
304+
i = randrange(n)
305+
x[i] = dpt.nan
306+
307+
r = dpt.cumulative_sum(x)
308+
assert dpt.all(dpt.isnan(r[i:]))
309+
310+
311+
def test_cumprod_nan_propagation():
312+
get_queue_or_skip()
313+
314+
n = 100
315+
x = dpt.ones(n, dtype="f4")
316+
i = randrange(n)
317+
x[i] = dpt.nan
318+
319+
r = dpt.cumulative_prod(x)
320+
assert dpt.all(dpt.isnan(r[i:]))
321+
322+
323+
def test_logcumsumexp_nan_propagation():
324+
get_queue_or_skip()
325+
326+
n = 100
327+
x = dpt.ones(n, dtype="f4")
328+
i = randrange(n)
329+
x[i] = dpt.nan
330+
331+
r = dpt.cumulative_logsumexp(x)
332+
assert dpt.all(dpt.isnan(r[i:]))
333+
334+
335+
@pytest.mark.parametrize("arg_dtype", no_complex_types)
336+
def test_logcumsumexp_arg_dtype_default_output_dtype_matrix(arg_dtype):
337+
q = get_queue_or_skip()
338+
skip_if_dtype_not_supported(arg_dtype, q)
339+
340+
x = dpt.ones(10, dtype=arg_dtype, sycl_queue=q)
341+
r = dpt.cumulative_logsumexp(x)
342+
343+
if arg_dtype.kind in "biu":
344+
assert r.dtype.kind == "f"
345+
else:
346+
assert r.dtype == arg_dtype
347+
348+
349+
def test_logcumsumexp_complex_error():
350+
get_queue_or_skip()
351+
352+
x = dpt.ones(10, dtype="c8")
353+
with pytest.raises(ValueError):
354+
dpt.cumulative_logsumexp(x)
355+
356+
357+
def test_cumprod_basic():
358+
get_queue_or_skip()
359+
360+
n = 50
361+
val = 2
362+
x = dpt.full(n, val, dtype="i8")
363+
r = dpt.cumulative_prod(x)
364+
expected = dpt.pow(val, dpt.arange(1, n + 1, dtype="i8"))
365+
366+
assert dpt.all(r == expected)
367+
368+
x = dpt.tile(dpt.asarray([2, 0.5], dtype="f4"), 10000)
369+
expected = dpt.tile(dpt.asarray([2, 1], dtype="f4"), 10000)
370+
r = dpt.cumulative_prod(x)
371+
assert dpt.all(r == expected)
372+
373+
374+
def test_cumlogsumexp_basic():
375+
get_queue_or_skip()
376+
377+
dt = dpt.float32
378+
x = dpt.ones(100, dtype=dt)
379+
r = dpt.cumulative_logsumexp(x)
380+
expected = dpt.log(dpt.cumulative_sum(dpt.exp(x)))
381+
382+
tol = 7 * dpt.finfo(dt).resolution
383+
assert dpt.allclose(r, expected, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)