Skip to content

Commit 3d04679

Browse files
committed
replace pytensorf identity with pytensor identity
1 parent 8c98a1f commit 3d04679

File tree

2 files changed

+1
-32
lines changed

2 files changed

+1
-32
lines changed

pymc/pytensorf.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import pytensor.tensor as pt
2222
import scipy.sparse as sps
2323

24-
from pytensor import scalar
2524
from pytensor.compile import Function, Mode, get_mode
2625
from pytensor.compile.builders import OpFromGraph
2726
from pytensor.gradient import grad
@@ -38,6 +37,7 @@
3837
from pytensor.graph.fg import FunctionGraph
3938
from pytensor.graph.op import Op
4039
from pytensor.scalar.basic import Cast
40+
from pytensor.scalar.basic import identity as scalar_identity
4141
from pytensor.scan.op import Scan
4242
from pytensor.tensor.basic import _as_tensor_variable
4343
from pytensor.tensor.elemwise import Elemwise
@@ -378,28 +378,6 @@ def hessian_diag(f, vars=None):
378378
return empty_gradient
379379

380380

381-
class IdentityOp(scalar.UnaryScalarOp):
382-
@staticmethod
383-
def st_impl(x):
384-
return x
385-
386-
def impl(self, x):
387-
return x
388-
389-
def grad(self, inp, grads):
390-
return grads
391-
392-
def c_code(self, node, name, inp, out, sub):
393-
return f"{out[0]} = {inp[0]};"
394-
395-
def __eq__(self, other):
396-
return isinstance(self, type(other))
397-
398-
def __hash__(self):
399-
return hash(type(self))
400-
401-
402-
scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity")
403381
identity = Elemwise(scalar_identity, name="identity")
404382

405383

pymc/sampling/jax.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from pymc.distributions.multivariate import PosDefMatrix
4747
from pymc.initial_point import StartDict
4848
from pymc.logprob.utils import CheckParameterValue
49-
from pymc.pytensorf import IdentityOp
5049
from pymc.sampling.mcmc import _init_jitter
5150
from pymc.stats.convergence import log_warnings, run_convergence_checks
5251
from pymc.util import (
@@ -70,14 +69,6 @@
7069
)
7170

7271

73-
@jax_funcify.register(IdentityOp)
74-
def jax_funcify_Identity(op, **kwargs):
75-
def identity_fn(value):
76-
return value
77-
78-
return identity_fn
79-
80-
8172
@jax_funcify.register(Assert)
8273
@jax_funcify.register(CheckParameterValue)
8374
def jax_funcify_Assert(op, **kwargs):

0 commit comments

Comments
 (0)