Skip to content

Commit b1e3f2e

Browse files
committed
Fix right array modify an in-place array in type promotion tests
1 parent bcce8e6 commit b1e3f2e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List, Tuple, Union
66

77
import pytest
8-
from hypothesis import assume, given, reject
8+
from hypothesis import given, reject
99
from hypothesis import strategies as st
1010

1111
from . import _array_module as xp
@@ -258,15 +258,15 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
258258

259259

260260
@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", inplace_params)
261-
@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data())
262-
def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
263-
assume(len(shapes[0]) >= len(shapes[1]))
261+
@given(shape=hh.shapes(), data=st.data())
262+
def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shape, data):
263+
# TODO: test broadcastable shapes (that don't change x1's shape)
264264
elements = func_elements[func_name]
265265
x1 = data.draw(
266-
xps.arrays(dtype=in_dtypes[0], shape=shapes[0], elements=elements), label="x1"
266+
xps.arrays(dtype=in_dtypes[0], shape=shape, elements=elements), label="x1"
267267
)
268268
x2 = data.draw(
269-
xps.arrays(dtype=in_dtypes[1], shape=shapes[1], elements=elements), label="x2"
269+
xps.arrays(dtype=in_dtypes[1], shape=shape, elements=elements), label="x2"
270270
)
271271
locals_ = {"x1": x1, "x2": x2}
272272
try:

0 commit comments

Comments
 (0)