Skip to content

Commit 604be62

Browse files
committed
Type hint some hypothesis helpers
1 parent daa73c2 commit 604be62

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from operator import mul
33
from math import sqrt
44
import itertools
5-
from typing import Tuple, Optional
5+
from typing import Tuple, Optional, List
66

77
from hypothesis import assume
88
from hypothesis.strategies import (lists, integers, sampled_from,
@@ -11,6 +11,7 @@
1111

1212
from .pytest_helpers import nargs
1313
from .array_helpers import ndindex
14+
from .typing import DataType, Shape
1415
from . import dtype_helpers as dh
1516
from ._array_module import (full, float32, float64, bool as bool_dtype,
1617
_UndefinedStub, eye, broadcast_to)
@@ -49,7 +50,7 @@
4950
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes]
5051
_sorted_dtypes = [d for category in _dtype_categories for d in category]
5152

52-
def _dtypes_sorter(dtype_pair):
53+
def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
5354
dtype1, dtype2 = dtype_pair
5455
if dtype1 == dtype2:
5556
return _sorted_dtypes.index(dtype1)
@@ -66,7 +67,7 @@ def _dtypes_sorter(dtype_pair):
6667
key += 1
6768
return key
6869

69-
promotable_dtypes = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
70+
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
7071

7172
if FILTER_UNDEFINED_DTYPES:
7273
promotable_dtypes = [
@@ -79,8 +80,8 @@ def _dtypes_sorter(dtype_pair):
7980
def mutually_promotable_dtypes(
8081
max_size: Optional[int] = 2,
8182
*,
82-
dtypes=dh.all_dtypes,
83-
) -> SearchStrategy[Tuple]:
83+
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
84+
) -> SearchStrategy[Tuple[DataType, ...]]:
8485
if max_size == 2:
8586
return sampled_from(
8687
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
@@ -164,7 +165,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
164165

165166
def mutually_broadcastable_shapes(
166167
num_shapes: int, **kw
167-
) -> SearchStrategy[Tuple[Tuple[int, ...], ...]]:
168+
) -> SearchStrategy[Tuple[Shape, ...]]:
168169
return (
169170
xps.mutually_broadcastable_shapes(num_shapes, **kw)
170171
.map(lambda BS: BS.input_shapes)
@@ -347,8 +348,9 @@ def multiaxis_indices(draw, shapes):
347348

348349

349350
def two_mutual_arrays(
350-
dtypes=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes
351-
):
351+
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
352+
two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
353+
) -> SearchStrategy:
352354
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
353355
mutual_shapes = shared(two_shapes)
354356
arrays1 = xps.arrays(

array_api_tests/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
__all__ = [
44
"DataType",
55
"ScalarType",
6+
"Shape",
67
"Param",
78
]
89

910
DataType = Type[Any]
1011
ScalarType = Union[Type[bool], Type[int], Type[float]]
12+
Shape = Tuple[int, ...]
1113
Param = Tuple

0 commit comments

Comments
 (0)