Skip to content

Implement polygamma #499

Closed
Closed
@ferrine

Description

@ferrine

Description

from pytensor.scalar.basic import ScalarOp, ScalarType, discrete_types
from pytensor.gradient import grad_undefined
from pytensor.tensor.elemwise import Elemwise

class PolyGamma(ScalarOp):
    nfunc_spec = ("scipy.special.polygamma", 2, 1)

    @staticmethod
    def st_impl(i, x):
        return scipy.special.polygamma(i, x)

    def impl(self, i, x):
        return self.st_impl(i, x)

    def L_op(self, inputs, outputs, grads):
        (i, x,) = inputs
        (gz,) = grads
        return [grad_undefined(self, 0, i), gz * self(i+1, x)]

def polygamma_dtype_rule(t1: ScalarType, t2: ScalarType) -> tuple[ScalarType]:
    if t1 not in discrete_types:
        raise TypeError("First input to polygamma should be discrete")
    return upgrade_to_float(t2)

polygamma = Elemwise(PolyGamma(polygamma_dtype_rule, "polygamma"))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions