Skip to content

Commit d1c5ae2

Browse files
committed
Constants are not inputs
1 parent 58840ba commit d1c5ae2

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

pytensor/link/jax/linker.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from numpy.random import Generator, RandomState
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
6-
from pytensor.graph.basic import Constant
76
from pytensor.link.basic import JITLinker
87

98

@@ -72,12 +71,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
7271
def jit_compile(self, fn):
7372
import jax
7473

75-
# I suppose we can consider `Constant`s to be "static" according to
76-
# JAX.
77-
static_argnums = [
78-
n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
79-
]
80-
return jax.jit(fn, static_argnums=static_argnums)
74+
return jax.jit(fn)
8175

8276
def create_thunk_inputs(self, storage_map):
8377
from pytensor.link.jax.dispatch import jax_typify

0 commit comments

Comments
 (0)