Skip to content

Commit 4bd02b4

Browse files
Add test for cumulative_logsumexp for geometric series summation, testing against closed form
1 parent 6520b8b commit 4bd02b4

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

dpctl/tests/test_tensor_accumulation.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,34 @@ def test_logcumsumexp_basic():
379379
x = dpt.ones(10, dtype=dt)
380380
r = dpt.cumulative_logsumexp(x)
381381

382-
x_np = dpt.asnumpy(x)
383-
expected = np.logaddexp.accumulate(x_np, dtype=dt)
382+
expected = 1 + np.log(np.arange(1, 11, dtype=dt))
384383

385-
tol = 32 * dpt.finfo(dt).resolution
384+
tol = 4 * dpt.finfo(dt).resolution
386385
assert np.allclose(dpt.asnumpy(r), expected, atol=tol, rtol=tol)
386+
387+
388+
def geometric_series_closed_form(n, dtype=None, device=None):
389+
"""Closed form for cumulative_logsumexp(dpt.arange(-n, 0))
390+
391+
:math:`r[k] == -n + k + log(1 - exp(-k-1)) - log(1-exp(-1))`
392+
"""
393+
x = dpt.arange(-n, 0, dtype=dtype, device=device)
394+
y = dpt.arange(-1, -n - 1, step=-1, dtype=dtype, device=device)
395+
y = dpt.exp(y, out=y)
396+
y = dpt.negative(y, out=y)
397+
y = dpt.log1p(y, out=y)
398+
y -= y[0]
399+
return x + y
400+
401+
402+
@pytest.mark.parametrize("fpdt", rfp_types)
403+
def test_cumulative_logsumexp_closed_form(fpdt):
404+
q = get_queue_or_skip()
405+
skip_if_dtype_not_supported(fpdt, q)
406+
407+
n = 128
408+
r = dpt.cumulative_logsumexp(dpt.arange(-n, 0, dtype=fpdt, device=q))
409+
expected = geometric_series_closed_form(n, dtype=fpdt, device=q)
410+
411+
tol = 4 * dpt.finfo(fpdt).eps
412+
assert dpt.allclose(r, expected, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)