File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -174,16 +174,16 @@ def one_hot(
174
174
# Delegate where possible.
175
175
if is_jax_namespace (xp ):
176
176
assert is_jax_array (x )
177
- from jax .nn import one_hot
177
+ from jax .nn import one_hot as jax_one_hot
178
178
179
- return one_hot (x , num_classes , dtype = dtype , axis = axis )
179
+ return jax_one_hot (x , num_classes , dtype = dtype , axis = axis )
180
180
if is_torch_namespace (xp ):
181
181
assert is_torch_array (x )
182
- from torch .nn .functional import one_hot
182
+ from torch .nn .functional import one_hot as torch_one_hot
183
183
184
184
x = xp .astype (x , xp .int64 ) # PyTorch only supports int64 here.
185
185
try :
186
- out = one_hot (x , num_classes )
186
+ out = torch_one_hot (x , num_classes )
187
187
except RuntimeError as e :
188
188
raise IndexError from e
189
189
out = xp .astype (out , dtype )
You can’t perform that action at this time.
0 commit comments