|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
| 17 | +from random import randrange |
| 18 | + |
17 | 19 | import pytest
|
18 | 20 | from helper import get_queue_or_skip, skip_if_dtype_not_supported
|
19 | 21 |
|
20 | 22 | import dpctl.tensor as dpt
|
21 |
| -from dpctl.tensor._tensor_impl import default_device_int_type |
22 | 23 | from dpctl.utils import ExecutionPlacementError
|
23 | 24 |
|
24 | 25 | sint_types = [
|
@@ -156,7 +157,7 @@ def test_cumulative_logsumexp_identity():
|
156 | 157 | assert r[0] == -dpt.inf
|
157 | 158 |
|
158 | 159 |
|
159 |
| -def test_accumulate_empty_array(): |
| 160 | +def test_accumulate_zero_size_dims(): |
160 | 161 | get_queue_or_skip()
|
161 | 162 |
|
162 | 163 | n0, n1, n2 = 3, 0, 5
|
@@ -224,33 +225,32 @@ def test_accumulator_out_kwarg():
|
224 | 225 | q = get_queue_or_skip()
|
225 | 226 |
|
226 | 227 | n = 100
|
227 |
| - default_int = default_device_int_type(q) |
228 | 228 |
|
229 |
| - expected = dpt.arange(1, n + 1, dtype=default_int, sycl_queue=q) |
| 229 | + expected = dpt.arange(1, n + 1, dtype="i4", sycl_queue=q) |
230 | 230 | x = dpt.ones(n, dtype="i4", sycl_queue=q)
|
231 |
| - out = dpt.empty_like(x, dtype=default_int) |
232 |
| - dpt.cumulative_sum(x, out=out) |
| 231 | + out = dpt.empty_like(x, dtype="i4") |
| 232 | + dpt.cumulative_sum(x, dtype="i4", out=out) |
233 | 233 | assert dpt.all(expected == out)
|
234 | 234 |
|
235 | 235 | # overlap
|
236 |
| - x = dpt.ones(n, dtype=default_int, sycl_queue=q) |
237 |
| - dpt.cumulative_sum(x, out=x) |
| 236 | + x = dpt.ones(n, dtype="i4", sycl_queue=q) |
| 237 | + dpt.cumulative_sum(x, dtype="i4", out=x) |
238 | 238 | assert dpt.all(x == expected)
|
239 | 239 |
|
240 | 240 | # axis before final axis
|
241 | 241 | expected = dpt.broadcast_to(
|
242 |
| - dpt.arange(1, n + 1, dtype=default_int, sycl_queue=q), (n, n) |
| 242 | + dpt.arange(1, n + 1, dtype="i4", sycl_queue=q), (n, n) |
243 | 243 | ).mT
|
244 | 244 | x = dpt.ones((n, n), dtype="i4", sycl_queue=q)
|
245 |
| - out = dpt.empty_like(x, dtype=default_int) |
246 |
| - dpt.cumulative_sum(x, axis=0, out=out) |
| 245 | + out = dpt.empty_like(x, dtype="i4") |
| 246 | + dpt.cumulative_sum(x, axis=0, dtype="i4", out=out) |
247 | 247 | assert dpt.all(expected == out)
|
248 | 248 |
|
249 | 249 | # scalar
|
250 | 250 | x = dpt.asarray(3, dtype="i4")
|
251 |
| - out = dpt.empty((), dtype=default_int) |
252 |
| - expected = dpt.asarray(3, dtype=default_int) |
253 |
| - dpt.cumulative_sum(x, out=out) |
| 251 | + out = dpt.empty((), dtype="i4") |
| 252 | + expected = 3 |
| 253 | + dpt.cumulative_sum(x, dtype="i4", out=out) |
254 | 254 | assert expected == out
|
255 | 255 |
|
256 | 256 |
|
@@ -294,3 +294,90 @@ def test_accumulator_arg_validation():
|
294 | 294 | out_wrong_queue = dpt.empty_like(x2, sycl_queue=q2)
|
295 | 295 | with pytest.raises(ExecutionPlacementError):
|
296 | 296 | dpt.cumulative_sum(x2, out=out_wrong_queue)
|
| 297 | + |
| 298 | + |
| 299 | +def test_cumsum_nan_propagation(): |
| 300 | + get_queue_or_skip() |
| 301 | + |
| 302 | + n = 100 |
| 303 | + x = dpt.ones(n, dtype="f4") |
| 304 | + i = randrange(n) |
| 305 | + x[i] = dpt.nan |
| 306 | + |
| 307 | + r = dpt.cumulative_sum(x) |
| 308 | + assert dpt.all(dpt.isnan(r[i:])) |
| 309 | + |
| 310 | + |
| 311 | +def test_cumprod_nan_propagation(): |
| 312 | + get_queue_or_skip() |
| 313 | + |
| 314 | + n = 100 |
| 315 | + x = dpt.ones(n, dtype="f4") |
| 316 | + i = randrange(n) |
| 317 | + x[i] = dpt.nan |
| 318 | + |
| 319 | + r = dpt.cumulative_prod(x) |
| 320 | + assert dpt.all(dpt.isnan(r[i:])) |
| 321 | + |
| 322 | + |
| 323 | +def test_logcumsumexp_nan_propagation(): |
| 324 | + get_queue_or_skip() |
| 325 | + |
| 326 | + n = 100 |
| 327 | + x = dpt.ones(n, dtype="f4") |
| 328 | + i = randrange(n) |
| 329 | + x[i] = dpt.nan |
| 330 | + |
| 331 | + r = dpt.cumulative_logsumexp(x) |
| 332 | + assert dpt.all(dpt.isnan(r[i:])) |
| 333 | + |
| 334 | + |
| 335 | +@pytest.mark.parametrize("arg_dtype", no_complex_types) |
| 336 | +def test_logcumsumexp_arg_dtype_default_output_dtype_matrix(arg_dtype): |
| 337 | + q = get_queue_or_skip() |
| 338 | + skip_if_dtype_not_supported(arg_dtype, q) |
| 339 | + |
| 340 | + x = dpt.ones(10, dtype=arg_dtype, sycl_queue=q) |
| 341 | + r = dpt.cumulative_logsumexp(x) |
| 342 | + |
| 343 | + if arg_dtype.kind in "biu": |
| 344 | + assert r.dtype.kind == "f" |
| 345 | + else: |
| 346 | + assert r.dtype == arg_dtype |
| 347 | + |
| 348 | + |
| 349 | +def test_logcumsumexp_complex_error(): |
| 350 | + get_queue_or_skip() |
| 351 | + |
| 352 | + x = dpt.ones(10, dtype="c8") |
| 353 | + with pytest.raises(ValueError): |
| 354 | + dpt.cumulative_logsumexp(x) |
| 355 | + |
| 356 | + |
| 357 | +def test_cumprod_basic(): |
| 358 | + get_queue_or_skip() |
| 359 | + |
| 360 | + n = 50 |
| 361 | + val = 2 |
| 362 | + x = dpt.full(n, val, dtype="i8") |
| 363 | + r = dpt.cumulative_prod(x) |
| 364 | + expected = dpt.pow(val, dpt.arange(1, n + 1, dtype="i8")) |
| 365 | + |
| 366 | + assert dpt.all(r == expected) |
| 367 | + |
| 368 | + x = dpt.tile(dpt.asarray([2, 0.5], dtype="f4"), 10000) |
| 369 | + expected = dpt.tile(dpt.asarray([2, 1], dtype="f4"), 10000) |
| 370 | + r = dpt.cumulative_prod(x) |
| 371 | + assert dpt.all(r == expected) |
| 372 | + |
| 373 | + |
| 374 | +def test_cumlogsumexp_basic(): |
| 375 | + get_queue_or_skip() |
| 376 | + |
| 377 | + dt = dpt.float32 |
| 378 | + x = dpt.ones(100, dtype=dt) |
| 379 | + r = dpt.cumulative_logsumexp(x) |
| 380 | + expected = dpt.log(dpt.cumulative_sum(dpt.exp(x))) |
| 381 | + |
| 382 | + tol = 7 * dpt.finfo(dt).resolution |
| 383 | + assert dpt.allclose(r, expected, atol=tol, rtol=tol) |
0 commit comments