Skip to content

Commit 00f63bb

Browse files
authored
Merge pull request #75 from honno/mbs
Change `hh.mutually_broadcastable_shapes()` defaults to generate more interesting shapes
2 parents b1e3f2e + b0f2bca commit 00f63bb

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,27 @@ def matrix_shapes(draw, stack_shapes=shapes()):
165165
allow_infinity=False))
166166

167167
def mutually_broadcastable_shapes(
168-
num_shapes: int, **kw
168+
num_shapes: int,
169+
*,
170+
base_shape: Shape = (),
171+
min_dims: int = 0,
172+
max_dims: Optional[int] = None,
173+
min_side: int = 0,
174+
max_side: Optional[int] = None,
169175
) -> SearchStrategy[Tuple[Shape, ...]]:
176+
if max_dims is None:
177+
max_dims = min(max(len(base_shape), min_dims) + 5, 32)
178+
if max_side is None:
179+
max_side = max(base_shape[-max_dims:] + (min_side,)) + 5
170180
return (
171-
xps.mutually_broadcastable_shapes(num_shapes, **kw)
181+
xps.mutually_broadcastable_shapes(
182+
num_shapes,
183+
base_shape=base_shape,
184+
min_dims=min_dims,
185+
max_dims=max_dims,
186+
min_side=min_side,
187+
max_side=max_side,
188+
)
172189
.map(lambda BS: BS.input_shapes)
173190
.filter(lambda shapes: all(
174191
prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes

0 commit comments

Comments
 (0)