@@ -2081,63 +2081,65 @@ def local_pow_to_nested_squaring(fgraph, node):
2081
2081
Note: This sounds like the kind of thing any half-decent compiler can do by itself?
2082
2082
"""
2083
2083
2084
- if node .op == at_pow :
2085
- # the idea here is that we have pow(x, y)
2086
- odtype = node .outputs [0 ].dtype
2087
- xsym = node .inputs [0 ]
2088
- ysym = node .inputs [1 ]
2089
- y = get_constant (ysym )
2090
-
2091
- # the next line is needed to fix a strange case that I don't
2092
- # know how to make a separate test.
2093
- # That happen in the `test_log_erfc` test.
2094
- # y is a ndarray with dtype int8 and value 2,4 or 6. This make
2095
- # the abs(y) <= 512 fail!
2096
- # taking the value outside ndarray solve the problem.
2097
- # it could be that in that case, numpy make the comparison
2098
- # into the wrong type(do in int8 that overflow.)
2099
- if isinstance (y , np .ndarray ):
2100
- assert y .size == 1
2101
- try :
2102
- y = y [0 ]
2103
- except IndexError :
2104
- pass
2105
- if (y is not None ) and not broadcasted_by (xsym , ysym ):
2106
- rval = None
2107
- # 512 is too small for the cpu and too big for some gpu!
2108
- if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
2109
- pow2 = [xsym ]
2110
- pow2_scal = [aes .get_scalar_type (xsym .dtype )()]
2111
- y_to_do = abs (y )
2112
- for i in range (int (np .log2 (y_to_do ))):
2113
- pow2 .append (sqr (pow2 [i ]))
2114
- pow2_scal .append (aes .sqr (pow2_scal [i ]))
2115
- rval1 = None
2116
- rval1_scal = None
2117
- while y_to_do > 0 :
2118
- log_to_do = int (np .log2 (y_to_do ))
2119
- if rval1 :
2120
- rval1 *= pow2 [log_to_do ]
2121
- rval1_scal *= pow2_scal [log_to_do ]
2122
- else :
2123
- rval1 = pow2 [log_to_do ]
2124
- rval1_scal = pow2_scal [log_to_do ]
2125
- y_to_do -= 2 ** log_to_do
2126
-
2127
- if abs (y ) > 2 :
2128
- # We fuse all the pow together here to make
2129
- # compilation faster
2130
- rval1 = Elemwise (
2131
- aes .Composite ([pow2_scal [0 ]], [rval1_scal ])
2132
- ).make_node (xsym )
2133
- if y < 0 :
2134
- rval = [reciprocal (rval1 )]
2084
+ # the idea here is that we have pow(x, y)
2085
+ odtype = node .outputs [0 ].dtype
2086
+ xsym = node .inputs [0 ]
2087
+ ysym = node .inputs [1 ]
2088
+ y = get_constant (ysym )
2089
+
2090
+ # the next line is needed to fix a strange case that I don't
2091
+ # know how to make a separate test.
2092
+ # That happen in the `test_log_erfc` test.
2093
+ # y is a ndarray with dtype int8 and value 2,4 or 6. This make
2094
+ # the abs(y) <= 512 fail!
2095
+ # taking the value outside ndarray solve the problem.
2096
+ # it could be that in that case, numpy make the comparison
2097
+ # into the wrong type(do in int8 that overflow.)
2098
+ if isinstance (y , np .ndarray ):
2099
+ assert y .size == 1
2100
+ try :
2101
+ y = y [0 ]
2102
+ except IndexError :
2103
+ pass
2104
+ if (y is not None ) and not broadcasted_by (xsym , ysym ):
2105
+ rval = None
2106
+ # 512 is too small for the cpu and too big for some gpu!
2107
+ if abs (y ) == int (abs (y )) and abs (y ) <= 512 :
2108
+ pow2 = [xsym ]
2109
+ pow2_scal = [aes .get_scalar_type (xsym .dtype )()]
2110
+ y_to_do = abs (y )
2111
+ for i in range (int (np .log2 (y_to_do ))):
2112
+ pow2 .append (sqr (pow2 [i ]))
2113
+ pow2_scal .append (aes .sqr (pow2_scal [i ]))
2114
+ rval1 = None
2115
+ rval1_scal = None
2116
+ while y_to_do > 0 :
2117
+ log_to_do = int (np .log2 (y_to_do ))
2118
+ if rval1 :
2119
+ rval1 *= pow2 [log_to_do ]
2120
+ rval1_scal *= pow2_scal [log_to_do ]
2135
2121
else :
2136
- rval = [rval1 ]
2137
- if rval :
2138
- rval [0 ] = cast (rval [0 ], odtype )
2139
- assert rval [0 ].type == node .outputs [0 ].type , (rval , node .outputs )
2140
- return rval
2122
+ rval1 = pow2 [log_to_do ]
2123
+ rval1_scal = pow2_scal [log_to_do ]
2124
+ y_to_do -= 2 ** log_to_do
2125
+
2126
+ if abs (y ) > 2 :
2127
+ # We fuse all the pow together here to make
2128
+ # compilation faster
2129
+ rval1 = Elemwise (aes .Composite ([pow2_scal [0 ]], [rval1_scal ])).make_node (
2130
+ xsym
2131
+ )
2132
+ if y < 0 :
2133
+ rval = [reciprocal (rval1 )]
2134
+ else :
2135
+ rval = [rval1 ]
2136
+ if rval :
2137
+ rval [0 ] = cast (rval [0 ], odtype )
2138
+ # TODO: We can add a specify_broadcastable and/or unbroadcast to make the
2139
+ # output types compatible. Or work on #408 and let TensorType.filter_variable do it.
2140
+ if rval [0 ].type .broadcastable != node .outputs [0 ].type .broadcastable :
2141
+ return None
2142
+ return rval
2141
2143
2142
2144
2143
2145
@register_specialize
0 commit comments