Skip to content

Commit 5106e0a

Browse files
committed
replace pytensorf identity with pytensor identity
1 parent 0defe25 commit 5106e0a

File tree

2 files changed

+1
-32
lines changed

2 files changed

+1
-32
lines changed

pymc/pytensorf.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import pytensor.tensor as pt
3333
import scipy.sparse as sps
3434

35-
from pytensor import scalar
3635
from pytensor.compile import Function, Mode, get_mode
3736
from pytensor.gradient import grad
3837
from pytensor.graph import Type, rewrite_graph
@@ -47,6 +46,7 @@
4746
from pytensor.graph.fg import FunctionGraph
4847
from pytensor.graph.op import Op
4948
from pytensor.scalar.basic import Cast
49+
from pytensor.scalar.basic import identity as scalar_identity
5050
from pytensor.scan.op import Scan
5151
from pytensor.tensor.basic import _as_tensor_variable
5252
from pytensor.tensor.elemwise import Elemwise
@@ -387,28 +387,6 @@ def hessian_diag(f, vars=None):
387387
return empty_gradient
388388

389389

390-
class IdentityOp(scalar.UnaryScalarOp):
391-
@staticmethod
392-
def st_impl(x):
393-
return x
394-
395-
def impl(self, x):
396-
return x
397-
398-
def grad(self, inp, grads):
399-
return grads
400-
401-
def c_code(self, node, name, inp, out, sub):
402-
return f"{out[0]} = {inp[0]};"
403-
404-
def __eq__(self, other):
405-
return isinstance(self, type(other))
406-
407-
def __hash__(self):
408-
return hash(type(self))
409-
410-
411-
scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity")
412390
identity = Elemwise(scalar_identity, name="identity")
413391

414392

pymc/sampling/jax.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from pymc.distributions.multivariate import PosDefMatrix
4646
from pymc.initial_point import StartDict
4747
from pymc.logprob.utils import CheckParameterValue
48-
from pymc.pytensorf import IdentityOp
4948
from pymc.sampling.mcmc import _init_jitter
5049
from pymc.util import (
5150
RandomSeed,
@@ -68,14 +67,6 @@
6867
)
6968

7069

71-
@jax_funcify.register(IdentityOp)
72-
def jax_funcify_Identity(op, **kwargs):
73-
def identity_fn(value):
74-
return value
75-
76-
return identity_fn
77-
78-
7970
@jax_funcify.register(Assert)
8071
@jax_funcify.register(CheckParameterValue)
8172
def jax_funcify_Assert(op, **kwargs):

0 commit comments

Comments
 (0)