Skip to content

Commit 105e954

Browse files
Rename remaining instances of aes and aer
1 parent 99d5ec4 commit 105e954

File tree

20 files changed

+334
-334
lines changed

20 files changed

+334
-334
lines changed

pytensor/compile/profiling.py

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,65 +1485,65 @@ def print_tips(self, file):
14851485
file=file,
14861486
)
14871487

1488-
from pytensor import scalar as aes
1488+
from pytensor import scalar as ps
14891489
from pytensor.tensor.elemwise import Elemwise
14901490
from pytensor.tensor.math import Dot
14911491

14921492
scalar_op_amdlibm_no_speed_up = [
1493-
aes.LT,
1494-
aes.GT,
1495-
aes.LE,
1496-
aes.GE,
1497-
aes.EQ,
1498-
aes.NEQ,
1499-
aes.InRange,
1500-
aes.Switch,
1501-
aes.OR,
1502-
aes.XOR,
1503-
aes.AND,
1504-
aes.Invert,
1505-
aes.ScalarMaximum,
1506-
aes.ScalarMinimum,
1507-
aes.Add,
1508-
aes.Mul,
1509-
aes.Sub,
1510-
aes.TrueDiv,
1511-
aes.IntDiv,
1512-
aes.Clip,
1513-
aes.Second,
1514-
aes.Identity,
1515-
aes.Cast,
1516-
aes.Sign,
1517-
aes.Neg,
1518-
aes.Reciprocal,
1519-
aes.Sqr,
1493+
ps.LT,
1494+
ps.GT,
1495+
ps.LE,
1496+
ps.GE,
1497+
ps.EQ,
1498+
ps.NEQ,
1499+
ps.InRange,
1500+
ps.Switch,
1501+
ps.OR,
1502+
ps.XOR,
1503+
ps.AND,
1504+
ps.Invert,
1505+
ps.ScalarMaximum,
1506+
ps.ScalarMinimum,
1507+
ps.Add,
1508+
ps.Mul,
1509+
ps.Sub,
1510+
ps.TrueDiv,
1511+
ps.IntDiv,
1512+
ps.Clip,
1513+
ps.Second,
1514+
ps.Identity,
1515+
ps.Cast,
1516+
ps.Sign,
1517+
ps.Neg,
1518+
ps.Reciprocal,
1519+
ps.Sqr,
15201520
]
15211521
scalar_op_amdlibm_speed_up = [
1522-
aes.Mod,
1523-
aes.Pow,
1524-
aes.Ceil,
1525-
aes.Floor,
1526-
aes.RoundHalfToEven,
1527-
aes.RoundHalfAwayFromZero,
1528-
aes.Log,
1529-
aes.Log2,
1530-
aes.Log10,
1531-
aes.Log1p,
1532-
aes.Exp,
1533-
aes.Sqrt,
1534-
aes.Abs,
1535-
aes.Cos,
1536-
aes.Sin,
1537-
aes.Tan,
1538-
aes.Tanh,
1539-
aes.Cosh,
1540-
aes.Sinh,
1541-
aes.Sigmoid,
1542-
aes.Softplus,
1522+
ps.Mod,
1523+
ps.Pow,
1524+
ps.Ceil,
1525+
ps.Floor,
1526+
ps.RoundHalfToEven,
1527+
ps.RoundHalfAwayFromZero,
1528+
ps.Log,
1529+
ps.Log2,
1530+
ps.Log10,
1531+
ps.Log1p,
1532+
ps.Exp,
1533+
ps.Sqrt,
1534+
ps.Abs,
1535+
ps.Cos,
1536+
ps.Sin,
1537+
ps.Tan,
1538+
ps.Tanh,
1539+
ps.Cosh,
1540+
ps.Sinh,
1541+
ps.Sigmoid,
1542+
ps.Softplus,
15431543
]
15441544

15451545
def get_scalar_ops(s):
1546-
if isinstance(s, aes.Composite):
1546+
if isinstance(s, ps.Composite):
15471547
l = []
15481548
for node in s.fgraph.toposort():
15491549
l += get_scalar_ops(node.op)
@@ -1552,7 +1552,7 @@ def get_scalar_ops(s):
15521552
return [s]
15531553

