|
3 | 3 | """
|
4 | 4 |
|
5 | 5 | import pytest
|
6 |
| -from hypothesis import assume, given |
7 |
| -from hypothesis import strategies as st |
8 | 6 |
|
9 |
| -from .. import _array_module as xp |
10 |
| -from .. import dtype_helpers as dh |
11 |
| -from .. import hypothesis_helpers as hh |
12 |
| -from .. import pytest_helpers as ph |
13 |
| -from .._array_module import _UndefinedStub |
14 | 7 | from ..algos import BroadcastError, _broadcast_shapes
|
15 |
| -from ..function_stubs import elementwise_functions |
16 | 8 |
|
17 | 9 |
|
18 | 10 | @pytest.mark.parametrize(
|
@@ -41,33 +33,3 @@ def test_broadcast_shapes(shape1, shape2, expected):
|
41 | 33 | def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
|
42 | 34 | with pytest.raises(BroadcastError):
|
43 | 35 | _broadcast_shapes(shape1, shape2)
|
44 |
| - |
45 |
| - |
46 |
| -# TODO: Extend this to all functions (not just elementwise), and handle |
47 |
| -# functions that take more than 2 args |
48 |
| -@pytest.mark.parametrize( |
49 |
| - "func_name", [i for i in elementwise_functions.__all__ if ph.nargs(i) > 1] |
50 |
| -) |
51 |
| -@given(shape1=hh.shapes(), shape2=hh.shapes(), data=st.data()) |
52 |
| -def test_broadcasting_hypothesis(func_name, shape1, shape2, data): |
53 |
| - dtype = data.draw(st.sampled_from(dh.func_in_dtypes[func_name]), label="dtype") |
54 |
| - if hh.FILTER_UNDEFINED_DTYPES: |
55 |
| - assume(not isinstance(dtype, _UndefinedStub)) |
56 |
| - func = getattr(xp, func_name) |
57 |
| - if isinstance(func, xp._UndefinedStub): |
58 |
| - func._raise() |
59 |
| - args = [xp.ones(shape1, dtype=dtype), xp.ones(shape2, dtype=dtype)] |
60 |
| - try: |
61 |
| - broadcast_shape = _broadcast_shapes(shape1, shape2) |
62 |
| - except BroadcastError: |
63 |
| - ph.raises( |
64 |
| - Exception, |
65 |
| - lambda: func(*args), |
66 |
| - f"{func_name} should raise an exception from not being able to broadcast inputs with hh.shapes {(shape1, shape2)}", |
67 |
| - ) |
68 |
| - else: |
69 |
| - result = ph.doesnt_raise( |
70 |
| - lambda: func(*args), |
71 |
| - f"{func_name} raised an unexpected exception from broadcastable inputs with hh.shapes {(shape1, shape2)}", |
72 |
| - ) |
73 |
| - assert result.shape == broadcast_shape, "broadcast hh.shapes incorrect" |
0 commit comments