@@ -379,8 +379,34 @@ def test_logcumsumexp_basic():
379
379
x = dpt .ones (10 , dtype = dt )
380
380
r = dpt .cumulative_logsumexp (x )
381
381
382
- x_np = dpt .asnumpy (x )
383
- expected = np .logaddexp .accumulate (x_np , dtype = dt )
382
+ expected = 1 + np .log (np .arange (1 , 11 , dtype = dt ))
384
383
385
- tol = 32 * dpt .finfo (dt ).resolution
384
+ tol = 4 * dpt .finfo (dt ).resolution
386
385
assert np .allclose (dpt .asnumpy (r ), expected , atol = tol , rtol = tol )
386
+
387
+
388
+ def geometric_series_closed_form (n , dtype = None , device = None ):
389
+ """Closed form for cumulative_logsumexp(dpt.arange(-n, 0))
390
+
391
+ :math:`r[k] == -n + k + log(1 - exp(-k-1)) - log(1-exp(-1))`
392
+ """
393
+ x = dpt .arange (- n , 0 , dtype = dtype , device = device )
394
+ y = dpt .arange (- 1 , - n - 1 , step = - 1 , dtype = dtype , device = device )
395
+ y = dpt .exp (y , out = y )
396
+ y = dpt .negative (y , out = y )
397
+ y = dpt .log1p (y , out = y )
398
+ y -= y [0 ]
399
+ return x + y
400
+
401
+
402
+ @pytest .mark .parametrize ("fpdt" , rfp_types )
403
+ def test_cumulative_logsumexp_closed_form (fpdt ):
404
+ q = get_queue_or_skip ()
405
+ skip_if_dtype_not_supported (fpdt , q )
406
+
407
+ n = 128
408
+ r = dpt .cumulative_logsumexp (dpt .arange (- n , 0 , dtype = fpdt , device = q ))
409
+ expected = geometric_series_closed_form (n , dtype = fpdt , device = q )
410
+
411
+ tol = 4 * dpt .finfo (fpdt ).eps
412
+ assert dpt .allclose (r , expected , atol = tol , rtol = tol )
0 commit comments