Skip to content

Commit eea5935

Browse files
committed
Move reshape
1 parent 07d96b8 commit eea5935

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/array_api_extra/_delegation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,6 @@ def one_hot(
190190
else:
191191
out = _funcs.one_hot(x, num_classes, x_size=x_size, dtype=dtype, xp=xp)
192192

193-
if x.ndim != 1:
194-
out = xp.reshape(out, (*x.shape, num_classes))
195193
if axis != -1:
196194
out = xp.moveaxis(out, -1, axis)
197195
return out

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,11 @@ def one_hot(
393393
out = xp.zeros((x.size, num_classes), dtype=dtype)
394394
x_flattened = xp.reshape(x, (-1,))
395395
if is_numpy_namespace(xp):
396-
return at(out)[xp.arange(x_size), x_flattened].set(1)
396+
out = at(out)[xp.arange(x_size), x_flattened].set(1)
397397
for i in range(x_size):
398398
out = at(out)[i, int(x_flattened[i])].set(1)
399+
if x.ndim != 1:
400+
out = xp.reshape(out, (*x.shape, num_classes))
399401
return out
400402

401403

0 commit comments

Comments
 (0)