diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 4cbb1ab3..7045586c 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -4,7 +4,7 @@ from operator import mul from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union -from hypothesis import assume +from hypothesis import assume, note from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, just, lists, none, one_of, sampled_from, shared) @@ -164,16 +164,39 @@ def matrix_shapes(draw, stack_shapes=shapes()): elements=dict(allow_nan=False, allow_infinity=False)) -def mutually_broadcastable_shapes( - num_shapes: int, **kw -) -> SearchStrategy[Tuple[Shape, ...]]: - return ( - xps.mutually_broadcastable_shapes(num_shapes, **kw) - .map(lambda BS: BS.input_shapes) - .filter(lambda shapes: all( - prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes - )) - ) +@composite +def mutually_broadcastable_shapes(draw, num_shapes, **kwargs): + # mutually_broadcastable_shapes() with the default inputs doesn't generate + # very interesting examples (see + # https://github.com/HypothesisWorks/hypothesis/issues/3170). It's very + # difficult to get it to do so by tweaking the max_* parameters, because + # making them too big leads to generating too large shapes and filtering + # too much. So instead, we trick it into generating more interesting + # examples by telling it to create shapes that broadcast against some base + # shape. + kwargs.setdefault('min_side', 0) + if 'base_shape' not in kwargs: + base_shape = draw(shapes(**kwargs)) + kwargs['base_shape'] = base_shape + + input_shapes, result_shape = draw(xps.mutually_broadcastable_shapes(num_shapes, **kwargs)) + + # result_shape is input_shapes broadcasted with base_shape, but base_shape + # itself is not part of input_shapes. We "really" want our base shape to + # be (). We are only using it here to trick mutually_broadcastable_shapes + # into giving more interesting examples. + final_result_shape = broadcast_shapes(*input_shapes) + + # The broadcast compatible shapes can be bigger than the base shape. This + # is already somewhat limited by the mutually_broadcastable_shapes + # defaults, and pretty unlikely, but we filter again here just to be safe. + assume(prod(i for i in final_result_shape if i) < SQRT_MAX_ARRAY_SIZE) + + # The hypothesis strategy would return this. We don't actually need the + # result shape in most cases (if we end up needing it, we can uncomment + # the below line). + # return BroadcastableShapes(input_shapes, final_result_shape) + return input_shapes two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2)