|
| 1 | +import pytest |
| 2 | + |
| 3 | +import dpctl.tensor as dpt |
| 4 | + |
| 5 | +sint_types = [ |
| 6 | + dpt.int8, |
| 7 | + dpt.int16, |
| 8 | + dpt.int32, |
| 9 | + dpt.int64, |
| 10 | +] |
| 11 | +uint_types = [ |
| 12 | + dpt.uint8, |
| 13 | + dpt.uint16, |
| 14 | + dpt.uint32, |
| 15 | + dpt.uint64, |
| 16 | +] |
| 17 | +rfp_types = [ |
| 18 | + dpt.float16, |
| 19 | + dpt.float32, |
| 20 | + dpt.float64, |
| 21 | +] |
| 22 | +cfp_types = [ |
| 23 | + dpt.complex64, |
| 24 | + dpt.complex128, |
| 25 | +] |
| 26 | + |
| 27 | + |
| 28 | +@pytest.mark.parametrize("dt", sint_types[2:]) |
| 29 | +def test_contig_cumsum_sint(dt): |
| 30 | + n = 10000 |
| 31 | + x = dpt.repeat(dpt.asarray([1, -1], dtype=dt), n) |
| 32 | + |
| 33 | + res = dpt.cumulative_sum(x, dtype=dt) |
| 34 | + |
| 35 | + ar = dpt.arange(n, dtype=dt) |
| 36 | + expected = dpt.concat((1 + ar, dpt.flip(ar))) |
| 37 | + assert dpt.all(res == expected) |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.parametrize("dt", sint_types[2:]) |
| 41 | +def test_strided_cumsum_sint(dt): |
| 42 | + n = 10000 |
| 43 | + x = dpt.repeat(dpt.asarray([1, -1], dtype=dt), 2 * n)[1::2] |
| 44 | + |
| 45 | + res = dpt.cumulative_sum(x, dtype=dt) |
| 46 | + |
| 47 | + ar = dpt.arange(n, dtype=dt) |
| 48 | + expected = dpt.concat((1 + ar, dpt.flip(ar))) |
| 49 | + assert dpt.all(res == expected) |
| 50 | + |
| 51 | + x2 = dpt.repeat(dpt.asarray([-1, 1], dtype=dt), 2 * n)[-1::-2] |
| 52 | + |
| 53 | + res = dpt.cumulative_sum(x2, dtype=dt) |
| 54 | + |
| 55 | + ar = dpt.arange(n, dtype=dt) |
| 56 | + expected = dpt.concat((1 + ar, dpt.flip(ar))) |
| 57 | + assert dpt.all(res == expected) |
| 58 | + |
| 59 | + |
| 60 | +@pytest.mark.parametrize("dt", sint_types[2:]) |
| 61 | +def test_contig_cumsum_axis_sint(dt): |
| 62 | + n0, n1 = 1000, 173 |
| 63 | + x = dpt.repeat(dpt.asarray([1, -1], dtype=dt), n0) |
| 64 | + m = dpt.tile(dpt.expand_dims(x, axis=1), (1, n1)) |
| 65 | + |
| 66 | + res = dpt.cumulative_sum(m, dtype=dt, axis=0) |
| 67 | + |
| 68 | + ar = dpt.arange(n0, dtype=dt) |
| 69 | + expected = dpt.concat((1 + ar, dpt.flip(ar))) |
| 70 | + assert dpt.all(res == dpt.expand_dims(expected, axis=1)) |
| 71 | + |
| 72 | + |
| 73 | +@pytest.mark.parametrize("dt", sint_types[2:]) |
| 74 | +def test_strided_cumsum_axis_sint(dt): |
| 75 | + n0, n1 = 1000, 173 |
| 76 | + x = dpt.repeat(dpt.asarray([1, -1], dtype=dt), 2 * n0) |
| 77 | + m = dpt.tile(dpt.expand_dims(x, axis=1), (1, n1))[1::2, ::-1] |
| 78 | + |
| 79 | + res = dpt.cumulative_sum(m, dtype=dt, axis=0) |
| 80 | + |
| 81 | + ar = dpt.arange(n0, dtype=dt) |
| 82 | + expected = dpt.concat((1 + ar, dpt.flip(ar))) |
| 83 | + assert dpt.all(res == dpt.expand_dims(expected, axis=1)) |
0 commit comments