Skip to content

Commit 1428e4c

Browse files
committed
added calc_utils
2 parents 0271b8f + 094a269 commit 1428e4c

File tree

5 files changed

+187
-86
lines changed

5 files changed

+187
-86
lines changed

PyMC.tmbundle/Snippets/Impute.tmSnippet

Lines changed: 0 additions & 16 deletions
This file was deleted.

pymc/calc_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
'''
2+
Created on Jan 20, 2011
3+
4+
@author: jsalvatier
5+
'''
6+
import numpy as np
7+
from collections import defaultdict
8+
9+
_sts_memory = defaultdict(dict)
10+
def sum_to_shape(key1,key2, value, sum_shape):
11+
12+
try :
13+
axes, lx = _sts_memory[key1][key2]
14+
15+
except KeyError:
16+
17+
value_shape = np.array(np.shape(value))
18+
19+
sum_shape_expanded = np.zeros(value_shape.size)
20+
sum_shape_expanded[0:len(sum_shape)] += np.array(sum_shape)
21+
22+
axes = np.where(sum_shape_expanded != value_shape)[0]
23+
lx = np.size(axes)
24+
25+
_sts_memory[key1][key2] = (axes, lx )
26+
27+
if lx > 0:
28+
return np.apply_over_axes(np.sum, value, axes)
29+
30+
else:
31+
return value

pymc/distributions.py

Lines changed: 66 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class ArgumentError(AttributeError):
5959
'degenerate', 'exponential', 'exponweib',
6060
'gamma', 'half_normal', 'hypergeometric',
6161
'inverse_gamma', 'laplace', 'logistic',
62-
'lognormal', 'normal', 'pareto', 't',
63-
'truncated_pareto', 'uniform',
62+
'lognormal', 'noncentral_t', 'normal',
63+
'pareto', 't', 'truncated_pareto', 'uniform',
6464
'weibull', 'skew_normal', 'truncated_normal',
6565
'von_mises']
6666
sc_bool_distributions = ['bernoulli']
@@ -2241,9 +2241,9 @@ def truncated_pareto_expval(alpha, m, b):
22412241

22422242
if alpha <= 1:
22432243
return inf
2244-
part1 = (m**alpha)/(1 - (m/b)**alpha)
2245-
part2 = alpha/(alpha-1)
2246-
part3 = (1/(m**(alpha-1)) - 1/(b**(alpha-1)))
2244+
part1 = (m**alpha)/(1. - (m/b)**alpha)
2245+
part2 = 1.*alpha/(alpha-1)
2246+
part3 = (1./(m**(alpha-1)) - 1./(b**(alpha-1.)))
22472247
return part1*part2*part3
22482248

22492249
def truncated_pareto_like(x, alpha, m, b):
@@ -2404,72 +2404,23 @@ def truncated_poisson_like(x,mu,k):
24042404

24052405
# Truncated normal distribution--------------------------
24062406
@randomwrap
2407-
def rtruncated_normal(mu, tau, a=None, b=None, size=None):
2407+
def rtruncated_normal(mu, tau, a=-np.inf, b=np.inf, size=None):
24082408
"""rtruncated_normal(mu, tau, a, b, size=1)
24092409
2410-
Random truncated normal variates using method from Robert (1995).
2410+
Random truncated normal variates.
24112411
"""
2412-
2413-
factor = 10
2414-
sign = 1.0
2415-
2416-
while True:
2417-
2418-
if a is None and b is None:
2419-
raise ValueError, 'No truncation boundary given.'
2420-
2421-
elif a is None or b is None:
2422-
# One-sided truncation
2423-
2424-
if a is None:
2425-
# See top of p.123 in Robert (1995)
2426-
a = -b
2427-
mu = -mu
2428-
sign = -1.0
2429-
2430-
# Algorithm is in terms of standard normal
2431-
a = np.sqrt(tau) * (a - mu)
2432-
2433-
# Parameter of exponential proposal
2434-
beta = (a + np.sqrt(a**2 + 4))/2.0
2435-
# Sample from exponential
2436-
z = np.random.exponential(1./beta, size*factor) + a
2437-
2438-
if a<beta:
2439-
2440-
x = np.exp(-0.5 * (beta - z)**2)
2441-
2442-
else:
2443-
2444-
x = np.exp(0.5 * (a - beta)**2) * np.exp(-(beta - z)**2)
2445-
2446-
else:
2447-
# Two-sided truncation
2448-
2449-
# Algorithm is in terms of standard normal
2450-
a = np.sqrt(tau) * (a - mu)
2451-
b = np.sqrt(tau) * (b - mu)
2452-
2453-
# Sample z ~ U(a,b)
2454-
z = (b - a) * random_number(size*factor) + a
2455-
2456-
if a<=0<=b:
2457-
x = np.exp(-0.5 * z**2)
2458-
elif b<0:
2459-
x = np.exp(0.5 * (b**2 - z**2))
2460-
else:
2461-
x = np.exp(0.5 * (a**2 - z**2))
2462-
2463-
# Accept-reject
2464-
u = random_number(size*factor)
2465-
y = sign * (z[u <= x] / np.sqrt(tau) + mu)
2466-
2467-
# Return <size> samples
2468-
if len(y) >= size:
2469-
return y[:size]
2470-
else:
2471-
# Get a larger sample next time
2472-
factor *=10
2412+
2413+
sigma = 1./np.sqrt(tau)
2414+
na = pymc.utils.normcdf((a-mu)/sigma)
2415+
nb = pymc.utils.normcdf((b-mu)/sigma)
2416+
2417+
# Use the inverse CDF generation method.
2418+
U = np.random.mtrand.uniform(size=size)
2419+
q = U * nb + (1-U)*na
2420+
R = pymc.utils.invcdf(q)
2421+
2422+
# Unnormalize
2423+
return R*sigma + mu
24732424

