33
33
ctypedef unsigned char UChar
34
34
35
35
cimport util
36
- from util cimport is_array, _checknull, _checknan
36
+ from util cimport is_array, _checknull, _checknan, get_nat
37
+
38
+ cdef int64_t iNaT = get_nat()
37
39
38
40
# import datetime C API
39
41
PyDateTime_IMPORT
@@ -1159,16 +1161,15 @@ def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1159
1161
Only aggregates on axis=0
1160
1162
'''
1161
1163
cdef:
1162
- Py_ssize_t i, j, N, K, lab
1163
- %(dest_type2)s val
1164
- ndarray[%(dest_type2)s, ndim=2] nobs = np.zeros_like(out)
1165
-
1164
+ Py_ssize_t i, j, lab
1165
+ Py_ssize_t N = values.shape[0], K = values.shape[1]
1166
+ %(c_type)s val
1167
+ ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1168
+ dtype=np.int64)
1166
1169
1167
- if not len(values) = = len(labels):
1170
+ if len(values) ! = len(labels):
1168
1171
raise AssertionError("len(index) != len(labels)")
1169
1172
1170
- N, K = (<object> values).shape
1171
-
1172
1173
for i in range(N):
1173
1174
lab = labels[i]
1174
1175
if lab < 0:
@@ -1179,7 +1180,7 @@ def group_count_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1179
1180
val = values[i, j]
1180
1181
1181
1182
# not nan
1182
- nobs[lab, j] += val == val
1183
+ nobs[lab, j] += val == val and val != iNaT
1183
1184
1184
1185
for i in range(len(counts)):
1185
1186
for j in range(K):
@@ -1198,20 +1199,14 @@ def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1198
1199
Only aggregates on axis=0
1199
1200
'''
1200
1201
cdef:
1201
- Py_ssize_t i, j, N, K, ngroups, b
1202
- %(dest_type2)s val, count
1203
- ndarray[%(dest_type2)s, ndim=2] nobs
1204
-
1205
- nobs = np.zeros_like(out )
1202
+ Py_ssize_t i, j, ngroups
1203
+ Py_ssize_t N = values.shape[0], K = values.shape[1], b = 0
1204
+ %(c_type)s val
1205
+ ndarray[int64_t, ndim=2] nobs = np.zeros((out.shape[0], out.shape[1]),
1206
+ dtype= np.int64 )
1206
1207
1207
- if bins[len(bins) - 1] == len(values):
1208
- ngroups = len(bins)
1209
- else:
1210
- ngroups = len(bins) + 1
1208
+ ngroups = len(bins) + (bins[len(bins) - 1] != N)
1211
1209
1212
- N, K = (<object> values).shape
1213
-
1214
- b = 0
1215
1210
for i in range(N):
1216
1211
while b < ngroups - 1 and i >= bins[b]:
1217
1212
b += 1
@@ -1221,7 +1216,7 @@ def group_count_bin_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
1221
1216
val = values[i, j]
1222
1217
1223
1218
# not nan
1224
- nobs[b, j] += val == val
1219
+ nobs[b, j] += val == val and val != iNaT
1225
1220
1226
1221
for i in range(ngroups):
1227
1222
for j in range(K):
@@ -2224,7 +2219,8 @@ def put2d_%(name)s_%(dest_type)s(ndarray[%(c_type)s, ndim=2, cast=True] values,
2224
2219
#-------------------------------------------------------------------------
2225
2220
# Generators
2226
2221
2227
- def generate_put_template (template , use_ints = True , use_floats = True ):
2222
+ def generate_put_template (template , use_ints = True , use_floats = True ,
2223
+ use_objects = False ):
2228
2224
floats_list = [
2229
2225
('float64' , 'float64_t' , 'float64_t' , 'np.float64' ),
2230
2226
('float32' , 'float32_t' , 'float32_t' , 'np.float32' ),
@@ -2235,11 +2231,14 @@ def generate_put_template(template, use_ints = True, use_floats = True):
2235
2231
('int32' , 'int32_t' , 'float64_t' , 'np.float64' ),
2236
2232
('int64' , 'int64_t' , 'float64_t' , 'np.float64' ),
2237
2233
]
2234
+ object_list = [('object' , 'object' , 'float64_t' , 'np.float64' )]
2238
2235
function_list = []
2239
2236
if use_floats :
2240
2237
function_list .extend (floats_list )
2241
2238
if use_ints :
2242
2239
function_list .extend (ints_list )
2240
+ if use_objects :
2241
+ function_list .extend (object_list )
2243
2242
2244
2243
output = StringIO ()
2245
2244
for name , c_type , dest_type , dest_dtype in function_list :
@@ -2373,7 +2372,7 @@ def generate_take_cython_file(path='generated.pyx'):
2373
2372
print (generate_put_template (template , use_ints = False ), file = f )
2374
2373
2375
2374
for template in groupby_count :
2376
- print (generate_put_template (template ), file = f )
2375
+ print (generate_put_template (template , use_objects = True ), file = f )
2377
2376
2378
2377
# for template in templates_1d_datetime:
2379
2378
# print >> f, generate_from_template_datetime(template)
0 commit comments