Skip to content

Implement Polygamma Op #505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def grad_undefined(op, x_pos, x, comment=""):
return (
NullType(
"This variable is Null because the grad method for "
f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}"
f"input {x_pos} ({x}) of the {op} op is undefined. {comment}"
)
)()

Expand Down
57 changes: 53 additions & 4 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import scipy.stats

from pytensor.configdefaults import config
from pytensor.gradient import grad_not_implemented
from pytensor.gradient import grad_not_implemented, grad_undefined
from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
from pytensor.scalar.basic import abs as scalar_abs
from pytensor.scalar.basic import (
Expand Down Expand Up @@ -473,8 +473,12 @@ def st_impl(x):
def impl(self, x):
return TriGamma.st_impl(x)

def grad(self, inputs, outputs_gradients):
raise NotImplementedError()
def L_op(self, inputs, outputs, outputs_gradients):
(x,) = inputs
(g_out,) = outputs_gradients
if x in complex_types:
raise NotImplementedError("gradient not implemented for complex types")
return [g_out * polygamma(2, x)]

def c_support_code(self, **kwargs):
# The implementation has been copied from
Expand Down Expand Up @@ -541,7 +545,52 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floating point is implemented")


tri_gamma = TriGamma(upgrade_to_float, name="tri_gamma")
# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma")


class PolyGamma(BinaryScalarOp):
"""Polygamma function of order n evaluated at x.

It corresponds to the (n+1)th derivative of the log gamma function.

TODO: Because the first input is discrete and the output is continuous,
the default elemwise inplace won't work, as it always tries to store the results in the first input.
"""

nfunc_spec = ("scipy.special.polygamma", 2, 1)

@staticmethod
def output_types_preference(n_type, x_type):
if n_type not in discrete_types:
raise TypeError(
f"Polygamma order parameter must be discrete, got {n_type} dtype"
)
# Scipy doesn't support it
return upgrade_to_float_no_complex(x_type)

@staticmethod
def st_impl(n, x):
return scipy.special.polygamma(n, x)

def impl(self, n, x):
return PolyGamma.st_impl(n, x)

def L_op(self, inputs, outputs, output_gradients):
(n, x) = inputs
(g_out,) = output_gradients
if x in complex_types:
raise NotImplementedError("gradient not implemented for complex types")
return [
grad_undefined(self, 0, n),
g_out * self(n + 1, x),
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


polygamma = PolyGamma(name="polygamma")


class Chi2SF(BinaryScalarOp):
Expand Down
6 changes: 6 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,11 @@ def tri_gamma(a):
"""second derivative of the log gamma function"""


@scalar_elemwise
def polygamma(n, x):
"""Polygamma function of order n evaluated at x"""


@scalar_elemwise
def chi2sf(x, k):
"""chi squared survival function"""
Expand Down Expand Up @@ -3008,6 +3013,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
"psi",
"digamma",
"tri_gamma",
"polygamma",
"chi2sf",
"gammainc",
"gammaincc",
Expand Down
24 changes: 22 additions & 2 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import (
add,
digamma,
dot,
eq,
erf,
Expand All @@ -68,7 +69,7 @@
makeKeepDims,
)
from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum, mul, neg
from pytensor.tensor.math import maximum, mul, neg, polygamma
from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import (
prod,
Expand All @@ -81,7 +82,7 @@
sub,
)
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div
from pytensor.tensor.math import tri_gamma, true_div
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
Expand Down Expand Up @@ -3638,3 +3639,22 @@ def local_useless_conj(fgraph, node):
x = node.inputs[0]
if x.type.dtype not in complex_dtypes:
return [x]


local_polygamma_to_digamma = PatternNodeRewriter(
(polygamma, 0, "x"),
(digamma, "x"),
allow_multiple_clients=True,
name="local_polygamma_to_digamma",
)

register_specialize(local_polygamma_to_digamma)

local_polygamma_to_tri_gamma = PatternNodeRewriter(
(polygamma, 1, "x"),
(tri_gamma, "x"),
allow_multiple_clients=True,
name="local_polygamma_to_tri_gamma",
)

register_specialize(local_polygamma_to_tri_gamma)
15 changes: 15 additions & 0 deletions tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
iv,
log,
log1mexp,
polygamma,
psi,
sigmoid,
softplus,
Expand Down Expand Up @@ -178,6 +179,20 @@ def test_tri_gamma():
compare_jax_and_py(fg, [np.array([3.0, 5.0])])


def test_polygamma():
n = vector("n", dtype="int32")
x = vector("x", dtype="float32")
out = polygamma(n, x)
fg = FunctionGraph([n, x], [out])
compare_jax_and_py(
fg,
[
np.array([0, 1, 2]).astype("int32"),
np.array([0.5, 0.9, 2.5]).astype("float32"),
],
)


def test_log1mexp():
x = vector("x")
out = log1mexp(x)
Expand Down
20 changes: 18 additions & 2 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.scalar import Pow
from pytensor.scalar import PolyGamma, Pow, Psi, TriGamma
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
Expand Down Expand Up @@ -69,7 +69,7 @@
from pytensor.tensor.math import max as at_max
from pytensor.tensor.math import maximum
from pytensor.tensor.math import min as at_min
from pytensor.tensor.math import minimum, mul, neg, neq
from pytensor.tensor.math import minimum, mul, neg, neq, polygamma
from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import (
prod,
Expand Down Expand Up @@ -4236,3 +4236,19 @@ def test_logdiffexp():
np.testing.assert_almost_equal(
f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test))
)


def test_polygamma_specialization():
x = vector("x")

y1 = polygamma(0, x)
y2 = polygamma(1, x)
y3 = polygamma(2, x)

fn = pytensor.function(
[x], [y1, y2, y3], mode=get_default_mode().including("specialize")
)
fn_outs = fn.maker.fgraph.outputs
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)
44 changes: 44 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import pytest
import scipy.special
from numpy.testing import assert_array_equal
from scipy.special import logsumexp as scipy_logsumexp

