@@ -319,7 +319,7 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
319
319
return from_dtype
320
320
321
321
322
- def parse_cond (cond_str : str ) -> Tuple [UnaryCheck , str , FromDtypeFunc ]:
322
+ def parse_cond (cond_str : str ) -> Tuple [UnaryCheck , str , BoundFromDtype ]:
323
323
"""
324
324
Parses a Sphinx-formatted condition string to return:
325
325
@@ -348,22 +348,30 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
348
348
124.978
349
349
350
350
"""
351
+ # We first identify whether the condition starts with "not". If so, we note
352
+ # this but parse the condition as if it was not negated.
351
353
if m := r_not .match (cond_str ):
352
354
cond_str = m .group (1 )
353
355
not_cond = True
354
356
else :
355
357
not_cond = False
356
358
359
+ # We parse the condition to identify the condition function, expression
360
+ # template, and xps.from_dtype()-like condition strategy.
357
361
kwargs = {}
358
362
filter_ = None
359
363
from_dtype = None # type: ignore
360
- strat = None
361
364
if m := r_code .match (cond_str ):
362
365
value = parse_value (m .group (1 ))
363
366
cond = make_strict_eq (value )
364
367
expr_template = "{} == " + m .group (1 )
365
- if not not_cond :
366
- strat = st .just (value )
368
+ from_dtype = wrap_strat_as_from_dtype (st .just (value ))
369
+ elif m := r_either_code .match (cond_str ):
370
+ v1 = parse_value (m .group (1 ))
371
+ v2 = parse_value (m .group (2 ))
372
+ cond = make_or (make_strict_eq (v1 ), make_strict_eq (v2 ))
373
+ expr_template = "({} == " + m .group (1 ) + " or {} == " + m .group (2 ) + ")"
374
+ from_dtype = wrap_strat_as_from_dtype (st .sampled_from ([v1 , v2 ]))
367
375
elif m := r_equal_to .match (cond_str ):
368
376
value = parse_value (m .group (1 ))
369
377
if math .isnan (value ):
@@ -374,97 +382,73 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
374
382
value = parse_value (m .group (1 ))
375
383
cond = make_gt (value )
376
384
expr_template = "{} > " + m .group (1 )
377
- if not not_cond :
378
- kwargs = {"min_value" : value , "exclude_min" : True }
385
+ kwargs = {"min_value" : value , "exclude_min" : True }
379
386
elif m := r_lt .match (cond_str ):
380
387
value = parse_value (m .group (1 ))
381
388
cond = make_lt (value )
382
389
expr_template = "{} < " + m .group (1 )
383
- if not not_cond :
384
- kwargs = {"max_value" : value , "exclude_max" : True }
385
- elif m := r_either_code .match (cond_str ):
386
- v1 = parse_value (m .group (1 ))
387
- v2 = parse_value (m .group (2 ))
388
- cond = make_or (make_strict_eq (v1 ), make_strict_eq (v2 ))
389
- expr_template = "({} == " + m .group (1 ) + " or {} == " + m .group (2 ) + ")"
390
- if not not_cond :
391
- strat = st .sampled_from ([v1 , v2 ])
390
+ kwargs = {"max_value" : value , "exclude_max" : True }
392
391
elif cond_str in ["finite" , "a finite number" ]:
393
392
cond = math .isfinite
394
393
expr_template = "isfinite({})"
395
- if not not_cond :
396
- kwargs = {"allow_nan" : False , "allow_infinity" : False }
394
+ kwargs = {"allow_nan" : False , "allow_infinity" : False }
397
395
elif cond_str in "a positive (i.e., greater than ``0``) finite number" :
398
396
cond = lambda i : math .isfinite (i ) and i > 0
399
397
expr_template = "isfinite({}) and {} > 0"
400
- if not not_cond :
401
- kwargs = {
402
- "allow_nan" : False ,
403
- "allow_infinity" : False ,
404
- "min_value" : 0 ,
405
- "exclude_min" : True ,
406
- }
398
+ kwargs = {
399
+ "allow_nan" : False ,
400
+ "allow_infinity" : False ,
401
+ "min_value" : 0 ,
402
+ "exclude_min" : True ,
403
+ }
407
404
elif cond_str == "a negative (i.e., less than ``0``) finite number" :
408
405
cond = lambda i : math .isfinite (i ) and i < 0
409
406
expr_template = "isfinite({}) and {} < 0"
410
- if not not_cond :
411
- kwargs = {
412
- "allow_nan" : False ,
413
- "allow_infinity" : False ,
414
- "max_value" : 0 ,
415
- "exclude_max" : True ,
416
- }
407
+ kwargs = {
408
+ "allow_nan" : False ,
409
+ "allow_infinity" : False ,
410
+ "max_value" : 0 ,
411
+ "exclude_max" : True ,
412
+ }
417
413
elif cond_str == "positive" :
418
414
cond = lambda i : math .copysign (1 , i ) == 1
419
415
expr_template = "copysign(1, {}) == 1"
420
- if not not_cond :
421
- # We assume (positive) zero is special cased seperately
422
- kwargs = {"min_value" : 0 , "exclude_min" : True }
416
+ # We assume (positive) zero is special cased seperately
417
+ kwargs = {"min_value" : 0 , "exclude_min" : True }
423
418
elif cond_str == "negative" :
424
419
cond = lambda i : math .copysign (1 , i ) == - 1
425
420
expr_template = "copysign(1, {}) == -1"
426
- if not not_cond :
427
- # We assume (negative) zero is special cased seperately
428
- kwargs = {"max_value" : 0 , "exclude_max" : True }
421
+ # We assume (negative) zero is special cased seperately
422
+ kwargs = {"max_value" : 0 , "exclude_max" : True }
429
423
elif "nonzero finite" in cond_str :
430
424
cond = lambda i : math .isfinite (i ) and i != 0
431
425
expr_template = "isfinite({}) and {} != 0"
432
- if not not_cond :
433
- kwargs = {"allow_nan" : False , "allow_infinity" : False }
434
- filter_ = lambda n : n != 0
426
+ kwargs = {"allow_nan" : False , "allow_infinity" : False }
427
+ filter_ = lambda n : n != 0
435
428
elif cond_str == "an integer value" :
436
429
cond = lambda i : i .is_integer ()
437
430
expr_template = "{}.is_integer()"
438
- if not not_cond :
439
- from_dtype = integers_from_dtype # type: ignore
431
+ from_dtype = integers_from_dtype # type: ignore
440
432
elif cond_str == "an odd integer value" :
441
433
cond = lambda i : i .is_integer () and i % 2 == 1
442
434
expr_template = "{}.is_integer() and {} % 2 == 1"
443
- if not not_cond :
435
+ from_dtype = integers_from_dtype # type: ignore
444
436
445
- def from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
446
- return integers_from_dtype (dtype , ** kw ).filter (lambda n : n % 2 == 1 )
437
+ def from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
438
+ return integers_from_dtype (dtype , ** kw ).filter (lambda n : n % 2 == 1 )
447
439
448
440
else :
449
441
raise ParseError (cond_str )
450
442
451
- if strat is not None :
452
- if (
453
- not_cond
454
- or len (kwargs ) != 0
455
- or filter_ is not None
456
- or from_dtype is not None
457
- ):
458
- raise ParseError (cond_str )
459
- return cond , expr_template , wrap_strat_as_from_dtype (strat )
460
-
461
443
if not_cond :
462
- expr_template = f"not { expr_template } "
444
+ # We handle negated conitions by simply negating the condition function
445
+ # and using it as a filter for xps.from_dtype() (or an equivalent).
463
446
cond = make_not_cond (cond )
464
- kwargs = {}
447
+ expr_template = f"not { expr_template } "
465
448
filter_ = cond
466
- assert kwargs is not None
467
- return cond , expr_template , BoundFromDtype (kwargs , filter_ , from_dtype )
449
+ return cond , expr_template , BoundFromDtype (filter_ = filter_ )
450
+ else :
451
+ return cond , expr_template , BoundFromDtype (kwargs , filter_ , from_dtype )
468
452
469
453
470
454
def parse_result (result_str : str ) -> Tuple [UnaryCheck , str ]:
@@ -838,6 +822,9 @@ def check_result(i1: float, i2: float, result: float) -> bool:
838
822
839
823
840
824
def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
825
+ """
826
+ Returns a strategy that generates float-casted integers within the bounds of dtype.
827
+ """
841
828
for k in kw .keys ():
842
829
# sanity check
843
830
assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
@@ -1036,16 +1023,12 @@ def cond(i1: float, i2: float) -> bool:
1036
1023
elif len (x1_cond_from_dtypes ) == 1 :
1037
1024
x1_cond_from_dtype = x1_cond_from_dtypes [0 ]
1038
1025
else :
1039
- if not all (isinstance (fd , BoundFromDtype ) for fd in x1_cond_from_dtypes ):
1040
- raise ParseError (case_str )
1041
1026
x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ())
1042
1027
if len (x2_cond_from_dtypes ) == 0 :
1043
1028
x2_cond_from_dtype = xps .from_dtype
1044
1029
elif len (x2_cond_from_dtypes ) == 1 :
1045
1030
x2_cond_from_dtype = x2_cond_from_dtypes [0 ]
1046
1031
else :
1047
- if not all (isinstance (fd , BoundFromDtype ) for fd in x2_cond_from_dtypes ):
1048
- raise ParseError (case_str )
1049
1032
x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ())
1050
1033
1051
1034
return BinaryCase (
0 commit comments