Skip to content

Commit 6f2ddaf

Browse files
committed
Use at's return value
1 parent d09dde1 commit 6f2ddaf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,10 @@ def one_hot(
420420
out = xp.zeros((x.size, num_classes), dtype=dtype)
421421
x_flattened = xp.reshape(x, (-1,))
422422
if is_numpy_namespace(xp):
423-
at(out)[xp.arange(x_size), x_flattened].set(1)
423+
out = at(out)[xp.arange(x_size), x_flattened].set(1)
424424
else:
425425
for i in range(x_size):
426-
at(out)[i, int(x_flattened[i])].set(1)
426+
out = at(out)[i, int(x_flattened[i])].set(1)
427427
if x.ndim != 1:
428428
out = xp.reshape(out, (*x.shape, num_classes))
429429
if axis != -1:

0 commit comments

Comments
 (0)