Skip to content

Commit 4ac1e63

Browse files
committed
Simplify implementation of tile
Deprecate obscure ndim kwarg
1 parent c22e79e commit 4ac1e63

File tree

2 files changed

+205
-234
lines changed

2 files changed

+205
-234
lines changed

pytensor/tensor/basic.py

Lines changed: 113 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Sequence
1111
from functools import partial
1212
from numbers import Number
13-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, Union
1414
from typing import cast as type_cast
1515

1616
import numpy as np
@@ -33,7 +33,7 @@
3333
from pytensor.link.c.op import COp
3434
from pytensor.link.c.params_type import ParamsType
3535
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
36-
from pytensor.raise_op import CheckAndRaise, assert_op
36+
from pytensor.raise_op import CheckAndRaise
3737
from pytensor.scalar import int32
3838
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
3939
from pytensor.tensor import (
@@ -3084,87 +3084,132 @@ def flatten(x, ndim=1):
30843084
return x_reshaped
30853085

30863086

3087-
def tile(x, reps, ndim=None):
3087+
def tile(
3088+
A: "TensorLike", reps: Union[Sequence[Union[int, "TensorLike"]], "TensorLike"]
3089+
) -> TensorVariable:
30883090
"""
3089-
Tile input array `x` according to `reps`.
3091+
Tile input tensor `A` according to `reps`.
30903092
30913093
See the docstring of `numpy.tile` for details.
30923094
3093-
'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),
3094-
symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())
3095-
or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
3095+
If `reps` is a PyTensor vector, its length must be statically known.
3096+
You can use `specify_shape` to set the length.
3097+
3098+
Examples
3099+
--------
3100+
3101+
.. testcode::
3102+
3103+
import pytensor.tensor as pt
3104+
3105+
A = pt.matrix("A", dtype=int)
3106+
A_tiled = pt.tile(A, 2)
3107+
print(A_tiled.eval({A: [[1, 2], [3, 4]]}))
3108+
3109+
.. testoutput::
3110+
3111+
[[1 2 1 2]
3112+
[3 4 3 4]]
3113+
3114+
Reps can be a sequence of constants and/ or symbolic integer variables
3115+
3116+
.. testcode::
3117+
3118+
rep0 = pt.scalar("rep0", dtype=int)
3119+
A_tiled = pt.tile(A, (rep0, 1))
3120+
print(A_tiled.eval({A: [[1, 2], [3, 4]], rep0: 2}))
3121+
3122+
.. testoutput::
3123+
3124+
[[1 2]
3125+
[3 4]
3126+
[1 2]
3127+
[3 4]]
3128+
3129+
Reps can be a single integer vector, in which case its length must be statically known.
3130+
Either of the following is a valid way to specify the length:
3131+
3132+
.. testcode::
3133+
3134+
reps = pt.vector("reps", dtype=int, shape=(2,))
3135+
A_tiled = pt.tile(A, reps)
3136+
print(A_tiled.eval({A: [[1, 2], [3, 4]], reps: [1, 2]}))
3137+
3138+
.. testoutput::
3139+
3140+
[[1 2 1 2]
3141+
[3 4 3 4]]
3142+
3143+
.. testcode::
30963144
3097-
ndim is the number of the dimensions of the output, if it is provided, ndim
3098-
should be equal or larger than x.ndim and len(reps), otherwise, we will use
3099-
max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to
3100-
be provided.
3145+
reps = pt.vector("reps", dtype=int)
3146+
reps = pt.specify_shape(reps, (2,))
3147+
A_tiled = pt.tile(A, reps)
3148+
print(A_tiled.eval({A: [[1, 2], [3, 4]], reps: [2, 2]}))
3149+
3150+
.. testoutput::
3151+
3152+
[[1 2 1 2]
3153+
[3 4 3 4]
3154+
[1 2 1 2]
3155+
[3 4 3 4]]
31013156
31023157
"""
3103-
from pytensor.tensor.math import ge
31043158

3105-
_x = as_tensor_variable(x)
3106-
if ndim is not None and ndim < _x.ndim:
3107-
raise ValueError("ndim should be equal or larger than _x.ndim")
3159+
A = as_tensor_variable(A)
31083160

3109-
# If reps is a scalar, integer or vector, we convert it to a list.
3161+
# Convert symbolic reps to a tuple
31103162
if not isinstance(reps, list | tuple):
3111-
reps_astensor = as_tensor_variable(reps)
3112-
ndim_check = reps_astensor.ndim
3113-
if reps_astensor.dtype not in discrete_dtypes:
3114-
raise ValueError("elements of reps must be integer dtype")
3115-
3116-
# The scalar/integer case
3117-
if ndim_check == 0:
3118-
reps = [reps]
3119-
3120-
# The vector case
3121-
elif ndim_check == 1:
3122-
if ndim is None:
3163+
reps = as_tensor_variable(reps)
3164+
if reps.type.ndim == 0:
3165+
reps = (reps,)
3166+
elif reps.type.ndim == 1:
3167+
try:
3168+
reps = tuple(reps)
3169+
except ValueError:
31233170
raise ValueError(
3124-
"if reps is tensor.vector, you should specify the ndim"
3171+
"Length of repetitions tensor cannot be determined. Use specify_shape to set the length."
31253172
)
3126-
else:
3127-
offset = ndim - reps.shape[0]
3128-
3129-
# assert that reps.shape[0] does not exceed ndim
3130-
offset = assert_op(offset, ge(offset, 0))
3173+
else:
3174+
raise ValueError(
3175+
f"Repetitions tensor must be a scalar or a vector, got ndim={reps.type.ndim}"
3176+
)
31313177

3132-
# if reps.ndim is less than _x.ndim, we pad the reps with
3133-
# "1" so that reps will have the same ndim as _x.
3134-
reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)]
3135-
reps = reps_
3178+
reps = [as_tensor_variable(rep) for rep in reps]
3179+
if not all(
3180+
rep.type.ndim == 0 and rep.type.dtype in discrete_dtypes for rep in reps
3181+
):
3182+
raise ValueError(
3183+
f"All reps entries shoud be scalar integers, got {reps} of type {[rep.type for rep in reps]}"
3184+
)
31363185

