Skip to content

Commit 0a68c34

Browse files
Add test that result_types(dtypes) works the same for Python/NumPy scalars
1 parent 4f0f1a3 commit 0a68c34

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616

1717

18+
import itertools
19+
1820
import numpy as np
1921
import pytest
2022
from numpy.testing import assert_, assert_array_equal, assert_raises_regex
@@ -1555,3 +1557,22 @@ def test_repeat_0_size():
15551557
res = dpt.repeat(x, repetitions, axis=1)
15561558
axis_sz = 2 * x.shape[1]
15571559
assert res.shape == (0, axis_sz, 0)
1560+
1561+
1562+
def test_result_type_bug_1874():
1563+
dts_bool = [True, np.bool_(True)]
1564+
dts_ints = [int(1), np.int64(1)]
1565+
dts_floats = [float(1), np.float64(1)]
1566+
dts_complexes = [complex(1), np.complex128(1)]
1567+
1568+
# iterate over two categories
1569+
for dts1, dts2 in itertools.product(
1570+
[dts_bool, dts_ints, dts_floats, dts_complexes], repeat=2
1571+
):
1572+
res_dts = []
1573+
# iterate over Python scalar/NumPy scalar choices within categories
1574+
for dt1, dt2 in itertools.product(dts1, dts2):
1575+
res_dt = dpt.result_type(dt1, dt2)
1576+
res_dts.append(res_dt)
1577+
# check that all results are the same
1578+
assert res_dts and all(res_dts[0] == el for el in res_dts[1:])

0 commit comments

Comments
 (0)