Skip to content

Commit 79208c8

Browse files
committed
Adjusts tests for in-place element-wise operations to account for "same_kind" casting
1 parent bab3571 commit 79208c8

11 files changed

+12
-12
lines changed

dpctl/tests/elementwise/test_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
358358
dev = q.sycl_device
359359
_fp16 = dev.has_aspect_fp16
360360
_fp64 = dev.has_aspect_fp64
361-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
361+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
362362
ar1 += ar2
363363
assert (
364364
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)

dpctl/tests/elementwise/test_bitwise_and.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_bitwise_and_inplace_dtype_matrix(op1_dtype, op2_dtype):
114114
dev = q.sycl_device
115115
_fp16 = dev.has_aspect_fp16
116116
_fp64 = dev.has_aspect_fp64
117-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
117+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
118118
ar1 &= ar2
119119
assert dpt.all(ar1 == 1)
120120

dpctl/tests/elementwise/test_bitwise_left_shift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def test_bitwise_left_shift_inplace_dtype_matrix(op1_dtype, op2_dtype):
122122
dev = q.sycl_device
123123
_fp16 = dev.has_aspect_fp16
124124
_fp64 = dev.has_aspect_fp64
125-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
125+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
126126
ar1 <<= ar2
127127
assert dpt.all(ar1 == 2)
128128

dpctl/tests/elementwise/test_bitwise_or.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_bitwise_or_inplace_dtype_matrix(op1_dtype, op2_dtype):
114114
dev = q.sycl_device
115115
_fp16 = dev.has_aspect_fp16
116116
_fp64 = dev.has_aspect_fp64
117-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
117+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
118118
ar1 |= ar2
119119
assert dpt.all(ar1 == 1)
120120

dpctl/tests/elementwise/test_bitwise_xor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_bitwise_xor_inplace_dtype_matrix(op1_dtype, op2_dtype):
114114
dev = q.sycl_device
115115
_fp16 = dev.has_aspect_fp16
116116
_fp64 = dev.has_aspect_fp64
117-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
117+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
118118
ar1 ^= ar2
119119
assert dpt.all(ar1 == 0)
120120

dpctl/tests/elementwise/test_divide.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
226226
_fp64 = dev.has_aspect_fp64
227227
# out array only valid if it is inexact
228228
if (
229-
_can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64)
229+
_can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind")
230230
and dpt.dtype(op1_dtype).kind in "fc"
231231
):
232232
ar1 /= ar2
@@ -276,7 +276,7 @@ def test_divide_gh_1711():
276276

277277

278278
# don't test for overflowing double as Python won't cast
279-
# an Python integer of that size to a Python float
279+
# a Python integer of that size to a Python float
280280
@pytest.mark.parametrize("fp_dt", [dpt.float16, dpt.float32])
281281
def test_divide_by_scalar_overflow(fp_dt):
282282
q = get_queue_or_skip()

dpctl/tests/elementwise/test_floor_divide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_floor_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
290290
_fp16 = dev.has_aspect_fp16
291291
_fp64 = dev.has_aspect_fp64
292292
# out array only valid if it is inexact
293-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
293+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
294294
ar1 //= ar2
295295
assert dpt.all(ar1 == 1)
296296

dpctl/tests/elementwise/test_multiply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test_multiply_inplace_dtype_matrix(op1_dtype, op2_dtype):
205205
dev = q.sycl_device
206206
_fp16 = dev.has_aspect_fp16
207207
_fp64 = dev.has_aspect_fp64
208-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
208+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
209209
ar1 *= ar2
210210
assert (
211211
dpt.asnumpy(ar1) == np.full(ar1.shape, 1, dtype=ar1.dtype)

dpctl/tests/elementwise/test_pow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_pow_inplace_dtype_matrix(op1_dtype, op2_dtype):
183183
dev = q.sycl_device
184184
_fp16 = dev.has_aspect_fp16
185185
_fp64 = dev.has_aspect_fp64
186-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
186+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
187187
ar1 **= ar2
188188
assert (
189189
dpt.asnumpy(ar1) == np.full(ar1.shape, 1, dtype=ar1.dtype)

dpctl/tests/elementwise/test_remainder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_remainder_inplace_dtype_matrix(op1_dtype, op2_dtype):
235235
dev = q.sycl_device
236236
_fp16 = dev.has_aspect_fp16
237237
_fp64 = dev.has_aspect_fp64
238-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
238+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
239239
ar1 %= ar2
240240
assert dpt.all(ar1 == dpt.zeros(ar1.shape, dtype=ar1.dtype))
241241

dpctl/tests/elementwise/test_subtract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_subtract_inplace_dtype_matrix(op1_dtype, op2_dtype):
208208
dev = q.sycl_device
209209
_fp16 = dev.has_aspect_fp16
210210
_fp64 = dev.has_aspect_fp64
211-
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
211+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64, casting="same_kind"):
212212
ar1 -= ar2
213213
assert (dpt.asnumpy(ar1) == np.zeros(ar1.shape, dtype=ar1.dtype)).all()
214214

0 commit comments

Comments
 (0)