@@ -326,6 +326,28 @@ def test_logsumexp_keepdims():
326
326
assert s .shape == (3 , 1 , 1 , 6 , 1 )
327
327
328
328
329
+ def test_logsumexp_keepdims_zero_size ():
330
+ get_queue_or_skip ()
331
+ n = 10
332
+ a = dpt .ones ((n , 0 , n ))
333
+
334
+ s1 = dpt .logsumexp (a , keepdims = True )
335
+ assert s1 .shape == (1 , 1 , 1 )
336
+
337
+ s2 = dpt .logsumexp (a , axis = (0 , 1 ), keepdims = True )
338
+ assert s2 .shape == (1 , 1 , n )
339
+
340
+ s3 = dpt .logsumexp (a , axis = (1 , 2 ), keepdims = True )
341
+ assert s3 .shape == (n , 1 , 1 )
342
+
343
+ s4 = dpt .logsumexp (a , axis = (0 , 2 ), keepdims = True )
344
+ assert s4 .shape == (1 , 0 , 1 )
345
+
346
+ a0 = a [0 ]
347
+ s5 = dpt .logsumexp (a0 , keepdims = True )
348
+ assert s5 .shape == (1 , 1 )
349
+
350
+
329
351
def test_logsumexp_scalar ():
330
352
get_queue_or_skip ()
331
353
@@ -337,6 +359,29 @@ def test_logsumexp_scalar():
337
359
assert s .shape == ()
338
360
339
361
362
+ def test_logsumexp_complex ():
363
+ get_queue_or_skip ()
364
+
365
+ x = dpt .zeros (1 , dtype = "c8" )
366
+ with pytest .raises (TypeError ):
367
+ dpt .logsumexp (x )
368
+
369
+
370
+ def test_logsumexp_int_axis ():
371
+ get_queue_or_skip ()
372
+
373
+ x = dpt .zeros ((8 , 10 ), dtype = "f4" )
374
+ res = dpt .logsumexp (x , axis = 0 )
375
+ assert res .ndim == 1
376
+ assert res .shape [0 ] == 10
377
+
378
+
379
+ def test_logsumexp_invalid_arr ():
380
+ x = dict ()
381
+ with pytest .raises (TypeError ):
382
+ dpt .logsumexp (x )
383
+
384
+
340
385
@pytest .mark .parametrize ("arg_dtype" , _no_complex_dtypes [1 :])
341
386
def test_hypot_arg_dtype_default_output_dtype_matrix (arg_dtype ):
342
387
q = get_queue_or_skip ()
@@ -376,3 +421,11 @@ def test_hypot_arg_out_dtype_matrix(arg_dtype, out_dtype):
376
421
377
422
assert isinstance (r , dpt .usm_ndarray )
378
423
assert r .dtype == dpt .dtype (out_dtype )
424
+
425
+
426
+ def test_hypot_complex ():
427
+ get_queue_or_skip ()
428
+
429
+ x = dpt .zeros (1 , dtype = "c8" )
430
+ with pytest .raises (TypeError ):
431
+ dpt .reduce_hypot (x )
0 commit comments