15541554
def list_scalar_op(op):
1555-
if isinstance(op.scalar_op, aes.Composite):
1555+
if isinstance(op.scalar_op, ps.Composite):
15561556
return get_scalar_ops(op.scalar_op)
15571557
else:
15581558
return [op.scalar_op]
@@ -1579,7 +1579,7 @@ def exp_float32_op(op):
15791579
return False
15801580
else:
15811581
l = list_scalar_op(op)
1582-
return any(s_op.__class__ in [aes.Exp] for s_op in l)
1582+
return any(s_op.__class__ in [ps.Exp] for s_op in l)
15831583

15841584
printed_tip = False
15851585
# tip 1

pytensor/ifelse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any):
185185
input_f.type, HasDataType
186186
):
187187
# TODO: Be smarter about dtype casting.
188-
# up_dtype = aes.upcast(input_t.type.dtype, input_f.type.dtype)
188+
# up_dtype = ps.upcast(input_t.type.dtype, input_f.type.dtype)
189189

190190
if input_t.type.dtype != input_f.type.dtype:
191191
raise TypeError(

pytensor/link/jax/dispatch/random.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
_coerce_to_uint32_array,
88
)
99

10-
import pytensor.tensor.random.basic as aer
10+
import pytensor.tensor.random.basic as ptr
1111
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
1212
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
1313
from pytensor.tensor.shape import Shape, Shape_i
@@ -86,7 +86,7 @@ def jax_typify_Generator(rng, **kwargs):
8686
return state
8787

8888

89-
@jax_funcify.register(aer.RandomVariable)
89+
@jax_funcify.register(ptr.RandomVariable)
9090
def jax_funcify_RandomVariable(op, node, **kwargs):
9191
"""JAX implementation of random variables."""
9292
rv = node.outputs[1]
@@ -121,10 +121,10 @@ def jax_sample_fn(op):
121121
)
122122

123123

124-
@jax_sample_fn.register(aer.BetaRV)
125-
@jax_sample_fn.register(aer.DirichletRV)
126-
@jax_sample_fn.register(aer.PoissonRV)
127-
@jax_sample_fn.register(aer.MvNormalRV)
124+
@jax_sample_fn.register(ptr.BetaRV)
125+
@jax_sample_fn.register(ptr.DirichletRV)
126+
@jax_sample_fn.register(ptr.PoissonRV)
127+
@jax_sample_fn.register(ptr.MvNormalRV)
128128
def jax_sample_fn_generic(op):
129129
"""Generic JAX implementation of random variables."""
130130
name = op.name
@@ -140,12 +140,12 @@ def sample_fn(rng, size, dtype, *parameters):
140140
return sample_fn
141141

142142

143-
@jax_sample_fn.register(aer.CauchyRV)
144-
@jax_sample_fn.register(aer.GumbelRV)
145-
@jax_sample_fn.register(aer.LaplaceRV)
146-
@jax_sample_fn.register(aer.LogisticRV)
147-
@jax_sample_fn.register(aer.NormalRV)
148-
@jax_sample_fn.register(aer.StandardNormalRV)
143+
@jax_sample_fn.register(ptr.CauchyRV)
144+
@jax_sample_fn.register(ptr.GumbelRV)
145+
@jax_sample_fn.register(ptr.LaplaceRV)
146+
@jax_sample_fn.register(ptr.LogisticRV)
147+
@jax_sample_fn.register(ptr.NormalRV)
148+
@jax_sample_fn.register(ptr.StandardNormalRV)
149149
def jax_sample_fn_loc_scale(op):
150150
"""JAX implementation of random variables in the loc-scale families.
151151
@@ -168,8 +168,8 @@ def sample_fn(rng, size, dtype, *parameters):
168168
return sample_fn
169169

170170

171-
@jax_sample_fn.register(aer.BernoulliRV)
172-
@jax_sample_fn.register(aer.CategoricalRV)
171+
@jax_sample_fn.register(ptr.BernoulliRV)
172+
@jax_sample_fn.register(ptr.CategoricalRV)
173173
def jax_sample_fn_no_dtype(op):
174174
"""Generic JAX implementation of random variables."""
175175
name = op.name
@@ -185,9 +185,9 @@ def sample_fn(rng, size, dtype, *parameters):
185185
return sample_fn
186186

187187

188-
@jax_sample_fn.register(aer.RandIntRV)
189-
@jax_sample_fn.register(aer.IntegersRV)
190-
@jax_sample_fn.register(aer.UniformRV)
188+
@jax_sample_fn.register(ptr.RandIntRV)
189+
@jax_sample_fn.register(ptr.IntegersRV)
190+
@jax_sample_fn.register(ptr.UniformRV)
191191
def jax_sample_fn_uniform(op):
192192
"""JAX implementation of random variables with uniform density.
193193
@@ -197,7 +197,7 @@ def jax_sample_fn_uniform(op):
197197
"""
198198
name = op.name
199199
# IntegersRV is equivalent to RandintRV
200-
if isinstance(op, aer.IntegersRV):
200+
if isinstance(op, ptr.IntegersRV):
201201
name = "randint"
202202
jax_op = getattr(jax.random, name)
203203

@@ -214,8 +214,8 @@ def sample_fn(rng, size, dtype, *parameters):
214214
return sample_fn
215215

216216

217-
@jax_sample_fn.register(aer.ParetoRV)
218-
@jax_sample_fn.register(aer.GammaRV)
217+
@jax_sample_fn.register(ptr.ParetoRV)
218+
@jax_sample_fn.register(ptr.GammaRV)
219219
def jax_sample_fn_shape_scale(op):
220220
"""JAX implementation of random variables in the shape-scale family.
221221
@@ -236,7 +236,7 @@ def sample_fn(rng, size, dtype, shape, scale):
236236
return sample_fn
237237

