2
2
from operator import mul
3
3
from math import sqrt
4
4
import itertools
5
- from typing import Tuple , Optional
5
+ from typing import Tuple , Optional , List
6
6
7
7
from hypothesis import assume
8
8
from hypothesis .strategies import (lists , integers , sampled_from ,
11
11
12
12
from .pytest_helpers import nargs
13
13
from .array_helpers import ndindex
14
+ from .typing import DataType , Shape
14
15
from . import dtype_helpers as dh
15
16
from ._array_module import (full , float32 , float64 , bool as bool_dtype ,
16
17
_UndefinedStub , eye , broadcast_to )
49
50
_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .float_dtypes ]
50
51
_sorted_dtypes = [d for category in _dtype_categories for d in category ]
51
52
52
- def _dtypes_sorter (dtype_pair ):
53
+ def _dtypes_sorter (dtype_pair : Tuple [ DataType , DataType ] ):
53
54
dtype1 , dtype2 = dtype_pair
54
55
if dtype1 == dtype2 :
55
56
return _sorted_dtypes .index (dtype1 )
@@ -66,7 +67,7 @@ def _dtypes_sorter(dtype_pair):
66
67
key += 1
67
68
return key
68
69
69
- promotable_dtypes = sorted (dh .promotion_table .keys (), key = _dtypes_sorter )
70
+ promotable_dtypes : List [ Tuple [ DataType , DataType ]] = sorted (dh .promotion_table .keys (), key = _dtypes_sorter )
70
71
71
72
if FILTER_UNDEFINED_DTYPES :
72
73
promotable_dtypes = [
@@ -79,8 +80,8 @@ def _dtypes_sorter(dtype_pair):
79
80
def mutually_promotable_dtypes (
80
81
max_size : Optional [int ] = 2 ,
81
82
* ,
82
- dtypes = dh .all_dtypes ,
83
- ) -> SearchStrategy [Tuple ]:
83
+ dtypes : Tuple [ DataType , ...] = dh .all_dtypes ,
84
+ ) -> SearchStrategy [Tuple [ DataType , ...] ]:
84
85
if max_size == 2 :
85
86
return sampled_from (
86
87
[(i , j ) for i , j in promotable_dtypes if i in dtypes and j in dtypes ]
@@ -164,7 +165,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
164
165
165
166
def mutually_broadcastable_shapes (
166
167
num_shapes : int , ** kw
167
- ) -> SearchStrategy [Tuple [Tuple [ int , ...] , ...]]:
168
+ ) -> SearchStrategy [Tuple [Shape , ...]]:
168
169
return (
169
170
xps .mutually_broadcastable_shapes (num_shapes , ** kw )
170
171
.map (lambda BS : BS .input_shapes )
@@ -347,8 +348,9 @@ def multiaxis_indices(draw, shapes):
347
348
348
349
349
350
def two_mutual_arrays (
350
- dtypes = dh .all_dtypes , two_shapes = two_mutually_broadcastable_shapes
351
- ):
351
+ dtypes : Tuple [DataType , ...] = dh .all_dtypes ,
352
+ two_shapes : SearchStrategy [Tuple [Shape , Shape ]] = two_mutually_broadcastable_shapes ,
353
+ ) -> SearchStrategy :
352
354
mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtypes ))
353
355
mutual_shapes = shared (two_shapes )
354
356
arrays1 = xps .arrays (
0 commit comments