Skip to content

Commit e2eb26d

Browse files
committed
Do not require class_name for CustomDist and Simulator dists
* Also remove the experimental warning when using CustomSymbolicDists
1 parent aae97a2 commit e2eb26d

File tree

4 files changed

+74
-95
lines changed

4 files changed

+74
-95
lines changed

pymc/distributions/distribution.py

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -488,14 +488,14 @@ class _CustomDist(Distribution):
488488
def dist(
489489
cls,
490490
*dist_params,
491-
class_name: str,
492491
logp: Optional[Callable] = None,
493492
logcdf: Optional[Callable] = None,
494493
random: Optional[Callable] = None,
495494
moment: Optional[Callable] = None,
496495
ndim_supp: int = 0,
497496
ndims_params: Optional[Sequence[int]] = None,
498497
dtype: str = "floatX",
498+
class_name: str = "CustomDist",
499499
**kwargs,
500500
):
501501
dist_params = [as_tensor_variable(param) for param in dist_params]
@@ -523,36 +523,36 @@ def dist(
523523

524524
return super().dist(
525525
dist_params,
526-
class_name=class_name,
527526
logp=logp,
528527
logcdf=logcdf,
529528
random=random,
530529
moment=moment,
531530
ndim_supp=ndim_supp,
532531
ndims_params=ndims_params,
533532
dtype=dtype,
533+
class_name=class_name,
534534
**kwargs,
535535
)
536536

537537
@classmethod
538538
def rv_op(
539539
cls,
540540
*dist_params,
541-
class_name: str,
542541
logp: Optional[Callable],
543542
logcdf: Optional[Callable],
544543
random: Optional[Callable],
545544
moment: Optional[Callable],
546545
ndim_supp: int,
547546
ndims_params: Optional[Sequence[int]],
548547
dtype: str,
548+
class_name: str,
549549
**kwargs,
550550
):
551551
rv_type = type(
552-
f"CustomDistRV_{class_name}",
552+
class_name,
553553
(CustomDistRV,),
554554
dict(
555-
name=f"CustomDist_{class_name}",
555+
name=class_name,
556556
inplace=False,
557557
ndim_supp=ndim_supp,
558558
ndims_params=ndims_params,
@@ -613,20 +613,15 @@ class _CustomSymbolicDist(Distribution):
613613
def dist(
614614
cls,
615615
*dist_params,
616-
class_name: str,
617616
dist: Callable,
618617
logp: Optional[Callable] = None,
619618
logcdf: Optional[Callable] = None,
620619
moment: Optional[Callable] = None,
621620
ndim_supp: int = 0,
622621
dtype: str = "floatX",
622+
class_name: str = "CustomSymbolicDist",
623623
**kwargs,
624624
):
625-
warnings.warn(
626-
"CustomDist with dist function is still experimental. Expect bugs!",
627-
UserWarning,
628-
)
629-
630625
dist_params = [as_tensor_variable(param) for param in dist_params]
631626

632627
if logcdf is None:
@@ -655,13 +650,13 @@ def dist(
655650
def rv_op(
656651
cls,
657652
*dist_params,
658-
class_name: str,
659653
dist: Callable,
660654
logp: Optional[Callable],
661655
logcdf: Optional[Callable],
662656
moment: Optional[Callable],
663657
size=None,
664658
ndim_supp: int,
659+
class_name: str,
665660
):
666661
size = normalize_size_param(size)
667662
dummy_size_param = size.type()
@@ -674,7 +669,7 @@ def rv_op(
674669
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
675670

676671
rv_type = type(
677-
f"CustomSymbolicDistRV_{class_name}",
672+
class_name,
678673
(CustomSymbolicDistRV,),
679674
# If logp is not provided, we try to infer it from the dist graph
680675
dict(
@@ -758,15 +753,6 @@ class CustomDist:
758753
dist_params : Tuple
759754
A sequence of the distribution's parameter. These will be converted into
760755
Pytensor tensor variables internally.
761-
class_name : str
762-
Name for the class which will wrap the CustomDist methods. When not specified,
763-
it will be given the name of the model variable.
764-
765-
.. warning:: New CustomDists created with the same class_name will override the
766-
methods dispatched onto the previous classes. If using CustomDists with
767-
different methods across separate models, be sure to use distinct
768-
class_names.
769-
770756
dist: Optional[Callable]
771757
A callable that returns a PyTensor graph built from simpler PyMC distributions
772758
which represents the distribution. This can be used by PyMC to take random draws
@@ -831,6 +817,9 @@ class CustomDist:
831817
The dtype of the distribution. All draws and observations passed into the
832818
distribution will be cast onto this dtype. This is not needed if an PyTensor
833819
dist function is provided, which should already return the right dtype!
820+
class_name : str
821+
Name for the class which will wrap the CustomDist methods. When not specified,
822+
it will be given the name of the model variable.
834823
kwargs :
835824
Extra keyword arguments are passed to the parent's class ``__new__`` method.
836825
@@ -979,36 +968,36 @@ def __new__(
979968
dist_params = cls.parse_dist_params(dist_params)
980969
cls.check_valid_dist_random(dist, random, dist_params)
981970
if dist is not None:
971+
kwargs.setdefault("class_name", f"CustomSymbolicDist_{name}")
982972
return _CustomSymbolicDist(
983973
name,
984974
*dist_params,
985-
class_name=name,
986975
dist=dist,
987976
logp=logp,
988977
logcdf=logcdf,
989978
moment=moment,
990979
ndim_supp=ndim_supp,
991980
**kwargs,
992981
)
993-
return _CustomDist(
994-
name,
995-
*dist_params,
996-
class_name=name,
997-
random=random,
998-
logp=logp,
999-
logcdf=logcdf,
1000-
moment=moment,
1001-
ndim_supp=ndim_supp,
1002-
ndims_params=ndims_params,
1003-
dtype=dtype,
1004-
**kwargs,
1005-
)
982+
else:
983+
kwargs.setdefault("class_name", f"CustomDist_{name}")
984+
return _CustomDist(
985+
name,
986+
*dist_params,
987+
random=random,
988+
logp=logp,
989+
logcdf=logcdf,
990+
moment=moment,
991+
ndim_supp=ndim_supp,
992+
ndims_params=ndims_params,
993+
dtype=dtype,
994+
**kwargs,
995+
)
1006996

1007997
@classmethod
1008998
def dist(
1009999
cls,
10101000
*dist_params,
1011-
class_name: str,
10121001
dist: Optional[Callable] = None,
10131002
random: Optional[Callable] = None,
10141003
logp: Optional[Callable] = None,
@@ -1024,7 +1013,6 @@ def dist(
10241013
if dist is not None:
10251014
return _CustomSymbolicDist.dist(
10261015
*dist_params,
1027-
class_name=class_name,
10281016
dist=dist,
10291017
logp=logp,
10301018
logcdf=logcdf,
@@ -1035,7 +1023,6 @@ def dist(
10351023
else:
10361024
return _CustomDist.dist(
10371025
*dist_params,
1038-
class_name=class_name,
10391026
random=random,
10401027
logp=logp,
10411028
logcdf=logcdf,

pymc/distributions/simulator.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,6 @@ class Simulator(Distribution):
8686
Keyword form of ''unnamed_params''.
8787
One of unnamed_params or params must be provided.
8888
If passed both unnamed_params and params, an error is raised.
89-
class_name : str
90-
Name for the RandomVariable class which will wrap the Simulator methods.
91-
When not specified, it will be given the name of the variable.
92-
93-
.. warning:: New Simulators created with the same class_name will override the
94-
methods dispatched onto the previous classes. If using Simulators with
95-
different methods across separate models, be sure to use distinct
96-
class_names.
97-
9889
distance : PyTensor_Op, callable or str, default "gaussian"
9990
Distance function. Available options are ``"gaussian"``, ``"laplace"``,
10091
``"kullback_leibler"`` or a user defined function (or PyTensor_Op) that takes
@@ -123,6 +114,8 @@ class Simulator(Distribution):
123114
Number of minimum dimensions of each parameter of the RV. For example,
124115
if the Simulator accepts two scalar inputs, it should be ``[0, 0]``.
125116
Default to list of 0 with length equal to the number of parameters.
117+
class_name : str, optional
118+
Suffix name for the RandomVariable class which will wrap the Simulator methods.
126119
127120
Examples
128121
--------
@@ -149,7 +142,7 @@ def simulator_fn(rng, loc, scale, size):
149142
rv_type = SimulatorRV
150143

151144
def __new__(cls, name, *args, **kwargs):
152-
kwargs.setdefault("class_name", name)
145+
kwargs.setdefault("class_name", f"Simulator_{name}")
153146
return super().__new__(cls, name, *args, **kwargs)
154147

155148
@classmethod
@@ -158,13 +151,13 @@ def dist( # type: ignore
158151
fn,
159152
*unnamed_params,
160153
params=None,
161-
class_name: str,
162154
distance="gaussian",
163155
sum_stat="identity",
164156
epsilon=1,
165157
ndim_supp=0,
166158
ndims_params=None,
167159
dtype="floatX",
160+
class_name: str = "Simulator",
168161
**kwargs,
169162
):
170163
if not isinstance(distance, Op):
@@ -213,36 +206,36 @@ def dist( # type: ignore
213206

214207
return super().dist(
215208
params,
216-
class_name=class_name,
217209
fn=fn,
218210
ndim_supp=ndim_supp,
219211
ndims_params=ndims_params,
220212
dtype=dtype,
221213
distance=distance,
222214
sum_stat=sum_stat,
223215
epsilon=epsilon,
216+
class_name=class_name,
224217
**kwargs,
225218
)
226219

227220
@classmethod
228221
def rv_op(
229222
cls,
230223
*params,
231-
class_name,
232224
fn,
233225
ndim_supp,
234226
ndims_params,
235227
dtype,
236228
distance,
237229
sum_stat,
238230
epsilon,
231+
class_name,
239232
**kwargs,
240233
):
241234
sim_op = type(
242-
f"Simulator_{class_name}",
235+
class_name,
243236
(SimulatorRV,),
244237
dict(
245-
name=f"Simulator_{class_name}",
238+
name=class_name,
246239
ndim_supp=ndim_supp,
247240
ndims_params=ndims_params,
248241
dtype=dtype,

0 commit comments

Comments
 (0)