@@ -526,6 +526,7 @@ def check_icdf(
526
526
pymc_dist : Distribution ,
527
527
paramdomains : Dict [str , Domain ],
528
528
scipy_icdf : Callable ,
529
+ skip_paramdomain_outside_edge_test = False ,
529
530
decimal : Optional [int ] = None ,
530
531
n_samples : int = 100 ,
531
532
) -> None :
@@ -548,7 +549,7 @@ def check_icdf(
548
549
paramdomains : Dictionary of Parameter : Domain pairs
549
550
Supported domains of distribution parameters
550
551
scipy_icdf : Scipy icdf method
551
- Scipy icdf (ppp ) method of equivalent pymc_dist distribution
552
+ Scipy icdf (ppf ) method of equivalent pymc_dist distribution
552
553
decimal : int, optional
553
554
Level of precision with which pymc_dist and scipy_icdf are compared.
554
555
Defaults to 6 for float64 and 3 for float32
@@ -557,6 +558,9 @@ def check_icdf(
557
558
are compared between pymc and scipy methods. If n_samples is below the
558
559
total number of combinations, a random subset is evaluated. Setting
559
560
n_samples = -1, will return all possible combinations. Defaults to 100
561
+ skip_paradomain_outside_edge_test : Bool
562
+ Whether to run test 2., which checks that pymc distribution icdf
563
+ returns nan for invalid parameter values outside the supported domain edge
560
564
561
565
"""
562
566
if decimal is None :
@@ -586,19 +590,20 @@ def check_icdf(
586
590
valid_params = {param : paramdomain .vals [0 ] for param , paramdomain in paramdomains .items ()}
587
591
valid_params ["q" ] = valid_value
588
592
589
- # Test pymc distribution raises ParameterValueError for parameters outside the
590
- # supported domain edges (excluding edges)
591
- invalid_params = find_invalid_scalar_params (paramdomains )
592
- for invalid_param , invalid_edges in invalid_params .items ():
593
- for invalid_edge in invalid_edges :
594
- if invalid_edge is None :
595
- continue
593
+ if not skip_paramdomain_outside_edge_test :
594
+ # Test pymc distribution raises ParameterValueError for parameters outside the
595
+ # supported domain edges (excluding edges)
596
+ invalid_params = find_invalid_scalar_params (paramdomains )
597
+ for invalid_param , invalid_edges in invalid_params .items ():
598
+ for invalid_edge in invalid_edges :
599
+ if invalid_edge is None :
600
+ continue
596
601
597
- point = valid_params .copy ()
598
- point [invalid_param ] = invalid_edge
599
- with pytest .raises (ParameterValueError ):
600
- pymc_icdf (** point )
601
- pytest .fail (f"test_params={ point } " )
602
+ point = valid_params .copy ()
603
+ point [invalid_param ] = invalid_edge
604
+ with pytest .raises (ParameterValueError ):
605
+ pymc_icdf (** point )
606
+ pytest .fail (f"test_params={ point } " )
602
607
603
608
# Test that values below 0 or above 1 evaluate to nan
604
609
invalid_values = find_invalid_scalar_params ({"q" : domain })["q" ]
0 commit comments