@@ -313,12 +313,18 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
313
313
# For now, just generate stacks of diagonal matrices.
314
314
n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ),)
315
315
stack_shape = draw (stack_shapes )
316
- d = draw (arrays (dtypes , shape = (* stack_shape , 1 , n ),
317
- elements = dict (allow_nan = False , allow_infinity = False )))
316
+ dtype = draw (dtypes )
317
+ elements = one_of (
318
+ from_dtype (dtype , min_value = 0.5 , allow_nan = False , allow_infinity = False ),
319
+ from_dtype (dtype , max_value = - 0.5 , allow_nan = False , allow_infinity = False ),
320
+ )
321
+ d = draw (arrays (dtype , shape = (* stack_shape , 1 , n ), elements = elements ))
322
+
318
323
# Functions that require invertible matrices may do anything when it is
319
324
# singular, including raising an exception, so we make sure the diagonals
320
325
# are sufficiently nonzero to avoid any numerical issues.
321
- assume (xp .all (xp .abs (d ) > 0.5 ))
326
+ assert xp .all (xp .abs (d ) >= 0.5 )
327
+
322
328
diag_mask = xp .arange (n ) == xp .reshape (xp .arange (n ), (n , 1 ))
323
329
return xp .where (diag_mask , d , xp .zeros_like (d ))
324
330
0 commit comments