|
4 | 4 | from typing import (
|
5 | 5 | TYPE_CHECKING,
|
6 | 6 | Any,
|
| 7 | + Iterator, |
7 | 8 | Literal,
|
8 | 9 | Sequence,
|
9 | 10 | TypeVar,
|
|
28 | 29 | npt,
|
29 | 30 | type_t,
|
30 | 31 | )
|
31 |
| -from pandas.compat import pa_version_under2p0 |
| 32 | +from pandas.compat import ( |
| 33 | + pa_version_under1p01, |
| 34 | + pa_version_under2p0, |
| 35 | + pa_version_under5p0, |
| 36 | +) |
32 | 37 | from pandas.errors import AbstractMethodError
|
33 | 38 | from pandas.util._decorators import doc
|
34 | 39 | from pandas.util._validators import (
|
|
38 | 43 | )
|
39 | 44 |
|
40 | 45 | from pandas.core.dtypes.common import (
|
| 46 | + is_bool_dtype, |
41 | 47 | is_dtype_equal,
|
| 48 | + is_integer, |
| 49 | + is_scalar, |
42 | 50 | pandas_dtype,
|
43 | 51 | )
|
44 | 52 | from pandas.core.dtypes.dtypes import (
|
45 | 53 | DatetimeTZDtype,
|
46 | 54 | ExtensionDtype,
|
47 | 55 | PeriodDtype,
|
48 | 56 | )
|
49 |
| -from pandas.core.dtypes.missing import array_equivalent |
| 57 | +from pandas.core.dtypes.missing import ( |
| 58 | + array_equivalent, |
| 59 | + isna, |
| 60 | +) |
50 | 61 |
|
51 | 62 | from pandas.core import missing
|
52 | 63 | from pandas.core.algorithms import (
|
|
65 | 76 | "NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
|
66 | 77 | )
|
67 | 78 |
|
68 |
| -if TYPE_CHECKING: |
69 |
| - |
| 79 | +if not pa_version_under1p01: |
70 | 80 | import pyarrow as pa
|
| 81 | + import pyarrow.compute as pc |
71 | 82 |
|
| 83 | +if TYPE_CHECKING: |
72 | 84 | from pandas._typing import (
|
73 | 85 | NumpySorter,
|
74 | 86 | NumpyValueArrayLike,
|
@@ -607,3 +619,195 @@ def _concat_same_type(
|
607 | 619 | chunks = [array for ea in to_concat for array in ea._data.iterchunks()]
|
608 | 620 | arr = pa.chunked_array(chunks)
|
609 | 621 | return cls(arr)
|
| 622 | + |
| 623 | + def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None: |
| 624 | + """Set one or more values inplace. |
| 625 | +
|
| 626 | + Parameters |
| 627 | + ---------- |
| 628 | + key : int, ndarray, or slice |
| 629 | + When called from, e.g. ``Series.__setitem__``, ``key`` will be |
| 630 | + one of |
| 631 | +
|
| 632 | + * scalar int |
| 633 | + * ndarray of integers. |
| 634 | + * boolean ndarray |
| 635 | + * slice object |
| 636 | +
|
| 637 | + value : ExtensionDtype.type, Sequence[ExtensionDtype.type], or object |
| 638 | + value or values to be set of ``key``. |
| 639 | +
|
| 640 | + Returns |
| 641 | + ------- |
| 642 | + None |
| 643 | + """ |
| 644 | + key = check_array_indexer(self, key) |
| 645 | + indices = self._indexing_key_to_indices(key) |
| 646 | + value = self._maybe_convert_setitem_value(value) |
| 647 | + |
| 648 | + argsort = np.argsort(indices) |
| 649 | + indices = indices[argsort] |
| 650 | + |
| 651 | + if is_scalar(value): |
| 652 | + value = np.broadcast_to(value, len(self)) |
| 653 | + elif len(indices) != len(value): |
| 654 | + raise ValueError("Length of indexer and values mismatch") |
| 655 | + else: |
| 656 | + value = np.asarray(value)[argsort] |
| 657 | + |
| 658 | + self._data = self._set_via_chunk_iteration(indices=indices, value=value) |
| 659 | + |
| 660 | + def _indexing_key_to_indices( |
| 661 | + self, key: int | slice | np.ndarray |
| 662 | + ) -> npt.NDArray[np.intp]: |
| 663 | + """ |
| 664 | + Convert indexing key for self into positional indices. |
| 665 | +
|
| 666 | + Parameters |
| 667 | + ---------- |
| 668 | + key : int | slice | np.ndarray |
| 669 | +
|
| 670 | + Returns |
| 671 | + ------- |
| 672 | + npt.NDArray[np.intp] |
| 673 | + """ |
| 674 | + n = len(self) |
| 675 | + if isinstance(key, slice): |
| 676 | + indices = np.arange(n)[key] |
| 677 | + elif is_integer(key): |
| 678 | + indices = np.arange(n)[[key]] # type: ignore[index] |
| 679 | + elif is_bool_dtype(key): |
| 680 | + key = np.asarray(key) |
| 681 | + if len(key) != n: |
| 682 | + raise ValueError("Length of indexer and values mismatch") |
| 683 | + indices = key.nonzero()[0] |
| 684 | + else: |
| 685 | + key = np.asarray(key) |
| 686 | + indices = np.arange(n)[key] |
| 687 | + return indices |
| 688 | + |
| 689 | + def _maybe_convert_setitem_value(self, value): |
| 690 | + """Maybe convert value to be pyarrow compatible.""" |
| 691 | + raise NotImplementedError() |
| 692 | + |
| 693 | + def _set_via_chunk_iteration( |
| 694 | + self, indices: npt.NDArray[np.intp], value: npt.NDArray[Any] |
| 695 | + ) -> pa.ChunkedArray: |
| 696 | + """ |
| 697 | + Loop through the array chunks and set the new values while |
| 698 | + leaving the chunking layout unchanged. |
| 699 | + """ |
| 700 | + chunk_indices = self._indices_to_chunk_indices(indices) |
| 701 | + new_data = list(self._data.iterchunks()) |
| 702 | + |
| 703 | + for i, c_ind in enumerate(chunk_indices): |
| 704 | + n = len(c_ind) |
| 705 | + if n == 0: |
| 706 | + continue |
| 707 | + c_value, value = value[:n], value[n:] |
| 708 | + new_data[i] = self._replace_with_indices(new_data[i], c_ind, c_value) |
| 709 | + |
| 710 | + return pa.chunked_array(new_data) |
| 711 | + |
| 712 | + def _indices_to_chunk_indices( |
| 713 | + self, indices: npt.NDArray[np.intp] |
| 714 | + ) -> Iterator[npt.NDArray[np.intp]]: |
| 715 | + """ |
| 716 | + Convert *sorted* indices for self into a list of ndarrays |
| 717 | + each containing the indices *within* each chunk of the |
| 718 | + underlying ChunkedArray. |
| 719 | +
|
| 720 | + Parameters |
| 721 | + ---------- |
| 722 | + indices : npt.NDArray[np.intp] |
| 723 | + Position indices for the underlying ChunkedArray. |
| 724 | +
|
| 725 | + Returns |
| 726 | + ------- |
| 727 | + Generator yielding positional indices for each chunk |
| 728 | +
|
| 729 | + Notes |
| 730 | + ----- |
| 731 | + Assumes that indices is sorted. Caller is responsible for sorting. |
| 732 | + """ |
| 733 | + for start, stop in self._chunk_positional_ranges(): |
| 734 | + if len(indices) == 0 or stop <= indices[0]: |
| 735 | + yield np.array([], dtype=np.intp) |
| 736 | + else: |
| 737 | + n = int(np.searchsorted(indices, stop, side="left")) |
| 738 | + c_ind = indices[:n] - start |
| 739 | + indices = indices[n:] |
| 740 | + yield c_ind |
| 741 | + |
| 742 | + def _chunk_positional_ranges(self) -> tuple[tuple[int, int], ...]: |
| 743 | + """ |
| 744 | + Return a tuple of tuples each containing the left (inclusive) |
| 745 | + and right (exclusive) positional bounds of each chunk's values |
| 746 | + within the underlying ChunkedArray. |
| 747 | +
|
| 748 | + Returns |
| 749 | + ------- |
| 750 | + tuple[tuple] |
| 751 | + """ |
| 752 | + ranges = [] |
| 753 | + stop = 0 |
| 754 | + for c in self._data.iterchunks(): |
| 755 | + start, stop = stop, stop + len(c) |
| 756 | + ranges.append((start, stop)) |
| 757 | + return tuple(ranges) |
| 758 | + |
| 759 | + @classmethod |
| 760 | + def _replace_with_indices( |
| 761 | + cls, |
| 762 | + chunk: pa.Array, |
| 763 | + indices: npt.NDArray[np.intp], |
| 764 | + value: npt.NDArray[Any], |
| 765 | + ) -> pa.Array: |
| 766 | + """ |
| 767 | + Replace items selected with a set of positional indices. |
| 768 | +
|
| 769 | + Analogous to pyarrow.compute.replace_with_mask, except that replacement |
| 770 | + positions are identified via indices rather than a mask. |
| 771 | +
|
| 772 | + Parameters |
| 773 | + ---------- |
| 774 | + chunk : pa.Array |
| 775 | + indices : npt.NDArray[np.intp] |
| 776 | + value : npt.NDArray[Any] |
| 777 | + Replacement value(s). |
| 778 | +
|
| 779 | + Returns |
| 780 | + ------- |
| 781 | + pa.Array |
| 782 | + """ |
| 783 | + n = len(indices) |
| 784 | + |
| 785 | + if n == 0: |
| 786 | + return chunk |
| 787 | + |
| 788 | + start, stop = indices[[0, -1]] |
| 789 | + |
| 790 | + if (stop - start) == (n - 1): |
| 791 | + # fast path for a contiguous set of indices |
| 792 | + arrays = [ |
| 793 | + chunk[:start], |
| 794 | + pa.array(value, type=chunk.type), |
| 795 | + chunk[stop + 1 :], |
| 796 | + ] |
| 797 | + arrays = [arr for arr in arrays if len(arr)] |
| 798 | + if len(arrays) == 1: |
| 799 | + return arrays[0] |
| 800 | + return pa.concat_arrays(arrays) |
| 801 | + |
| 802 | + mask = np.zeros(len(chunk), dtype=np.bool_) |
| 803 | + mask[indices] = True |
| 804 | + |
| 805 | + if pa_version_under5p0: |
| 806 | + arr = chunk.to_numpy(zero_copy_only=False) |
| 807 | + arr[mask] = value |
| 808 | + return pa.array(arr, type=chunk.type) |
| 809 | + |
| 810 | + if isna(value).all(): |
| 811 | + return pc.if_else(mask, None, chunk) |
| 812 | + |
| 813 | + return pc.replace_with_mask(chunk, mask, value) |
0 commit comments