Skip to content

Commit 2d6a2c3

Browse files
authored
REF: move ArrowStringArray.__setitem__ and related methods to ArrowExtensionArray (#46439)
1 parent afec0e9 commit 2d6a2c3

File tree

2 files changed

+214
-138
lines changed

2 files changed

+214
-138
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 208 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import (
55
TYPE_CHECKING,
66
Any,
7+
Iterator,
78
Literal,
89
Sequence,
910
TypeVar,
@@ -28,7 +29,11 @@
2829
npt,
2930
type_t,
3031
)
31-
from pandas.compat import pa_version_under2p0
32+
from pandas.compat import (
33+
pa_version_under1p01,
34+
pa_version_under2p0,
35+
pa_version_under5p0,
36+
)
3237
from pandas.errors import AbstractMethodError
3338
from pandas.util._decorators import doc
3439
from pandas.util._validators import (
@@ -38,15 +43,21 @@
3843
)
3944

4045
from pandas.core.dtypes.common import (
46+
is_bool_dtype,
4147
is_dtype_equal,
48+
is_integer,
49+
is_scalar,
4250
pandas_dtype,
4351
)
4452
from pandas.core.dtypes.dtypes import (
4553
DatetimeTZDtype,
4654
ExtensionDtype,
4755
PeriodDtype,
4856
)
49-
from pandas.core.dtypes.missing import array_equivalent
57+
from pandas.core.dtypes.missing import (
58+
array_equivalent,
59+
isna,
60+
)
5061

5162
from pandas.core import missing
5263
from pandas.core.algorithms import (
@@ -65,10 +76,11 @@
6576
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
6677
)
6778

68-
if TYPE_CHECKING:
69-
79+
if not pa_version_under1p01:
7080
import pyarrow as pa
81+
import pyarrow.compute as pc
7182

83+
if TYPE_CHECKING:
7284
from pandas._typing import (
7385
NumpySorter,
7486
NumpyValueArrayLike,
@@ -607,3 +619,195 @@ def _concat_same_type(
607619
chunks = [array for ea in to_concat for array in ea._data.iterchunks()]
608620
arr = pa.chunked_array(chunks)
609621
return cls(arr)
622+
623+
def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
624+
"""Set one or more values inplace.
625+
626+
Parameters
627+
----------
628+
key : int, ndarray, or slice
629+
When called from, e.g. ``Series.__setitem__``, ``key`` will be
630+
one of
631+
632+
* scalar int
633+
* ndarray of integers.
634+
* boolean ndarray
635+
* slice object
636+
637+
value : ExtensionDtype.type, Sequence[ExtensionDtype.type], or object
638+
value or values to be set of ``key``.
639+
640+
Returns
641+
-------
642+
None
643+
"""
644+
key = check_array_indexer(self, key)
645+
indices = self._indexing_key_to_indices(key)
646+
value = self._maybe_convert_setitem_value(value)
647+
648+
argsort = np.argsort(indices)
649+
indices = indices[argsort]
650+
651+
if is_scalar(value):
652+
value = np.broadcast_to(value, len(self))
653+
elif len(indices) != len(value):
654+
raise ValueError("Length of indexer and values mismatch")
655+
else:
656+
value = np.asarray(value)[argsort]
657+
658+
self._data = self._set_via_chunk_iteration(indices=indices, value=value)
659+
660+
def _indexing_key_to_indices(
661+
self, key: int | slice | np.ndarray
662+
) -> npt.NDArray[np.intp]:
663+
"""
664+
Convert indexing key for self into positional indices.
665+
666+
Parameters
667+
----------
668+
key : int | slice | np.ndarray
669+
670+
Returns
671+
-------
672+
npt.NDArray[np.intp]
673+
"""
674+
n = len(self)
675+
if isinstance(key, slice):
676+
indices = np.arange(n)[key]
677+
elif is_integer(key):
678+
indices = np.arange(n)[[key]] # type: ignore[index]
679+
elif is_bool_dtype(key):
680+
key = np.asarray(key)
681+
if len(key) != n:
682+
raise ValueError("Length of indexer and values mismatch")
683+
indices = key.nonzero()[0]
684+
else:
685+
key = np.asarray(key)
686+
indices = np.arange(n)[key]
687+
return indices
688+
689+
def _maybe_convert_setitem_value(self, value):
690+
"""Maybe convert value to be pyarrow compatible."""
691+
raise NotImplementedError()
692+
693+
def _set_via_chunk_iteration(
694+
self, indices: npt.NDArray[np.intp], value: npt.NDArray[Any]
695+
) -> pa.ChunkedArray:
696+
"""
697+
Loop through the array chunks and set the new values while
698+
leaving the chunking layout unchanged.
699+
"""
700+
chunk_indices = self._indices_to_chunk_indices(indices)
701+
new_data = list(self._data.iterchunks())
702+
703+
for i, c_ind in enumerate(chunk_indices):
704+
n = len(c_ind)
705+
if n == 0:
706+
continue
707+
c_value, value = value[:n], value[n:]
708+
new_data[i] = self._replace_with_indices(new_data[i], c_ind, c_value)
709+
710+
return pa.chunked_array(new_data)
711+
712+
def _indices_to_chunk_indices(
713+
self, indices: npt.NDArray[np.intp]
714+
) -> Iterator[npt.NDArray[np.intp]]:
715+
"""
716+
Convert *sorted* indices for self into a list of ndarrays
717+
each containing the indices *within* each chunk of the
718+
underlying ChunkedArray.
719+
720+
Parameters
721+
----------
722+
indices : npt.NDArray[np.intp]
723+
Position indices for the underlying ChunkedArray.
724+
725+
Returns
726+
-------
727+
Generator yielding positional indices for each chunk
728+
729+
Notes
730+
-----
731+
Assumes that indices is sorted. Caller is responsible for sorting.
732+
"""
733+
for start, stop in self._chunk_positional_ranges():
734+
if len(indices) == 0 or stop <= indices[0]:
735+
yield np.array([], dtype=np.intp)
736+
else:
737+
n = int(np.searchsorted(indices, stop, side="left"))
738+
c_ind = indices[:n] - start
739+
indices = indices[n:]
740+
yield c_ind
741+
742+
def _chunk_positional_ranges(self) -> tuple[tuple[int, int], ...]:
743+
"""
744+
Return a tuple of tuples each containing the left (inclusive)
745+
and right (exclusive) positional bounds of each chunk's values
746+
within the underlying ChunkedArray.
747+
748+
Returns
749+
-------
750+
tuple[tuple]
751+
"""
752+
ranges = []
753+
stop = 0
754+
for c in self._data.iterchunks():
755+
start, stop = stop, stop + len(c)
756+
ranges.append((start, stop))
757+
return tuple(ranges)
758+
759+
@classmethod
760+
def _replace_with_indices(
761+
cls,
762+
chunk: pa.Array,
763+
indices: npt.NDArray[np.intp],
764+
value: npt.NDArray[Any],
765+
) -> pa.Array:
766+
"""
767+
Replace items selected with a set of positional indices.
768+
769+
Analogous to pyarrow.compute.replace_with_mask, except that replacement
770+
positions are identified via indices rather than a mask.
771+
772+
Parameters
773+
----------
774+
chunk : pa.Array
775+
indices : npt.NDArray[np.intp]
776+
value : npt.NDArray[Any]
777+
Replacement value(s).
778+
779+
Returns
780+
-------
781+
pa.Array
782+
"""
783+
n = len(indices)
784+
785+
if n == 0:
786+
return chunk
787+
788+
start, stop = indices[[0, -1]]
789+
790+
if (stop - start) == (n - 1):
791+
# fast path for a contiguous set of indices
792+
arrays = [
793+
chunk[:start],
794+
pa.array(value, type=chunk.type),
795+
chunk[stop + 1 :],
796+
]
797+
arrays = [arr for arr in arrays if len(arr)]
798+
if len(arrays) == 1:
799+
return arrays[0]
800+
return pa.concat_arrays(arrays)
801+
802+
mask = np.zeros(len(chunk), dtype=np.bool_)
803+
mask[indices] = True
804+
805+
if pa_version_under5p0:
806+
arr = chunk.to_numpy(zero_copy_only=False)
807+
arr[mask] = value
808+
return pa.array(arr, type=chunk.type)
809+
810+
if isna(value).all():
811+
return pc.if_else(mask, None, chunk)
812+
813+
return pc.replace_with_mask(chunk, mask, value)

0 commit comments

Comments
 (0)