Skip to content

Commit 957b93a

Browse files
committed
Add tests for new fill method
1 parent 9691cf0 commit 957b93a

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/test_fill.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import dpctl
2+
import numpy as np
3+
import pytest
4+
from dpctl.utils import ExecutionPlacementError
5+
from numpy.testing import assert_array_equal
6+
7+
import dpnp as dnp
8+
9+
10+
def test_fill_non_scalar():
11+
a = dnp.ones(5, dtype="i4")
12+
val = dnp.ones(2, dtype="i4")
13+
14+
with pytest.raises(ValueError):
15+
a.fill(val)
16+
17+
with pytest.raises(TypeError):
18+
a.fill(dict())
19+
20+
21+
def test_fill_compute_follows_data():
22+
q1 = dpctl.SyclQueue()
23+
q2 = dpctl.SyclQueue()
24+
25+
a = dnp.ones(5, dtype="i4", sycl_queue=q1)
26+
val = dnp.ones((), dtype=a.dtype, sycl_queue=q2)
27+
28+
with pytest.raises(ExecutionPlacementError):
29+
a.fill(val)
30+
31+
32+
def test_fill_strided_array():
33+
a = dnp.zeros(100, dtype="i4")
34+
b = a[::-2]
35+
36+
expected = dnp.tile(dnp.asarray([0, 1], dtype=a.dtype), 50)
37+
38+
b.fill(1)
39+
assert_array_equal(b, 1)
40+
assert_array_equal(a, expected)
41+
42+
43+
@pytest.mark.parametrize("order", ["C", "F"])
44+
def test_fill_strided_2d_array(order):
45+
a = dnp.zeros((10, 10), dtype="i4", order=order)
46+
b = a[::-2, ::2]
47+
48+
expected = dnp.copy(a)
49+
expected[::-2, ::2] = 1
50+
51+
b.fill(1)
52+
assert_array_equal(b, 1)
53+
assert_array_equal(a, expected)
54+
55+
56+
@pytest.mark.parametrize("order", ["C", "F"])
57+
def test_fill_memset(order):
58+
a = dnp.ones((10, 10), dtype="i4", order=order)
59+
a.fill(0)
60+
61+
assert_array_equal(a, 0)

0 commit comments

Comments
 (0)