Skip to content

Commit 4b94241

Browse files
committed
Fix numpy vector_norm(keepdims=True)
1 parent 40603a9 commit 4b94241

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

array_api_compat/common/_linalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,22 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
110110
# on a single dimension.
111111
if axis is None:
112112
# Note: xp.linalg.norm() doesn't handle 0-D arrays
113-
x = x.ravel()
113+
_x = x.ravel()
114114
_axis = 0
115115
elif isinstance(axis, tuple):
116116
# Note: The axis argument supports any number of axes, whereas
117117
# xp.linalg.norm() only supports a single axis for vector norm.
118118
normalized_axis = normalize_axis_tuple(axis, x.ndim)
119119
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
120120
newshape = axis + rest
121-
x = xp.transpose(x, newshape).reshape(
121+
_x = xp.transpose(x, newshape).reshape(
122122
(xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest]))
123123
_axis = 0
124124
else:
125+
_x = x
125126
_axis = axis
126127

127-
res = xp.linalg.norm(x, axis=_axis, ord=ord)
128+
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
128129

129130
if keepdims:
130131
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks

0 commit comments

Comments
 (0)