Skip to content

Commit 20cc969

Browse files
committed
Pylint error
1 parent c278c32 commit 20cc969

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/array_api_extra/_delegation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,16 @@ def one_hot(
174174
# Delegate where possible.
175175
if is_jax_namespace(xp):
176176
assert is_jax_array(x)
177-
from jax.nn import one_hot
177+
from jax.nn import one_hot as jax_one_hot
178178

179-
return one_hot(x, num_classes, dtype=dtype, axis=axis)
179+
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
180180
if is_torch_namespace(xp):
181181
assert is_torch_array(x)
182-
from torch.nn.functional import one_hot
182+
from torch.nn.functional import one_hot as torch_one_hot
183183

184184
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
185185
try:
186-
out = one_hot(x, num_classes)
186+
out = torch_one_hot(x, num_classes)
187187
except RuntimeError as e:
188188
raise IndexError from e
189189
out = xp.astype(out, dtype)

0 commit comments

Comments
 (0)