238238

239-
@jax_sample_fn.register(aer.ExponentialRV)
239+
@jax_sample_fn.register(ptr.ExponentialRV)
240240
def jax_sample_fn_exponential(op):
241241
"""JAX implementation of `ExponentialRV`."""
242242

@@ -251,7 +251,7 @@ def sample_fn(rng, size, dtype, *parameters):
251251
return sample_fn
252252

253253

254-
@jax_sample_fn.register(aer.StudentTRV)
254+
@jax_sample_fn.register(ptr.StudentTRV)
255255
def jax_sample_fn_t(op):
256256
"""JAX implementation of `StudentTRV`."""
257257

@@ -270,7 +270,7 @@ def sample_fn(rng, size, dtype, *parameters):
270270
return sample_fn
271271

272272

273-
@jax_sample_fn.register(aer.ChoiceRV)
273+
@jax_sample_fn.register(ptr.ChoiceRV)
274274
def jax_funcify_choice(op):
275275
"""JAX implementation of `ChoiceRV`."""
276276

@@ -285,7 +285,7 @@ def sample_fn(rng, size, dtype, *parameters):
285285
return sample_fn
286286

287287

288-
@jax_sample_fn.register(aer.PermutationRV)
288+
@jax_sample_fn.register(ptr.PermutationRV)
289289
def jax_sample_fn_permutation(op):
290290
"""JAX implementation of `PermutationRV`."""
291291

@@ -300,7 +300,7 @@ def sample_fn(rng, size, dtype, *parameters):
300300
return sample_fn
301301

302302

303-
@jax_sample_fn.register(aer.BinomialRV)
303+
@jax_sample_fn.register(ptr.BinomialRV)
304304
def jax_sample_fn_binomial(op):
305305
if not numpyro_available:
306306
raise NotImplementedError(
@@ -323,7 +323,7 @@ def sample_fn(rng, size, dtype, n, p):
323323
return sample_fn
324324

325325

326-
@jax_sample_fn.register(aer.MultinomialRV)
326+
@jax_sample_fn.register(ptr.MultinomialRV)
327327
def jax_sample_fn_multinomial(op):
328328
if not numpyro_available:
329329
raise NotImplementedError(
@@ -346,7 +346,7 @@ def sample_fn(rng, size, dtype, n, p):
346346
return sample_fn
347347

348348

349-
@jax_sample_fn.register(aer.VonMisesRV)
349+
@jax_sample_fn.register(ptr.VonMisesRV)
350350
def jax_sample_fn_vonmises(op):
351351
if not numpyro_available:
352352
raise NotImplementedError(

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
2828
to be constants. The graph that you defined thus cannot be JIT-compiled
2929
by JAX. An example of a graph that can be compiled to JAX:
30-
>>> import pytensor.tensor basic
31-
>>> at.arange(1, 10, 2)
30+
>>> import pytensor.tensor as pt
31+
>>> pt.arange(1, 10, 2)
3232
"""
3333

3434

0 commit comments

Comments
 (0)