We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d09dde1 commit 6f2ddafCopy full SHA for 6f2ddaf
src/array_api_extra/_lib/_funcs.py
@@ -420,10 +420,10 @@ def one_hot(
420
out = xp.zeros((x.size, num_classes), dtype=dtype)
421
x_flattened = xp.reshape(x, (-1,))
422
if is_numpy_namespace(xp):
423
- at(out)[xp.arange(x_size), x_flattened].set(1)
+ out = at(out)[xp.arange(x_size), x_flattened].set(1)
424
else:
425
for i in range(x_size):
426
- at(out)[i, int(x_flattened[i])].set(1)
+ out = at(out)[i, int(x_flattened[i])].set(1)
427
if x.ndim != 1:
428
out = xp.reshape(out, (*x.shape, num_classes))
429
if axis != -1:
0 commit comments