Skip to content

Commit 839786b

Browse files
committed
transformations now are their own special thing
1 parent d4588fa commit 839786b

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

pymc/distributions/dist_math.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
from __future__ import division
77
from ..quickclass import *
88
import theano.tensor as t
9-
from theano.tensor import sum, switch, log,exp, eq, neq, lt, gt, le, ge, zeros_like, cast,arange, round, max, min
9+
from theano.tensor import (
10+
sum, switch, log,exp,
11+
eq, neq, lt, gt, le, ge, all, any,
12+
cast,arange, round, max, min,
13+
zeros_like, ones, ones_like,
14+
concatenate, constant)
15+
1016

1117
from numpy import pi, inf, nan
1218
from special import gammaln

pymc/distributions/transforms.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ def TransformedDistribtuion():
1212
support = dist.support
1313

1414
def logp(x):
15-
return dist.logp(x) + jacobian_det(x)
15+
return dist.logp(backward(x)) + jacobian_det(x)
1616

1717
if hasattr(dist, "mode"):
18-
mode = backward(dist.mode)
18+
mode = forward(dist.mode)
19+
if hasattr(dist, "median"):
20+
mode = forward(dist.median)
1921

2022
return locals()
2123

@@ -30,5 +32,5 @@ def __str__():
3032

3133
simplex = transform("simplex",
3234
lambda p: p[:-1],
33-
lambda p: concatenate([p, 1- sum(p)]),
34-
0)
35+
lambda p: concatenate([p, 1- sum(p, keepdims = True)]),
36+
lambda p: constant([0]))

pymc/model.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,10 @@ def Var(model, name, dist, shape = 1, dtype = None, testval = try_defaults):
7676
model.factors.append(dist.logp(var))
7777
return var
7878

79-
def TransformedVar(model, name, dist, transform, logjacobian, shape = 1, dtype = None, testval = try_defaults):
80-
if not dtype:
81-
dtype = default_type[dist.support]
82-
83-
var = Variable('transformed_' + name, shape, dtype, get_test_val(dist, testval))
84-
85-
model.vars.append(var)
86-
87-
tvar = transform(var)
88-
model.factors.append(dist.logp(tvar) + logjacobian(var))
79+
def TransformedVar(model, name, dist, trans, shape = 1, dtype = None, testval = try_defaults):
80+
var = model.Var(trans.name + '_' + name, trans.apply(dist), shape, dtype, testval)
8981

90-
return tvar, var
82+
return trans.backward(var), var
9183

9284
@property
9385
def logp(model):

0 commit comments

Comments
 (0)