Skip to content

Commit 74e8957

Browse files
Merge pull request #1904 from IntelPython/backport-gh1874
Backport gh1874
2 parents 3976d5c + bf71ca1 commit 74e8957

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.18.2] - Nov. XX, 2024
8+
9+
### Fixed
10+
* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1904](https://github.com/IntelPython/dpctl/pull/1904)
11+
712
## [0.18.1] - Oct. 11, 2024
813

914
### Changed

dpctl/tensor/_type_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,9 @@ def result_type(*arrays_and_dtypes):
767767
target_dev = d
768768
inspected = True
769769

770+
if not dtypes and weak_dtypes:
771+
dtypes.append(weak_dtypes[0].get())
772+
770773
if not (has_fp16 and has_fp64):
771774
for dt in dtypes:
772775
if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64):

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616

1717

18+
import itertools
19+
1820
import numpy as np
1921
import pytest
2022
from numpy.testing import assert_, assert_array_equal, assert_raises_regex
@@ -1531,3 +1533,26 @@ def test_repeat_0_size():
15311533
res = dpt.repeat(x, repetitions, axis=1)
15321534
axis_sz = 2 * x.shape[1]
15331535
assert res.shape == (0, axis_sz, 0)
1536+
1537+
1538+
def test_result_type_bug_1874():
1539+
py_sc = True
1540+
np_sc = np.asarray([py_sc])[0]
1541+
dts_bool = [py_sc, np_sc]
1542+
py_sc = int(1)
1543+
np_sc = np.asarray([py_sc])[0]
1544+
dts_ints = [py_sc, np_sc]
1545+
dts_floats = [float(1), np.float64(1)]
1546+
dts_complexes = [complex(1), np.complex128(1)]
1547+
1548+
# iterate over two categories
1549+
for dts1, dts2 in itertools.product(
1550+
[dts_bool, dts_ints, dts_floats, dts_complexes], repeat=2
1551+
):
1552+
res_dts = []
1553+
# iterate over Python scalar/NumPy scalar choices within categories
1554+
for dt1, dt2 in itertools.product(dts1, dts2):
1555+
res_dt = dpt.result_type(dt1, dt2)
1556+
res_dts.append(res_dt)
1557+
# check that all results are the same
1558+
assert res_dts and all(res_dts[0] == el for el in res_dts[1:])

0 commit comments

Comments
 (0)