Skip to content

Commit d219243

Browse files
committed
Add CompWrapper to clean code
1 parent 2626215 commit d219243

File tree

6 files changed

+97
-32
lines changed

6 files changed

+97
-32
lines changed

pandas/core/arrays/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from .array_ import array # noqa
22
from .base import (ExtensionArray, # noqa
33
ExtensionOpsMixin,
4-
ExtensionScalarOpsMixin)
4+
ExtensionScalarOpsMixin,
5+
CompWrapper)
56
from .categorical import Categorical # noqa
67
from .datetimes import DatetimeArray # noqa
78
from .interval import IntervalArray # noqa

pandas/core/arrays/base.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
without warning.
77
"""
88
import operator
9+
from functools import wraps
910

1011
import numpy as np
1112

@@ -15,7 +16,7 @@
1516
from pandas.util._decorators import Appender, Substitution
1617

1718
from pandas.core.dtypes.common import is_list_like
18-
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
19+
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
1920
from pandas.core.dtypes.missing import isna
2021

2122
from pandas.core import ops
@@ -1118,3 +1119,83 @@ def _create_arithmetic_method(cls, op):
11181119
@classmethod
11191120
def _create_comparison_method(cls, op):
11201121
return cls._create_method(op, coerce_to_dtype=False)
1122+
1123+
'''
1124+
def validate_comp_other(comp, list_to_array=False, validate_len=False,
1125+
zerodim=False, inst_from_senior_cls=False):
1126+
def wrapper(self, other):
1127+
if list_to_array is True:
1128+
if is_list_like(other):
1129+
other = np.asarray(other)
1130+
1131+
if validate_len is True:
1132+
if is_list_like(other) and len(other) != len(self):
1133+
raise ValueError("Lenghts must match")
1134+
1135+
if zerodim is True:
1136+
import pandas._libs as lib
1137+
other = lib.item_from_zerodim(other)
1138+
1139+
if inst_from_senior_cls is True:
1140+
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
1141+
# Rely on pandas to unbox and dispatch to us.
1142+
return NotImplemented
1143+
1144+
comp(self, other)
1145+
return wrapper
1146+
'''
1147+
1148+
1149+
class CompWrapper(object):
1150+
__key__ = ['list_to_array', 'validate_len',
1151+
'zerodim', 'inst_from_senior_cls']
1152+
1153+
def __init__(self,
1154+
list_to_array=None,
1155+
validate_len=None,
1156+
zerodim=None,
1157+
inst_from_senior_cls=None):
1158+
self.list_to_array = list_to_array
1159+
self.validate_len = validate_len
1160+
self.zerodim = zerodim
1161+
self.inst_from_senior_cls = inst_from_senior_cls
1162+
1163+
def _list_to_array(self, comp):
1164+
@wraps(comp)
1165+
def wrapper(comp_self, comp_other):
1166+
if is_list_like(comp_other):
1167+
comp_other = np.asarray(comp_other)
1168+
return comp(comp_self, comp_other)
1169+
return wrapper
1170+
1171+
def _validate_len(self, comp):
1172+
@wraps(comp)
1173+
def wrapper(comp_self, comp_other):
1174+
if is_list_like(comp_other) and len(comp_other) != len(comp_self):
1175+
raise ValueError("Lengths must match to compare")
1176+
return comp(comp_self, comp_other)
1177+
return wrapper
1178+
1179+
def _zerodim(self, comp):
1180+
@wraps(comp)
1181+
def wrapper(comp_self, comp_other):
1182+
from pandas._libs import lib
1183+
comp_other = lib.item_from_zerodim(comp_other)
1184+
return comp(comp_self, comp_other)
1185+
return wrapper
1186+
1187+
def _inst_from_senior_cls(self, comp):
1188+
@wraps(comp)
1189+
def wrapper(comp_self, comp_other):
1190+
if isinstance(comp_other, (ABCDataFrame,
1191+
ABCSeries, ABCIndexClass)):
1192+
# Rely on pandas to unbox and dispatch to us.
1193+
return NotImplemented
1194+
return comp(comp_self, comp_other)
1195+
return wrapper
1196+
1197+
def __call__(self, comp):
1198+
for key in CompWrapper.__key__:
1199+
if getattr(self, key) is True:
1200+
comp = getattr(self, '_' + key)(comp)
1201+
return comp

pandas/core/arrays/datetimes.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
is_string_dtype, is_timedelta64_dtype, pandas_dtype)
2222
from pandas.core.dtypes.dtypes import DatetimeTZDtype
2323
from pandas.core.dtypes.generic import (
24-
ABCDataFrame, ABCIndexClass, ABCPandasArray, ABCSeries)
24+
ABCIndexClass, ABCPandasArray, ABCSeries)
2525
from pandas.core.dtypes.missing import isna
2626

2727
from pandas.core import ops
2828
from pandas.core.algorithms import checked_add_with_arr
29-
from pandas.core.arrays import datetimelike as dtl
29+
from pandas.core.arrays import datetimelike as dtl, CompWrapper
3030
from pandas.core.arrays._ranges import generate_regular_range
3131
import pandas.core.common as com
3232

@@ -130,12 +130,8 @@ def _dt_array_cmp(cls, op):
130130
opname = '__{name}__'.format(name=op.__name__)
131131
nat_result = True if opname == '__ne__' else False
132132

133+
@CompWrapper(inst_from_senior_cls=True, validate_len=True, zerodim=True)
133134
def wrapper(self, other):
134-
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
135-
return NotImplemented
136-
137-
other = lib.item_from_zerodim(other)
138-
139135
if isinstance(other, (datetime, np.datetime64, compat.string_types)):
140136
if isinstance(other, (datetime, np.datetime64)):
141137
# GH#18435 strings get a pass from tzawareness compat
@@ -152,8 +148,8 @@ def wrapper(self, other):
152148
result.fill(nat_result)
153149
elif lib.is_scalar(other) or np.ndim(other) == 0:
154150
return ops.invalid_comparison(self, other, op)
155-
elif len(other) != len(self):
156-
raise ValueError("Lengths must match")
151+
#elif len(other) != len(self):
152+
# raise ValueError("Lengths must match")
157153
else:
158154
if isinstance(other, list):
159155
try:

pandas/core/arrays/integer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pandas.core.dtypes.missing import isna, notna
1919

2020
from pandas.core import nanops
21-
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
21+
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin, CompWrapper
2222
from pandas.core.tools.numeric import to_numeric
2323

2424

@@ -529,22 +529,17 @@ def _values_for_argsort(self):
529529

530530
@classmethod
531531
def _create_comparison_method(cls, op):
532+
@CompWrapper(validate_len=True, inst_from_senior_cls=True)
532533
def cmp_method(self, other):
533534

534535
op_name = op.__name__
535536
mask = None
536537

537-
if isinstance(other, (ABCSeries, ABCIndexClass)):
538-
# Rely on pandas to unbox and dispatch to us.
539-
return NotImplemented
540-
541538
if isinstance(other, IntegerArray):
542539
other, mask = other._data, other._mask
543540

544541
elif is_list_like(other):
545542
other = np.asarray(other)
546-
if other.ndim > 0 and len(self) != len(other):
547-
raise ValueError('Lengths must match to compare')
548543

549544
other = lib.item_from_zerodim(other)
550545

pandas/core/arrays/period.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616

1717
from pandas.core.dtypes.common import (
1818
_TD_DTYPE, ensure_object, is_datetime64_dtype, is_float_dtype,
19-
is_list_like, is_period_dtype, pandas_dtype)
19+
is_period_dtype, pandas_dtype)
2020
from pandas.core.dtypes.dtypes import PeriodDtype
2121
from pandas.core.dtypes.generic import (
22-
ABCDataFrame, ABCIndexClass, ABCPeriodIndex, ABCSeries)
22+
ABCIndexClass, ABCPeriodIndex, ABCSeries)
2323
from pandas.core.dtypes.missing import isna, notna
2424

2525
import pandas.core.algorithms as algos
2626
from pandas.core.arrays import datetimelike as dtl
27+
from pandas.core.arrays.base import CompWrapper
2728
import pandas.core.common as com
2829

2930
from pandas.tseries import frequencies
@@ -48,15 +49,10 @@ def _period_array_cmp(cls, op):
4849
opname = '__{name}__'.format(name=op.__name__)
4950
nat_result = True if opname == '__ne__' else False
5051

52+
@CompWrapper(validate_len=True, inst_from_senior_cls=True)
5153
def wrapper(self, other):
5254
op = getattr(self.asi8, opname)
5355

54-
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
55-
return NotImplemented
56-
57-
if is_list_like(other) and len(other) != len(self):
58-
raise ValueError("Lengths must match")
59-
6056
if isinstance(other, Period):
6157
self._check_compatible_with(other)
6258

pandas/core/arrays/timedeltas.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pandas.compat as compat
1616
from pandas.util._decorators import Appender
1717

18+
from pandas.core.arrays import CompWrapper
1819
from pandas.core.dtypes.common import (
1920
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_dtype_equal,
2021
is_float_dtype, is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
@@ -64,10 +65,8 @@ def _td_array_cmp(cls, op):
6465
opname = '__{name}__'.format(name=op.__name__)
6566
nat_result = True if opname == '__ne__' else False
6667

68+
@CompWrapper(validate_len=True, inst_from_senior_cls=True)
6769
def wrapper(self, other):
68-
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
69-
return NotImplemented
70-
7170
if _is_convertible_to_td(other) or other is NaT:
7271
try:
7372
other = Timedelta(other)
@@ -82,9 +81,6 @@ def wrapper(self, other):
8281
elif not is_list_like(other):
8382
return ops.invalid_comparison(self, other, op)
8483

85-
elif len(other) != len(self):
86-
raise ValueError("Lengths must match")
87-
8884
else:
8985
try:
9086
other = type(self)._from_sequence(other)._data

0 commit comments

Comments
 (0)