1
1
from copy import copy
2
+ from textwrap import dedent
2
3
3
4
import numpy as np
4
5
from numpy .core .numeric import normalize_axis_tuple
@@ -1145,6 +1146,7 @@ def c_support_code_apply(self, node, nodename):
1145
1146
return support_code
1146
1147
1147
1148
def c_code_cache_version_apply (self , node ):
1149
+ return None
1148
1150
version = [15 ] # the version corresponding to the c code in this Op
1149
1151
1150
1152
# now we insert versions for the ops on which we depend...
@@ -1466,116 +1468,114 @@ def infer_shape(self, fgraph, node, shapes):
1466
1468
return ((),)
1467
1469
return ([ishape [i ] for i in range (node .inputs [0 ].type .ndim ) if i not in axis ],)
1468
1470
1469
- def _c_all (self , node , name , inames , onames , sub ):
1470
- input = node .inputs [0 ]
1471
- output = node .outputs [0 ]
1471
+ def _c_all (self , node , name , input_names , output_names , sub ):
1472
+ [inp ] = node .inputs
1473
+ [out ] = node .outputs
1474
+ ndim = inp .type .ndim
1472
1475
1473
- iname = inames [ 0 ]
1474
- oname = onames [ 0 ]
1476
+ [ inp_name ] = input_names
1477
+ [ out_name ] = output_names
1475
1478
1476
- idtype = input .type .dtype_specs ()[1 ]
1477
- odtype = output .type .dtype_specs ()[1 ]
1479
+ inp_dtype = inp .type .dtype_specs ()[1 ]
1480
+ out_dtype = out .type .dtype_specs ()[1 ]
1478
1481
1479
1482
acc_dtype = getattr (self , "acc_dtype" , None )
1480
1483
1481
1484
if acc_dtype is not None :
1482
1485
if acc_dtype == "float16" :
1483
1486
raise MethodNotDefined ("no c_code for float16" )
1484
1487
acc_type = TensorType (shape = node .outputs [0 ].type .shape , dtype = acc_dtype )
1485
- adtype = acc_type .dtype_specs ()[1 ]
1488
+ acc_dtype = acc_type .dtype_specs ()[1 ]
1486
1489
else :
1487
- adtype = odtype
1490
+ acc_dtype = out_dtype
1488
1491
1489
1492
axis = self .axis
1490
1493
if axis is None :
1491
- axis = list (range (input .type .ndim ))
1494
+ axis = list (range (inp .type .ndim ))
1492
1495
1493
1496
if len (axis ) == 0 :
1497
+ # This is just an Elemwise cast operation
1494
1498
# The acc_dtype is never a downcast compared to the input dtype
1495
1499
# So we just need a cast to the output dtype.
1496
- var = pytensor .tensor .basic .cast (input , node .outputs [0 ].dtype )
1497
- if var is input :
1498
- var = Elemwise (scalar_identity )(input )
1500
+ var = pytensor .tensor .basic .cast (inp , node .outputs [0 ].dtype )
1501
+ if var is inp :
1502
+ var = Elemwise (scalar_identity )(inp )
1499
1503
assert var .dtype == node .outputs [0 ].dtype
1500
- return var .owner .op ._c_all (var .owner , name , inames , onames , sub )
1501
-
1502
- order1 = [i for i in range (input .type .ndim ) if i not in axis ]
1503
- order = order1 + list (axis )
1504
+ return var .owner .op ._c_all (var .owner , name , input_names , output_names , sub )
1504
1505
1505
- nnested = len (order1 )
1506
+ inp_dims = list (range (ndim ))
1507
+ non_reduced_dims = [i for i in inp_dims if i not in axis ]
1508
+ counter = iter (range (ndim ))
1509
+ acc_dims = ["x" if i in axis else next (counter ) for i in range (ndim )]
1506
1510
1507
- sub = dict (sub )
1508
- for i , (input , iname ) in enumerate (zip (node .inputs , inames )):
1509
- sub [f"lv{ i } " ] = iname
1511
+ sub = sub .copy ()
1512
+ sub ["lv0" ] = inp_name
1513
+ sub ["lv1" ] = out_name
1514
+ sub ["olv" ] = out_name
1510
1515
1511
- decl = ""
1512
- if adtype != odtype :
1516
+ if acc_dtype != out_dtype :
1513
1517
# Create an accumulator variable different from the output
1514
- aname = "acc"
1515
- decl = acc_type .c_declare (aname , sub )
1516
- decl += acc_type .c_init (aname , sub )
1518
+ acc_name = "acc"
1519
+ setup = acc_type .c_declare (acc_name , sub ) + acc_type .c_init (acc_name , sub )
1517
1520
else :
1518
1521
# the output is the accumulator variable
1519
- aname = oname
1520
-
1521
- decl += cgen .make_declare ([order ], [idtype ], sub )
1522
- checks = cgen .make_checks ([order ], [idtype ], sub )
1523
-
1524
- alloc = ""
1525
- i += 1
1526
- sub [f"lv{ i } " ] = oname
1527
- sub ["olv" ] = oname
1528
-
1529
- # Allocate output buffer
1530
- alloc += cgen .make_declare (
1531
- [list (range (nnested )) + ["x" ] * len (axis )], [odtype ], dict (sub , lv0 = oname )
1532
- )
1533
- alloc += cgen .make_alloc ([order1 ], odtype , sub )
1534
- alloc += cgen .make_checks (
1535
- [list (range (nnested )) + ["x" ] * len (axis )], [odtype ], dict (sub , lv0 = oname )
1522
+ acc_name = out_name
1523
+ setup = ""
1524
+
1525
+ # Define strides of input array
1526
+ setup += cgen .make_declare (
1527
+ [inp_dims ], [inp_dtype ], sub , compute_stride_jump = False
1528
+ ) + cgen .make_checks ([inp_dims ], [inp_dtype ], sub , compute_stride_jump = False )
1529
+
1530
+ # Define strides of output array and allocate it
1531
+ out_sub = sub | {"lv0" : out_name }
1532
+ alloc = (
1533
+ cgen .make_declare (
1534
+ [acc_dims ], [out_dtype ], out_sub , compute_stride_jump = False
1535
+ )
1536
+ + cgen .make_alloc ([non_reduced_dims ], out_dtype , sub )
1537
+ + cgen .make_checks (
1538
+ [acc_dims ], [out_dtype ], out_sub , compute_stride_jump = False
1539
+ )
1536
1540
)
1537
1541
1538
- if adtype != odtype :
1539
- # Allocate accumulation buffer
1540
- sub [f"lv { i } " ] = aname
1541
- sub ["olv" ] = aname
1542
+ if acc_dtype != out_dtype :
1543
+ # Define strides of accumulation buffer and allocate it
1544
+ sub ["lv1 " ] = acc_name
1545
+ sub ["olv" ] = acc_name
1542
1546
1543
- alloc += cgen .make_declare (
1544
- [list (range (nnested )) + ["x" ] * len (axis )],
1545
- [adtype ],
1546
- dict (sub , lv0 = aname ),
1547
- )
1548
- alloc += cgen .make_alloc ([order1 ], adtype , sub )
1549
- alloc += cgen .make_checks (
1550
- [list (range (nnested )) + ["x" ] * len (axis )],
1551
- [adtype ],
1552
- dict (sub , lv0 = aname ),
1547
+ acc_sub = sub | {"lv0" : acc_name }
1548
+ alloc += (
1549
+ cgen .make_declare (
1550
+ [acc_dims ], [acc_dtype ], acc_sub , compute_stride_jump = False
1551
+ )
1552
+ + cgen .make_alloc ([non_reduced_dims ], acc_dtype , sub )
1553
+ + cgen .make_checks (
1554
+ [acc_dims ], [acc_dtype ], acc_sub , compute_stride_jump = False
1555
+ )
1553
1556
)
1554
1557
1555
1558
identity = self .scalar_op .identity
1556
-
1557
1559
if np .isposinf (identity ):
1558
- if input .type .dtype in ("float32" , "float64" ):
1560
+ if inp .type .dtype in ("float32" , "float64" ):
1559
1561
identity = "__builtin_inf()"
1560
- elif input .type .dtype .startswith ("uint" ) or input .type .dtype == "bool" :
1562
+ elif inp .type .dtype .startswith ("uint" ) or inp .type .dtype == "bool" :
1561
1563
identity = "1"
1562
1564
else :
1563
- identity = "NPY_MAX_" + str (input .type .dtype ).upper ()
1565
+ identity = "NPY_MAX_" + str (inp .type .dtype ).upper ()
1564
1566
elif np .isneginf (identity ):
1565
- if input .type .dtype in ("float32" , "float64" ):
1567
+ if inp .type .dtype in ("float32" , "float64" ):
1566
1568
identity = "-__builtin_inf()"
1567
- elif input .type .dtype .startswith ("uint" ) or input .type .dtype == "bool" :
1569
+ elif inp .type .dtype .startswith ("uint" ) or inp .type .dtype == "bool" :
1568
1570
identity = "0"
1569
1571
else :
1570
- identity = "NPY_MIN_" + str (input .type .dtype ).upper ()
1572
+ identity = "NPY_MIN_" + str (inp .type .dtype ).upper ()
1571
1573
elif identity is None :
1572
1574
raise TypeError (f"The { self .scalar_op } does not define an identity." )
1573
1575
1574
- task0_decl = f"{ adtype } & { aname } _i = *{ aname } _iter;\n { aname } _i = { identity } ;"
1575
-
1576
- task1_decl = f"{ idtype } & { inames [0 ]} _i = *{ inames [0 ]} _iter;\n "
1576
+ initial_value = f"{ acc_name } _i = { identity } ;"
1577
1577
1578
- task1_code = self .scalar_op .c_code (
1578
+ inner_task = self .scalar_op .c_code (
1579
1579
Apply (
1580
1580
self .scalar_op ,
1581
1581
[
@@ -1588,44 +1588,44 @@ def _c_all(self, node, name, inames, onames, sub):
1588
1588
],
1589
1589
),
1590
1590
None ,
1591
- [f"{ aname } _i" , f"{ inames [ 0 ] } _i" ],
1592
- [f"{ aname } _i" ],
1591
+ [f"{ acc_name } _i" , f"{ inp_name } _i" ],
1592
+ [f"{ acc_name } _i" ],
1593
1593
sub ,
1594
1594
)
1595
- code1 = f"""
1596
- {{
1597
- { task1_decl }
1598
- { task1_code }
1599
- }}
1600
- """
1601
1595
1602
- if node . inputs [ 0 ]. type .ndim :
1603
- if len ( axis ) == 1 :
1604
- all_code = [( "" , "" )] * nnested + [( task0_decl , code1 ), "" ]
1605
- else :
1606
- all_code = (
1607
- [( "" , "" )] * nnested
1608
- + [( task0_decl , "" )]
1609
- + [( "" , "" )] * ( len ( axis ) - 2 )
1610
- + [( "" , code1 ), "" ]
1611
- )
1596
+ if out . type .ndim == 0 :
1597
+ # Simple case where everything is reduced, no need for loop ordering
1598
+ loop = cgen . make_complete_loop_careduce (
1599
+ inp_var = inp_name ,
1600
+ acc_var = acc_name ,
1601
+ inp_dtype = inp_dtype ,
1602
+ acc_dtype = acc_dtype ,
1603
+ initial_value = initial_value ,
1604
+ inner_task = inner_task ,
1605
+ )
1612
1606
else :
1613
- all_code = [task0_decl + code1 ]
1614
- loop = cgen .make_loop_careduce (
1615
- [order , list (range (nnested )) + ["x" ] * len (axis )],
1616
- [idtype , adtype ],
1617
- all_code ,
1618
- sub ,
1619
- )
1607
+ loop = cgen .make_reordered_loop_careduce (
1608
+ inp_var = inp_name ,
1609
+ acc_var = acc_name ,
1610
+ inp_dtype = inp_dtype ,
1611
+ acc_dtype = acc_dtype ,
1612
+ inp_ndim = ndim ,
1613
+ reduction_axes = axis ,
1614
+ initial_value = initial_value ,
1615
+ inner_task = inner_task ,
1616
+ )
1620
1617
1621
- end = ""
1622
- if adtype != odtype :
1623
- end = f"""
1624
- PyArray_CopyInto({ oname } , { aname } );
1625
- """
1626
- end += acc_type .c_cleanup (aname , sub )
1618
+ if acc_dtype != out_dtype :
1619
+ cast = dedent (
1620
+ f"""
1621
+ PyArray_CopyInto({ out_name } , { acc_name } );
1622
+ { acc_type .c_cleanup (acc_name , sub )}
1623
+ """
1624
+ )
1625
+ else :
1626
+ cast = ""
1627
1627
1628
- return decl , checks , alloc , loop , end
1628
+ return setup , alloc , loop , cast
1629
1629
1630
1630
def c_code (self , node , name , inames , onames , sub ):
1631
1631
code = "\n " .join (self ._c_all (node , name , inames , onames , sub ))
@@ -1637,7 +1637,7 @@ def c_headers(self, **kwargs):
1637
1637
1638
1638
def c_code_cache_version_apply (self , node ):
1639
1639
# the version corresponding to the c code in this Op
1640
- version = [9 ]
1640
+ version = [10 ]
1641
1641
1642
1642
# now we insert versions for the ops on which we depend...
1643
1643
scalar_node = Apply (
0 commit comments