Skip to content

Commit 375bec2

Browse files
committed
Fixed dpctl.tensor.result_type function for scalars
1 parent 9afb742 commit 375bec2

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -727,12 +727,16 @@ def result_type(*arrays_and_dtypes):
727727
The dtype resulting from an operation involving the
728728
input arrays and dtypes.
729729
"""
730-
dtypes = [
731-
X.dtype if isinstance(X, dpt.usm_ndarray) else dpt.dtype(X)
732-
for X in arrays_and_dtypes
733-
]
734-
735-
_supported_dtype(dtypes)
730+
dtypes = []
731+
for X in arrays_and_dtypes:
732+
if isinstance(X, dpt.usm_ndarray):
733+
dtypes.append(X.dtype)
734+
elif np.isscalar(X) or isinstance(X, (tuple, list, range)):
735+
dtypes.append(X)
736+
else:
737+
dtype = dpt.dtype(X)
738+
_supported_dtype([dtype])
739+
dtypes.append(dtype)
736740

737741
return np.result_type(*dtypes)
738742

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,8 +935,8 @@ def test_can_cast():
935935
def test_result_type():
936936
q = get_queue_or_skip()
937937

938-
X = [dpt.ones((2), dtype=dpt.int64, sycl_queue=q), dpt.int32, "float16"]
939-
X_np = [np.ones((2), dtype=np.int64), np.int32, "float16"]
938+
X = [dpt.ones((2), dtype=dpt.int64, sycl_queue=q), dpt.int32, "float16", 2]
939+
X_np = [np.ones((2), dtype=np.int64), np.int32, "float16", 2]
940940

941941
assert dpt.result_type(*X) == np.result_type(*X_np)
942942

0 commit comments

Comments
 (0)