Skip to content

Commit c5f9dd4

Browse files
snexusColCarroll
authored andcommitted
First implementation of TruncatedNormal distribution (#3052)
* First implementation of TruncatedNormal distribution * Add proper mean calculation Add test_truncated_normal into test_distributions * Add dist_shape to random(..) to generate samples with proper shape. Add test cases to test_distributions_random * Update RELEASE_NOTES.md
1 parent 7b5cd32 commit c5f9dd4

File tree

5 files changed

+241
-2
lines changed

5 files changed

+241
-2
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Improve error message `NaN occurred in optimization.` during ADVI
1414
- Save and load traces without `pickle` using `pm.save_trace` and `pm.load_trace`
1515
- Add `Kumaraswamy` distribution
16+
- Add 'TruncatedNormal' distribution
1617
- Rewrite parallel sampling of multiple chains on py3. This resolves
1718
long standing issues when transferring large traces to the main process,
1819
avoids pickling issues on UNIX, and allows us to show a progress bar

pymc3/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .continuous import Uniform
55
from .continuous import Flat
66
from .continuous import HalfFlat
7+
from .continuous import TruncatedNormal
78
from .continuous import Normal
89
from .continuous import Beta
910
from .continuous import Kumaraswamy
@@ -88,6 +89,7 @@
8889
__all__ = ['Uniform',
8990
'Flat',
9091
'HalfFlat',
92+
'TruncatedNormal',
9193
'Normal',
9294
'Beta',
9395
'Kumaraswamy',

pymc3/distributions/continuous.py

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, alltrue_elemwise, SplineWrapper, i0e
2525
from .distribution import Continuous, draw_values, generate_samples
2626

27-
__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'Beta', 'Kumaraswamy', 'Exponential',
27+
__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'TruncatedNormal', 'Beta', 'Kumaraswamy', 'Exponential',
2828
'Laplace', 'StudentT', 'Cauchy', 'HalfCauchy', 'Gamma', 'Weibull',
2929
'HalfStudentT', 'Lognormal', 'ChiSquared', 'HalfNormal', 'Wald',
3030
'Pareto', 'InverseGamma', 'ExGaussian', 'VonMises', 'SkewNormal',
@@ -434,6 +434,223 @@ def _repr_latex_(self, name=None, dist=None):
434434
get_variable_name(mu),
435435
get_variable_name(sd))
436436

437+
class TruncatedNormal(Continuous):
438+
R"""
439+
Univariate truncated normal log-likelihood.
440+
441+
The pdf of this distribution is
442+
443+
.. math::
444+
445+
f(x;\mu ,\sigma ,a,b)={\frac {\phi ({\frac {x-\mu }{\sigma }})}{
446+
\sigma \left(\Phi ({\frac {b-\mu }{\sigma }})-\Phi ({\frac {a-\mu }{\sigma }})\right)}}
447+
448+
Truncated normal distribution can be parameterized either in terms of precision
449+
or standard deviation. The link between the two parametrizations is
450+
given by
451+
452+
.. math::
453+
454+
\tau = \dfrac{1}{\sigma^2}
455+
456+
457+
.. plot::
458+
459+
import matplotlib.pyplot as plt
460+
import numpy as np
461+
import scipy.stats as st
462+
plt.style.use('seaborn-darkgrid')
463+
x = np.linspace(-10, 10, 1000)
464+
mus = [0., 0., 0.]
465+
sds = [3.,5.,7.]
466+
a1 = [-3, -5, -5]
467+
b1 = [7, 5, 4]
468+
for mu, sd, a, b in zip(mus, sds,a1,b1):
469+
print mu, sd, a, b
470+
an, bn = (a - mu) / sd, (b - mu) / sd
471+
pdf = st.truncnorm.pdf(x, an,bn, loc=mu, scale=sd)
472+
plt.plot(x, pdf, label=r'$\mu$ = {}, $\sigma$ = {}, a={}, b={}'.format(mu, sd, a, b))
473+
plt.xlabel('x', fontsize=12)
474+
plt.ylabel('f(x)', fontsize=12)
475+
plt.legend(loc=1)
476+
plt.show()
477+
478+
======== ==========================================
479+
Support :math:`x \in [a, b]`
480+
Mean :math:`\mu +{\frac {\phi (\alpha )-\phi (\beta )}{Z}}\sigma`
481+
Variance :math:`\sigma ^{2}\left[1+{\frac {\alpha \phi (\alpha )-\beta \phi (\beta )}{Z}}-
482+
\left({\frac {\phi (\alpha )-\phi (\beta )}{Z}}\right)^{2}\right]`
483+
======== ==========================================
484+
485+
Parameters
486+
----------
487+
mu : float
488+
Mean.
489+
sd : float
490+
Standard deviation (sd > 0).
491+
a : float (optional)
492+
Left bound.
493+
b : float (optional)
494+
Right bound.
495+
496+
Examples
497+
--------
498+
.. code-block:: python
499+
500+
with pm.Model():
501+
x = pm.TruncatedNormal('x', mu=0, sd=10, a=0)
502+
503+
with pm.Model():
504+
x = pm.TruncatedNormal('x', mu=0, sd=10, b=1)
505+
506+
with pm.Model():
507+
x = pm.TruncatedNormal('x', mu=0, sd=10, a=0, b=1)
508+
509+
"""
510+
511+
def __init__(self, mu=0, sd=None, tau=None, a=None, b=None, **kwargs):
512+
tau, sd = get_tau_sd(tau=tau, sd=sd)
513+
self.sd = tt.as_tensor_variable(sd)
514+
self.tau = tt.as_tensor_variable(tau)
515+
self.a = tt.as_tensor_variable(a) if a is not None else a
516+
self.b = tt.as_tensor_variable(b) if b is not None else b
517+
self.mu = tt.as_tensor_variable(mu)
518+
519+
# Calculate mean
520+
pdf_a, pdf_b, cdf_a, cdf_b = self._get_boundary_parameters()
521+
z = cdf_b - cdf_a
522+
self.mean = self.mu + (pdf_a+pdf_b) / z * self.sd
523+
524+
assert_negative_support(sd, 'sd', 'TruncatedNormal')
525+
assert_negative_support(tau, 'tau', 'TruncatedNormal')
526+
527+
super(TruncatedNormal, self).__init__(**kwargs)
528+
529+
def random(self, point=None, size=None):
530+
"""
531+
Draw random values from TruncatedNormal distribution.
532+
533+
Parameters
534+
----------
535+
point : dict, optional
536+
Dict of variable values on which random values are to be
537+
conditioned (uses default point if not specified).
538+
size : int, optional
539+
Desired size of random sample (returns one sample if not
540+
specified).
541+
542+
Returns
543+
-------
544+
array
545+
"""
546+
mu_v, std_v, a_v, b_v = draw_values([self.mu, self.sd, self.a, self.b], point=point, size=size)
547+
return generate_samples(stats.truncnorm.rvs,
548+
a=(a_v - mu_v)/std_v,
549+
b=(b_v - mu_v) / std_v,
550+
loc=mu_v,
551+
scale=std_v,
552+
dist_shape=self.shape,
553+
size=size,
554+
)
555+
556+
def logp(self, value):
557+
"""
558+
Calculate log-probability of TruncatedNormal distribution at specified value.
559+
560+
Parameters
561+
----------
562+
value : numeric
563+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
564+
values are desired the values must be provided in a numpy array or theano tensor
565+
566+
Returns
567+
-------
568+
TensorVariable
569+
"""
570+
sd = self.sd
571+
tau = self.tau
572+
mu = self.mu
573+
a = self.a
574+
b = self.b
575+
576+
# In case either a or b are not specified, normalization terms simplify to 1.0 and 0.0
577+
# https://en.wikipedia.org/wiki/Truncated_normal_distribution
578+
norm_left, norm_right = 1.0, 0.0
579+
580+
# Define normalization
581+
if b is not None:
582+
norm_left = self._cdf((b - mu) / sd)
583+
584+
if a is not None:
585+
norm_right = self._cdf((a - mu) / sd)
586+
587+
f = self._pdf((value - mu) / sd) / sd / ((norm_left - norm_right))
588+
589+
return bound(tt.log(f), value >= a, value <= b, sd > 0)
590+
591+
592+
def _cdf(self, value):
593+
"""
594+
Calculate cdf of standard normal distribution
595+
596+
Parameters
597+
----------
598+
value : numeric
599+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
600+
values are desired the values must be provided in a numpy array or theano tensor
601+
602+
Returns
603+
-------
604+
TensorVariable
605+
"""
606+
return 0.5 * (1.0 + tt.erf(value / tt.sqrt(2)))
607+
608+
def _pdf(self, value):
609+
"""
610+
Calculate pdf of standard normal distribution
611+
612+
Parameters
613+
----------
614+
value : numeric
615+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
616+
values are desired the values must be provided in a numpy array or theano tensor
617+
618+
Returns
619+
-------
620+
TensorVariable
621+
"""
622+
return 1.0 / tt.sqrt(2 * np.pi) * tt.exp(-0.5 * (value ** 2))
623+
624+
def _repr_latex_(self, name=None, dist=None):
625+
if dist is None:
626+
dist = self
627+
sd = dist.sd
628+
mu = dist.mu
629+
a = dist.a
630+
b = dist.b
631+
name = r'\text{%s}' % name
632+
return r'${} \sim \text{{TruncatedNormal}}(\mathit{{mu}}={},~\mathit{{sd}}={},a={},b={})$'.format(name,
633+
get_variable_name(mu),
634+
get_variable_name(sd),
635+
get_variable_name(a),
636+
get_variable_name(b))
637+
638+
def _get_boundary_parameters(self):
639+
"""
640+
Calcualte values of cdf and pdf at boundary points a and b
641+
642+
Returns
643+
-------
644+
pdf(a), pdf(b), cdf(a), cdf(b) if a,b defined, otherwise 0.0, 0.0, 0.0, 1.0
645+
"""
646+
# pdf = 0 at +-inf
647+
pdf_a = self._pdf(self.a) if not self.a is None else 0.0
648+
pdf_b = self._pdf(self.b) if not self.b is None else 0.0
649+
650+
# b-> inf, cdf(b) = 1.0, a->-inf, cdf(a) = 0
651+
cdf_a = self._cdf(self.a) if not self.a is None else 0.0
652+
cdf_b = self._cdf(self.b) if not self.b is None else 1.0
653+
return pdf_a, pdf_b, cdf_a, cdf_b
437654

438655
class HalfNormal(PositiveContinuous):
439656
R"""

pymc3/tests/test_distributions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ZeroInflatedNegativeBinomial, Constant, Poisson, Bernoulli, Beta,
1313
BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto,
1414
InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace,
15-
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal,
15+
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal, TruncatedNormal,
1616
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
1717
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull,
1818
Gumbel, Logistic, OrderedLogistic, LogitNormal, Interpolated,
@@ -534,6 +534,17 @@ def test_normal(self):
534534
decimal=select_by_precision(float64=6, float32=1)
535535
)
536536

537+
def test_truncated_normal(self):
538+
# Rplusbig domain is specified for eveything, to avoid silly cases such as
539+
# {'mu': array(-2.1), 'a': array(-100.), 'b': array(0.01), 'value': array(0.), 'sd': array(0.01)}
540+
# TruncatedNormal: pdf = 0.0, logpdf = -inf
541+
# Scipy's answer: pdf = 0.0, logpdf = -22048.413!!!
542+
self.pymc3_matches_scipy(TruncatedNormal, R, {'mu': R, 'sd': Rplusbig, 'a':-Rplusbig, 'b':Rplusbig},
543+
lambda value, mu, sd, a, b: sp.truncnorm.logpdf(value, (a-mu)/sd, (b-mu)/sd,
544+
loc=mu, scale=sd),
545+
decimal=select_by_precision(float64=6, float32=1)
546+
)
547+
537548
def test_half_normal(self):
538549
self.pymc3_matches_scipy(HalfNormal, Rplus, {'sd': Rplus},
539550
lambda value, sd: sp.halfnorm.logpdf(value, scale=sd),

pymc3/tests/test_distributions_random.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ class TestNormal(BaseTestCases.BaseTestCase):
217217
distribution = pm.Normal
218218
params = {'mu': 0., 'tau': 1.}
219219

220+
class TestTruncatedNormal(BaseTestCases.BaseTestCase):
221+
distribution = pm.TruncatedNormal
222+
params = {'mu': 0., 'tau': 1., 'a':-0.5, 'b':0.5}
220223

221224
class TestSkewNormal(BaseTestCases.BaseTestCase):
222225
distribution = pm.SkewNormal
@@ -419,6 +422,11 @@ def ref_rand(size, mu, sd):
419422
return st.norm.rvs(size=size, loc=mu, scale=sd)
420423
pymc3_random(pm.Normal, {'mu': R, 'sd': Rplus}, ref_rand=ref_rand)
421424

425+
def test_truncated_normal(self):
426+
def ref_rand(size, mu, sd, a,b):
427+
return st.truncnorm.rvs((a-mu)/sd, (b-mu)/sd, size=size, loc=mu, scale=sd)
428+
pymc3_random(pm.TruncatedNormal, {'mu': R, 'sd': Rplusbig, 'a':-Rplusbig, 'b':Rplusbig}, ref_rand=ref_rand)
429+
422430
def test_skew_normal(self):
423431
def ref_rand(size, alpha, mu, sd):
424432
return st.skewnorm.rvs(size=size, a=alpha, loc=mu, scale=sd)

0 commit comments

Comments
 (0)