Skip to content

Commit ae6a25b

Browse files
fjetterFlorian Jetter
authored and
Florian Jetter
committed
fix cache_readonly
1 parent baa8539 commit ae6a25b

File tree

3 files changed

+50
-29
lines changed

3 files changed

+50
-29
lines changed

pandas/_libs/properties.pyx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,28 @@ cdef class CachedProperty:
2424

2525
# Get the cache or set a default one if needed
2626
cache = getattr(obj, '_cache', None)
27+
28+
if cache is not None:
29+
# When accessing cython extension types, the attribute is already
30+
# registered and known to the class, unlike for python object. To
31+
# ensure we're not accidentally using a global scope / class level
32+
# cache we'll need to check whether the instance and class
33+
# attribute is identical
34+
cache_class = getattr(typ, "_cache", None)
35+
if cache_class is not None and cache_class is cache:
36+
raise TypeError(
37+
f"Class {typ} defines a `_cache` attribute on class level "
38+
"which is forbidden in combination with @cache_readonly."
39+
)
40+
2741
if cache is None:
2842
try:
2943
cache = obj._cache = {}
3044
except (AttributeError):
31-
return self
45+
raise TypeError(
46+
f"Cython extension type {type(obj)} must declare attribute "
47+
"`_cache` to use @cache_readonly."
48+
)
3249

3350
if PyDict_Contains(cache, self.name):
3451
# not necessary to Py_INCREF

pandas/core/dtypes/dtypes.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytz
2222

