Skip to content

Add allow_copy flag to interchange protocol #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion protocol/dataframe_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,16 +347,23 @@ class DataFrame:
``__dataframe__`` method of a public data frame class in a library adhering
to the dataframe interchange protocol specification.
"""
def __dataframe__(self, nan_as_null : bool = False) -> dict:
def __dataframe__(self, nan_as_null : bool = False,
allow_copy : bool = True) -> dict:
"""
Produces a dictionary object following the dataframe protocol specification.

``nan_as_null`` is a keyword intended for the consumer to tell the
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
It is intended for cases where the consumer does not support the bit
mask or byte mask that is the producer's native representation.

``allow_copy`` is a keyword that defines whether or not the library is
allowed to make a copy of the data. For example, copying data would be
necessary if a library supports strided buffers, given that this protocol
specifies contiguous buffers.
"""
self._nan_as_null = nan_as_null
self._allow_zero_zopy = allow_copy
return {
"dataframe": self, # DataFrame object adhering to the protocol
"version": 0 # Version number of the protocol
Expand Down
90 changes: 59 additions & 31 deletions protocol/pandas_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
ColumnObject = Any


def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
def from_dataframe(df : DataFrameObject,
allow_copy : bool = True) -> pd.DataFrame:
"""
Construct a pandas DataFrame from ``df`` if it supports ``__dataframe__``
"""
Expand All @@ -46,7 +47,7 @@ def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
if not hasattr(df, '__dataframe__'):
raise ValueError("`df` does not support __dataframe__")

return _from_dataframe(df.__dataframe__())
return _from_dataframe(df.__dataframe__(allow_copy=allow_copy))


def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
Expand All @@ -63,19 +64,24 @@ def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
# least for now, deal with non-numpy dtypes later).
columns = dict()
_k = _DtypeKind
_buffers = [] # hold on to buffers, keeps memory alive
for name in df.column_names():
col = df.get_column_by_name(name)
if col.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
# Simple numerical or bool dtype, turn into numpy array
columns[name] = convert_column_to_ndarray(col)
columns[name], _buf = convert_column_to_ndarray(col)
elif col.dtype[0] == _k.CATEGORICAL:
columns[name] = convert_categorical_column(col)
columns[name], _buf = convert_categorical_column(col)
elif col.dtype[0] == _k.STRING:
columns[name] = convert_string_column(col)
columns[name], _buf = convert_string_column(col)
else:
raise NotImplementedError(f"Data type {col.dtype[0]} not handled yet")

return pd.DataFrame(columns)
_buffers.append(_buf)

df_new = pd.DataFrame(columns)
df_new._buffers = _buffers
return df_new


class _DtypeKind(enum.IntEnum):
Expand All @@ -100,7 +106,7 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
"sentinel values not handled yet")

_buffer, _dtype = col.get_buffers()["data"]
return buffer_to_ndarray(_buffer, _dtype)
return buffer_to_ndarray(_buffer, _dtype), _buffer


def buffer_to_ndarray(_buffer, _dtype) -> np.ndarray:
Expand Down Expand Up @@ -159,7 +165,7 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series:
raise NotImplementedError("Only categorical columns with sentinel "
"value supported at the moment")

return series
return series, codes_buffer


def convert_string_column(col : ColumnObject) -> np.ndarray:
Expand Down Expand Up @@ -196,7 +202,7 @@ def convert_string_column(col : ColumnObject) -> np.ndarray:
v = mbuf[i/8]
if null_value == 1:
v = ~v

if v & (1<<(i%8)):
str_list.append(np.nan)
continue
Expand All @@ -218,10 +224,11 @@ def convert_string_column(col : ColumnObject) -> np.ndarray:
str_list.append(s)

# Convert the string list to a NumPy array
return np.asarray(str_list, dtype="object")
return np.asarray(str_list, dtype="object"), buffers


def __dataframe__(cls, nan_as_null : bool = False) -> dict:
def __dataframe__(cls, nan_as_null : bool = False,
allow_copy : bool = True) -> dict:
"""
The public method to attach to pd.DataFrame.

Expand All @@ -232,12 +239,21 @@ def __dataframe__(cls, nan_as_null : bool = False) -> dict:
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
This currently has no effect; once support for nullable extension
dtypes is added, this value should be propagated to columns.

``allow_copy`` is a keyword that defines whether or not the library is
allowed to make a copy of the data. For example, copying data would be
necessary if a library supports strided buffers, given that this protocol
specifies contiguous buffers.
Currently, if the flag is set to ``False`` and a copy is needed, a
``RuntimeError`` will be raised.
"""
return _PandasDataFrame(cls, nan_as_null=nan_as_null)
return _PandasDataFrame(
cls, nan_as_null=nan_as_null, allow_copy=allow_copy)


# Monkeypatch the Pandas DataFrame class to support the interchange protocol
pd.DataFrame.__dataframe__ = __dataframe__
pd.DataFrame._buffers = []


# Implementation of interchange protocol
Expand All @@ -248,16 +264,18 @@ class _PandasBuffer:
Data in the buffer is guaranteed to be contiguous in memory.
"""

