Skip to content

Commit 3fefd20

Browse files
committed
Remove need for filtering in invertible_matrices()
1 parent 2d918e4 commit 3fefd20

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,18 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
313313
# For now, just generate stacks of diagonal matrices.
314314
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
315315
stack_shape = draw(stack_shapes)
316-
d = draw(arrays(dtypes, shape=(*stack_shape, 1, n),
317-
elements=dict(allow_nan=False, allow_infinity=False)))
316+
dtype = draw(dtypes)
317+
elements = one_of(
318+
from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False),
319+
from_dtype(dtype, max_value=-0.5, allow_nan=False, allow_infinity=False),
320+
)
321+
d = draw(arrays(dtype, shape=(*stack_shape, 1, n), elements=elements))
322+
318323
# Functions that require invertible matrices may do anything when it is
319324
# singular, including raising an exception, so we make sure the diagonals
320325
# are sufficiently nonzero to avoid any numerical issues.
321-
assume(xp.all(xp.abs(d) > 0.5))
326+
assert xp.all(xp.abs(d) >= 0.5)
327+
322328
diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1))
323329
return xp.where(diag_mask, d, xp.zeros_like(d))
324330

0 commit comments

Comments
 (0)