Skip to content

Commit 69e17be

Browse files
Tests to exercise specialized code-paths for as_c_contig/as_f_contig
1 parent 2ba9829 commit 69e17be

File tree

1 file changed

+146
-1
lines changed

1 file changed

+146
-1
lines changed

dpctl/tests/test_tensor_asarray.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import dpctl
2121
import dpctl.tensor as dpt
2222

23-
from .helper import get_queue_or_skip
23+
from .helper import get_queue_or_skip, skip_if_dtype_not_supported
2424

2525

2626
@pytest.mark.parametrize(
@@ -411,3 +411,148 @@ def test_orderK_gh_1350():
411411
assert c.strides == b.strides
412412
assert c._element_offset == 0
413413
assert not c._pointer == b._pointer
414+
415+
416+
def _typesafe_arange(n: int, dtype_: dpt.dtype, device: object):
417+
n_half = n // 2
418+
if dtype_.kind in "ui":
419+
ii = dpt.iinfo(dtype_)
420+
m0 = max(ii.min, -n_half)
421+
m1 = min(m0 + n, ii.max)
422+
n_tiles = (n + m1 - m0 - 1) // (m1 - m0)
423+
res = dpt.arange(m0, m1, dtype=dtype_, device=device)
424+
elif dtype_.kind == "b":
425+
n_tiles = (n + 1) // 2
426+
res = dpt.asarray([False, True], dtype=dtype_, device=device)
427+
else:
428+
m0 = -n_half
429+
m1 = m0 + n
430+
n_tiles = 1
431+
res = dpt.linspace(m0, m1, num=n, dtype=dtype_, device=device)
432+
if n_tiles > 1:
433+
res = dpt.tile(res, n_tiles)[:n]
434+
return res
435+
436+
437+
_all_dtypes = [
438+
"b1",
439+
"i1",
440+
"u1",
441+
"i2",
442+
"u2",
443+
"i4",
444+
"u4",
445+
"i8",
446+
"u8",
447+
"f2",
448+
"f4",
449+
"f8",
450+
"c8",
451+
"c16",
452+
]
453+
454+
455+
@pytest.mark.parametrize("dt", _all_dtypes)
456+
def test_as_c_contig_rect(dt):
457+
q = get_queue_or_skip()
458+
skip_if_dtype_not_supported(dt, q)
459+
460+
dtype_ = dpt.dtype(dt)
461+
n0, n1, n2 = 6, 35, 37
462+
463+
arr_flat = _typesafe_arange(n0 * n1 * n2, dtype_, q)
464+
x = dpt.reshape(arr_flat, (n0, n1, n2)).mT
465+
466+
y = dpt.asarray(x, order="C")
467+
assert dpt.all(x == y)
468+
469+
x2 = x[0]
470+
y2 = dpt.asarray(x2, order="C")
471+
assert dpt.all(x2 == y2)
472+
473+
x3 = dpt.flip(x, axis=1)
474+
y3 = dpt.asarray(x3, order="C")
475+
assert dpt.all(x3 == y3)
476+
477+
x4 = dpt.reshape(arr_flat, (2, 3, n1, n2)).mT
478+
x5 = x4[:, :2]
479+
y5 = dpt.asarray(x5, order="C")
480+
assert dpt.all(x5 == y5)
481+
482+
x6 = dpt.reshape(arr_flat, (n0, n1, n2), order="F")
483+
y6 = dpt.asarray(x6, order="C")
484+
assert dpt.all(x6 == y6)
485+
486+
487+
@pytest.mark.parametrize("dt", _all_dtypes)
488+
def test_as_f_contig_rect(dt):
489+
q = get_queue_or_skip()
490+
skip_if_dtype_not_supported(dt, q)
491+
492+
dtype_ = dpt.dtype(dt)
493+
n0, n1, n2 = 6, 35, 37
494+
495+
arr_flat = _typesafe_arange(n0 * n1 * n2, dtype_, q)
496+
x = dpt.reshape(arr_flat, (n0, n1, n2))
497+
498+
y = dpt.asarray(x, order="F")
499+
assert dpt.all(x == y)
500+
501+
x2 = x[0]
502+
y2 = dpt.asarray(x2, order="F")
503+
assert dpt.all(x2 == y2)
504+
505+
x3 = dpt.flip(x, axis=1)
506+
y3 = dpt.asarray(x3, order="F")
507+
assert dpt.all(x3 == y3)
508+
509+
x4 = dpt.reshape(arr_flat, (2, 3, n1, n2))
510+
x5 = dpt.moveaxis(x4[:, :2], (2, 3), (0, 1))
511+
y5 = dpt.asarray(x5, order="F")
512+
assert dpt.all(x5 == y5)
513+
514+
515+
@pytest.mark.parametrize("dt", _all_dtypes)
516+
def test_as_c_contig_square(dt):
517+
q = get_queue_or_skip()
518+
skip_if_dtype_not_supported(dt, q)
519+
520+
dtype_ = dpt.dtype(dt)
521+
n0, n1 = 4, 53
522+
523+
arr_flat = _typesafe_arange(n0 * n1 * n1, dtype_, q)
524+
x = dpt.reshape(arr_flat, (n0, n1, n1)).mT
525+
526+
y = dpt.asarray(x, order="C")
527+
assert dpt.all(x == y)
528+
529+
x2 = x[0]
530+
y2 = dpt.asarray(x2, order="C")
531+
assert dpt.all(x2 == y2)
532+
533+
x3 = dpt.flip(x, axis=1)
534+
y3 = dpt.asarray(x3, order="C")
535+
assert dpt.all(x3 == y3)
536+
537+
538+
@pytest.mark.parametrize("dt", _all_dtypes)
539+
def test_as_f_contig_square(dt):
540+
q = get_queue_or_skip()
541+
skip_if_dtype_not_supported(dt, q)
542+
543+
dtype_ = dpt.dtype(dt)
544+
n0, n1 = 6, 53
545+
546+
arr_flat = _typesafe_arange(n0 * n1 * n1, dtype_, q)
547+
x = dpt.moveaxis(dpt.reshape(arr_flat, (n0, n1, n1)), (1, 2), (0, 1))
548+
549+
y = dpt.asarray(x, order="F")
550+
assert dpt.all(x == y)
551+
552+
x2 = x[..., 0]
553+
y2 = dpt.asarray(x2, order="F")
554+
assert dpt.all(x2 == y2)
555+
556+
x3 = dpt.flip(x, axis=1)
557+
y3 = dpt.asarray(x3, order="F")
558+
assert dpt.all(x3 == y3)

0 commit comments

Comments
 (0)