3137-
# For others, raise an error
3138-
else:
3139-
raise ValueError("the dimension of reps should not exceed 1")
3140-
else:
3141-
if ndim is not None and len(reps) > ndim:
3142-
raise ValueError("len(reps) should be equal or less than ndim")
3143-
if not all(
3144-
isinstance(r, int)
3145-
or (isinstance(r, TensorVariable) and r.dtype in discrete_dtypes)
3146-
for r in reps
3147-
):
3148-
raise ValueError("elements of reps must be scalars of integer dtype")
3186+
len_reps = len(reps)
3187+
out_ndim = builtins.max(len_reps, A.type.ndim)
3188+
3189+
# Pad reps on the left (if needed)
3190+
if len_reps < out_ndim:
3191+
reps = (*((1,) * (out_ndim - len_reps)), *reps)
3192+
3193+
# Pad A's shape on the left (if needed)
3194+
elif A.type.ndim < out_ndim:
3195+
A = shape_padleft(A, out_ndim - A.type.ndim)
3196+
3197+
# Expand every other dim of A and expand n-reps via Alloc
3198+
# A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1])
3199+
A_shape = A.shape
3200+
interleaved_reps_shape = [
3201+
d for pair in zip(reps, A_shape, strict=True) for d in pair
3202+
]
3203+
every_other_axis = tuple(range(0, out_ndim * 2, 2))
3204+
A_replicated = alloc(
3205+
expand_dims(A, every_other_axis),
3206+
*interleaved_reps_shape,
3207+
)
31493208

3150-
# If reps.ndim is less than _x.ndim, we pad the reps with
3151-
# "1" so that reps will have the same ndim as _x
3152-
reps = list(reps)
3153-
if ndim is None:
3154-
ndim = builtins.max(len(reps), _x.ndim)
3155-
if len(reps) < ndim:
3156-
reps = [1] * (ndim - len(reps)) + reps
3157-
3158-
_shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)]
3159-
alloc_shape = reps + _shape
3160-
y = alloc(_x, *alloc_shape)
3161-
shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)
3162-
shuffle_ind = shuffle_ind.transpose().flatten()
3163-
y = y.dimshuffle(*shuffle_ind)
3164-
new_shapes = [sh * reps[i] for i, sh in enumerate(_shape)]
3165-
y = y.reshape(new_shapes)
3166-
3167-
return y
3209+
# Combine replicate and original dimensions via reshape
3210+
# A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1])
3211+
tiled_shape = tuple(rep * A_dim for rep, A_dim in zip(reps, A_shape, strict=True))
3212+
return A_replicated.reshape(tiled_shape)
31683213

31693214

31703215
class ARange(Op):

0 commit comments

Comments
 (0)