|
10 | 10 | from collections.abc import Sequence
|
11 | 11 | from functools import partial
|
12 | 12 | from numbers import Number
|
13 |
| -from typing import TYPE_CHECKING |
| 13 | +from typing import TYPE_CHECKING, Union |
14 | 14 | from typing import cast as type_cast
|
15 | 15 |
|
16 | 16 | import numpy as np
|
|
33 | 33 | from pytensor.link.c.op import COp
|
34 | 34 | from pytensor.link.c.params_type import ParamsType
|
35 | 35 | from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
|
36 |
| -from pytensor.raise_op import CheckAndRaise, assert_op |
| 36 | +from pytensor.raise_op import CheckAndRaise |
37 | 37 | from pytensor.scalar import int32
|
38 | 38 | from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
|
39 | 39 | from pytensor.tensor import (
|
@@ -3084,87 +3084,132 @@ def flatten(x, ndim=1):
|
3084 | 3084 | return x_reshaped
|
3085 | 3085 |
|
3086 | 3086 |
|
3087 |
| -def tile(x, reps, ndim=None): |
| 3087 | +def tile( |
| 3088 | + A: "TensorLike", reps: Union[Sequence[Union[int, "TensorLike"]], "TensorLike"] |
| 3089 | +) -> TensorVariable: |
3088 | 3090 | """
|
3089 |
| - Tile input array `x` according to `reps`. |
| 3091 | + Tile input tensor `A` according to `reps`. |
3090 | 3092 |
|
3091 | 3093 | See the docstring of `numpy.tile` for details.
|
3092 | 3094 |
|
3093 |
| - 'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]), |
3094 |
| - symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector()) |
3095 |
| - or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]). |
| 3095 | + If `reps` is a PyTensor vector, its length must be statically known. |
| 3096 | + You can use `specify_shape` to set the length. |
| 3097 | +
|
| 3098 | + Examples |
| 3099 | + -------- |
| 3100 | +
|
| 3101 | + .. testcode:: |
| 3102 | +
|
| 3103 | + import pytensor.tensor as pt |
| 3104 | +
|
| 3105 | + A = pt.matrix("A", dtype=int) |
| 3106 | + A_tiled = pt.tile(A, 2) |
| 3107 | + print(A_tiled.eval({A: [[1, 2], [3, 4]]})) |
| 3108 | +
|
| 3109 | + .. testoutput:: |
| 3110 | +
|
| 3111 | + [[1 2 1 2] |
| 3112 | + [3 4 3 4]] |
| 3113 | +
|
| 3114 | + Reps can be a sequence of constants and/ or symbolic integer variables |
| 3115 | +
|
| 3116 | + .. testcode:: |
| 3117 | +
|
| 3118 | + rep0 = pt.scalar("rep0", dtype=int) |
| 3119 | + A_tiled = pt.tile(A, (rep0, 1)) |
| 3120 | + print(A_tiled.eval({A: [[1, 2], [3, 4]], rep0: 2})) |
| 3121 | +
|
| 3122 | + .. testoutput:: |
| 3123 | +
|
| 3124 | + [[1 2] |
| 3125 | + [3 4] |
| 3126 | + [1 2] |
| 3127 | + [3 4]] |
| 3128 | +
|
| 3129 | + Reps can be a single integer vector, in which case its length must be statically known. |
| 3130 | + Either of the following is a valid way to specify the length: |
| 3131 | +
|
| 3132 | + .. testcode:: |
| 3133 | +
|
| 3134 | + reps = pt.vector("reps", dtype=int, shape=(2,)) |
| 3135 | + A_tiled = pt.tile(A, reps) |
| 3136 | + print(A_tiled.eval({A: [[1, 2], [3, 4]], reps: [1, 2]})) |
| 3137 | +
|
| 3138 | + .. testoutput:: |
| 3139 | +
|
| 3140 | + [[1 2 1 2] |
| 3141 | + [3 4 3 4]] |
| 3142 | +
|
| 3143 | + .. testcode:: |
3096 | 3144 |
|
3097 |
| - ndim is the number of the dimensions of the output, if it is provided, ndim |
3098 |
| - should be equal or larger than x.ndim and len(reps), otherwise, we will use |
3099 |
| - max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to |
3100 |
| - be provided. |
| 3145 | + reps = pt.vector("reps", dtype=int) |
| 3146 | + reps = pt.specify_shape(reps, (2,)) |
| 3147 | + A_tiled = pt.tile(A, reps) |
| 3148 | + print(A_tiled.eval({A: [[1, 2], [3, 4]], reps: [2, 2]})) |
| 3149 | +
|
| 3150 | + .. testoutput:: |
| 3151 | +
|
| 3152 | + [[1 2 1 2] |
| 3153 | + [3 4 3 4] |
| 3154 | + [1 2 1 2] |
| 3155 | + [3 4 3 4]] |
3101 | 3156 |
|
3102 | 3157 | """
|
3103 |
| - from pytensor.tensor.math import ge |
3104 | 3158 |
|
3105 |
| - _x = as_tensor_variable(x) |
3106 |
| - if ndim is not None and ndim < _x.ndim: |
3107 |
| - raise ValueError("ndim should be equal or larger than _x.ndim") |
| 3159 | + A = as_tensor_variable(A) |
3108 | 3160 |
|
3109 |
| - # If reps is a scalar, integer or vector, we convert it to a list. |
| 3161 | + # Convert symbolic reps to a tuple |
3110 | 3162 | if not isinstance(reps, list | tuple):
|
3111 |
| - reps_astensor = as_tensor_variable(reps) |
3112 |
| - ndim_check = reps_astensor.ndim |
3113 |
| - if reps_astensor.dtype not in discrete_dtypes: |
3114 |
| - raise ValueError("elements of reps must be integer dtype") |
3115 |
| - |
3116 |
| - # The scalar/integer case |
3117 |
| - if ndim_check == 0: |
3118 |
| - reps = [reps] |
3119 |
| - |
3120 |
| - # The vector case |
3121 |
| - elif ndim_check == 1: |
3122 |
| - if ndim is None: |
| 3163 | + reps = as_tensor_variable(reps) |
| 3164 | + if reps.type.ndim == 0: |
| 3165 | + reps = (reps,) |
| 3166 | + elif reps.type.ndim == 1: |
| 3167 | + try: |
| 3168 | + reps = tuple(reps) |
| 3169 | + except ValueError: |
3123 | 3170 | raise ValueError(
|
3124 |
| - "if reps is tensor.vector, you should specify the ndim" |
| 3171 | + "Length of repetitions tensor cannot be determined. Use specify_shape to set the length." |
3125 | 3172 | )
|
3126 |
| - else: |
3127 |
| - offset = ndim - reps.shape[0] |
3128 |
| - |
3129 |
| - # assert that reps.shape[0] does not exceed ndim |
3130 |
| - offset = assert_op(offset, ge(offset, 0)) |
| 3173 | + else: |
| 3174 | + raise ValueError( |
| 3175 | + f"Repetitions tensor must be a scalar or a vector, got ndim={reps.type.ndim}" |
| 3176 | + ) |
3131 | 3177 |
|
3132 |
| - # if reps.ndim is less than _x.ndim, we pad the reps with |
3133 |
| - # "1" so that reps will have the same ndim as _x. |
3134 |
| - reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)] |
3135 |
| - reps = reps_ |
| 3178 | + reps = [as_tensor_variable(rep) for rep in reps] |
| 3179 | + if not all( |
| 3180 | + rep.type.ndim == 0 and rep.type.dtype in discrete_dtypes for rep in reps |
| 3181 | + ): |
| 3182 | + raise ValueError( |
| 3183 | + f"All reps entries shoud be scalar integers, got {reps} of type {[rep.type for rep in reps]}" |
| 3184 | + ) |
3136 | 3185 |
|
3137 |
| - # For others, raise an error |
3138 |
| - else: |
3139 |
| - raise ValueError("the dimension of reps should not exceed 1") |
3140 |
| - else: |
3141 |
| - if ndim is not None and len(reps) > ndim: |
3142 |
| - raise ValueError("len(reps) should be equal or less than ndim") |
3143 |
| - if not all( |
3144 |
| - isinstance(r, int) |
3145 |
| - or (isinstance(r, TensorVariable) and r.dtype in discrete_dtypes) |
3146 |
| - for r in reps |
3147 |
| - ): |
3148 |
| - raise ValueError("elements of reps must be scalars of integer dtype") |
| 3186 | + len_reps = len(reps) |
| 3187 | + out_ndim = builtins.max(len_reps, A.type.ndim) |
| 3188 | + |
| 3189 | + # Pad reps on the left (if needed) |
| 3190 | + if len_reps < out_ndim: |
| 3191 | + reps = (*((1,) * (out_ndim - len_reps)), *reps) |
| 3192 | + |
| 3193 | + # Pad A's shape on the left (if needed) |
| 3194 | + elif A.type.ndim < out_ndim: |
| 3195 | + A = shape_padleft(A, out_ndim - A.type.ndim) |
| 3196 | + |
| 3197 | + # Expand every other dim of A and expand n-reps via Alloc |
| 3198 | + # A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1]) |
| 3199 | + A_shape = A.shape |
| 3200 | + interleaved_reps_shape = [ |
| 3201 | + d for pair in zip(reps, A_shape, strict=True) for d in pair |
| 3202 | + ] |
| 3203 | + every_other_axis = tuple(range(0, out_ndim * 2, 2)) |
| 3204 | + A_replicated = alloc( |
| 3205 | + expand_dims(A, every_other_axis), |
| 3206 | + *interleaved_reps_shape, |
| 3207 | + ) |
3149 | 3208 |
|
3150 |
| - # If reps.ndim is less than _x.ndim, we pad the reps with |
3151 |
| - # "1" so that reps will have the same ndim as _x |
3152 |
| - reps = list(reps) |
3153 |
| - if ndim is None: |
3154 |
| - ndim = builtins.max(len(reps), _x.ndim) |
3155 |
| - if len(reps) < ndim: |
3156 |
| - reps = [1] * (ndim - len(reps)) + reps |
3157 |
| - |
3158 |
| - _shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)] |
3159 |
| - alloc_shape = reps + _shape |
3160 |
| - y = alloc(_x, *alloc_shape) |
3161 |
| - shuffle_ind = np.arange(ndim * 2).reshape(2, ndim) |
3162 |
| - shuffle_ind = shuffle_ind.transpose().flatten() |
3163 |
| - y = y.dimshuffle(*shuffle_ind) |
3164 |
| - new_shapes = [sh * reps[i] for i, sh in enumerate(_shape)] |
3165 |
| - y = y.reshape(new_shapes) |
3166 |
| - |
3167 |
| - return y |
| 3209 | + # Combine replicate and original dimensions via reshape |
| 3210 | + # A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1]) |
| 3211 | + tiled_shape = tuple(rep * A_dim for rep, A_dim in zip(reps, A_shape, strict=True)) |
| 3212 | + return A_replicated.reshape(tiled_shape) |
3168 | 3213 |
|
3169 | 3214 |
|
3170 | 3215 | class ARange(Op):
|
|
0 commit comments