Closed
Description
Description
from pytensor.scalar.basic import ScalarOp, ScalarType, discrete_types
from pytensor.gradient import grad_undefined
from pytensor.tensor.elemwise import Elemwise
class PolyGamma(ScalarOp):
nfunc_spec = ("scipy.special.polygamma", 2, 1)
@staticmethod
def st_impl(i, x):
return scipy.special.polygamma(i, x)
def impl(self, i, x):
return self.st_impl(i, x)
def L_op(self, inputs, outputs, grads):
(i, x,) = inputs
(gz,) = grads
return [grad_undefined(self, 0, i), gz * self(i+1, x)]
def polygamma_dtype_rule(t1: ScalarType, t2: ScalarType) -> tuple[ScalarType]:
if t1 not in discrete_types:
raise TypeError("First input to polygamma should be discrete")
return upgrade_to_float(t2)
polygamma = Elemwise(PolyGamma(polygamma_dtype_rule, "polygamma"))