Skip to content

Commit 1206acb

Browse files
committed
Simplify makeKeepdDims
1 parent 31bf682 commit 1206acb

File tree

1 file changed

+1
-24
lines changed

1 file changed

+1
-24
lines changed

pytensor/tensor/math.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -297,32 +297,9 @@ def makeKeepDims(x, y, axis):
297297
298298
"""
299299
x = as_tensor_variable(x)
300-
y = as_tensor_variable(y)
301-
302300
if axis is None:
303301
axis = list(range(x.type.ndim))
304-
elif isinstance(axis, int | np.integer):
305-
axis = [axis]
306-
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
307-
axis = [int(axis)]
308-
else:
309-
axis = [int(a) for a in axis]
310-
newaxis = []
311-
for a in axis:
312-
if not isinstance(a, int):
313-
raise ValueError("keepdims option can be used only with constant axis")
314-
if a < 0:
315-
a += x.type.ndim
316-
newaxis.append(a)
317-
i = 0
318-
new_dims = []
319-
for j, _ in enumerate(x.type.broadcastable):
320-
if j in newaxis:
321-
new_dims.append("x")
322-
else:
323-
new_dims.append(i)
324-
i += 1
325-
return DimShuffle(y.type.broadcastable, new_dims)(y)
302+
return expand_dims(y, axis)
326303

327304

328305
def check_and_normalize_axes(x, axis):

0 commit comments

Comments
 (0)