Skip to content

Commit 8625b30

Browse files
committed
Additional backend tests for nan_to_num
1 parent 61bbb3c commit 8625b30

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

test/test_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def test_empty_backend():
264264
nx.detach(M)
265265
with pytest.raises(NotImplementedError):
266266
nx.matmul(M, M.T)
267+
with pytest.raises(NotImplementedError):
268+
nx.nan_to_num(M)
267269

268270

269271
def test_func_backends(nx):
@@ -667,6 +669,11 @@ def test_func_backends(nx):
667669
lst_b.append(nx.to_numpy(A))
668670
lst_name.append("matmul broadcast")
669671

672+
vec = nx.from_numpy(np.array([1, np.nan, -1]))
673+
vec = nx.nan_to_num(vec, nan=0)
674+
lst_b.append(nx.to_numpy(vec))
675+
lst_name.append("nan_to_num")
676+
670677
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
671678
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
672679
assert not nx.array_equal(

0 commit comments

Comments
 (0)