Expand Down Expand Up @@ -64,6 +65,7 @@
cov,
deg2rad,
dense_dot,
digamma,
dot,
eq,
exp,
Expand Down Expand Up @@ -93,6 +95,7 @@
neg,
neq,
outer,
polygamma,
power,
ptp,
rad2deg,
Expand Down Expand Up @@ -3470,3 +3473,44 @@ def test_dot22_opt(self):
fn = function([x, y], x @ y, mode="FAST_RUN")
[node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, Dot22)


class TestPolyGamma:
def test_basic(self):
n = vector("n", dtype="int64")
x = scalar("x")

np.testing.assert_allclose(
polygamma(n, x).eval({n: [0, 1], x: 0.5}),
scipy.special.polygamma([0, 1], 0.5),
)

def test_continuous_n_raises(self):
n = scalar("n", dtype="float64")
with pytest.raises(TypeError, match="must be discrete"):
polygamma(n, 0.5)

def test_complex_x_raises(self):
x = scalar(dtype="complex128")
with pytest.raises(TypeError, match="complex argument not supported"):
polygamma(0, x)

def test_output_dtype(self):
n = scalar("n", dtype="int64")
polygamma(n, scalar("x", dtype="float32")).dtype == "float32"
polygamma(n, scalar("x", dtype="float64")).dtype == "float64"
polygamma(n, scalar("x", dtype="int32")).dtype == "float64"

def test_grad_x(self):
x = scalar("x")
op_grad = grad(polygamma(0, x), wrt=x)
ref_grad = grad(digamma(x), wrt=x)
np.testing.assert_allclose(
op_grad.eval({x: 0.9}),
ref_grad.eval({x: 0.9}),
)

def test_grad_n_undefined(self):
n = scalar(dtype="int64")
with pytest.raises(NullTypeGradError):
grad(polygamma(n, 0.5), wrt=n)