Skip to content

Add a better mutually_broadcastable_shapes strategy #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from operator import mul
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union

from hypothesis import assume
from hypothesis import assume, note
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
integers, just, lists, none, one_of,
sampled_from, shared)
Expand Down Expand Up @@ -164,16 +164,39 @@ def matrix_shapes(draw, stack_shapes=shapes()):
elements=dict(allow_nan=False,
allow_infinity=False))

def mutually_broadcastable_shapes(
num_shapes: int, **kw
) -> SearchStrategy[Tuple[Shape, ...]]:
return (
xps.mutually_broadcastable_shapes(num_shapes, **kw)
.map(lambda BS: BS.input_shapes)
.filter(lambda shapes: all(
prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes
))
)
@composite
def mutually_broadcastable_shapes(draw, num_shapes, **kwargs):
# mutually_broadcastable_shapes() with the default inputs doesn't generate
# very interesting examples (see
# https://github.com/HypothesisWorks/hypothesis/issues/3170). It's very
# difficult to get it to do so by tweaking the max_* parameters, because
# making them too big leads to generating too large shapes and filtering
# too much. So instead, we trick it into generating more interesting
# examples by telling it to create shapes that broadcast against some base
# shape.
kwargs.setdefault('min_side', 0)
if 'base_shape' not in kwargs:
base_shape = draw(shapes(**kwargs))
kwargs['base_shape'] = base_shape

input_shapes, result_shape = draw(xps.mutually_broadcastable_shapes(num_shapes, **kwargs))

# result_shape is input_shapes broadcasted with base_shape, but base_shape
# itself is not part of input_shapes. We "really" want our base shape to
# be (). We are only using it here to trick mutually_broadcastable_shapes
# into giving more interesting examples.
final_result_shape = broadcast_shapes(*input_shapes)

# The broadcast compatible shapes can be bigger than the base shape. This
# is already somewhat limited by the mutually_broadcastable_shapes
# defaults, and pretty unlikely, but we filter again here just to be safe.
assume(prod(i for i in final_result_shape if i) < SQRT_MAX_ARRAY_SIZE)

# The hypothesis strategy would return this. We don't actually need the
# result shape in most cases (if we end up needing it, we can uncomment
# the below line).
# return BroadcastableShapes(input_shapes, final_result_shape)
return input_shapes

two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2)

Expand Down