Skip to content

Commit 690a09f

Browse files
committed
ENH: SparseArray constructor supports 1d scipy.sparse.spmatrix input
1 parent 639fc6f commit 690a09f

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

pandas/core/sparse/array.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
is_bool_dtype,
2424
is_list_like,
2525
is_string_dtype,
26-
is_scalar, is_dtype_equal)
26+
is_scalar, is_dtype_equal,
27+
is_scipy_sparse)
2728
from pandas.core.dtypes.cast import (
2829
maybe_convert_platform, maybe_promote,
2930
astype_nansafe, find_common_type)
@@ -164,11 +165,13 @@ class SparseArray(PandasObject, np.ndarray):
164165
165166
Parameters
166167
----------
167-
data : {array-like (1-D), Series, SparseSeries, dict}
168+
data : {array-like (1-D), Series, SparseSeries, dict, \
169+
scipy.sparse.spmatrix}
168170
kind : {'block', 'integer'}
169171
fill_value : float
170172
Code for missing value. Defaults depends on dtype.
171-
0 for int dtype, False for bool dtype, and NaN for other dtypes
173+
0 for int dtype or scipy sparse matrix, False for bool dtype, and NaN
174+
for other dtypes
172175
sparse_index : {BlockIndex, IntIndex}, optional
173176
Only if you have one. Mainly used internally
174177
@@ -197,17 +200,27 @@ def __new__(cls, data, sparse_index=None, index=None, kind='integer',
197200
values.fill(data)
198201
data = values
199202

200-
if isinstance(data, ABCSparseSeries):
201-
data = data.values
202-
is_sparse_array = isinstance(data, SparseArray)
203-
204203
if dtype is not None:
205204
dtype = np.dtype(dtype)
206205

207-
if is_sparse_array:
206+
if isinstance(data, ABCSparseSeries):
207+
data = data.values
208+
209+
if isinstance(data, SparseArray):
208210
sparse_index = data.sp_index
209211
values = data.sp_values
210212
fill_value = data.fill_value
213+
elif is_scipy_sparse(data):
214+
if not any(ax == 1 for ax in data.shape):
215+
raise ValueError('Need 1D sparse matrix shaped '
216+
'(n, 1) or (1, n)')
217+
coo = data.tocoo()
218+
values = coo.data
219+
indices = coo.row if coo.shape[0] != 1 else coo.col
220+
sparse_index = _make_index(max(coo.shape), indices, kind)
221+
# SciPy Sparse matrices imply missing value = 0
222+
if fill_value is None:
223+
fill_value = 0
211224
else:
212225
# array-like
213226
if sparse_index is None:

pandas/tests/sparse/test_array.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,24 @@ def test_constructor_spindex_dtype(self):
105105
assert arr.dtype == np.int64
106106
assert arr.fill_value == 0
107107

108+
def test_constructor_spmatrix(self):
109+
# GH-15634
110+
tm.skip_if_no_package('scipy')
111+
from scipy.sparse import csr_matrix
112+
113+
spm = csr_matrix(np.arange(5))
114+
115+
arr = SparseArray(spm)
116+
assert arr.dtype == spm.dtype
117+
assert arr.fill_value == 0
118+
119+
arr = SparseArray(spm, kind='block', dtype=float, fill_value=np.nan)
120+
assert arr.dtype == float
121+
assert np.isnan(arr.fill_value)
122+
123+
tm.assert_raises_regex(ValueError, '1D',
124+
lambda: SparseArray(csr_matrix(np.eye(3))))
125+
108126
def test_sparseseries_roundtrip(self):
109127
# GH 13999
110128
for kind in ['integer', 'block']:

0 commit comments

Comments
 (0)