Skip to content

Commit e91ab52

Browse files
committed
refactor dtypes
1 parent 2dd5780 commit e91ab52

File tree

3 files changed

+50
-37
lines changed

3 files changed

+50
-37
lines changed

pandas/core/arrays/mask/_base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,42 @@
77

88
import copy
99
import numpy as np
10+
from pandas import compat
1011
from pandas.core.missing import isna
1112
from pandas.core.arrays.base import ExtensionArray
13+
from pandas.api.extensions import ExtensionDtype
14+
15+
16+
class BoolDtype(ExtensionDtype):
17+
18+
type = np.bool_
19+
kind = 'b'
20+
name = 'bool'
21+
22+
@classmethod
23+
def construct_from_string(cls, string):
24+
if string == cls.name:
25+
return cls()
26+
else:
27+
raise TypeError("Cannot construct a '{}' from "
28+
"'{}'".format(cls, string))
29+
30+
def _is_boolean(self):
31+
return True
32+
33+
def __hash__(self):
34+
return hash(str(self))
35+
36+
def __eq__(self, other):
37+
# compare == to np.dtype('bool')
38+
if isinstance(other, compat.string_types):
39+
return other == self.name
40+
elif other is self:
41+
return True
42+
elif isinstance(other, np.dtype):
43+
return other == 'bool'
44+
else:
45+
return hash(self) == hash(other)
1246

1347

1448
class BoolArray(ExtensionArray):

pandas/core/arrays/mask/_numpy.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,19 @@
44

55
import numpy as np
66

7-
from pandas.core.arrays.mask._base import BoolArray
8-
from pandas.api.extensions import (
9-
ExtensionDtype, register_extension_dtype)
7+
from pandas.core.arrays.mask._base import BoolDtype, BoolArray
8+
from pandas.api.extensions import register_extension_dtype, take
109

1110

1211
@register_extension_dtype
13-
class NumpyBoolDtype(ExtensionDtype):
12+
class NumpyBoolDtype(BoolDtype):
1413

15-
type = np.bool_
16-
kind = 'b'
17-
name = 'bool'
1814
na_value = np.nan
1915

20-
@classmethod
21-
def construct_from_string(cls, string):
22-
if string == cls.name:
23-
return cls()
24-
else:
25-
raise TypeError("Cannot construct a '{}' from "
26-
"'{}'".format(cls, string))
27-
2816
@classmethod
2917
def construct_array_type(cls):
3018
return NumpyBoolArray
3119

32-
def _is_boolean(self):
33-
return True
34-
3520

3621
class NumpyBoolArray(BoolArray):
3722
"""Generic class which can be used to represent missing data.
@@ -79,5 +64,13 @@ def reshape(self, shape, **kwargs):
7964
def astype(self, dtype, copy=False):
8065
return np.array(self, copy=False).astype(dtype, copy=copy)
8166

82-
def take(self, indicies, **kwargs):
83-
return np.array(self, copy=False).take(indicies)
67+
def take(self, indices, allow_fill=False, fill_value=None, axis=None):
68+
# TODO: had to add axis here
69+
data = self._data
70+
71+
if allow_fill and fill_value is None:
72+
fill_value = self.dtype.na_value
73+
74+
result = take(data, indices, fill_value=fill_value,
75+
allow_fill=allow_fill)
76+
return self._from_sequence(result, dtype=self.dtype)

pandas/core/arrays/mask/_pyarrow.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
from pandas.api.types import is_scalar
1414
from pandas.api.extensions import (
15-
ExtensionDtype, register_extension_dtype, take)
16-
from pandas.core.arrays.mask._base import BoolArray
15+
register_extension_dtype, take)
16+
from pandas.core.arrays.mask._base import BoolDtype, BoolArray
1717

1818

1919
# we require pyarrow >= 0.10.0
@@ -27,28 +27,14 @@
2727

2828

2929
@register_extension_dtype
30-
class ArrowBoolDtype(ExtensionDtype):
30+
class ArrowBoolDtype(BoolDtype):
3131

32-
type = np.bool_
33-
kind = 'b'
34-
name = 'arrow_bool'
3532
na_value = pa.NULL
3633

37-
@classmethod
38-
def construct_from_string(cls, string):
39-
if string == cls.name:
40-
return cls()
41-
else:
42-
raise TypeError("Cannot construct a '{}' from "
43-
"'{}'".format(cls, string))
44-
4534
@classmethod
4635
def construct_array_type(cls):
4736
return ArrowBoolArray
4837

49-
def _is_boolean(self):
50-
return True
51-
5238

5339
class ArrowBoolArray(BoolArray):
5440

0 commit comments

Comments
 (0)