Skip to content

Commit 1dc8bd4

Browse files
committed
Fix PyTorch errors and test int types
1 parent 6f2ddaf commit 1dc8bd4

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,11 @@ def one_hot(
410410
if is_torch_namespace(xp):
411411
assert is_torch_array(x)
412412
from torch.nn.functional import one_hot
413-
out = one_hot(x, num_classes)
413+
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
414+
try:
415+
out = one_hot(x, num_classes)
416+
except RuntimeError as e:
417+
raise IndexError from e
414418
if dtype is None:
415419
dtype = xp.float
416420
out = xp.astype(out, dtype)

tests/test_funcs.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,18 +480,32 @@ def test_basic(self, xp: ModuleType):
480480
[1., 0., 0.]])
481481
xp_assert_equal(actual, expected)
482482

483+
@pytest.mark.skip_xp_backend(
484+
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
485+
)
483486
def test_out_of_bound(self, xp: ModuleType):
484487
# Undefined behavior. Either return zero, or raise.
485488
try:
486489
actual = one_hot(xp.asarray([-1, 3]), 3)
487-
except (IndexError, RuntimeError):
490+
except IndexError:
488491
return
489492
expected = xp.asarray([[0., 0., 0.],
490493
[0., 0., 0.]])
491494
xp_assert_equal(actual, expected)
492495

496+
@pytest.mark.parametrize("int_dtype", ['int8', 'int16', 'int32', 'int64', 'uint8',
497+
'uint16', 'uint32', 'uint64'])
498+
def test_int_types(self, xp: ModuleType, int_dtype: str):
499+
dtype = getattr(xp, int_dtype)
500+
x = xp.asarray([0, 1, 2], dtype=dtype)
501+
actual = one_hot(x, 3)
502+
expected = xp.asarray([[1., 0., 0.],
503+
[0., 1., 0.],
504+
[0., 0., 1.]])
505+
xp_assert_equal(actual, expected)
506+
493507
def test_custom_dtype(self, xp: ModuleType):
494-
actual = one_hot(xp.asarray([0, 1, 2]), 3, dtype=xp.bool)
508+
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
495509
expected = xp.asarray([[True, False, False],
496510
[False, True, False],
497511
[False, False, True]])

0 commit comments

Comments
 (0)