Skip to content

Commit 1dbc65c

Browse files
committed
More accumulator tests
1 parent 64b83bb commit 1dbc65c

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

dpctl/tests/test_tensor_accumulation.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from helper import get_queue_or_skip, skip_if_dtype_not_supported
1919

2020
import dpctl.tensor as dpt
21+
from dpctl.tensor._tensor_impl import default_device_int_type
22+
from dpctl.utils import ExecutionPlacementError
2123

2224
sint_types = [
2325
dpt.int8,
@@ -216,3 +218,84 @@ def test_cumsum_arg_out_dtype_matrix(arg_dtype, out_dtype):
216218
else:
217219
r_expected = dpt.arange(1, n + 1, dtype=out_dtype)
218220
assert dpt.all(r == r_expected)
221+
222+
223+
def test_accumulator_out_kwarg():
224+
q = get_queue_or_skip()
225+
226+
n = 100
227+
default_int = default_device_int_type(q)
228+
229+
expected = dpt.arange(1, n + 1, dtype=default_int, sycl_queue=q)
230+
x = dpt.ones(n, dtype="i4", sycl_queue=q)
231+
out = dpt.empty_like(x, dtype=default_int)
232+
dpt.cumulative_sum(x, out=out)
233+
assert dpt.all(expected == out)
234+
# overlap
235+
x = dpt.ones(n, dtype=default_int, sycl_queue=q)
236+
dpt.cumulative_sum(x, out=x)
237+
assert dpt.all(x == expected)
238+
239+
# axis before final axis
240+
expected = dpt.broadcast_to(
241+
dpt.arange(1, n + 1, dtype=default_int, sycl_queue=q), (n, n)
242+
).mT
243+
x = dpt.ones((n, n), dtype="i4", sycl_queue=q)
244+
out = dpt.empty_like(x, dtype=default_int)
245+
dpt.cumulative_sum(x, axis=0, out=out)
246+
assert dpt.all(expected == out)
247+
248+
# scalar
249+
x = dpt.asarray(3, dtype="i4")
250+
out = dpt.empty((), dtype=default_int)
251+
expected = dpt.asarray(3, dtype=default_int)
252+
dpt.cumulative_sum(x, out=out)
253+
assert expected == out
254+
255+
# overlapping and unimplemented
256+
x = dpt.ones(n, dtype="?", sycl_queue=q)
257+
x[20:] = False
258+
dpt.cumulative_sum(x, dtype="?", out=x)
259+
assert dpt.all(x)
260+
261+
262+
def test_accumulator_arg_validation():
263+
q1 = get_queue_or_skip()
264+
q2 = get_queue_or_skip()
265+
266+
n = 5
267+
x1 = dpt.ones((n, n), dtype="f4", sycl_queue=q1)
268+
x2 = dpt.ones(n, dtype="f4", sycl_queue=q1)
269+
270+
# must be usm_ndarray
271+
with pytest.raises(TypeError):
272+
dpt.cumulative_sum(dict())
273+
274+
# axis must be specified when input not 1D
275+
with pytest.raises(ValueError):
276+
dpt.cumulative_sum(x1)
277+
278+
# out must be usm_ndarray
279+
with pytest.raises(TypeError):
280+
dpt.cumulative_sum(x2, out=dict())
281+
282+
# out must be writable
283+
out_not_writable = dpt.empty_like(x2)
284+
out_not_writable.flags.writable = False
285+
with pytest.raises(ValueError):
286+
dpt.cumulative_sum(x2, out=out_not_writable)
287+
288+
# out must be expected shape
289+
out_wrong_shape = dpt.ones(n + 1, dtype=x2.dtype, sycl_queue=q1)
290+
with pytest.raises(ValueError):
291+
dpt.cumulative_sum(x2, out=out_wrong_shape)
292+
293+
# out must be expected dtype
294+
out_wrong_dtype = dpt.empty_like(x2, dtype="i4")
295+
with pytest.raises(ValueError):
296+
dpt.cumulative_sum(x2, out=out_wrong_dtype)
297+
298+
# compute follows data
299+
out_wrong_queue = dpt.empty_like(x2, sycl_queue=q2)
300+
with pytest.raises(ExecutionPlacementError):
301+
dpt.cumulative_sum(x2, out=out_wrong_queue)

0 commit comments

Comments
 (0)