Skip to content

Commit 20ee0de

Browse files
committed
Use np.logaddexp.accumulate in hopes of better numerical accuracy of expected result for cumulative_logsumexp
1 parent d2fa2a2 commit 20ee0de

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

dpctl/tests/test_tensor_accumulation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from random import randrange
1818

19+
import numpy as np
1920
import pytest
2021
from helper import get_queue_or_skip, skip_if_dtype_not_supported
2122

@@ -377,7 +378,9 @@ def test_logcumsumexp_basic():
377378
dt = dpt.float32
378379
x = dpt.ones(100, dtype=dt)
379380
r = dpt.cumulative_logsumexp(x)
380-
expected = dpt.log(dpt.cumulative_sum(dpt.exp(x)))
381+
382+
x_np = dpt.asnumpy(x)
383+
expected = np.logaddexp.accumulate(x_np, dtype=dt)
381384

382385
tol = 32 * dpt.finfo(dt).resolution
383-
assert dpt.allclose(r, expected, atol=tol, rtol=tol)
386+
assert np.allclose(dpt.asnumpy(r), expected, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)