Skip to content

Commit 07d96b8

Browse files
committed
pyright errors
1 parent 20cc969 commit 07d96b8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
465465
assert y.shape == (*x.shape, num_classes)
466466
for *i_list, j in ndindex(*shape, num_classes):
467467
i = tuple(i_list)
468-
assert float(y[*i, j]) == (int(x[i]) == j)
468+
assert float(y[(*i, j)]) == (int(x[i]) == j)
469469

470470
def test_basic(self, xp: ModuleType):
471471
actual = one_hot(xp.asarray([0, 1, 2]), 3)
@@ -516,7 +516,7 @@ def test_axis(self, xp: ModuleType):
516516

517517
def test_non_integer(self, xp: ModuleType):
518518
with pytest.raises(TypeError):
519-
one_hot(xp.asarray([1.0]), 3)
519+
_ = one_hot(xp.asarray([1.0]), 3)
520520

521521

522522
@pytest.mark.skip_xp_backend(

0 commit comments

Comments
 (0)