Skip to content

Commit 0293364

Browse files
committed
Factor out broadcast_shapes
1 parent f3bbc52 commit 0293364

File tree

8 files changed

+158
-145
lines changed

8 files changed

+158
-145
lines changed

array_api_tests/algos.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
__all__ = ["broadcast_shapes"]
2+
3+
4+
from .typing import Shape
5+
6+
7+
# We use a custom exception to differentiate from potential bugs
8+
class BroadcastError(ValueError):
9+
pass
10+
11+
12+
def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
13+
"""Broadcasts `shape1` and `shape2`"""
14+
N1 = len(shape1)
15+
N2 = len(shape2)
16+
N = max(N1, N2)
17+
shape = [None for _ in range(N)]
18+
i = N - 1
19+
while i >= 0:
20+
n1 = N1 - N + i
21+
if N1 - N + i >= 0:
22+
d1 = shape1[n1]
23+
else:
24+
d1 = 1
25+
n2 = N2 - N + i
26+
if N2 - N + i >= 0:
27+
d2 = shape2[n2]
28+
else:
29+
d2 = 1
30+
31+
if d1 == 1:
32+
shape[i] = d2
33+
elif d2 == 1:
34+
shape[i] = d1
35+
elif d1 == d2:
36+
shape[i] = d1
37+
else:
38+
raise BroadcastError
39+
40+
i = i - 1
41+
42+
return tuple(shape)
43+
44+
45+
def broadcast_shapes(*shapes: Shape):
46+
if len(shapes) == 0:
47+
raise ValueError("shapes=[] must be non-empty")
48+
elif len(shapes) == 1:
49+
return shapes[0]
50+
result = _broadcast_shapes(shapes[0], shapes[1])
51+
for i in range(2, len(shapes)):
52+
result = _broadcast_shapes(result, shapes[i])
53+
return result

array_api_tests/hypothesis_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .function_stubs import elementwise_functions
2020
from .pytest_helpers import nargs
2121
from .typing import Array, DataType, Shape
22+
from .algos import broadcast_shapes
2223

2324
# Set this to True to not fail tests just because a dtype isn't implemented.
2425
# If no compatible dtype is implemented for a given test, the test will fail
@@ -218,7 +219,6 @@ def two_broadcastable_shapes(draw):
218219
This will produce two shapes (shape1, shape2) such that shape2 can be
219220
broadcast to shape1.
220221
"""
221-
from .test_broadcasting import broadcast_shapes
222222
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
223223
assume(broadcast_shapes(shape1, shape2) == shape1)
224224
return (shape1, shape2)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
https://github.com/data-apis/array-api/blob/master/spec/API_specification/broadcasting.md
3+
"""
4+
5+
import pytest
6+
from hypothesis import assume, given
7+
from hypothesis import strategies as st
8+
9+
from .. import _array_module as xp
10+
from .. import dtype_helpers as dh
11+
from .. import hypothesis_helpers as hh
12+
from .. import pytest_helpers as ph
13+
from .._array_module import _UndefinedStub
14+
from ..algos import BroadcastError, _broadcast_shapes
15+
from ..function_stubs import elementwise_functions
16+
17+
18+
@pytest.mark.parametrize(
19+
"shape1, shape2, expected",
20+
[
21+
[(8, 1, 6, 1), (7, 1, 5), (8, 7, 6, 5)],
22+
[(5, 4), (1,), (5, 4)],
23+
[(5, 4), (4,), (5, 4)],
24+
[(15, 3, 5), (15, 1, 5), (15, 3, 5)],
25+
[(15, 3, 5), (3, 5), (15, 3, 5)],
26+
[(15, 3, 5), (3, 1), (15, 3, 5)],
27+
],
28+
)
29+
def test_broadcast_shapes(shape1, shape2, expected):
30+
assert _broadcast_shapes(shape1, shape2) == expected
31+
32+
33+
@pytest.mark.parametrize(
34+
"shape1, shape2",
35+
[
36+
[(3,), (4,)], # dimension does not match
37+
[(2, 1), (8, 4, 3)], # second dimension does not match
38+
[(15, 3, 5), (15, 3)], # singleton dimensions can only be prepended
39+
],
40+
)
41+
def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
42+
with pytest.raises(BroadcastError):
43+
_broadcast_shapes(shape1, shape2)
44+
45+
46+
# TODO: Extend this to all functions (not just elementwise), and handle
47+
# functions that take more than 2 args
48+
@pytest.mark.parametrize(
49+
"func_name", [i for i in elementwise_functions.__all__ if ph.nargs(i) > 1]
50+
)
51+
@given(shape1=hh.shapes(), shape2=hh.shapes(), data=st.data())
52+
def test_broadcasting_hypothesis(func_name, shape1, shape2, data):
53+
dtype = data.draw(st.sampled_from(dh.func_in_dtypes[func_name]), label="dtype")
54+
if hh.FILTER_UNDEFINED_DTYPES:
55+
assume(not isinstance(dtype, _UndefinedStub))
56+
func = getattr(xp, func_name)
57+
if isinstance(func, xp._UndefinedStub):
58+
func._raise()
59+
args = [xp.ones(shape1, dtype=dtype), xp.ones(shape2, dtype=dtype)]
60+
try:
61+
broadcast_shape = _broadcast_shapes(shape1, shape2)
62+
except BroadcastError:
63+
ph.raises(
64+
Exception,
65+
lambda: func(*args),
66+
f"{func_name} should raise an exception from not being able to broadcast inputs with hh.shapes {(shape1, shape2)}",
67+
)
68+
else:
69+
result = ph.doesnt_raise(
70+
lambda: func(*args),
71+
f"{func_name} raised an unexpected exception from broadcastable inputs with hh.shapes {(shape1, shape2)}",
72+
)
73+
assert result.shape == broadcast_shape, "broadcast hh.shapes incorrect"

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from math import prod
22

