|
18 | 18 | from helper import get_queue_or_skip, skip_if_dtype_not_supported
|
19 | 19 |
|
20 | 20 | import dpctl.tensor as dpt
|
| 21 | +from dpctl.tensor._tensor_impl import default_device_int_type |
| 22 | +from dpctl.utils import ExecutionPlacementError |
21 | 23 |
|
22 | 24 | sint_types = [
|
23 | 25 | dpt.int8,
|
@@ -216,3 +218,84 @@ def test_cumsum_arg_out_dtype_matrix(arg_dtype, out_dtype):
|
216 | 218 | else:
|
217 | 219 | r_expected = dpt.arange(1, n + 1, dtype=out_dtype)
|
218 | 220 | assert dpt.all(r == r_expected)
|
| 221 | + |
| 222 | + |
| 223 | +def test_accumulator_out_kwarg(): |
| 224 | + q = get_queue_or_skip() |
| 225 | + |
| 226 | + n = 100 |
| 227 | + default_int = default_device_int_type(q) |
| 228 | + |
| 229 | + expected = dpt.arange(1, n + 1, dtype=default_int, sycl_queue=q) |
| 230 | + x = dpt.ones(n, dtype="i4", sycl_queue=q) |
| 231 | + out = dpt.empty_like(x, dtype=default_int) |
| 232 | + dpt.cumulative_sum(x, out=out) |
| 233 | + assert dpt.all(expected == out) |
| 234 | + # overlap |
| 235 | + x = dpt.ones(n, dtype=default_int, sycl_queue=q) |
| 236 | + dpt.cumulative_sum(x, out=x) |
| 237 | + assert dpt.all(x == expected) |
| 238 | + |
| 239 | + # axis before final axis |
| 240 | + expected = dpt.broadcast_to( |
| 241 | + dpt.arange(1, n + 1, dtype=default_int, sycl_queue=q), (n, n) |
| 242 | + ).mT |
| 243 | + x = dpt.ones((n, n), dtype="i4", sycl_queue=q) |
| 244 | + out = dpt.empty_like(x, dtype=default_int) |
| 245 | + dpt.cumulative_sum(x, axis=0, out=out) |
| 246 | + assert dpt.all(expected == out) |
| 247 | + |
| 248 | + # scalar |
| 249 | + x = dpt.asarray(3, dtype="i4") |
| 250 | + out = dpt.empty((), dtype=default_int) |
| 251 | + expected = dpt.asarray(3, dtype=default_int) |
| 252 | + dpt.cumulative_sum(x, out=out) |
| 253 | + assert expected == out |
| 254 | + |
| 255 | + # overlapping and unimplemented |
| 256 | + x = dpt.ones(n, dtype="?", sycl_queue=q) |
| 257 | + x[20:] = False |
| 258 | + dpt.cumulative_sum(x, dtype="?", out=x) |
| 259 | + assert dpt.all(x) |
| 260 | + |
| 261 | + |
| 262 | +def test_accumulator_arg_validation(): |
| 263 | + q1 = get_queue_or_skip() |
| 264 | + q2 = get_queue_or_skip() |
| 265 | + |
| 266 | + n = 5 |
| 267 | + x1 = dpt.ones((n, n), dtype="f4", sycl_queue=q1) |
| 268 | + x2 = dpt.ones(n, dtype="f4", sycl_queue=q1) |
| 269 | + |
| 270 | + # must be usm_ndarray |
| 271 | + with pytest.raises(TypeError): |
| 272 | + dpt.cumulative_sum(dict()) |
| 273 | + |
| 274 | + # axis must be specified when input not 1D |
| 275 | + with pytest.raises(ValueError): |
| 276 | + dpt.cumulative_sum(x1) |
| 277 | + |
| 278 | + # out must be usm_ndarray |
| 279 | + with pytest.raises(TypeError): |
| 280 | + dpt.cumulative_sum(x2, out=dict()) |
| 281 | + |
| 282 | + # out must be writable |
| 283 | + out_not_writable = dpt.empty_like(x2) |
| 284 | + out_not_writable.flags.writable = False |
| 285 | + with pytest.raises(ValueError): |
| 286 | + dpt.cumulative_sum(x2, out=out_not_writable) |
| 287 | + |
| 288 | + # out must be expected shape |
| 289 | + out_wrong_shape = dpt.ones(n + 1, dtype=x2.dtype, sycl_queue=q1) |
| 290 | + with pytest.raises(ValueError): |
| 291 | + dpt.cumulative_sum(x2, out=out_wrong_shape) |
| 292 | + |
| 293 | + # out must be expected dtype |
| 294 | + out_wrong_dtype = dpt.empty_like(x2, dtype="i4") |
| 295 | + with pytest.raises(ValueError): |
| 296 | + dpt.cumulative_sum(x2, out=out_wrong_dtype) |
| 297 | + |
| 298 | + # compute follows data |
| 299 | + out_wrong_queue = dpt.empty_like(x2, sycl_queue=q2) |
| 300 | + with pytest.raises(ExecutionPlacementError): |
| 301 | + dpt.cumulative_sum(x2, out=out_wrong_queue) |
0 commit comments