Skip to content

Commit 977bcef

Browse files
authored
Merge pull request #35 from honno/operator-tests
Operator tests
2 parents d906712 + e4f331e commit 977bcef

10 files changed

+1185
-663
lines changed

array_api_tests/algos.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
__all__ = ["broadcast_shapes"]
2+
3+
4+
from .typing import Shape
5+
6+
7+
# We use a custom exception to differentiate from potential bugs
8+
class BroadcastError(ValueError):
9+
pass
10+
11+
12+
def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
13+
"""Broadcasts `shape1` and `shape2`"""
14+
N1 = len(shape1)
15+
N2 = len(shape2)
16+
N = max(N1, N2)
17+
shape = [None for _ in range(N)]
18+
i = N - 1
19+
while i >= 0:
20+
n1 = N1 - N + i
21+
if N1 - N + i >= 0:
22+
d1 = shape1[n1]
23+
else:
24+
d1 = 1
25+
n2 = N2 - N + i
26+
if N2 - N + i >= 0:
27+
d2 = shape2[n2]
28+
else:
29+
d2 = 1
30+
31+
if d1 == 1:
32+
shape[i] = d2
33+
elif d2 == 1:
34+
shape[i] = d1
35+
elif d1 == d2:
36+
shape[i] = d1
37+
else:
38+
raise BroadcastError
39+
40+
i = i - 1
41+
42+
return tuple(shape)
43+
44+
45+
def broadcast_shapes(*shapes: Shape):
46+
if len(shapes) == 0:
47+
raise ValueError("shapes=[] must be non-empty")
48+
elif len(shapes) == 1:
49+
return shapes[0]
50+
result = _broadcast_shapes(shapes[0], shapes[1])
51+
for i in range(2, len(shapes)):
52+
result = _broadcast_shapes(result, shapes[i])
53+
return result

array_api_tests/dtype_helpers.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'dtype_to_scalars',
2020
'is_int_dtype',
2121
'is_float_dtype',
22+
'get_scalar_type',
2223
'dtype_ranges',
2324
'default_int',
2425
'default_float',
@@ -30,6 +31,7 @@
3031
'binary_op_to_symbol',
3132
'unary_op_to_symbol',
3233
'inplace_op_to_symbol',
34+
'op_to_func',
3335
'fmt_types',
3436
]
3537

@@ -74,6 +76,15 @@ def is_float_dtype(dtype):
7476
return dtype in float_dtypes
7577

7678

79+
def get_scalar_type(dtype: DataType) -> ScalarType:
80+
if is_int_dtype(dtype):
81+
return int
82+
elif is_float_dtype(dtype):
83+
return float
84+
else:
85+
return bool
86+
87+
7788
class MinMax(NamedTuple):
7889
min: int
7990
max: int
@@ -332,7 +343,7 @@ def result_type(*dtypes: DataType):
332343
}
333344

334345

