Skip to content

Commit 7355cbf

Browse files
committed
replace pytensorf identity with pytensor identity
1 parent 7378c72 commit 7355cbf

File tree

2 files changed

+1
-32
lines changed

2 files changed

+1
-32
lines changed

pymc/pytensorf.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import pytensor.tensor as pt
2727
import scipy.sparse as sps
2828

29-
from pytensor import scalar
3029
from pytensor.compile import Function, Mode, get_mode
3130
from pytensor.gradient import grad
3231
from pytensor.graph import Type, rewrite_graph
@@ -41,6 +40,7 @@
4140
from pytensor.graph.fg import FunctionGraph
4241
from pytensor.graph.op import Op
4342
from pytensor.scalar.basic import Cast
43+
from pytensor.scalar.basic import identity as scalar_identity
4444
from pytensor.scan.op import Scan
4545
from pytensor.tensor.basic import _as_tensor_variable
4646
from pytensor.tensor.elemwise import Elemwise
@@ -381,28 +381,6 @@ def hessian_diag(f, vars=None):
381381
return empty_gradient
382382

383383

384-
class IdentityOp(scalar.UnaryScalarOp):
385-
@staticmethod
386-
def st_impl(x):
387-
return x
388-
389-
def impl(self, x):
390-
return x
391-
392-
def grad(self, inp, grads):
393-
return grads
394-
395-
def c_code(self, node, name, inp, out, sub):
396-
return f"{out[0]} = {inp[0]};"
397-
398-
def __eq__(self, other):
399-
return isinstance(self, type(other))
400-
401-
def __hash__(self):
402-
return hash(type(self))
403-
404-
405-
scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity")
406384
identity = Elemwise(scalar_identity, name="identity")
407385

408386

pymc/sampling/jax.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from pymc.distributions.multivariate import PosDefMatrix
4747
from pymc.initial_point import StartDict
4848
from pymc.logprob.utils import CheckParameterValue
49-
from pymc.pytensorf import IdentityOp
5049
from pymc.sampling.mcmc import _init_jitter
5150
from pymc.stats.convergence import log_warnings, run_convergence_checks
5251
from pymc.util import (
@@ -70,14 +69,6 @@
7069
)
7170

7271

73-
@jax_funcify.register(IdentityOp)
74-
def jax_funcify_Identity(op, **kwargs):
75-
def identity_fn(value):
76-
return value
77-
78-
return identity_fn
79-
80-
8172
@jax_funcify.register(Assert)
8273
@jax_funcify.register(CheckParameterValue)
8374
def jax_funcify_Assert(op, **kwargs):

0 commit comments

Comments
 (0)