diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index bf7aa5970519f..8539d0547e5d1 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -53,7 +53,7 @@ from pandas.io.formats.printing import adjoin, pprint_thing if TYPE_CHECKING: - from tables import File # noqa:F401 + from tables import File, Node # noqa:F401 # versioning attribute @@ -244,7 +244,7 @@ def to_hdf( key, value, mode=None, - complevel=None, + complevel: Optional[int] = None, complib=None, append=None, **kwargs, @@ -459,12 +459,14 @@ class HDFStore: """ _handle: Optional["File"] + _complevel: int + _fletcher32: bool def __init__( self, path, mode=None, - complevel=None, + complevel: Optional[int] = None, complib=None, fletcher32: bool = False, **kwargs, @@ -526,7 +528,7 @@ def __getattr__(self, name: str): f"'{type(self).__name__}' object has no attribute '{name}'" ) - def __contains__(self, key: str): + def __contains__(self, key: str) -> bool: """ check for existence of this key can match the exact pathname or the pathnm w/o the leading '/' """ @@ -1267,18 +1269,22 @@ def walk(self, where="/"): yield (g._v_pathname.rstrip("/"), groups, leaves) - def get_node(self, key: str): + def get_node(self, key: str) -> Optional["Node"]: """ return the node with the key or None if it does not exist """ self._check_if_open() if not key.startswith("/"): key = "/" + key assert self._handle is not None + assert _table_mod is not None # for mypy try: - return self._handle.get_node(self.root, key) - except _table_mod.exceptions.NoSuchNodeError: # type: ignore + node = self._handle.get_node(self.root, key) + except _table_mod.exceptions.NoSuchNodeError: return None + assert isinstance(node, _table_mod.Node), type(node) + return node + def get_storer(self, key: str) -> Union["GenericFixed", "Table"]: """ return the storer object for a key, raise if not in the file """ group = self.get_node(key) @@ -1296,7 +1302,7 @@ def copy( propindexes: bool = True, keys=None, complib=None, - complevel=None, + complevel: Optional[int] = None, fletcher32: bool = False, overwrite=True, ): @@ -1387,7 +1393,9 @@ def info(self) -> str: return output - # private methods ###### + # ------------------------------------------------------------------------ + # private methods + def _check_if_open(self): if not self.is_open: raise ClosedFileError(f"{self._path} file is not open!") @@ -1559,7 +1567,7 @@ def _write_to_group( if isinstance(s, Table) and index: s.create_index(columns=index) - def _read_group(self, group, **kwargs): + def _read_group(self, group: "Node", **kwargs): s = self._create_storer(group) s.infer_axes() return s.read(**kwargs) @@ -1786,7 +1794,7 @@ def copy(self): new_self = copy.copy(self) return new_self - def infer(self, handler): + def infer(self, handler: "Table"): """infer this column from the table: create and return a new object""" table = handler.table new_self = self.copy() @@ -2499,9 +2507,16 @@ class Fixed: pandas_kind: str obj_type: Type[Union[DataFrame, Series]] ndim: int + parent: HDFStore + group: "Node" is_table = False - def __init__(self, parent, group, encoding=None, errors="strict", **kwargs): + def __init__( + self, parent: HDFStore, group: "Node", encoding=None, errors="strict", **kwargs + ): + assert isinstance(parent, HDFStore), type(parent) + assert _table_mod is not None # needed for mypy + assert isinstance(group, _table_mod.Node), type(group) self.parent = parent self.group = group self.encoding = _ensure_encoding(encoding) @@ -2568,11 +2583,11 @@ def _filters(self): return self.parent._filters @property - def _complevel(self): + def _complevel(self) -> int: return self.parent._complevel @property - def _fletcher32(self): + def _fletcher32(self) -> bool: return self.parent._fletcher32 @property @@ -2637,7 +2652,7 @@ def read( def write(self, **kwargs): raise NotImplementedError( - "cannot write on an abstract storer: sublcasses should implement" + "cannot write on an abstract storer: subclasses should implement" ) def delete( @@ -2803,7 +2818,7 @@ def write_index(self, key: str, index: Index): if isinstance(index, DatetimeIndex) and index.tz is not None: node._v_attrs.tz = _get_tz(index.tz) - def write_multi_index(self, key, index): + def write_multi_index(self, key: str, index: MultiIndex): setattr(self.attrs, f"{key}_nlevels", index.nlevels) for i, (lev, level_codes, name) in enumerate( @@ -2828,7 +2843,7 @@ def write_multi_index(self, key, index): label_key = f"{key}_label{i}" self.write_array(label_key, level_codes) - def read_multi_index(self, key, **kwargs) -> MultiIndex: + def read_multi_index(self, key: str, **kwargs) -> MultiIndex: nlevels = getattr(self.attrs, f"{key}_nlevels") levels = [] @@ -2849,7 +2864,7 @@ def read_multi_index(self, key, **kwargs) -> MultiIndex: ) def read_index_node( - self, node, start: Optional[int] = None, stop: Optional[int] = None + self, node: "Node", start: Optional[int] = None, stop: Optional[int] = None ): data = node[start:stop] # If the index was an empty array write_array_empty() will @@ -3310,7 +3325,7 @@ def values_cols(self) -> List[str]: """ return a list of my values cols """ return [i.cname for i in self.values_axes] - def _get_metadata_path(self, key) -> str: + def _get_metadata_path(self, key: str) -> str: """ return the metadata pathname for this key """ group = self.group._v_pathname return f"{group}/meta/{key}/meta" @@ -3877,10 +3892,10 @@ def process_filter(field, filt): def create_description( self, complib=None, - complevel=None, + complevel: Optional[int] = None, fletcher32: bool = False, expectedrows: Optional[int] = None, - ): + ) -> Dict[str, Any]: """ create the description of the table from the axes & values """ # provided expected rows if its passed @@ -4537,10 +4552,10 @@ def _set_tz(values, tz, preserve_UTC: bool = False, coerce: bool = False): return values -def _convert_index(name: str, index, encoding=None, errors="strict"): +def _convert_index(name: str, index: Index, encoding=None, errors="strict"): assert isinstance(name, str) - index_name = getattr(index, "name", None) + index_name = index.name if isinstance(index, DatetimeIndex): converted = index.asi8 @@ -4630,8 +4645,9 @@ def _convert_index(name: str, index, encoding=None, errors="strict"): ) -def _unconvert_index(data, kind, encoding=None, errors="strict"): - kind = _ensure_decoded(kind) +def _unconvert_index(data, kind: str, encoding=None, errors="strict"): + index: Union[Index, np.ndarray] + if kind == "datetime64": index = DatetimeIndex(data) elif kind == "timedelta64":