-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: add NDArrayBackedExtensionArray to public API #45544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1f93779
522b548
ee4e23d
945f840
721ae11
ae68f9d
05d0e08
1ad0338
38113c8
18ec784
2919f60
0c52366
319ac2b
8513863
5309895
827f483
2cd9b31
cc75eda
ca323bb
bfd31f0
396da54
27cf80e
c716826
f4df0e9
8876b9a
1bdd1cd
4b0a948
5920778
38018e6
9277cf5
0b86bd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,6 +135,110 @@ by some other storage type, like Python lists. | |
See the `extension array source`_ for the interface definition. The docstrings | ||
and comments contain guidance for properly implementing the interface. | ||
|
||
:class:`~pandas.api.extensions.NDArrayBackedExtensionArray` | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
For ExtensionArrays backed by a single NumPy array, the | ||
:class:`~pandas.api.extensions.NDArrayBackedExtensionArray` class can save you | ||
some effort. It contains a private property ``_ndarray`` with the backing NumPy | ||
array and implements the extension array interface. | ||
|
||
Implement the following: | ||
|
||
``_box_func`` | ||
Convert from array values to the type you wish to expose to users. | ||
|
||
``_internal_fill_value`` | ||
Scalar used to denote ``NA`` value inside our ``self._ndarray``, e.g. ``-1`` | ||
for ``Categorical``, ``iNaT`` for ``Period``. | ||
|
||
``_validate_scalar`` | ||
Convert from an object to a value which can be stored in the NumPy array. | ||
|
||
``_validate_setitem_value`` | ||
Convert a value or values for use in setting a value or values in the backing | ||
NumPy array. | ||
|
||
``_validate_searchsorted_value`` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In 2.0 i think this is going away and we'll re-use _validate_setitem_value for this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarified that most implementations will be identical to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _validate_searchsorted_value is gone now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you remove _validate_searchsorted_value here |
||
Convert a value for use in searching for a value in the backing NumPy array. | ||
Note: in most cases, the implementation can be identical to that of | ||
``_validate_setitem_value``. | ||
|
||
.. code-block:: python | ||
|
||
class DateArray(NDArrayBackedExtensionArray): | ||
_internal_fill_value = numpy.datetime64("NaT") | ||
|
||
def __init__(self, values): | ||
backing_array_dtype = "<M8[ns]" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you make this a np.dtype object instead of a string |
||
super().__init__(values=values, dtype=backing_array_dtype) | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _box_func(self, value): | ||
if pandas.isna(x): | ||
return pandas.NaT | ||
return x.astype("datetime64[us]").item().date() | ||
|
||
def _validate_scalar(self, scalar): | ||
if pandas.isna(scalar): | ||
return numpy.datetime64("NaT") | ||
elif isinstance(scalar, datetime.date): | ||
return pandas.Timestamp( | ||
year=scalar.year, month=scalar.month, day=scalar.day | ||
).to_datetime64() | ||
else: | ||
raise TypeError("Invalid value type", scalar) | ||
|
||
def _validate_setitem_value(self, value): | ||
if pandas.api.types.is_list_like(value): | ||
return [self._validate_scalar(v) for v in value] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be an ndarray of the same dtype as self._ndarray |
||
return self._validate_scalar(value) | ||
|
||
def _validate_searchsorted_value(self, value): | ||
return self._validate_setitem_value(value) | ||
|
||
|
||
To support 2D arrays, use the ``_from_backing_data`` helper function when a | ||
method is called on multi-dimensional data of the same dtype as ``_ndarray``. | ||
|
||
.. code-block:: python | ||
|
||
class CustomArray(NDArrayBackedExtensionArray): | ||
|
||
... | ||
|
||
def min(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs): | ||
pandas.compat.numpy.function.validate_minnumpy_validate_min((), kwargs) | ||
result = pandas.core.nanops.nanmin( | ||
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna | ||
) | ||
if axis is None or self.ndim == 1: | ||
return self._box_func(result) | ||
return self._from_backing_data(result) | ||
|
||
|
||
Subclass the tests in :mod:`pandas.tests.extension.base` in your test suite to | ||
validate your implementation. | ||
|
||
.. code-block:: python | ||
|
||
@pytest.fixture | ||
def data(): | ||
return CustomArray(numpy.arange(-10, 10, 1) | ||
|
||
|
||
class Test2DCompat(base.NDArrayBacked2DTests): | ||
pass | ||
|
||
|
||
class TestComparisonOps(base.BaseComparisonOpsTests): | ||
pass | ||
|
||
... | ||
|
||
class TestSetitem(base.BaseSetitemTests): | ||
pass | ||
|
||
|
||
.. _extending.extension.operator: | ||
|
||
:class:`~pandas.api.extensions.ExtensionArray` operator support | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,7 @@ Other enhancements | |
- :func:`timedelta_range` now supports a ``unit`` keyword ("s", "ms", "us", or "ns") to specify the desired resolution of the output index (:issue:`49824`) | ||
- :meth:`DataFrame.to_json` now supports a ``mode`` keyword with supported inputs 'w' and 'a'. Defaulting to 'w', 'a' can be used when lines=True and orient='records' to append record oriented json lines to an existing json file. (:issue:`35849`) | ||
- Added ``name`` parameter to :meth:`IntervalIndex.from_breaks`, :meth:`IntervalIndex.from_arrays` and :meth:`IntervalIndex.from_tuples` (:issue:`48911`) | ||
- :class:`NDArrayBackedExtensionArray` now exposed in the public API. (:issue:`45544`) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no trailing period |
||
- | ||
|
||
.. --------------------------------------------------------------------------- | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import pandas as pd | ||
from pandas import api | ||
import pandas._testing as tm | ||
from pandas.api import extensions | ||
|
||
|
||
class Base: | ||
|
@@ -241,6 +242,33 @@ def test_api(self): | |
self.check(api, self.allowed) | ||
|
||
|
||
class TestExtensions(Base): | ||
# top-level classes | ||
classes = [ | ||
"ExtensionDtype", | ||
"ExtensionArray", | ||
"ExtensionScalarOpsMixin", | ||
"NDArrayBackedExtensionArray", | ||
] | ||
|
||
# top-level functions | ||
funcs = [ | ||
"register_extension_dtype", | ||
"register_dataframe_accessor", | ||
"register_index_accessor", | ||
"register_series_accessor", | ||
"take", | ||
] | ||
|
||
# misc | ||
misc = ["no_default"] | ||
|
||
def test_api(self): | ||
checkthese = self.classes + self.funcs + self.misc | ||
|
||
self.check(namespace=extensions, expected=checkthese) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. im not that familiar with this test file. what is being tested here? |
||
|
||
|
||
class TestTesting(Base): | ||
funcs = [ | ||
"assert_frame_equal", | ||
|
Uh oh!
There was an error while loading. Please reload this page.