def __init__(self, x : np.ndarray) -> None:
def __init__(self, x : np.ndarray, allow_copy : bool = True) -> None:
"""
Handle only regular columns (= numpy arrays) for now.
"""
if not x.strides == (x.dtype.itemsize,):
# Array is not contiguous - this is possible to get in Pandas,
# there was some discussion on whether to support it. Som extra
# complexity for libraries that don't support it (e.g. Arrow),
# but would help with numpy-based libraries like Pandas.
raise RuntimeError("Design needs fixing - non-contiguous buffer")
# The protocol does not support strided buffers, so a copy is
# necessary. If that's not allowed, we need to raise an exception.
if allow_copy:
x = x.copy()
else:
raise RuntimeError("Exports cannot be zero-copy in the case "
"of a non-contiguous buffer")

# Store the numpy array in which the data resides as a private
# attribute, so we can use it to retrieve the public attributes
Expand Down Expand Up @@ -313,7 +331,8 @@ class _PandasColumn:

"""

def __init__(self, column : pd.Series) -> None:
def __init__(self, column : pd.Series,
allow_copy : bool = True) -> None:
"""
Note: doesn't deal with extension arrays yet, just assume a regular
Series/ndarray for now.
Expand All @@ -324,6 +343,7 @@ def __init__(self, column : pd.Series) -> None:

# Store the column as a private attribute
self._col = column
self._allow_copy = allow_copy

@property
def size(self) -> int:
Expand Down Expand Up @@ -553,11 +573,13 @@ def _get_data_buffer(self) -> Tuple[_PandasBuffer, Any]: # Any is for self.dtyp
"""
_k = _DtypeKind
if self.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
buffer = _PandasBuffer(self._col.to_numpy())
buffer = _PandasBuffer(
self._col.to_numpy(), allow_copy=self._allow_copy)
dtype = self.dtype
elif self.dtype[0] == _k.CATEGORICAL:
codes = self._col.values.codes
buffer = _PandasBuffer(codes)
buffer = _PandasBuffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a sanity check, we're not similarly passing allow_copy at L595, L634, and L676 because we have guaranteed contiguous buffers in those cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed. np.asarray(some_list) creates a new contiguous array, and the bytearray usage seems to be contiguous too (that one was a bit harder to verify, but I did test it).

codes, allow_copy=self._allow_copy)
dtype = self._dtype_from_pandasdtype(codes.dtype)
elif self.dtype[0] == _k.STRING:
# Marshal the strings from a NumPy object array into a byte array
Expand Down Expand Up @@ -670,7 +692,8 @@ class _PandasDataFrame:
``pd.DataFrame.__dataframe__`` as objects with the methods and
attributes defined on this class.
"""
def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
def __init__(self, df : pd.DataFrame, nan_as_null : bool = False,
allow_copy : bool = True) -> None:
"""
Constructor - an instance of this (private) class is returned from
`pd.DataFrame.__dataframe__`.
Expand All @@ -681,6 +704,7 @@ def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
# This currently has no effect; once support for nullable extension
# dtypes is added, this value should be propagated to columns.
self._nan_as_null = nan_as_null
self._allow_copy = allow_copy

def num_columns(self) -> int:
return len(self._df.columns)
Expand All @@ -695,13 +719,16 @@ def column_names(self) -> Iterable[str]:
return self._df.columns.tolist()

def get_column(self, i: int) -> _PandasColumn:
return _PandasColumn(self._df.iloc[:, i])
return _PandasColumn(
self._df.iloc[:, i], allow_copy=self._allow_copy)

def get_column_by_name(self, name: str) -> _PandasColumn:
return _PandasColumn(self._df[name])
return _PandasColumn(
self._df[name], allow_copy=self._allow_copy)

def get_columns(self) -> Iterable[_PandasColumn]:
return [_PandasColumn(self._df[name]) for name in self._df.columns]
return [_PandasColumn(self._df[name], allow_copy=self._allow_copy)
for name in self._df.columns]

def select_columns(self, indices: Sequence[int]) -> '_PandasDataFrame':
if not isinstance(indices, collections.Sequence):
Expand Down Expand Up @@ -739,13 +766,14 @@ def test_mixed_intfloat():


def test_noncontiguous_columns():
# Currently raises: TBD whether it should work or not, see code comment
# where the RuntimeError is raised.
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
df = pd.DataFrame(arr)
assert df[0].to_numpy().strides == (24,)
pytest.raises(RuntimeError, from_dataframe, df)
#df2 = from_dataframe(df)
df = pd.DataFrame(arr, columns=['a', 'b', 'c'])
assert df['a'].to_numpy().strides == (24,)
df2 = from_dataframe(df) # uses default of allow_copy=True
tm.assert_frame_equal(df, df2)

with pytest.raises(RuntimeError):
from_dataframe(df, allow_copy=False)


def test_categorical_dtype():
Expand Down