diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 4cbb1ab3..a0adc8c9 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -165,10 +165,27 @@ def matrix_shapes(draw, stack_shapes=shapes()): allow_infinity=False)) def mutually_broadcastable_shapes( - num_shapes: int, **kw + num_shapes: int, + *, + base_shape: Shape = (), + min_dims: int = 0, + max_dims: Optional[int] = None, + min_side: int = 0, + max_side: Optional[int] = None, ) -> SearchStrategy[Tuple[Shape, ...]]: + if max_dims is None: + max_dims = min(max(len(base_shape), min_dims) + 5, 32) + if max_side is None: + max_side = max(base_shape[-max_dims:] + (min_side,)) + 5 return ( - xps.mutually_broadcastable_shapes(num_shapes, **kw) + xps.mutually_broadcastable_shapes( + num_shapes, + base_shape=base_shape, + min_dims=min_dims, + max_dims=max_dims, + min_side=min_side, + max_side=max_side, + ) .map(lambda BS: BS.input_shapes) .filter(lambda shapes: all( prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes