Skip to content

Commit 761ecd4

Browse files
committed
Adds the first tests for dpt.cumulative_sum
1 parent 80a77e0 commit 761ecd4

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import pytest
2+
3+
import dpctl.tensor as dpt
4+
5+
sint_types = [
6+
dpt.int8,
7+
dpt.int16,
8+
dpt.int32,
9+
dpt.int64,
10+
]
11+
uint_types = [
12+
dpt.uint8,
13+
dpt.uint16,
14+
dpt.uint32,
15+
dpt.uint64,
16+
]
17+
rfp_types = [
18+
dpt.float16,
19+
dpt.float32,
20+
dpt.float64,
21+
]
22+
cfp_types = [
23+
dpt.complex64,
24+
dpt.complex128,
25+
]
26+
27+
28+
@pytest.mark.parametrize("dt", sint_types[2:])
29+
def test_contig_cumsum_sint(dt):
30+
n = 10000
31+
x = dpt.repeat(dpt.asarray([1, -1], dtype=dt), n)
32+
33+
res = dpt.cumulative_sum(x, dtype=dt)
34+
35+
ar = dpt.arange(n, dtype=dt)
36+
expected = dpt.concat((1 + ar, dpt.flip(ar)))
37+
assert dpt.all(res == expected)
38+
39+
40+
@pytest.mark.parametrize("dt", sint_types[2:])
41+
def test_strided_cumsum_sint(dt):
42+
n = 10000
43+
x = dpt.repeat(dpt.asarray([1, -1], dtype=dt), 2 * n)[1::2]
44+
45+
res = dpt.cumulative_sum(x, dtype=dt)
46+
47+
ar = dpt.arange(n, dtype=dt)
48+
expected = dpt.concat((1 + ar, dpt.flip(ar)))
49+
assert dpt.all(res == expected)
50+
51+
x2 = dpt.repeat(dpt.asarray([-1, 1], dtype=dt), 2 * n)[-1::-2]
52+
53+
res = dpt.cumulative_sum(x2, dtype=dt)
54+
55+
ar = dpt.arange(n, dtype=dt)
56+
expected = dpt.concat((1 + ar, dpt.flip(ar)))
57+
assert dpt.all(res == expected)
58+
59+
60+
@pytest.mark.parametrize("dt", sint_types[2:])
61+
def test_contig_cumsum_axis_sint(dt):
62+
n0, n1 = 1000, 173
63+
x = dpt.repeat(dpt.asarray([1, -1], dtype=dt), n0)
64+
m = dpt.tile(dpt.expand_dims(x, axis=1), (1, n1))
65+
66+
res = dpt.cumulative_sum(m, dtype=dt, axis=0)
67+
68+
ar = dpt.arange(n0, dtype=dt)
69+
expected = dpt.concat((1 + ar, dpt.flip(ar)))
70+
assert dpt.all(res == dpt.expand_dims(expected, axis=1))
71+
72+
73+
@pytest.mark.parametrize("dt", sint_types[2:])
74+
def test_strided_cumsum_axis_sint(dt):
75+
n0, n1 = 1000, 173
76+
x = dpt.repeat(dpt.asarray([1, -1], dtype=dt), 2 * n0)
77+
m = dpt.tile(dpt.expand_dims(x, axis=1), (1, n1))[1::2, ::-1]
78+
79+
res = dpt.cumulative_sum(m, dtype=dt, axis=0)
80+
81+
ar = dpt.arange(n0, dtype=dt)
82+
expected = dpt.concat((1 + ar, dpt.flip(ar)))
83+
assert dpt.all(res == dpt.expand_dims(expected, axis=1))

0 commit comments

Comments
 (0)