-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add ZeroSumNormal
distribution
#6121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
6260c84
af96016
71e5651
3cadb26
a66c586
e3be495
759de36
a5a1e45
c9eea6e
0582d7c
0bdcdd7
854ef4c
fd3aefa
e94e4f1
dec4a9f
f7a55c5
da6eaab
a5ed1f0
126e76b
3a8d898
4c52737
7e4ed0a
44b5b91
99dbb38
e3dc1d4
09f0d91
cf5b384
3e86a3e
ce68f02
09d849c
b50909e
c204131
7ba1d0f
5ee950a
95ffc94
13a54e6
ca655bc
9d419ef
85da56c
f363118
64eca5c
c5e76c9
08c9df0
c120f7e
ba5f3a1
48dafe9
6612a24
6b07a2a
135ed47
cba0187
566f308
5954e65
3e72922
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
from aesara.tensor.random.utils import broadcast_params | ||
from aesara.tensor.slinalg import Cholesky, SolveTriangular | ||
from aesara.tensor.type import TensorType | ||
from numpy.core.numeric import normalize_axis_tuple | ||
from scipy import linalg, stats | ||
|
||
import pymc as pm | ||
|
@@ -63,15 +64,17 @@ | |
_change_dist_size, | ||
broadcast_dist_samples_to, | ||
change_dist_size, | ||
convert_dims, | ||
rv_size_is_none, | ||
to_tuple, | ||
) | ||
from pymc.distributions.transforms import Interval, _default_transform | ||
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform | ||
from pymc.math import kron_diag, kron_dot | ||
from pymc.util import check_dist_not_registered | ||
|
||
__all__ = [ | ||
"MvNormal", | ||
"ZeroSumNormal", | ||
"MvStudentT", | ||
"Dirichlet", | ||
"Multinomial", | ||
|
@@ -2380,3 +2383,169 @@ def logp(value, alpha, K): | |
K > 0, | ||
msg="alpha > 0, K > 0", | ||
) | ||
|
||
|
||
class ZeroSumNormalRV(SymbolicRandomVariable): | ||
"""ZeroSumNormal random variable""" | ||
|
||
_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") | ||
zerosum_axes = None | ||
|
||
def __init__(self, *args, zerosum_axes, **kwargs): | ||
self.zerosum_axes = zerosum_axes | ||
super().__init__(*args, **kwargs) | ||
|
||
|
||
class ZeroSumNormal(Distribution): | ||
r""" | ||
ZeroSumNormal distribution, i.e Normal distribution where one or | ||
several axes are constrained to sum to zero. | ||
By default, the last axis is constrained to sum to zero. | ||
See `zerosum_axes` kwarg for more details. | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Parameters | ||
---------- | ||
sigma : tensor_like of float | ||
Standard deviation (sigma > 0). | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Defaults to 1 if not specified. | ||
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. | ||
zerosum_axes: list or tuple of strings or integers | ||
Axis (or axes) along which the zero-sum constraint is enforced. | ||
Defaults to [-1], i.e the last axis. | ||
If strings are passed, then ``dims`` is needed. | ||
Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions. | ||
dims: list or tuple of strings, optional | ||
The dimension names of the axes. | ||
Necessary when ``zerosum_axes`` is specified with strings. | ||
|
||
Warnings | ||
-------- | ||
``sigma`` has to be a scalar, to ensure the zero-sum constraint. | ||
The ability to specifiy a vector of ``sigma`` may be added in future versions. | ||
|
||
Examples | ||
-------- | ||
.. code-block:: python | ||
COORDS = { | ||
"regions": ["a", "b", "c"], | ||
"answers": ["yes", "no", "whatever", "don't understand question"], | ||
} | ||
with pm.Model(coords=COORDS) as m: | ||
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers") | ||
|
||
with pm.Model(coords=COORDS) as m: | ||
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers")) | ||
|
||
with pm.Model(coords=COORDS) as m: | ||
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1) | ||
""" | ||
rv_type = ZeroSumNormalRV | ||
|
||
def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): | ||
dims = convert_dims(dims) | ||
if zerosum_axes is None: | ||
zerosum_axes = [-1] | ||
if not isinstance(zerosum_axes, (list, tuple)): | ||
zerosum_axes = [zerosum_axes] | ||
|
||
if isinstance(zerosum_axes[0], str): | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we might want to handle the case where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't that just a Normal distribution? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is, but if you write more general code that somehow produces the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh so you mean erroring out in that case, gotcha There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, erroring out would be ok, but can't we just handle that case correctly? I can't think of anything that should go wrong in this case. So why not just
And maybe add a test that checks if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh ok, I understand what you mean now. I'm curious what @ricardoV94 thinks, but I would prefer not to do that: if people wanna use a Normal they should just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I agree with you now @aseyboldt (mainly because that behavior would be consistent with other PyMC distributions' behavior). |
||
if not dims: | ||
raise ValueError("You need to specify dims if zerosum_axes are strings.") | ||
else: | ||
zerosum_axes_ = [] | ||
for axis in zerosum_axes: | ||
zerosum_axes_.append(dims.index(axis)) | ||
zerosum_axes = zerosum_axes_ | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs) | ||
|
||
@classmethod | ||
def dist(cls, sigma=1, zerosum_axes=None, **kwargs): | ||
if zerosum_axes is None: | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
zerosum_axes = [-1] | ||
|
||
sigma = at.as_tensor_variable(floatX(sigma)) | ||
if sigma.ndim > 0: | ||
raise ValueError("sigma has to be a scalar") | ||
|
||
return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs) | ||
|
||
# TODO: This is if we want ZeroSum constraint on other dists than Normal | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# def dist(cls, dist, lower, upper, **kwargs): | ||
# if not isinstance(dist, TensorVariable) or not isinstance( | ||
# dist.owner.op, (RandomVariable, SymbolicRandomVariable) | ||
# ): | ||
# raise ValueError( | ||
# f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" | ||
# ) | ||
# if dist.owner.op.ndim_supp > 0: | ||
# raise NotImplementedError( | ||
# "Censoring of multivariate distributions has not been implemented yet" | ||
# ) | ||
# check_dist_not_registered(dist) | ||
# return super().dist([dist, lower, upper], **kwargs) | ||
|
||
@classmethod | ||
def rv_op(cls, sigma, zerosum_axes, size=None): | ||
if size is None: | ||
zerosum_axes_ = np.asarray(zerosum_axes) | ||
# just a placeholder size to infer minimum shape | ||
size = np.ones( | ||
max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
).tolist() | ||
|
||
# check if zerosum_axes is valid | ||
normalize_axis_tuple(zerosum_axes, len(size)) | ||
|
||
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size)) | ||
normal_dist_, sigma_ = normal_dist.type(), sigma.type() | ||
|
||
# Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes | ||
zerosum_rv_ = normal_dist_ | ||
for axis in zerosum_axes: | ||
zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True) | ||
|
||
return ZeroSumNormalRV( | ||
inputs=[normal_dist_, sigma_], | ||
outputs=[zerosum_rv_], | ||
zerosum_axes=zerosum_axes, | ||
ndim_supp=0, | ||
)(normal_dist, sigma) | ||
|
||
|
||
@_change_dist_size.register(ZeroSumNormalRV) | ||
def change_zerosum_size(op, normal_dist, new_size, expand=False): | ||
normal_dist, sigma = normal_dist.owner.inputs | ||
if expand: | ||
new_size = tuple(new_size) + tuple(normal_dist.shape) | ||
return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size) | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@_moment.register(ZeroSumNormalRV) | ||
def zerosumnormal_moment(op, rv, *rv_inputs): | ||
return at.zeros_like(rv) | ||
|
||
|
||
@_default_transform.register(ZeroSumNormalRV) | ||
def zerosum_default_transform(op, rv): | ||
return ZeroSumTransform(op.zerosum_axes) | ||
|
||
|
||
@_logprob.register(ZeroSumNormalRV) | ||
def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs): | ||
(value,) = values | ||
shape = value.shape | ||
_deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1) | ||
_full_size = at.prod(shape) | ||
_degrees_of_freedom = at.prod(_deg_free_shape) | ||
zerosums = [ | ||
at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes | ||
] | ||
# out = at.sum( | ||
# pm.logp(dist, value) * _degrees_of_freedom / _full_size, | ||
# axis=op.zerosum_axes, | ||
# ) | ||
# figure out how dimensionality should be handled for logp | ||
# for now, we assume ZSN is a scalar distribut, which is not correct | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size | ||
return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,10 @@ | |
from aesara.graph import Op | ||
from aesara.tensor import TensorVariable | ||
|
||
# ignore mypy error because it somehow considers that | ||
# "numpy.core.numeric has no attribute normalize_axis_tuple" | ||
from numpy.core.numeric import normalize_axis_tuple # type: ignore | ||
|
||
__all__ = [ | ||
"RVTransform", | ||
"simplex", | ||
|
@@ -39,6 +43,7 @@ | |
"circular", | ||
"CholeskyCovPacked", | ||
"Chain", | ||
"ZeroSumTransform", | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
] | ||
|
||
|
||
|
@@ -266,6 +271,66 @@ def bounds_fn(*rv_inputs): | |
super().__init__(args_fn=bounds_fn) | ||
|
||
|
||
class ZeroSumTransform(RVTransform): | ||
""" | ||
Constrains the samples of a Normal distribution to sum to zero | ||
twiecki marked this conversation as resolved.
Show resolved
Hide resolved
|
||
along the user-provided ``zerosum_axes``. | ||
By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed | ||
on the last axis. | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
name = "zerosum" | ||
|
||
__props__ = ("zerosum_axes",) | ||
|
||
def __init__(self, zerosum_axes): | ||
""" | ||
Parameters | ||
---------- | ||
zerosum_axes : list of ints | ||
Must be a list of integers (positive or negative). | ||
By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed | ||
on the last axis. | ||
""" | ||
self.zerosum_axes = zerosum_axes | ||
|
||
def forward(self, value, *rv_inputs): | ||
for axis in self.zerosum_axes: | ||
value = extend_axis_rev(value, axis=axis) | ||
return value | ||
|
||
def backward(self, value, *rv_inputs): | ||
for axis in self.zerosum_axes: | ||
value = extend_axis(value, axis=axis) | ||
return value | ||
|
||
def log_jac_det(self, value, *rv_inputs): | ||
return at.constant(0.0) | ||
|
||
|
||
def extend_axis(array, axis): | ||
n = array.shape[axis] + 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could maybe add a comment here saying that this is using a householder reflection plus a projection operator to move forward from the constrained space onto the zero sum manifold. I’ll look up our notes and write something here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you find your notes @lucianopaz ? |
||
sum_vals = array.sum(axis, keepdims=True) | ||
norm = sum_vals / (np.sqrt(n) + n) | ||
fill_val = norm - sum_vals / np.sqrt(n) | ||
|
||
out = at.concatenate([array, fill_val], axis=axis) | ||
return out - norm | ||
|
||
|
||
def extend_axis_rev(array, axis): | ||
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] | ||
|
||
n = array.shape[normalized_axis] | ||
last = at.take(array, [-1], axis=normalized_axis) | ||
|
||
sum_vals = -last * np.sqrt(n) | ||
norm = sum_vals / (np.sqrt(n) + n) | ||
slice_before = (slice(None, None),) * normalized_axis | ||
|
||
return array[slice_before + (slice(None, -1),)] + norm | ||
|
||
|
||
log_exp_m1 = LogExpM1() | ||
log_exp_m1.__doc__ = """ | ||
Instantiation of :class:`pymc.distributions.transforms.LogExpM1` | ||
|
Uh oh!
There was an error while loading. Please reload this page.