Skip to content

Commit 0662f2b

Browse files
committed
infer types
1 parent 7c29393 commit 0662f2b

File tree

5 files changed

+25
-17
lines changed

5 files changed

+25
-17
lines changed

pandas/core/arrays/sparse.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,18 +1927,17 @@ def make_sparse(arr, kind='block', fill_value=None, dtype=None, copy=False):
19271927
index = _make_index(length, indices, kind)
19281928
sparsified_values = arr[mask]
19291929

1930-
# careful about casting here
1931-
# as we could easily specify a type that cannot hold the resulting values
1932-
# e.g. integer when we have floats
1930+
# careful about casting here as we could easily specify a type that
1931+
# cannot hold the resulting values, e.g. integer when we have floats
1932+
# if we don't have an object specified then use this as the cast
19331933
if dtype is not None:
1934-
try:
1935-
sparsified_values = astype_nansafe(
1936-
sparsified_values, dtype=dtype, casting='same_kind')
1937-
except TypeError:
1938-
dtype = 'float64'
1939-
sparsified_values = astype_nansafe(
1940-
sparsified_values, dtype=dtype, casting='unsafe')
19411934

1935+
ok_to_cast = all(not (is_object_dtype(t) or is_bool_dtype(t))
1936+
for t in (dtype, sparsified_values.dtype))
1937+
if ok_to_cast:
1938+
dtype = find_common_type([dtype, sparsified_values.dtype])
1939+
sparsified_values = astype_nansafe(
1940+
sparsified_values, dtype=dtype)
19421941

19431942
# TODO: copy
19441943
return sparsified_values, index, fill_value

pandas/core/internals/construction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,10 @@ def sanitize_array(data, index, dtype=None, copy=False,
666666
data = np.array(data, dtype=dtype, copy=False)
667667
subarr = np.array(data, dtype=object, copy=copy)
668668

669-
if is_object_dtype(subarr.dtype) and dtype != 'object':
669+
if (not (is_extension_array_dtype(subarr.dtype) or
670+
is_extension_array_dtype(dtype)) and
671+
is_object_dtype(subarr.dtype) and
672+
not is_object_dtype(dtype)):
670673
inferred = lib.infer_dtype(subarr, skipna=False)
671674
if inferred == 'period':
672675
try:

pandas/core/sparse/frame.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,20 +284,26 @@ def _unpickle_sparse_frame_compat(self, state):
284284
def to_dense(self):
285285
return SparseFrameAccessor(self).to_dense()
286286

287-
def _apply_columns(self, func):
287+
def _apply_columns(self, func, *args, **kwargs):
288288
"""
289289
Get new SparseDataFrame applying func to each columns
290290
"""
291291

292-
new_data = {col: func(series)
292+
new_data = {col: func(series, *args, **kwargs)
293293
for col, series in self.items()}
294294

295295
return self._constructor(
296296
data=new_data, index=self.index, columns=self.columns,
297297
default_fill_value=self.default_fill_value).__finalize__(self)
298298

299-
def astype(self, dtype):
300-
return self._apply_columns(lambda x: x.astype(dtype))
299+
def astype(self, dtype, **kwargs):
300+
301+
def f(x, dtype, **kwargs):
302+
if isinstance(dtype, (dict, Series)):
303+
dtype = dtype[x.name]
304+
return x.astype(dtype, **kwargs)
305+
306+
return self._apply_columns(f, dtype=dtype, **kwargs)
301307

302308
def copy(self, deep=True):
303309
"""

pandas/tests/sparse/frame/test_analytics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ def test_ufunc(data, dtype, func):
5555
result = func(df)
5656
expected = DataFrame(
5757
{'A': Series(func(data),
58-
dtype=dtype)})
58+
dtype=SparseDtype('float64', dtype.fill_value))})
5959
tm.assert_frame_equal(result, expected)

pandas/tests/sparse/series/test_analytics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ def test_ufunc(data, dtype, func):
1616
s = Series(data, dtype=dtype)
1717
result = func(s)
1818
expected = Series(func(data),
19-
dtype=dtype)
19+
dtype=SparseDtype('float64', dtype.fill_value))
2020
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)