|
20 | 20 | import dpctl
|
21 | 21 | import dpctl.tensor as dpt
|
22 | 22 |
|
23 |
| -from .helper import get_queue_or_skip |
| 23 | +from .helper import get_queue_or_skip, skip_if_dtype_not_supported |
24 | 24 |
|
25 | 25 |
|
26 | 26 | @pytest.mark.parametrize(
|
@@ -411,3 +411,148 @@ def test_orderK_gh_1350():
|
411 | 411 | assert c.strides == b.strides
|
412 | 412 | assert c._element_offset == 0
|
413 | 413 | assert not c._pointer == b._pointer
|
| 414 | + |
| 415 | + |
| 416 | +def _typesafe_arange(n: int, dtype_: dpt.dtype, device: object): |
| 417 | + n_half = n // 2 |
| 418 | + if dtype_.kind in "ui": |
| 419 | + ii = dpt.iinfo(dtype_) |
| 420 | + m0 = max(ii.min, -n_half) |
| 421 | + m1 = min(m0 + n, ii.max) |
| 422 | + n_tiles = (n + m1 - m0 - 1) // (m1 - m0) |
| 423 | + res = dpt.arange(m0, m1, dtype=dtype_, device=device) |
| 424 | + elif dtype_.kind == "b": |
| 425 | + n_tiles = (n + 1) // 2 |
| 426 | + res = dpt.asarray([False, True], dtype=dtype_, device=device) |
| 427 | + else: |
| 428 | + m0 = -n_half |
| 429 | + m1 = m0 + n |
| 430 | + n_tiles = 1 |
| 431 | + res = dpt.linspace(m0, m1, num=n, dtype=dtype_, device=device) |
| 432 | + if n_tiles > 1: |
| 433 | + res = dpt.tile(res, n_tiles)[:n] |
| 434 | + return res |
| 435 | + |
| 436 | + |
| 437 | +_all_dtypes = [ |
| 438 | + "b1", |
| 439 | + "i1", |
| 440 | + "u1", |
| 441 | + "i2", |
| 442 | + "u2", |
| 443 | + "i4", |
| 444 | + "u4", |
| 445 | + "i8", |
| 446 | + "u8", |
| 447 | + "f2", |
| 448 | + "f4", |
| 449 | + "f8", |
| 450 | + "c8", |
| 451 | + "c16", |
| 452 | +] |
| 453 | + |
| 454 | + |
| 455 | +@pytest.mark.parametrize("dt", _all_dtypes) |
| 456 | +def test_as_c_contig_rect(dt): |
| 457 | + q = get_queue_or_skip() |
| 458 | + skip_if_dtype_not_supported(dt, q) |
| 459 | + |
| 460 | + dtype_ = dpt.dtype(dt) |
| 461 | + n0, n1, n2 = 6, 35, 37 |
| 462 | + |
| 463 | + arr_flat = _typesafe_arange(n0 * n1 * n2, dtype_, q) |
| 464 | + x = dpt.reshape(arr_flat, (n0, n1, n2)).mT |
| 465 | + |
| 466 | + y = dpt.asarray(x, order="C") |
| 467 | + assert dpt.all(x == y) |
| 468 | + |
| 469 | + x2 = x[0] |
| 470 | + y2 = dpt.asarray(x2, order="C") |
| 471 | + assert dpt.all(x2 == y2) |
| 472 | + |
| 473 | + x3 = dpt.flip(x, axis=1) |
| 474 | + y3 = dpt.asarray(x3, order="C") |
| 475 | + assert dpt.all(x3 == y3) |
| 476 | + |
| 477 | + x4 = dpt.reshape(arr_flat, (2, 3, n1, n2)).mT |
| 478 | + x5 = x4[:, :2] |
| 479 | + y5 = dpt.asarray(x5, order="C") |
| 480 | + assert dpt.all(x5 == y5) |
| 481 | + |
| 482 | + x6 = dpt.reshape(arr_flat, (n0, n1, n2), order="F") |
| 483 | + y6 = dpt.asarray(x6, order="C") |
| 484 | + assert dpt.all(x6 == y6) |
| 485 | + |
| 486 | + |
| 487 | +@pytest.mark.parametrize("dt", _all_dtypes) |
| 488 | +def test_as_f_contig_rect(dt): |
| 489 | + q = get_queue_or_skip() |
| 490 | + skip_if_dtype_not_supported(dt, q) |
| 491 | + |
| 492 | + dtype_ = dpt.dtype(dt) |
| 493 | + n0, n1, n2 = 6, 35, 37 |
| 494 | + |
| 495 | + arr_flat = _typesafe_arange(n0 * n1 * n2, dtype_, q) |
| 496 | + x = dpt.reshape(arr_flat, (n0, n1, n2)) |
| 497 | + |
| 498 | + y = dpt.asarray(x, order="F") |
| 499 | + assert dpt.all(x == y) |
| 500 | + |
| 501 | + x2 = x[0] |
| 502 | + y2 = dpt.asarray(x2, order="F") |
| 503 | + assert dpt.all(x2 == y2) |
| 504 | + |
| 505 | + x3 = dpt.flip(x, axis=1) |
| 506 | + y3 = dpt.asarray(x3, order="F") |
| 507 | + assert dpt.all(x3 == y3) |
| 508 | + |
| 509 | + x4 = dpt.reshape(arr_flat, (2, 3, n1, n2)) |
| 510 | + x5 = dpt.moveaxis(x4[:, :2], (2, 3), (0, 1)) |
| 511 | + y5 = dpt.asarray(x5, order="F") |
| 512 | + assert dpt.all(x5 == y5) |
| 513 | + |
| 514 | + |
| 515 | +@pytest.mark.parametrize("dt", _all_dtypes) |
| 516 | +def test_as_c_contig_square(dt): |
| 517 | + q = get_queue_or_skip() |
| 518 | + skip_if_dtype_not_supported(dt, q) |
| 519 | + |
| 520 | + dtype_ = dpt.dtype(dt) |
| 521 | + n0, n1 = 4, 53 |
| 522 | + |
| 523 | + arr_flat = _typesafe_arange(n0 * n1 * n1, dtype_, q) |
| 524 | + x = dpt.reshape(arr_flat, (n0, n1, n1)).mT |
| 525 | + |
| 526 | + y = dpt.asarray(x, order="C") |
| 527 | + assert dpt.all(x == y) |
| 528 | + |
| 529 | + x2 = x[0] |
| 530 | + y2 = dpt.asarray(x2, order="C") |
| 531 | + assert dpt.all(x2 == y2) |
| 532 | + |
| 533 | + x3 = dpt.flip(x, axis=1) |
| 534 | + y3 = dpt.asarray(x3, order="C") |
| 535 | + assert dpt.all(x3 == y3) |
| 536 | + |
| 537 | + |
| 538 | +@pytest.mark.parametrize("dt", _all_dtypes) |
| 539 | +def test_as_f_contig_square(dt): |
| 540 | + q = get_queue_or_skip() |
| 541 | + skip_if_dtype_not_supported(dt, q) |
| 542 | + |
| 543 | + dtype_ = dpt.dtype(dt) |
| 544 | + n0, n1 = 6, 53 |
| 545 | + |
| 546 | + arr_flat = _typesafe_arange(n0 * n1 * n1, dtype_, q) |
| 547 | + x = dpt.moveaxis(dpt.reshape(arr_flat, (n0, n1, n1)), (1, 2), (0, 1)) |
| 548 | + |
| 549 | + y = dpt.asarray(x, order="F") |
| 550 | + assert dpt.all(x == y) |
| 551 | + |
| 552 | + x2 = x[..., 0] |
| 553 | + y2 = dpt.asarray(x2, order="F") |
| 554 | + assert dpt.all(x2 == y2) |
| 555 | + |
| 556 | + x3 = dpt.flip(x, axis=1) |
| 557 | + y3 = dpt.asarray(x3, order="F") |
| 558 | + assert dpt.all(x3 == y3) |
0 commit comments