2323
from pandas._libs.interval import Interval
24+
from pandas._libs.properties import cache_readonly
2425
from pandas._libs.tslibs import (
2526
BaseOffset,
2627
NaT,
@@ -86,7 +87,7 @@ class PandasExtensionDtype(ExtensionDtype):
8687
base: Optional[DtypeObj] = None
8788
isbuiltin = 0
8889
isnative = 0
89-
_cache: Dict[str_type, PandasExtensionDtype] = {}
90+
_cache_dtypes: Dict[str_type, PandasExtensionDtype] = {}
9091

9192
def __str__(self) -> str_type:
9293
"""
@@ -110,7 +111,7 @@ def __getstate__(self) -> Dict[str_type, Any]:
110111
@classmethod
111112
def reset_cache(cls) -> None:
112113
""" clear the cache """
113-
cls._cache = {}
114+
cls._cache_dtypes = {}
114115

115116

116117
class CategoricalDtypeType(type):
@@ -182,7 +183,7 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
182183
str = "|O08"
183184
base = np.dtype("O")
184185
_metadata = ("categories", "ordered")
185-
_cache: Dict[str_type, PandasExtensionDtype] = {}
186+
_cache_dtypes: Dict[str_type, PandasExtensionDtype] = {}
186187

187188
def __init__(self, categories=None, ordered: Ordered = False):
188189
self._finalize(categories, ordered, fastpath=False)
@@ -351,17 +352,6 @@ def __setstate__(self, state: MutableMapping[str_type, Any]) -> None:
351352
self._categories = state.pop("categories", None)
352353
self._ordered = state.pop("ordered", False)
353354

354-
def __hash__(self) -> int:
355-
# _hash_categories returns a uint64, so use the negative
356-
# space for when we have unknown categories to avoid a conflict
357-
if self.categories is None:
358-
if self.ordered:
359-
return -1
360-
else:
361-
return -2
362-
# We *do* want to include the real self.ordered here
363-
return int(self._hash_categories(self.categories, self.ordered))
364-
365355
def __eq__(self, other: Any) -> bool:
366356
"""
367357
Rules for CDT equality:
@@ -434,14 +424,28 @@ def __repr__(self) -> str_type:
434424
data = data.rstrip(", ")
435425
return f"CategoricalDtype(categories={data}, ordered={self.ordered})"
436426

437-
@staticmethod
438-
def _hash_categories(categories, ordered: Ordered = True) -> int:
427+
def __hash__(self) -> int:
428+
# _hash_categories returns a uint64, so use the negative
429+
# space for when we have unknown categories to avoid a conflict
430+
if self.categories is None:
431+
if self.ordered:
432+
return -1
433+
else:
434+
return -2
435+
return int(self._hash_categories)
436+
437+
@cache_readonly
438+
def _hash_categories(self) -> int:
439439
from pandas.core.util.hashing import (
440440
combine_hash_arrays,
441441
hash_array,
442442
hash_tuples,
443443
)
444444

445+
# We *do* want to include the real self.ordered here
446+
categories = self.categories
447+
ordered = self.ordered
448+
445449
if len(categories) and isinstance(categories[0], tuple):
446450
# assumes if any individual category is a tuple, then all our. ATM
447451
# I don't really want to support just some of the categories being
@@ -678,7 +682,7 @@ class DatetimeTZDtype(PandasExtensionDtype):
678682
na_value = NaT
679683
_metadata = ("unit", "tz")
680684
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
681-
_cache: Dict[str_type, PandasExtensionDtype] = {}
685+
_cache_dtypes: Dict[str_type, PandasExtensionDtype] = {}
682686

683687
def __init__(self, unit: Union[str_type, DatetimeTZDtype] = "ns", tz=None):
684688
if isinstance(unit, DatetimeTZDtype):
@@ -844,7 +848,7 @@ class PeriodDtype(dtypes.PeriodDtypeBase, PandasExtensionDtype):
844848
num = 102
845849
_metadata = ("freq",)
846850
_match = re.compile(r"(P|p)eriod\[(?P<freq>.+)\]")
847-
_cache: Dict[str_type, PandasExtensionDtype] = {}
851+
_cache_dtypes: Dict[str_type, PandasExtensionDtype] = {}
848852

849853
def __new__(cls, freq=None):
850854
"""
@@ -866,12 +870,12 @@ def __new__(cls, freq=None):
866870
freq = cls._parse_dtype_strict(freq)
867871

868872
try:
869-
return cls._cache[freq.freqstr]
873+
return cls._cache_dtypes[freq.freqstr]
870874
except KeyError:
871875
dtype_code = freq._period_dtype_code
872876
u = dtypes.PeriodDtypeBase.__new__(cls, dtype_code)
873877
u._freq = freq
874-
cls._cache[freq.freqstr] = u
878+
cls._cache_dtypes[freq.freqstr] = u
875879
return u
876880

877881
def __reduce__(self):
@@ -1049,7 +1053,7 @@ class IntervalDtype(PandasExtensionDtype):
10491053
_match = re.compile(
10501054
r"(I|i)nterval\[(?P<subtype>[^,]+)(, (?P<closed>(right|left|both|neither)))?\]"
10511055
)
1052-
_cache: Dict[str_type, PandasExtensionDtype] = {}
1056+
_cache_dtypes: Dict[str_type, PandasExtensionDtype] = {}
10531057

10541058
def __new__(cls, subtype=None, closed: Optional[str_type] = None):
10551059
from pandas.core.dtypes.common import (
@@ -1106,12 +1110,12 @@ def __new__(cls, subtype=None, closed: Optional[str_type] = None):
11061110

11071111
key = str(subtype) + str(closed)
11081112
try:
1109-
return cls._cache[key]
1113+
return cls._cache_dtypes[key]
11101114
except KeyError:
11111115
u = object.__new__(cls)
11121116
u._subtype = subtype
11131117
u._closed = closed
1114-
cls._cache[key] = u
1118+
cls._cache_dtypes[key] = u
11151119
return u
11161120

11171121
@property

pandas/tests/dtypes/test_dtypes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ def test_pickle(self, dtype):
6666

6767
# clear the cache
6868
type(dtype).reset_cache()
69-
assert not len(dtype._cache)
69+
assert not len(dtype._cache_dtypes)
7070

7171
# force back to the cache
7272
result = tm.round_trip_pickle(dtype)
7373
if not isinstance(dtype, PeriodDtype):
7474
# Because PeriodDtype has a cython class as a base class,
7575
# it has different pickle semantics, and its cache is re-populated
7676
# on un-pickling.
77-
assert not len(dtype._cache)
77+
assert not len(dtype._cache_dtypes)
7878
assert result == dtype
7979

8080

@@ -791,14 +791,14 @@ def test_basic_dtype(self):
791791
def test_caching(self):
792792
IntervalDtype.reset_cache()
793793
dtype = IntervalDtype("int64", "right")
794-
assert len(IntervalDtype._cache) == 1
794+
assert len(IntervalDtype._cache_dtypes) == 1
795795

796796
IntervalDtype("interval")
797-
assert len(IntervalDtype._cache) == 2
797+
assert len(IntervalDtype._cache_dtypes) == 2
798798

799799
IntervalDtype.reset_cache()
800800
tm.round_trip_pickle(dtype)
801-
assert len(IntervalDtype._cache) == 0
801+
assert len(IntervalDtype._cache_dtypes) == 0
802802

803803
def test_not_string(self):
804804
# GH30568: though IntervalDtype has object kind, it cannot be string

0 commit comments

Comments
 (0)