33
import pytest
4-
from hypothesis import given, strategies as st, settings
4+
from hypothesis import given, settings
5+
from hypothesis import strategies as st
56

67
from .. import _array_module as xp
7-
from .. import xps
8-
from .._array_module import _UndefinedStub
98
from .. import array_helpers as ah
109
from .. import dtype_helpers as dh
1110
from .. import hypothesis_helpers as hh
12-
from ..test_broadcasting import broadcast_shapes
11+
from .. import xps
12+
from .._array_module import _UndefinedStub
13+
from ..algos import broadcast_shapes
1314

1415
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1516
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]

array_api_tests/pytest_helpers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from array_api_tests.algos import broadcast_shapes
12
import math
23
from inspect import getfullargspec
34
from typing import Any, Dict, Optional, Tuple, Union
@@ -17,6 +18,7 @@
1718
"assert_default_float",
1819
"assert_default_int",
1920
"assert_shape",
21+
"assert_result_shape",
2022
"assert_fill",
2123
]
2224

@@ -131,6 +133,28 @@ def assert_shape(
131133
assert out_shape == expected, msg
132134

133135

136+
def assert_result_shape(
137+
func_name: str,
138+
in_shapes: Tuple[Shape],
139+
out_shape: Shape,
140+
/,
141+
expected: Optional[Shape] = None,
142+
*,
143+
repr_name="out.shape",
144+
**kw,
145+
):
146+
if expected is None:
147+
expected = broadcast_shapes(*in_shapes)
148+
f_in_shapes = " . ".join(str(s) for s in in_shapes)
149+
f_sig = f" {f_in_shapes} "
150+
if kw:
151+
f_sig += f", {fmt_kw(kw)}"
152+
msg = (
153+
f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]"
154+
)
155+
assert out_shape == expected, msg
156+
157+
134158
def assert_fill(
135159
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
136160
):

array_api_tests/test_broadcasting.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

array_api_tests/test_elementwise_functions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,9 @@
2424
from . import hypothesis_helpers as hh
2525
from . import pytest_helpers as ph
2626
from . import xps
27+
from .algos import broadcast_shapes
2728
from .typing import Array, DataType, Param, Scalar
2829

29-
# We might as well use this implementation rather than xp.broadcast_shapes()
30-
from .test_broadcasting import broadcast_shapes
31-
32-
3330
# When appropiate, this module tests operators alongside their respective
3431
# elementwise methods. We do this by parametrizing a generalised test method
3532
# with every relevant method and operator.

array_api_tests/test_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from . import dtype_helpers as dh
3030
from . import pytest_helpers as ph
3131

32-
from .test_broadcasting import broadcast_shapes
32+
from .algos import broadcast_shapes
3333

3434
from . import _array_module
3535
from ._array_module import linalg

0 commit comments

Comments
 (0)