Skip to content

Commit ff1a3a9

Browse files
rloufbrandonwillard
authored andcommitted
Move Softmax, LogSoftmax, SoftmaxGrad to new aesara.tensor.special
1 parent e2202bc commit ff1a3a9

File tree

18 files changed

+1221
-1143
lines changed

18 files changed

+1221
-1143
lines changed

aesara/link/jax/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy
55
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
6-
from aesara.tensor.math import LogSoftmax, Softmax, SoftmaxGrad
6+
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
77

88

99
@jax_funcify.register(Elemwise)

aesara/link/numba/dispatch/elemwise.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,8 @@
3838
from aesara.scalar.basic import add as add_as
3939
from aesara.scalar.basic import scalar_maximum
4040
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
41-
from aesara.tensor.math import (
42-
LogSoftmax,
43-
MaxAndArgmax,
44-
MulWithoutZeros,
45-
Softmax,
46-
SoftmaxGrad,
47-
)
41+
from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros
42+
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4843

4944

5045
@singledispatch

aesara/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
113113

114114
# isort: off
115115
from aesara.tensor import linalg # noqa
116+
from aesara.tensor import special
116117

117118
# For backward compatibility
118119
from aesara.tensor import nlinalg # noqa

0 commit comments

Comments
 (0)