Skip to content

Commit 49456c2

Browse files
committed
Make DataFrame.__dataframe__ pass protocol tests
Signed-off-by: Vasily Litvinov <vasilij.n.litvinov@intel.com>
1 parent 734b811 commit 49456c2

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

pandas/api/exchange/implementation.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import collections
1+
import collections.abc
22
import ctypes
33

44
from typing import Tuple, Any
@@ -14,9 +14,8 @@ def from_dataframe(df : DataFrameXchg,
1414
"""
1515
Construct a pandas DataFrame from ``df`` if it supports ``__dataframe__``
1616
"""
17-
# NOTE: commented out for roundtrip testing
18-
# if isinstance(df, pd.DataFrame):
19-
# return df
17+
if isinstance(df, pd.DataFrame):
18+
return df
2019

2120
if not hasattr(df, '__dataframe__'):
2221
raise ValueError("`df` does not support __dataframe__")
@@ -606,20 +605,31 @@ def get_columns(self):
606605
for name in self._df.columns]
607606

608607
def select_columns(self, indices):
609-
if not isinstance(indices, collections.Sequence):
608+
if not isinstance(indices, collections.abc.Sequence):
610609
raise ValueError("`indices` is not a sequence")
610+
if not isinstance(indices, list):
611+
indices = list(indices)
611612

612-
return _PandasDataFrameXchg(self._df.iloc[:, indices])
613+
return _PandasDataFrameXchg(self._df.iloc[:, indices], self._nan_as_null, self._allow_copy)
613614

614615
def select_columns_by_name(self, names):
615-
if not isinstance(names, collections.Sequence):
616+
if not isinstance(names, collections.abc.Sequence):
616617
raise ValueError("`names` is not a sequence")
618+
if not isinstance(names, list):
619+
names = list(names)
617620

618-
return _PandasDataFrameXchg(self._df.xs(names, axis='columns'))
621+
return _PandasDataFrameXchg(self._df.loc[:, names], self._nan_as_null, self._allow_copy)
619622

620623
def get_chunks(self, n_chunks=None):
621624
"""
622625
Return an iterator yielding the chunks.
623626
"""
624-
#TODO: implement chunking when n_chunks > 1
625-
return (self,)
627+
if n_chunks and n_chunks > 1:
628+
size = len(self._df)
629+
step = size // n_chunks
630+
if size % n_chunks != 0:
631+
step +=1
632+
for start in range(0, step * n_chunks, step):
633+
yield _PandasDataFrameXchg(self._df.iloc[start:start + step, :], self._nan_as_null, self._allow_copy)
634+
else:
635+
yield self

pandas/tests/api/conftest.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
import pytest
22
import pandas as pd
3+
from pandas.api.exchange.implementation import _from_dataframe
34

45
@pytest.fixture(scope='package')
5-
def create_df_from_dict():
6-
def maker(dct):
7-
return pd.DataFrame(dct)
6+
def df_from_dict():
7+
def maker(dct, is_categorical=False):
8+
df = pd.DataFrame(dct)
9+
return df.astype('category') if is_categorical else df
10+
return maker
11+
12+
@pytest.fixture(scope='package')
13+
def df_from_xchg():
14+
def maker(xchg):
15+
return _from_dataframe(xchg)
816
return maker

0 commit comments

Comments
 (0)