Skip to content

Commit 8319f5f

Browse files
committed
Formatting
1 parent 1dc8bd4 commit 8319f5f

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,12 +404,14 @@ def one_hot(
404404
if is_jax_namespace(xp):
405405
assert is_jax_array(x)
406406
from jax.nn import one_hot
407+
407408
if dtype is None:
408409
dtype = xp.float_
409410
return one_hot(x, num_classes, dtype=dtype, axis=axis)
410411
if is_torch_namespace(xp):
411412
assert is_torch_array(x)
412413
from torch.nn.functional import one_hot
414+
413415
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
414416
try:
415417
out = one_hot(x, num_classes)

tests/test_funcs.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -469,15 +469,11 @@ def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
469469

470470
def test_basic(self, xp: ModuleType):
471471
actual = one_hot(xp.asarray([0, 1, 2]), 3)
472-
expected = xp.asarray([[1., 0., 0.],
473-
[0., 1., 0.],
474-
[0., 0., 1.]])
472+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
475473
xp_assert_equal(actual, expected)
476474

477475
actual = one_hot(xp.asarray([1, 2, 0]), 3)
478-
expected = xp.asarray([[0., 1., 0.],
479-
[0., 0., 1.],
480-
[1., 0., 0.]])
476+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
481477
xp_assert_equal(actual, expected)
482478

483479
@pytest.mark.skip_xp_backend(
@@ -489,32 +485,29 @@ def test_out_of_bound(self, xp: ModuleType):
489485
actual = one_hot(xp.asarray([-1, 3]), 3)
490486
except IndexError:
491487
return
492-
expected = xp.asarray([[0., 0., 0.],
493-
[0., 0., 0.]])
488+
expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
494489
xp_assert_equal(actual, expected)
495490

496-
@pytest.mark.parametrize("int_dtype", ['int8', 'int16', 'int32', 'int64', 'uint8',
497-
'uint16', 'uint32', 'uint64'])
491+
@pytest.mark.parametrize(
492+
"int_dtype",
493+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"],
494+
)
498495
def test_int_types(self, xp: ModuleType, int_dtype: str):
499496
dtype = getattr(xp, int_dtype)
500497
x = xp.asarray([0, 1, 2], dtype=dtype)
501498
actual = one_hot(x, 3)
502-
expected = xp.asarray([[1., 0., 0.],
503-
[0., 1., 0.],
504-
[0., 0., 1.]])
499+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
505500
xp_assert_equal(actual, expected)
506501

507502
def test_custom_dtype(self, xp: ModuleType):
508503
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
509-
expected = xp.asarray([[True, False, False],
510-
[False, True, False],
511-
[False, False, True]])
504+
expected = xp.asarray(
505+
[[True, False, False], [False, True, False], [False, False, True]]
506+
)
512507
xp_assert_equal(actual, expected)
513508

514509
def test_axis(self, xp: ModuleType):
515-
expected = xp.asarray([[0., 1., 0.],
516-
[0., 0., 1.],
517-
[1., 0., 0.]]).T
510+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T
518511
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
519512
xp_assert_equal(actual, expected)
520513

0 commit comments

Comments
 (0)