Skip to content

Commit 286afae

Browse files
Merge pull request #1877 from IntelPython/bugfix/gh-1874-result_type
Bugfix/gh 1874 result type
2 parents 2023622 + dc1887e commit 286afae

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
* Improved performance of `tensor.sort` and `tensor.argsort` for short arrays in the range [16, 64] elements [gh-1866](https://github.com/IntelPython/dpctl/pull/1866)
1717

1818
### Fixed
19+
* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1877](https://github.com/IntelPython/dpctl/pull/1877)
1920

2021
### Maintenance
2122

dpctl/tensor/_type_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,9 @@ def result_type(*arrays_and_dtypes):
767767
target_dev = d
768768
inspected = True
769769

770+
if not dtypes and weak_dtypes:
771+
dtypes.append(weak_dtypes[0].get())
772+
770773
if not (has_fp16 and has_fp64):
771774
for dt in dtypes:
772775
if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64):

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616

1717

18+
import itertools
19+
1820
import numpy as np
1921
import pytest
2022
from numpy.testing import assert_, assert_array_equal, assert_raises_regex
@@ -1555,3 +1557,26 @@ def test_repeat_0_size():
15551557
res = dpt.repeat(x, repetitions, axis=1)
15561558
axis_sz = 2 * x.shape[1]
15571559
assert res.shape == (0, axis_sz, 0)
1560+
1561+
1562+
def test_result_type_bug_1874():
1563+
py_sc = True
1564+
np_sc = np.asarray([py_sc])[0]
1565+
dts_bool = [py_sc, np_sc]
1566+
py_sc = int(1)
1567+
np_sc = np.asarray([py_sc])[0]
1568+
dts_ints = [py_sc, np_sc]
1569+
dts_floats = [float(1), np.float64(1)]
1570+
dts_complexes = [complex(1), np.complex128(1)]
1571+
1572+
# iterate over two categories
1573+
for dts1, dts2 in itertools.product(
1574+
[dts_bool, dts_ints, dts_floats, dts_complexes], repeat=2
1575+
):
1576+
res_dts = []
1577+
# iterate over Python scalar/NumPy scalar choices within categories
1578+
for dt1, dt2 in itertools.product(dts1, dts2):
1579+
res_dt = dpt.result_type(dt1, dt2)
1580+
res_dts.append(res_dt)
1581+
# check that all results are the same
1582+
assert res_dts and all(res_dts[0] == el for el in res_dts[1:])

0 commit comments

Comments
 (0)