diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index a48d9abc3c13b..1f9e72bcf8a2f 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -39,6 +39,7 @@ is_list_like, is_timedelta64_dtype, ) +from pandas.core.dtypes.generic import ABCExtensionArray from pandas.core.dtypes.missing import array_equivalent from pandas import ( @@ -53,7 +54,7 @@ concat, isna, ) -from pandas._typing import FrameOrSeries +from pandas._typing import ArrayLike, FrameOrSeries from pandas.core.arrays.categorical import Categorical import pandas.core.common as com from pandas.core.computation.pytables import PyTablesExpr, maybe_expression @@ -2959,7 +2960,7 @@ def read_index_node( data = node[start:stop] # If the index was an empty array write_array_empty() will # have written a sentinel. Here we relace it with the original. - if "shape" in node._v_attrs and self._is_empty_array(node._v_attrs.shape): + if "shape" in node._v_attrs and np.prod(node._v_attrs.shape) == 0: data = np.empty(node._v_attrs.shape, dtype=node._v_attrs.value_type,) kind = _ensure_decoded(node._v_attrs.kind) name = None @@ -3005,25 +3006,27 @@ def read_index_node( return index - def write_array_empty(self, key: str, value): + def write_array_empty(self, key: str, value: ArrayLike): """ write a 0-len array """ # ugly hack for length 0 axes arr = np.empty((1,) * value.ndim) self._handle.create_array(self.group, key, arr) - getattr(self.group, key)._v_attrs.value_type = str(value.dtype) - getattr(self.group, key)._v_attrs.shape = value.shape + node = getattr(self.group, key) + node._v_attrs.value_type = str(value.dtype) + node._v_attrs.shape = value.shape - def _is_empty_array(self, shape) -> bool: - """Returns true if any axis is zero length.""" - return any(x == 0 for x in shape) + def write_array(self, key: str, value: ArrayLike, items: Optional[Index] = None): + # TODO: we only have one test that gets here, the only EA + # that gets passed is DatetimeArray, and we never have + # both self._filters and EA + assert isinstance(value, (np.ndarray, ABCExtensionArray)), type(value) - def write_array(self, key: str, value, items=None): if key in self.group: self._handle.remove_node(self.group, key) # Transform needed to interface with pytables row/col notation - empty_array = self._is_empty_array(value.shape) + empty_array = value.size == 0 transposed = False if is_categorical_dtype(value): @@ -3038,29 +3041,29 @@ def write_array(self, key: str, value, items=None): value = value.T transposed = True + atom = None if self._filters is not None: - atom = None try: # get the atom for this datatype atom = _tables().Atom.from_dtype(value.dtype) except ValueError: pass - if atom is not None: - # create an empty chunked array and fill it from value - if not empty_array: - ca = self._handle.create_carray( - self.group, key, atom, value.shape, filters=self._filters - ) - ca[:] = value - getattr(self.group, key)._v_attrs.transposed = transposed + if atom is not None: + # We only get here if self._filters is non-None and + # the Atom.from_dtype call succeeded - else: - self.write_array_empty(key, value) + # create an empty chunked array and fill it from value + if not empty_array: + ca = self._handle.create_carray( + self.group, key, atom, value.shape, filters=self._filters + ) + ca[:] = value - return + else: + self.write_array_empty(key, value) - if value.dtype.type == np.object_: + elif value.dtype.type == np.object_: # infer the type, warn if we have a non-string type here (for # performance) @@ -3070,35 +3073,30 @@ def write_array(self, key: str, value, items=None): elif inferred_type == "string": pass else: - try: - items = list(items) - except TypeError: - pass ws = performance_doc % (inferred_type, key, items) warnings.warn(ws, PerformanceWarning, stacklevel=7) vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) vlarr.append(value) + + elif empty_array: + self.write_array_empty(key, value) + elif is_datetime64_dtype(value.dtype): + self._handle.create_array(self.group, key, value.view("i8")) + getattr(self.group, key)._v_attrs.value_type = "datetime64" + elif is_datetime64tz_dtype(value.dtype): + # store as UTC + # with a zone + self._handle.create_array(self.group, key, value.asi8) + + node = getattr(self.group, key) + node._v_attrs.tz = _get_tz(value.tz) + node._v_attrs.value_type = "datetime64" + elif is_timedelta64_dtype(value.dtype): + self._handle.create_array(self.group, key, value.view("i8")) + getattr(self.group, key)._v_attrs.value_type = "timedelta64" else: - if empty_array: - self.write_array_empty(key, value) - else: - if is_datetime64_dtype(value.dtype): - self._handle.create_array(self.group, key, value.view("i8")) - getattr(self.group, key)._v_attrs.value_type = "datetime64" - elif is_datetime64tz_dtype(value.dtype): - # store as UTC - # with a zone - self._handle.create_array(self.group, key, value.asi8) - - node = getattr(self.group, key) - node._v_attrs.tz = _get_tz(value.tz) - node._v_attrs.value_type = "datetime64" - elif is_timedelta64_dtype(value.dtype): - self._handle.create_array(self.group, key, value.view("i8")) - getattr(self.group, key)._v_attrs.value_type = "timedelta64" - else: - self._handle.create_array(self.group, key, value) + self._handle.create_array(self.group, key, value) getattr(self.group, key)._v_attrs.transposed = transposed