@@ -236,157 +236,21 @@ def test_Dimshuffle_non_contiguous():
236
236
assert func (np .zeros (3 ), np .array ([1 ])).ndim == 0
237
237
238
238
239
- @pytest .mark .parametrize (
240
- "careduce_fn, axis, v" ,
241
- [
242
- (
243
- lambda x , axis = None , dtype = None , acc_dtype = None : Sum (
244
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
245
- )(x ),
246
- 0 ,
247
- set_test_value (pt .vector (), np .arange (3 , dtype = config .floatX )),
248
- ),
249
- (
250
- lambda x , axis = None , dtype = None , acc_dtype = None : All (axis )(x ),
251
- 0 ,
252
- set_test_value (pt .vector (), np .arange (3 , dtype = config .floatX )),
253
- ),
254
- (
255
- lambda x , axis = None , dtype = None , acc_dtype = None : Any (axis )(x ),
256
- 0 ,
257
- set_test_value (pt .vector (), np .arange (3 , dtype = config .floatX )),
258
- ),
259
- (
260
- lambda x , axis = None , dtype = None , acc_dtype = None : Mean (axis )(x ),
261
- 0 ,
262
- set_test_value (pt .vector (), np .arange (3 , dtype = config .floatX )),
263
- ),
264
- (
265
- lambda x , axis = None , dtype = None , acc_dtype = None : Mean (axis )(x ),
266
- 0 ,
267
- set_test_value (
268
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
269
- ),
270
- ),
271
- (
272
- lambda x , axis = None , dtype = None , acc_dtype = None : Sum (
273
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
274
- )(x ),
275
- 0 ,
276
- set_test_value (
277
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
278
- ),
279
- ),
280
- (
281
- lambda x , axis = None , dtype = None , acc_dtype = None : Sum (
282
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
283
- )(x ),
284
- (0 , 1 ),
285
- set_test_value (
286
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
287
- ),
288
- ),
289
- (
290
- lambda x , axis = None , dtype = None , acc_dtype = None : Sum (
291
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
292
- )(x ),
293
- (1 , 0 ),
294
- set_test_value (
295
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
296
- ),
297
- ),
298
- (
299
- lambda x , axis = None , dtype = None , acc_dtype = None : Sum (
300
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
301
- )(x ),
302
- None ,
303
- set_test_value (
304
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
305
- ),
306
- ),
307
- (
308
- lambda x , axis = None , dtype = None , acc_dtype = None : Sum (
309
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
310
- )(x ),
311
- 1 ,
312
- set_test_value (
313
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
314
- ),
315
- ),
316
- (
317
- lambda x , axis = None , dtype = None , acc_dtype = None : Prod (
318
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
319
- )(x ),
320
- 0 ,
321
- set_test_value (pt .vector (), np .arange (3 , dtype = config .floatX )),
322
- ),
323
- (
324
- lambda x , axis = None , dtype = None , acc_dtype = None : ProdWithoutZeros (
325
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
326
- )(x ),
327
- 0 ,
328
- set_test_value (pt .vector (), np .arange (3 , dtype = config .floatX )),
329
- ),
330
- (
331
- lambda x , axis = None , dtype = None , acc_dtype = None : Prod (
332
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
333
- )(x ),
334
- 0 ,
335
- set_test_value (
336
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
337
- ),
338
- ),
339
- (
340
- lambda x , axis = None , dtype = None , acc_dtype = None : Prod (
341
- axis = axis , dtype = dtype , acc_dtype = acc_dtype
342
- )(x ),
343
- 1 ,
344
- set_test_value (
345
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
346
- ),
347
- ),
348
- (
349
- lambda x , axis = None , dtype = None , acc_dtype = None : Max (axis )(x ),
350
- None ,
351
- set_test_value (
352
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
353
- ),
354
- ),
355
- (
356
- lambda x , axis = None , dtype = None , acc_dtype = None : Max (axis )(x ),
357
- None ,
358
- set_test_value (
359
- pt .lmatrix (), np .arange (3 * 2 , dtype = np .int64 ).reshape ((3 , 2 ))
360
- ),
361
- ),
362
- (
363
- lambda x , axis = None , dtype = None , acc_dtype = None : Min (axis )(x ),
364
- None ,
365
- set_test_value (
366
- pt .matrix (), np .arange (3 * 2 , dtype = config .floatX ).reshape ((3 , 2 ))
367
- ),
368
- ),
369
- (
370
- lambda x , axis = None , dtype = None , acc_dtype = None : Min (axis )(x ),
371
- None ,
372
- set_test_value (
373
- pt .lmatrix (), np .arange (3 * 2 , dtype = np .int64 ).reshape ((3 , 2 ))
374
- ),
375
- ),
376
- ],
377
- )
378
- def test_CAReduce (careduce_fn , axis , v ):
379
- g = careduce_fn (v , axis = axis )
380
- g_fg = FunctionGraph (outputs = [g ])
381
-
382
- compare_numba_and_py (
383
- g_fg ,
384
- [
385
- i .tag .test_value
386
- for i in g_fg .inputs
387
- if not isinstance (i , SharedVariable | Constant )
388
- ],
389
- )
239
+ @pytest .mark .parametrize ("axis" , [0 , None , (0 , 1 )])
240
+ @pytest .mark .parametrize ("op" , [Sum , Prod , ProdWithoutZeros , All , Any , Mean , Max , Min ])
241
+ def test_CAReduce (op , axis ):
242
+ if op == Mean and isinstance (axis , tuple ) and len (axis ) > 1 :
243
+ pytest .xfail ("Mean does not support multiple partial axes" )
244
+
245
+ bool_reduction = op in (All , Any )
246
+ x = pt .tensor3 ("x" , dtype = bool if bool_reduction else config .floatX )
247
+ g = op (axis = axis )(x )
248
+ g_fg = FunctionGraph ([x ], [g ])
249
+
250
+ x_test = np .random .normal (size = (2 , 3 , 4 )).astype (config .floatX )
251
+ if bool_reduction :
252
+ x_test = x_test > 0
253
+ compare_numba_and_py (g_fg , [x_test ])
390
254
391
255
392
256
def test_scalar_Elemwise_Clip ():
0 commit comments