@@ -488,14 +488,14 @@ class _CustomDist(Distribution):
488
488
def dist (
489
489
cls ,
490
490
* dist_params ,
491
- class_name : str ,
492
491
logp : Optional [Callable ] = None ,
493
492
logcdf : Optional [Callable ] = None ,
494
493
random : Optional [Callable ] = None ,
495
494
moment : Optional [Callable ] = None ,
496
495
ndim_supp : int = 0 ,
497
496
ndims_params : Optional [Sequence [int ]] = None ,
498
497
dtype : str = "floatX" ,
498
+ class_name : str = "CustomDist" ,
499
499
** kwargs ,
500
500
):
501
501
dist_params = [as_tensor_variable (param ) for param in dist_params ]
@@ -523,36 +523,36 @@ def dist(
523
523
524
524
return super ().dist (
525
525
dist_params ,
526
- class_name = class_name ,
527
526
logp = logp ,
528
527
logcdf = logcdf ,
529
528
random = random ,
530
529
moment = moment ,
531
530
ndim_supp = ndim_supp ,
532
531
ndims_params = ndims_params ,
533
532
dtype = dtype ,
533
+ class_name = class_name ,
534
534
** kwargs ,
535
535
)
536
536
537
537
@classmethod
538
538
def rv_op (
539
539
cls ,
540
540
* dist_params ,
541
- class_name : str ,
542
541
logp : Optional [Callable ],
543
542
logcdf : Optional [Callable ],
544
543
random : Optional [Callable ],
545
544
moment : Optional [Callable ],
546
545
ndim_supp : int ,
547
546
ndims_params : Optional [Sequence [int ]],
548
547
dtype : str ,
548
+ class_name : str ,
549
549
** kwargs ,
550
550
):
551
551
rv_type = type (
552
- f"CustomDistRV_ { class_name } " ,
552
+ class_name ,
553
553
(CustomDistRV ,),
554
554
dict (
555
- name = f"CustomDist_ { class_name } " ,
555
+ name = class_name ,
556
556
inplace = False ,
557
557
ndim_supp = ndim_supp ,
558
558
ndims_params = ndims_params ,
@@ -613,20 +613,15 @@ class _CustomSymbolicDist(Distribution):
613
613
def dist (
614
614
cls ,
615
615
* dist_params ,
616
- class_name : str ,
617
616
dist : Callable ,
618
617
logp : Optional [Callable ] = None ,
619
618
logcdf : Optional [Callable ] = None ,
620
619
moment : Optional [Callable ] = None ,
621
620
ndim_supp : int = 0 ,
622
621
dtype : str = "floatX" ,
622
+ class_name : str = "CustomSymbolicDist" ,
623
623
** kwargs ,
624
624
):
625
- warnings .warn (
626
- "CustomDist with dist function is still experimental. Expect bugs!" ,
627
- UserWarning ,
628
- )
629
-
630
625
dist_params = [as_tensor_variable (param ) for param in dist_params ]
631
626
632
627
if logcdf is None :
@@ -655,13 +650,13 @@ def dist(
655
650
def rv_op (
656
651
cls ,
657
652
* dist_params ,
658
- class_name : str ,
659
653
dist : Callable ,
660
654
logp : Optional [Callable ],
661
655
logcdf : Optional [Callable ],
662
656
moment : Optional [Callable ],
663
657
size = None ,
664
658
ndim_supp : int ,
659
+ class_name : str ,
665
660
):
666
661
size = normalize_size_param (size )
667
662
dummy_size_param = size .type ()
@@ -674,7 +669,7 @@ def rv_op(
674
669
dummy_updates_dict = collect_default_updates (dummy_params , (dummy_rv ,))
675
670
676
671
rv_type = type (
677
- f"CustomSymbolicDistRV_ { class_name } " ,
672
+ class_name ,
678
673
(CustomSymbolicDistRV ,),
679
674
# If logp is not provided, we try to infer it from the dist graph
680
675
dict (
@@ -758,15 +753,6 @@ class CustomDist:
758
753
dist_params : Tuple
759
754
A sequence of the distribution's parameter. These will be converted into
760
755
Pytensor tensor variables internally.
761
- class_name : str
762
- Name for the class which will wrap the CustomDist methods. When not specified,
763
- it will be given the name of the model variable.
764
-
765
- .. warning:: New CustomDists created with the same class_name will override the
766
- methods dispatched onto the previous classes. If using CustomDists with
767
- different methods across separate models, be sure to use distinct
768
- class_names.
769
-
770
756
dist: Optional[Callable]
771
757
A callable that returns a PyTensor graph built from simpler PyMC distributions
772
758
which represents the distribution. This can be used by PyMC to take random draws
@@ -831,6 +817,9 @@ class CustomDist:
831
817
The dtype of the distribution. All draws and observations passed into the
832
818
distribution will be cast onto this dtype. This is not needed if an PyTensor
833
819
dist function is provided, which should already return the right dtype!
820
+ class_name : str
821
+ Name for the class which will wrap the CustomDist methods. When not specified,
822
+ it will be given the name of the model variable.
834
823
kwargs :
835
824
Extra keyword arguments are passed to the parent's class ``__new__`` method.
836
825
@@ -979,36 +968,36 @@ def __new__(
979
968
dist_params = cls .parse_dist_params (dist_params )
980
969
cls .check_valid_dist_random (dist , random , dist_params )
981
970
if dist is not None :
971
+ kwargs .setdefault ("class_name" , f"CustomSymbolicDist_{ name } " )
982
972
return _CustomSymbolicDist (
983
973
name ,
984
974
* dist_params ,
985
- class_name = name ,
986
975
dist = dist ,
987
976
logp = logp ,
988
977
logcdf = logcdf ,
989
978
moment = moment ,
990
979
ndim_supp = ndim_supp ,
991
980
** kwargs ,
992
981
)
993
- return _CustomDist (
994
- name ,
995
- * dist_params ,
996
- class_name = name ,
997
- random = random ,
998
- logp = logp ,
999
- logcdf = logcdf ,
1000
- moment = moment ,
1001
- ndim_supp = ndim_supp ,
1002
- ndims_params = ndims_params ,
1003
- dtype = dtype ,
1004
- ** kwargs ,
1005
- )
982
+ else :
983
+ kwargs .setdefault ("class_name" , f"CustomDist_{ name } " )
984
+ return _CustomDist (
985
+ name ,
986
+ * dist_params ,
987
+ random = random ,
988
+ logp = logp ,
989
+ logcdf = logcdf ,
990
+ moment = moment ,
991
+ ndim_supp = ndim_supp ,
992
+ ndims_params = ndims_params ,
993
+ dtype = dtype ,
994
+ ** kwargs ,
995
+ )
1006
996
1007
997
@classmethod
1008
998
def dist (
1009
999
cls ,
1010
1000
* dist_params ,
1011
- class_name : str ,
1012
1001
dist : Optional [Callable ] = None ,
1013
1002
random : Optional [Callable ] = None ,
1014
1003
logp : Optional [Callable ] = None ,
@@ -1024,7 +1013,6 @@ def dist(
1024
1013
if dist is not None :
1025
1014
return _CustomSymbolicDist .dist (
1026
1015
* dist_params ,
1027
- class_name = class_name ,
1028
1016
dist = dist ,
1029
1017
logp = logp ,
1030
1018
logcdf = logcdf ,
@@ -1035,7 +1023,6 @@ def dist(
1035
1023
else :
1036
1024
return _CustomDist .dist (
1037
1025
* dist_params ,
1038
- class_name = class_name ,
1039
1026
random = random ,
1040
1027
logp = logp ,
1041
1028
logcdf = logcdf ,
0 commit comments