335-
_op_to_func = {
346+
op_to_func = {
336347
'__abs__': 'abs',
337348
'__add__': 'add',
338349
'__and__': 'bitwise_and',
@@ -341,14 +352,14 @@ def result_type(*dtypes: DataType):
341352
'__ge__': 'greater_equal',
342353
'__gt__': 'greater',
343354
'__le__': 'less_equal',
344-
'__lshift__': 'bitwise_left_shift',
345355
'__lt__': 'less',
346356
# '__matmul__': 'matmul', # TODO: support matmul
347357
'__mod__': 'remainder',
348358
'__mul__': 'multiply',
349359
'__ne__': 'not_equal',
350360
'__or__': 'bitwise_or',
351361
'__pow__': 'pow',
362+
'__lshift__': 'bitwise_left_shift',
352363
'__rshift__': 'bitwise_right_shift',
353364
'__sub__': 'subtract',
354365
'__truediv__': 'divide',
@@ -359,7 +370,7 @@ def result_type(*dtypes: DataType):
359370
}
360371

361372

362-
for op, elwise_func in _op_to_func.items():
373+
for op, elwise_func in op_to_func.items():
363374
func_in_dtypes[op] = func_in_dtypes[elwise_func]
364375
func_returns_bool[op] = func_returns_bool[elwise_func]
365376

array_api_tests/hypothesis_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from .array_helpers import ndindex
1919
from .function_stubs import elementwise_functions
2020
from .pytest_helpers import nargs
21-
from .typing import DataType, Shape, Array
21+
from .typing import Array, DataType, Shape
22+
from .algos import broadcast_shapes
2223

2324
# Set this to True to not fail tests just because a dtype isn't implemented.
2425
# If no compatible dtype is implemented for a given test, the test will fail
@@ -218,7 +219,6 @@ def two_broadcastable_shapes(draw):
218219
This will produce two shapes (shape1, shape2) such that shape2 can be
219220
broadcast to shape1.
220221
"""
221-
from .test_broadcasting import broadcast_shapes
222222
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
223223
assume(broadcast_shapes(shape1, shape2) == shape1)
224224
return (shape1, shape2)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
https://github.com/data-apis/array-api/blob/master/spec/API_specification/broadcasting.md
3+
"""
4+
5+
import pytest
6+
7+
from ..algos import BroadcastError, _broadcast_shapes
8+
9+
10+
@pytest.mark.parametrize(
11+
"shape1, shape2, expected",
12+
[
13+
[(8, 1, 6, 1), (7, 1, 5), (8, 7, 6, 5)],
14+
[(5, 4), (1,), (5, 4)],
15+
[(5, 4), (4,), (5, 4)],
16+
[(15, 3, 5), (15, 1, 5), (15, 3, 5)],
17+
[(15, 3, 5), (3, 5), (15, 3, 5)],
18+
[(15, 3, 5), (3, 1), (15, 3, 5)],
19+
],
20+
)
21+
def test_broadcast_shapes(shape1, shape2, expected):
22+
assert _broadcast_shapes(shape1, shape2) == expected
23+
24+
25+
@pytest.mark.parametrize(
26+
"shape1, shape2",
27+
[
28+
[(3,), (4,)], # dimension does not match
29+
[(2, 1), (8, 4, 3)], # second dimension does not match
30+
[(15, 3, 5), (15, 3)], # singleton dimensions can only be prepended
31+
],
32+
)
33+
def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
34+
with pytest.raises(BroadcastError):
35+
_broadcast_shapes(shape1, shape2)

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from math import prod
22

33
import pytest
4-
from hypothesis import given, strategies as st, settings
4+
from hypothesis import given, settings
5+
from hypothesis import strategies as st
56

67
from .. import _array_module as xp
7-
from .. import xps
8-
from .._array_module import _UndefinedStub
98
from .. import array_helpers as ah
109
from .. import dtype_helpers as dh
1110
from .. import hypothesis_helpers as hh
12-
from ..test_broadcasting import broadcast_shapes
11+
from .. import xps
12+
from .._array_module import _UndefinedStub
13+
from ..algos import broadcast_shapes
1314

1415
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1516
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]

array_api_tests/pytest_helpers.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from array_api_tests.algos import broadcast_shapes
12
import math
23
from inspect import getfullargspec
34
from typing import Any, Dict, Optional, Tuple, Union
@@ -17,6 +18,7 @@
1718
"assert_default_float",
1819
"assert_default_int",
1920
"assert_shape",
21+
"assert_result_shape",
2022
"assert_fill",
2123
]
2224

@@ -69,15 +71,15 @@ def assert_dtype(
6971
out_dtype: DataType,
7072
expected: Optional[DataType] = None,
7173
*,
72-
out_name: str = "out.dtype",
74+
repr_name: str = "out.dtype",
7375
):
7476
f_in_dtypes = dh.fmt_types(in_dtypes)
7577
f_out_dtype = dh.dtype_to_name[out_dtype]
7678
if expected is None:
7779
expected = dh.result_type(*in_dtypes)
7880
f_expected = dh.dtype_to_name[expected]
7981
msg = (
80-
f"{out_name}={f_out_dtype}, but should be {f_expected} "
82+
f"{repr_name}={f_out_dtype}, but should be {f_expected} "
8183
f"[{func_name}({f_in_dtypes})]"
8284
)
8385
assert out_dtype == expected, msg
@@ -114,14 +116,41 @@ def assert_default_int(func_name: str, dtype: DataType):
114116

115117

116118
def assert_shape(
117-
func_name: str, out_shape: Union[int, Shape], expected: Union[int, Shape], /, **kw
119+
func_name: str,
120+
out_shape: Union[int, Shape],
121+
expected: Union[int, Shape],
122+
/,
123+
repr_name="out.shape",
124+
**kw,
118125
):
119126
if isinstance(out_shape, int):
120127
out_shape = (out_shape,)
121128
if isinstance(expected, int):
122129
expected = (expected,)
123130
msg = (
124-
f"out.shape={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
131+
f"{repr_name}={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
132+
)
133+
assert out_shape == expected, msg
134+
135+
136+
def assert_result_shape(
137+
func_name: str,
138+
in_shapes: Tuple[Shape],
139+
out_shape: Shape,
140+
/,
141+
expected: Optional[Shape] = None,
142+
*,
143+
repr_name="out.shape",
144+
**kw,
145+
):
146+
if expected is None:
147+
expected = broadcast_shapes(*in_shapes)
148+
f_in_shapes = " . ".join(str(s) for s in in_shapes)
149+
f_sig = f" {f_in_shapes} "
150+
if kw:
151+
f_sig += f", {fmt_kw(kw)}"
152+
msg = (
153+
f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]"
125154
)
126155
assert out_shape == expected, msg
127156

0 commit comments

Comments
 (0)