24742425
rtruncnorm = rtruncated_normal
24752426

@@ -2496,7 +2447,7 @@ def truncated_normal_expval(mu, tau, a, b):
24962447

24972448
truncnorm_expval = truncated_normal_expval
24982449

2499-
def truncated_normal_like(x, mu, tau, a, b):
2450+
def truncated_normal_like(x, mu, tau, a=None, b=None):
25002451
R"""truncnorm_like(x, mu, tau, a, b)
25012452
25022453
Truncated normal log-likelihood.
@@ -2514,7 +2465,9 @@ def truncated_normal_like(x, mu, tau, a, b):
25142465
- `b` : Right bound of the distribution.
25152466
"""
25162467
x = np.atleast_1d(x)
2468+
if a is None: a = -np.inf
25172469
a = np.atleast_1d(a)
2470+
if b is None: b = np.inf
25182471
b = np.atleast_1d(b)
25192472
mu = np.atleast_1d(mu)
25202473
sigma = (1./np.atleast_1d(np.sqrt(tau)))
@@ -2617,6 +2570,50 @@ def t_expval(nu):
26172570
Expectation of Student's t random variables.
26182571
"""
26192572
return 0
2573+
2574+
# Non-central Student's t-----------------------------------
2575+
@randomwrap
2576+
def rnoncentral_t(mu, lam, nu, size=None):
2577+
"""rnoncentral_t(mu, lam, nu, size=1)
2578+
2579+
Non-central Student's t random variates.
2580+
"""
2581+
tau = rgamma(nu/2., nu/(2.*lam), size)
2582+
return rnormal(mu, tau)
2583+
2584+
def noncentral_t_like(x, mu, lam, nu):
2585+
R"""noncentral_t_like(x, mu, lam, nu)
2586+
2587+
Non-central Student's T log-likelihood. Describes a normal variable
2588+
whose precision is gamma distributed.
2589+
2590+
.. math::
2591+
f(x|\mu,\lambda,\nu) = \frac{\Gamma(\frac{\nu +
2592+
1}{2})}{\Gamma(\frac{\nu}{2})}
2593+
\left(\frac{\lambda}{\pi\nu}\right)^{\frac{1}{2}}
2594+
\left[1+\frac{\lambda(x-\mu)^2}{\nu}\right]^{-\frac{\nu+1}{2}}
2595+
2596+
:Parameters:
2597+
- `x` : Input data.
2598+
- `mu` : Location parameter.
2599+
- `lambda` : Scale parameter.
2600+
- `nu` : Degrees of freedom.
2601+
2602+
"""
2603+
mu = np.asarray(mu)
2604+
lam = np.asarray(lam)
2605+
nu = np.asarray(nu)
2606+
return flib.nct(x, mu, lam, nu)
2607+
2608+
def noncentral_t_expval(mu, lam, nu):
2609+
"""noncentral_t_expval(mu, lam, nu)
2610+
2611+
Expectation of non-central Student's t random variables. Only defined
2612+
for nu>1.
2613+
"""
2614+
if nu>1:
2615+
return mu
2616+
return inf
26202617

26212618
def t_grad_setup(x, nu, f):
26222619
nu = np.asarray(nu)

pymc/flib.f

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,65 @@ SUBROUTINE chi2_grad_nu(x,nu,n,nnu,gradlikenu)
14431443
return
14441444
END
14451445

1446+
SUBROUTINE nct(x,mu,lam,nu,n,nmu,nlam,nnu,like)
1447+
1448+
c Non-central Student's t log-likelihood function
1449+
1450+
cf2py double precision dimension(n),intent(in) :: x
1451+
cf2py double precision dimension(nmu),intent(in) :: mu
1452+
cf2py double precision dimension(nlam),intent(in) :: lam
1453+
cf2py double precision dimension(nnu),intent(in) :: nu
1454+
cf2py double precision intent(out) :: like
1455+
cf2py integer intent(hide),depend(x) :: n=len(x)
1456+
cf2py integer intent(hide),depend(mu) :: nmu=len(mu)
1457+
cf2py integer intent(hide),depend(lam) :: nlam=len(lam)
1458+
cf2py integer intent(hide),depend(nu) :: nnu=len(nu)
1459+
cf2py threadsafe
1460+
1461+
IMPLICIT NONE
1462+
INTEGER n, i, nnu, nmu, nlam
1463+
DOUBLE PRECISION x(n)
1464+
DOUBLE PRECISION nu(nnu), mu(nmu), lam(nlam), like, infinity
1465+
DOUBLE PRECISION mut, lamt, nut
1466+
PARAMETER (infinity = 1.7976931348623157d308)
1467+
DOUBLE PRECISION gammln
1468+
DOUBLE PRECISION PI
1469+
PARAMETER (PI=3.141592653589793238462643d0)
1470+
1471+
nut = nu(1)
1472+
mut = mu(1)
1473+
lamt = lam(1)
1474+
1475+
like = 0.0
1476+
do i=1,n
1477+
if (nmu .GT. 1) then
1478+
mut = mu(i)
1479+
endif
1480+
if (nlam .GT. 1) then
1481+
lamt = lam(i)
1482+
endif
1483+
if (nnu .GT. 1) then
1484+
nut = nu(i)
1485+
endif
1486+
1487+
if (nut .LE. 0.0) then
1488+
like = -infinity
1489+
RETURN
1490+
endif
1491+
if (lamt .LE. 0.0) then
1492+
like = -infinity
1493+
RETURN
1494+
endif
1495+
1496+
like = like + gammln((nut+1.0)/2.0)
1497+
like = like - gammln(nut/2.0)
1498+
like = like + 0.5*dlog(lamt) - 0.5*dlog(nut * PI)
1499+
like = like - (nut+1)/2 * dlog(1 + (lamt*(x(i) - mut)**2)/nut)
1500+
enddo
1501+
return
1502+
END
1503+
1504+
14461505
SUBROUTINE multinomial(x,n,p,nx,nn,np,k,like)
14471506

14481507
c Multinomial log-likelihood function

pymc/tests/test_distributions.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,25 @@ def test_consistency(self):
511511
compare_hist(figname='Student t', **figdata)
512512
assert_array_almost_equal(hist, like,1)
513513

514+
class test_noncentral_t(TestCase):
515+
"""Based on gamma."""
516+
def test_consistency(self):
517+
parameters={'mu':-10, 'lam':0.2, 'nu':5}
518+
hist, like, figdata = consistency(rnoncentral_t, noncentral_t_like,
519+
parameters, nrandom=5000)
520+
if PLOT:
521+
compare_hist(figname='noncentral t', **figdata)
522+
assert_array_almost_equal(hist, like,1)
523+
524+
def test_vectorization(self):
525+
a = flib.nct([3,4,5], mu=3, lam=.1, nu=5)
526+
b = flib.nct([3,4,5], mu=[3,3,3], lam=.1, nu=5)
527+
c = flib.nct([3,4,5], mu=[3,3,3], lam=[.1,.1,.1], nu=5)
528+
d = flib.nct([3,4,5], mu=[3,3,3], lam=[.1,.1,.1], nu=[5,5,5])
529+
assert_equal(a,b)
530+
assert_equal(b,c)
531+
assert_equal(c,d)
532+
514533
class test_exponweib(TestCase):
515534
def test_consistency(self):
516535
parameters = {'alpha':2, 'k':2, 'loc':1, 'scale':3}
@@ -823,6 +842,12 @@ def test_consistency(self):
823842
if PLOT:
824843
compare_hist(figname='truncated_pareto', **figdata)
825844
assert_array_almost_equal(hist, like, 1)
845+
846+
def test_random(self):
847+
r = rtruncated_pareto(alpha=3, m=1, b=6, size=10000)
848+
assert_almost_equal(r.mean(), truncated_pareto_expval(3, 1, 6), 1)
849+
assert (r > 1).all()
850+
assert (r < 6).all()
826851

827852
def test_vectorization(self):
828853
a = flib.truncated_pareto([3,4,5], alpha=3, m=1, b=6)
@@ -856,7 +881,12 @@ def test_consistency(self):
856881
if PLOT:
857882
compare_hist(figname='poisson', **figdata)
858883
assert_array_almost_equal(hist, like,1)
859-
884+
885+
def test_random(self):
886+
r = rtruncated_poisson(mu=5, k=1, size=10000)
887+
assert_almost_equal(r.mean(), truncated_poisson_expval(5, 1), 1)
888+
assert (r >= 1).all()
889+
860890
def test_normalization(self):
861891
parameters = {'mu':4., 'k':1}
862892
summation=discrete_normalization(flib.trpoisson,parameters,20)

0 commit comments

Comments
 (0)