|
24 | 24 | from .dist_math import bound, logpow, gammaln, betaln, std_cdf, alltrue_elemwise, SplineWrapper, i0e
|
25 | 25 | from .distribution import Continuous, draw_values, generate_samples
|
26 | 26 |
|
27 |
| -__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'Beta', 'Kumaraswamy', 'Exponential', |
| 27 | +__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'TruncatedNormal', 'Beta', 'Kumaraswamy', 'Exponential', |
28 | 28 | 'Laplace', 'StudentT', 'Cauchy', 'HalfCauchy', 'Gamma', 'Weibull',
|
29 | 29 | 'HalfStudentT', 'Lognormal', 'ChiSquared', 'HalfNormal', 'Wald',
|
30 | 30 | 'Pareto', 'InverseGamma', 'ExGaussian', 'VonMises', 'SkewNormal',
|
@@ -434,6 +434,223 @@ def _repr_latex_(self, name=None, dist=None):
|
434 | 434 | get_variable_name(mu),
|
435 | 435 | get_variable_name(sd))
|
436 | 436 |
|
| 437 | +class TruncatedNormal(Continuous): |
| 438 | + R""" |
| 439 | + Univariate truncated normal log-likelihood. |
| 440 | +
|
| 441 | + The pdf of this distribution is |
| 442 | +
|
| 443 | + .. math:: |
| 444 | +
|
| 445 | + f(x;\mu ,\sigma ,a,b)={\frac {\phi ({\frac {x-\mu }{\sigma }})}{ |
| 446 | + \sigma \left(\Phi ({\frac {b-\mu }{\sigma }})-\Phi ({\frac {a-\mu }{\sigma }})\right)}} |
| 447 | +
|
| 448 | + Truncated normal distribution can be parameterized either in terms of precision |
| 449 | + or standard deviation. The link between the two parametrizations is |
| 450 | + given by |
| 451 | +
|
| 452 | + .. math:: |
| 453 | +
|
| 454 | + \tau = \dfrac{1}{\sigma^2} |
| 455 | +
|
| 456 | +
|
| 457 | + .. plot:: |
| 458 | +
|
| 459 | + import matplotlib.pyplot as plt |
| 460 | + import numpy as np |
| 461 | + import scipy.stats as st |
| 462 | + plt.style.use('seaborn-darkgrid') |
| 463 | + x = np.linspace(-10, 10, 1000) |
| 464 | + mus = [0., 0., 0.] |
| 465 | + sds = [3.,5.,7.] |
| 466 | + a1 = [-3, -5, -5] |
| 467 | + b1 = [7, 5, 4] |
| 468 | + for mu, sd, a, b in zip(mus, sds,a1,b1): |
| 469 | + print mu, sd, a, b |
| 470 | + an, bn = (a - mu) / sd, (b - mu) / sd |
| 471 | + pdf = st.truncnorm.pdf(x, an,bn, loc=mu, scale=sd) |
| 472 | + plt.plot(x, pdf, label=r'$\mu$ = {}, $\sigma$ = {}, a={}, b={}'.format(mu, sd, a, b)) |
| 473 | + plt.xlabel('x', fontsize=12) |
| 474 | + plt.ylabel('f(x)', fontsize=12) |
| 475 | + plt.legend(loc=1) |
| 476 | + plt.show() |
| 477 | +
|
| 478 | + ======== ========================================== |
| 479 | + Support :math:`x \in [a, b]` |
| 480 | + Mean :math:`\mu +{\frac {\phi (\alpha )-\phi (\beta )}{Z}}\sigma` |
| 481 | + Variance :math:`\sigma ^{2}\left[1+{\frac {\alpha \phi (\alpha )-\beta \phi (\beta )}{Z}}- |
| 482 | + \left({\frac {\phi (\alpha )-\phi (\beta )}{Z}}\right)^{2}\right]` |
| 483 | + ======== ========================================== |
| 484 | +
|
| 485 | + Parameters |
| 486 | + ---------- |
| 487 | + mu : float |
| 488 | + Mean. |
| 489 | + sd : float |
| 490 | + Standard deviation (sd > 0). |
| 491 | + a : float (optional) |
| 492 | + Left bound. |
| 493 | + b : float (optional) |
| 494 | + Right bound. |
| 495 | +
|
| 496 | + Examples |
| 497 | + -------- |
| 498 | + .. code-block:: python |
| 499 | +
|
| 500 | + with pm.Model(): |
| 501 | + x = pm.TruncatedNormal('x', mu=0, sd=10, a=0) |
| 502 | +
|
| 503 | + with pm.Model(): |
| 504 | + x = pm.TruncatedNormal('x', mu=0, sd=10, b=1) |
| 505 | +
|
| 506 | + with pm.Model(): |
| 507 | + x = pm.TruncatedNormal('x', mu=0, sd=10, a=0, b=1) |
| 508 | +
|
| 509 | + """ |
| 510 | + |
| 511 | + def __init__(self, mu=0, sd=None, tau=None, a=None, b=None, **kwargs): |
| 512 | + tau, sd = get_tau_sd(tau=tau, sd=sd) |
| 513 | + self.sd = tt.as_tensor_variable(sd) |
| 514 | + self.tau = tt.as_tensor_variable(tau) |
| 515 | + self.a = tt.as_tensor_variable(a) if a is not None else a |
| 516 | + self.b = tt.as_tensor_variable(b) if b is not None else b |
| 517 | + self.mu = tt.as_tensor_variable(mu) |
| 518 | + |
| 519 | + # Calculate mean |
| 520 | + pdf_a, pdf_b, cdf_a, cdf_b = self._get_boundary_parameters() |
| 521 | + z = cdf_b - cdf_a |
| 522 | + self.mean = self.mu + (pdf_a+pdf_b) / z * self.sd |
| 523 | + |
| 524 | + assert_negative_support(sd, 'sd', 'TruncatedNormal') |
| 525 | + assert_negative_support(tau, 'tau', 'TruncatedNormal') |
| 526 | + |
| 527 | + super(TruncatedNormal, self).__init__(**kwargs) |
| 528 | + |
| 529 | + def random(self, point=None, size=None): |
| 530 | + """ |
| 531 | + Draw random values from TruncatedNormal distribution. |
| 532 | +
|
| 533 | + Parameters |
| 534 | + ---------- |
| 535 | + point : dict, optional |
| 536 | + Dict of variable values on which random values are to be |
| 537 | + conditioned (uses default point if not specified). |
| 538 | + size : int, optional |
| 539 | + Desired size of random sample (returns one sample if not |
| 540 | + specified). |
| 541 | +
|
| 542 | + Returns |
| 543 | + ------- |
| 544 | + array |
| 545 | + """ |
| 546 | + mu_v, std_v, a_v, b_v = draw_values([self.mu, self.sd, self.a, self.b], point=point, size=size) |
| 547 | + return generate_samples(stats.truncnorm.rvs, |
| 548 | + a=(a_v - mu_v)/std_v, |
| 549 | + b=(b_v - mu_v) / std_v, |
| 550 | + loc=mu_v, |
| 551 | + scale=std_v, |
| 552 | + dist_shape=self.shape, |
| 553 | + size=size, |
| 554 | + ) |
| 555 | + |
| 556 | + def logp(self, value): |
| 557 | + """ |
| 558 | + Calculate log-probability of TruncatedNormal distribution at specified value. |
| 559 | +
|
| 560 | + Parameters |
| 561 | + ---------- |
| 562 | + value : numeric |
| 563 | + Value(s) for which log-probability is calculated. If the log probabilities for multiple |
| 564 | + values are desired the values must be provided in a numpy array or theano tensor |
| 565 | +
|
| 566 | + Returns |
| 567 | + ------- |
| 568 | + TensorVariable |
| 569 | + """ |
| 570 | + sd = self.sd |
| 571 | + tau = self.tau |
| 572 | + mu = self.mu |
| 573 | + a = self.a |
| 574 | + b = self.b |
| 575 | + |
| 576 | + # In case either a or b are not specified, normalization terms simplify to 1.0 and 0.0 |
| 577 | + # https://en.wikipedia.org/wiki/Truncated_normal_distribution |
| 578 | + norm_left, norm_right = 1.0, 0.0 |
| 579 | + |
| 580 | + # Define normalization |
| 581 | + if b is not None: |
| 582 | + norm_left = self._cdf((b - mu) / sd) |
| 583 | + |
| 584 | + if a is not None: |
| 585 | + norm_right = self._cdf((a - mu) / sd) |
| 586 | + |
| 587 | + f = self._pdf((value - mu) / sd) / sd / ((norm_left - norm_right)) |
| 588 | + |
| 589 | + return bound(tt.log(f), value >= a, value <= b, sd > 0) |
| 590 | + |
| 591 | + |
| 592 | + def _cdf(self, value): |
| 593 | + """ |
| 594 | + Calculate cdf of standard normal distribution |
| 595 | +
|
| 596 | + Parameters |
| 597 | + ---------- |
| 598 | + value : numeric |
| 599 | + Value(s) for which log-probability is calculated. If the log probabilities for multiple |
| 600 | + values are desired the values must be provided in a numpy array or theano tensor |
| 601 | +
|
| 602 | + Returns |
| 603 | + ------- |
| 604 | + TensorVariable |
| 605 | + """ |
| 606 | + return 0.5 * (1.0 + tt.erf(value / tt.sqrt(2))) |
| 607 | + |
| 608 | + def _pdf(self, value): |
| 609 | + """ |
| 610 | + Calculate pdf of standard normal distribution |
| 611 | +
|
| 612 | + Parameters |
| 613 | + ---------- |
| 614 | + value : numeric |
| 615 | + Value(s) for which log-probability is calculated. If the log probabilities for multiple |
| 616 | + values are desired the values must be provided in a numpy array or theano tensor |
| 617 | +
|
| 618 | + Returns |
| 619 | + ------- |
| 620 | + TensorVariable |
| 621 | + """ |
| 622 | + return 1.0 / tt.sqrt(2 * np.pi) * tt.exp(-0.5 * (value ** 2)) |
| 623 | + |
| 624 | + def _repr_latex_(self, name=None, dist=None): |
| 625 | + if dist is None: |
| 626 | + dist = self |
| 627 | + sd = dist.sd |
| 628 | + mu = dist.mu |
| 629 | + a = dist.a |
| 630 | + b = dist.b |
| 631 | + name = r'\text{%s}' % name |
| 632 | + return r'${} \sim \text{{TruncatedNormal}}(\mathit{{mu}}={},~\mathit{{sd}}={},a={},b={})$'.format(name, |
| 633 | + get_variable_name(mu), |
| 634 | + get_variable_name(sd), |
| 635 | + get_variable_name(a), |
| 636 | + get_variable_name(b)) |
| 637 | + |
| 638 | + def _get_boundary_parameters(self): |
| 639 | + """ |
| 640 | + Calcualte values of cdf and pdf at boundary points a and b |
| 641 | +
|
| 642 | + Returns |
| 643 | + ------- |
| 644 | + pdf(a), pdf(b), cdf(a), cdf(b) if a,b defined, otherwise 0.0, 0.0, 0.0, 1.0 |
| 645 | + """ |
| 646 | + # pdf = 0 at +-inf |
| 647 | + pdf_a = self._pdf(self.a) if not self.a is None else 0.0 |
| 648 | + pdf_b = self._pdf(self.b) if not self.b is None else 0.0 |
| 649 | + |
| 650 | + # b-> inf, cdf(b) = 1.0, a->-inf, cdf(a) = 0 |
| 651 | + cdf_a = self._cdf(self.a) if not self.a is None else 0.0 |
| 652 | + cdf_b = self._cdf(self.b) if not self.b is None else 1.0 |
| 653 | + return pdf_a, pdf_b, cdf_a, cdf_b |
437 | 654 |
|
438 | 655 | class HalfNormal(PositiveContinuous):
|
439 | 656 | R"""
|
|
0 commit comments