|
50 | 50 | tensor,
|
51 | 51 | uint_dtypes,
|
52 | 52 | )
|
53 |
| -from pytensor.tensor.utils import as_list, normalize_reduce_axis |
| 53 | +from pytensor.tensor.utils import normalize_reduce_axis |
54 | 54 | from pytensor.tensor.variable import (
|
55 | 55 | TensorVariable,
|
56 | 56 | _tensor_py_operators,
|
@@ -3208,133 +3208,6 @@ def dense_dot(a, b):
|
3208 | 3208 | return _dot(a, b)
|
3209 | 3209 |
|
3210 | 3210 |
|
3211 |
| -def _tensordot_as_dot(a, b, axes, dot, batched): |
3212 |
| - """ |
3213 |
| - Reduces a tensor dot product to a matrix or vector dot product. Based |
3214 |
| - on code from Tijmen Tieleman's gnumpy |
3215 |
| - (http://www.cs.toronto.edu/~tijmen/gnumpy.html). |
3216 |
| -
|
3217 |
| - Please see the documentation of tensordot for the meaning of the a, b |
3218 |
| - and axes arguments. |
3219 |
| -
|
3220 |
| - :param dot: a function that accepts two symbolic variables and computes |
3221 |
| - the appropriate dot product (e.g. dot, batched_dot) |
3222 |
| - :type dot: function |
3223 |
| -
|
3224 |
| - :param batched: whether to treat the first axis of a and b as a batch |
3225 |
| - axis. If so, this axis will be preserved in the output, |
3226 |
| - allowing this function to be used also for batched |
3227 |
| - tensor dot products. |
3228 |
| - :type batched: boolean |
3229 |
| -
|
3230 |
| - :returns: a tensor with shape equal to the concatenation of a's shape |
3231 |
| - (less any dimensions that were summed over) and b's shape |
3232 |
| - (less the first dimension and any dimensions that were summed |
3233 |
| - over). |
3234 |
| - :rtype: symbolic tensor |
3235 |
| - """ |
3236 |
| - a, b = as_tensor_variable(a), as_tensor_variable(b) |
3237 |
| - |
3238 |
| - if not np.isscalar(axes) and len(axes) != 2: |
3239 |
| - raise ValueError( |
3240 |
| - "Axes should be an integer or a " |
3241 |
| - f"list/tuple of len 2 ({axes} was provided)" |
3242 |
| - ) |
3243 |
| - |
3244 |
| - # if 'axes' is a number of axes to multiply and sum over (trailing axes |
3245 |
| - # of a, leading axes of b), we can just reshape and use dot. |
3246 |
| - elif np.isscalar(axes): |
3247 |
| - axes = int(axes) |
3248 |
| - |
3249 |
| - for operand_name, operand in (("a", a), ("b", b)): |
3250 |
| - if axes > operand.ndim: |
3251 |
| - raise ValueError( |
3252 |
| - f"axes can not be larger than the dimension of {operand_name} " |
3253 |
| - f"({operand_name}.ndim={operand.ndim}, axes={axes})" |
3254 |
| - ) |
3255 |
| - if batched and axes == operand.ndim: |
3256 |
| - raise ValueError( |
3257 |
| - "axes to sum over must not include the batch axis " |
3258 |
| - f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})" |
3259 |
| - ) |
3260 |
| - |
3261 |
| - batch_axes = 1 if batched else 0 |
3262 |
| - a_outaxes = slice(0, a.ndim - axes) |
3263 |
| - b_outaxes = slice(batch_axes + axes, b.ndim) |
3264 |
| - outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]]) |
3265 |
| - outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes] |
3266 |
| - outndim = len(outbcast) |
3267 |
| - |
3268 |
| - a_shape = [1] * 2 |
3269 |
| - b_shape = [1] * 2 |
3270 |
| - |
3271 |
| - # compute total size of summed axes |
3272 |
| - for i in range(0, axes): |
3273 |
| - a_shape[1] *= a.shape[-(i + 1)] |
3274 |
| - b_shape[0] *= b.shape[batch_axes + i] |
3275 |
| - # compute total size of other axes |
3276 |
| - for i in range(0, a.ndim - axes - batch_axes): |
3277 |
| - a_shape[0] *= a.shape[batch_axes + i] |
3278 |
| - for i in range(0, b.ndim - axes - batch_axes): |
3279 |
| - b_shape[1] *= b.shape[-(i + 1)] |
3280 |
| - |
3281 |
| - if batched: |
3282 |
| - a_shape.insert(0, a.shape[0]) |
3283 |
| - b_shape.insert(0, b.shape[0]) |
3284 |
| - |
3285 |
| - a_reshaped = a.reshape(a_shape) |
3286 |
| - b_reshaped = b.reshape(b_shape) |
3287 |
| - |
3288 |
| - out_reshaped = dot(a_reshaped, b_reshaped) |
3289 |
| - out = out_reshaped.reshape(outshape, ndim=outndim) |
3290 |
| - # Make sure the broadcastable pattern of the result is correct, |
3291 |
| - # since some shape information can be lost in the reshapes. |
3292 |
| - if out.type.broadcastable != outbcast: |
3293 |
| - out = specify_broadcastable( |
3294 |
| - out, *(ax for (ax, b) in enumerate(outbcast) if b) |
3295 |
| - ) |
3296 |
| - return out |
3297 |
| - |
3298 |
| - # if 'axes' is a list, transpose a and b such that the summed axes of a |
3299 |
| - # are last and the summed axes of b are first. |
3300 |
| - else: |
3301 |
| - axes = [as_list(axes_) for axes_ in axes] |
3302 |
| - |
3303 |
| - if len(axes[0]) != len(axes[1]): |
3304 |
| - raise ValueError("Axes elements must have the same length.") |
3305 |
| - |
3306 |
| - for i, (operand_name, operand) in enumerate((("a", a), ("b", b))): |
3307 |
| - if len(axes[i]) > operand.ndim: |
3308 |
| - raise ValueError( |
3309 |
| - f"axes[{i}] should be array_like with length less than " |
3310 |
| - f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})." |
3311 |
| - ) |
3312 |
| - if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim: |
3313 |
| - raise ValueError( |
3314 |
| - f"axes[{i}] contains dimensions greater than or equal " |
3315 |
| - f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})." |
3316 |
| - ) |
3317 |
| - if batched and 0 in axes[i]: |
3318 |
| - raise ValueError( |
3319 |
| - "axes to sum over must not contain the batch axis " |
3320 |
| - f"(axes[{i}]={axes[i]})" |
3321 |
| - ) |
3322 |
| - |
3323 |
| - batch_axes = [0] if batched else [] |
3324 |
| - other_axes = [ |
3325 |
| - [x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes] |
3326 |
| - for i, operand in enumerate((a, b)) |
3327 |
| - ] |
3328 |
| - |
3329 |
| - a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0]) |
3330 |
| - b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1]) |
3331 |
| - |
3332 |
| - # now that a and b are in the right order, recur with integer axes |
3333 |
| - return _tensordot_as_dot( |
3334 |
| - a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched |
3335 |
| - ) |
3336 |
| - |
3337 |
| - |
3338 | 3211 | def tensordot(
|
3339 | 3212 | a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2
|
3340 | 3213 | ) -> TensorVariable:
|
|
0 commit comments