From f48ed5e2fd5ddd3951339d95bfe1a130c7bdf01c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 16 Nov 2023 17:05:42 +0100 Subject: [PATCH] Implement Polygamma Op --- pytensor/gradient.py | 2 +- pytensor/scalar/math.py | 57 +++++++++++++++++++++++++++-- pytensor/tensor/math.py | 6 +++ pytensor/tensor/rewriting/math.py | 24 +++++++++++- tests/link/jax/test_scalar.py | 15 ++++++++ tests/tensor/rewriting/test_math.py | 20 +++++++++- tests/tensor/test_math.py | 44 ++++++++++++++++++++++ 7 files changed, 159 insertions(+), 9 deletions(-) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 95ce3c9553..fdd096f2e9 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -101,7 +101,7 @@ def grad_undefined(op, x_pos, x, comment=""): return ( NullType( "This variable is Null because the grad method for " - f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}" + f"input {x_pos} ({x}) of the {op} op is undefined. {comment}" ) )() diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index f87f42066c..1f326b3fab 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -13,7 +13,7 @@ import scipy.stats from pytensor.configdefaults import config -from pytensor.gradient import grad_not_implemented +from pytensor.gradient import grad_not_implemented, grad_undefined from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp from pytensor.scalar.basic import abs as scalar_abs from pytensor.scalar.basic import ( @@ -473,8 +473,12 @@ def st_impl(x): def impl(self, x): return TriGamma.st_impl(x) - def grad(self, inputs, outputs_gradients): - raise NotImplementedError() + def L_op(self, inputs, outputs, outputs_gradients): + (x,) = inputs + (g_out,) = outputs_gradients + if x in complex_types: + raise NotImplementedError("gradient not implemented for complex types") + return [g_out * polygamma(2, x)] def c_support_code(self, **kwargs): # The implementation has been copied from @@ -541,7 +545,52 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floating point is implemented") -tri_gamma = TriGamma(upgrade_to_float, name="tri_gamma") +# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410 +tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma") + + +class PolyGamma(BinaryScalarOp): + """Polygamma function of order n evaluated at x. + + It corresponds to the (n+1)th derivative of the log gamma function. + + TODO: Because the first input is discrete and the output is continuous, + the default elemwise inplace won't work, as it always tries to store the results in the first input. + """ + + nfunc_spec = ("scipy.special.polygamma", 2, 1) + + @staticmethod + def output_types_preference(n_type, x_type): + if n_type not in discrete_types: + raise TypeError( + f"Polygamma order parameter must be discrete, got {n_type} dtype" + ) + # Scipy doesn't support it + return upgrade_to_float_no_complex(x_type) + + @staticmethod + def st_impl(n, x): + return scipy.special.polygamma(n, x) + + def impl(self, n, x): + return PolyGamma.st_impl(n, x) + + def L_op(self, inputs, outputs, output_gradients): + (n, x) = inputs + (g_out,) = output_gradients + if x in complex_types: + raise NotImplementedError("gradient not implemented for complex types") + return [ + grad_undefined(self, 0, n), + g_out * self(n + 1, x), + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +polygamma = PolyGamma(name="polygamma") class Chi2SF(BinaryScalarOp): diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 56777eeb67..5a95b71a2b 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1369,6 +1369,11 @@ def tri_gamma(a): """second derivative of the log gamma function""" +@scalar_elemwise +def polygamma(n, x): + """Polygamma function of order n evaluated at x""" + + @scalar_elemwise def chi2sf(x, k): """chi squared survival function""" @@ -3008,6 +3013,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None "psi", "digamma", "tri_gamma", + "polygamma", "chi2sf", "gammainc", "gammaincc", diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 14479998a2..97c7138d3a 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -52,6 +52,7 @@ from pytensor.tensor.math import abs as at_abs from pytensor.tensor.math import ( add, + digamma, dot, eq, erf, @@ -68,7 +69,7 @@ makeKeepDims, ) from pytensor.tensor.math import max as at_max -from pytensor.tensor.math import maximum, mul, neg +from pytensor.tensor.math import maximum, mul, neg, polygamma from pytensor.tensor.math import pow as at_pow from pytensor.tensor.math import ( prod, @@ -81,7 +82,7 @@ sub, ) from pytensor.tensor.math import sum as at_sum -from pytensor.tensor.math import true_div +from pytensor.tensor.math import tri_gamma, true_div from pytensor.tensor.rewriting.basic import ( alloc_like, broadcasted_by, @@ -3638,3 +3639,22 @@ def local_useless_conj(fgraph, node): x = node.inputs[0] if x.type.dtype not in complex_dtypes: return [x] + + +local_polygamma_to_digamma = PatternNodeRewriter( + (polygamma, 0, "x"), + (digamma, "x"), + allow_multiple_clients=True, + name="local_polygamma_to_digamma", +) + +register_specialize(local_polygamma_to_digamma) + +local_polygamma_to_tri_gamma = PatternNodeRewriter( + (polygamma, 1, "x"), + (tri_gamma, "x"), + allow_multiple_clients=True, + name="local_polygamma_to_tri_gamma", +) + +register_specialize(local_polygamma_to_tri_gamma) diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 837c07085f..18877a496c 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -20,6 +20,7 @@ iv, log, log1mexp, + polygamma, psi, sigmoid, softplus, @@ -178,6 +179,20 @@ def test_tri_gamma(): compare_jax_and_py(fg, [np.array([3.0, 5.0])]) +def test_polygamma(): + n = vector("n", dtype="int32") + x = vector("x", dtype="float32") + out = polygamma(n, x) + fg = FunctionGraph([n, x], [out]) + compare_jax_and_py( + fg, + [ + np.array([0, 1, 2]).astype("int32"), + np.array([0.5, 0.9, 2.5]).astype("float32"), + ], + ) + + def test_log1mexp(): x = vector("x") out = log1mexp(x) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 0cf7370c23..a37a161d62 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -29,7 +29,7 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.misc.safe_asarray import _asarray from pytensor.printing import debugprint -from pytensor.scalar import Pow +from pytensor.scalar import PolyGamma, Pow, Psi, TriGamma from pytensor.tensor import inplace from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv @@ -69,7 +69,7 @@ from pytensor.tensor.math import max as at_max from pytensor.tensor.math import maximum from pytensor.tensor.math import min as at_min -from pytensor.tensor.math import minimum, mul, neg, neq +from pytensor.tensor.math import minimum, mul, neg, neq, polygamma from pytensor.tensor.math import pow as pt_pow from pytensor.tensor.math import ( prod, @@ -4236,3 +4236,19 @@ def test_logdiffexp(): np.testing.assert_almost_equal( f(x_test, y_test), np.log(np.exp(x_test) - np.exp(y_test)) ) + + +def test_polygamma_specialization(): + x = vector("x") + + y1 = polygamma(0, x) + y2 = polygamma(1, x) + y3 = polygamma(2, x) + + fn = pytensor.function( + [x], [y1, y2, y3], mode=get_default_mode().including("specialize") + ) + fn_outs = fn.maker.fgraph.outputs + assert isinstance(fn_outs[0].owner.op.scalar_op, Psi) + assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma) + assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index fccb02e215..af653c2f51 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -7,6 +7,7 @@ import numpy as np import pytest +import scipy.special from numpy.testing import assert_array_equal from scipy.special import logsumexp as scipy_logsumexp @@ -64,6 +65,7 @@ cov, deg2rad, dense_dot, + digamma, dot, eq, exp, @@ -93,6 +95,7 @@ neg, neq, outer, + polygamma, power, ptp, rad2deg, @@ -3470,3 +3473,44 @@ def test_dot22_opt(self): fn = function([x, y], x @ y, mode="FAST_RUN") [node] = fn.maker.fgraph.apply_nodes assert isinstance(node.op, Dot22) + + +class TestPolyGamma: + def test_basic(self): + n = vector("n", dtype="int64") + x = scalar("x") + + np.testing.assert_allclose( + polygamma(n, x).eval({n: [0, 1], x: 0.5}), + scipy.special.polygamma([0, 1], 0.5), + ) + + def test_continuous_n_raises(self): + n = scalar("n", dtype="float64") + with pytest.raises(TypeError, match="must be discrete"): + polygamma(n, 0.5) + + def test_complex_x_raises(self): + x = scalar(dtype="complex128") + with pytest.raises(TypeError, match="complex argument not supported"): + polygamma(0, x) + + def test_output_dtype(self): + n = scalar("n", dtype="int64") + polygamma(n, scalar("x", dtype="float32")).dtype == "float32" + polygamma(n, scalar("x", dtype="float64")).dtype == "float64" + polygamma(n, scalar("x", dtype="int32")).dtype == "float64" + + def test_grad_x(self): + x = scalar("x") + op_grad = grad(polygamma(0, x), wrt=x) + ref_grad = grad(digamma(x), wrt=x) + np.testing.assert_allclose( + op_grad.eval({x: 0.9}), + ref_grad.eval({x: 0.9}), + ) + + def test_grad_n_undefined(self): + n = scalar(dtype="int64") + with pytest.raises(NullTypeGradError): + grad(polygamma(n, 0.5), wrt=n)