Skip to content

Commit b2791c1

Browse files
committed
Add CompWrapper to clean code
1 parent 33f91d8 commit b2791c1

File tree

6 files changed

+97
-32
lines changed

6 files changed

+97
-32
lines changed

pandas/core/arrays/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from .array_ import array # noqa
22
from .base import (ExtensionArray, # noqa
33
ExtensionOpsMixin,
4-
ExtensionScalarOpsMixin)
4+
ExtensionScalarOpsMixin,
5+
CompWrapper)
56
from .categorical import Categorical # noqa
67
from .datetimes import DatetimeArray # noqa
78
from .interval import IntervalArray # noqa

pandas/core/arrays/base.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
without warning.
77
"""
88
import operator
9+
from functools import wraps
910

1011
import numpy as np
1112

@@ -15,7 +16,7 @@
1516
from pandas.util._decorators import Appender, Substitution
1617

1718
from pandas.core.dtypes.common import is_list_like
18-
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
19+
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
1920
from pandas.core.dtypes.missing import isna
2021

2122
from pandas.core import ops
@@ -1118,3 +1119,83 @@ def _create_arithmetic_method(cls, op):
11181119
@classmethod
11191120
def _create_comparison_method(cls, op):
11201121
return cls._create_method(op, coerce_to_dtype=False)
1122+
1123+
'''
1124+
def validate_comp_other(comp, list_to_array=False, validate_len=False,
1125+
zerodim=False, inst_from_senior_cls=False):
1126+
def wrapper(self, other):
1127+
if list_to_array is True:
1128+
if is_list_like(other):
1129+
other = np.asarray(other)
1130+
1131+
if validate_len is True:
1132+
if is_list_like(other) and len(other) != len(self):
1133+
raise ValueError("Lenghts must match")
1134+
1135+
if zerodim is True:
1136+
import pandas._libs as lib
1137+
other = lib.item_from_zerodim(other)
1138+
1139+
if inst_from_senior_cls is True:
1140+
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
1141+
# Rely on pandas to unbox and dispatch to us.
1142+
return NotImplemented
1143+
1144+
comp(self, other)
1145+
return wrapper
1146+
'''
1147+
1148+
1149+
class CompWrapper(object):
1150+
__key__ = ['list_to_array', 'validate_len',
1151+
'zerodim', 'inst_from_senior_cls']
1152+
1153+
def __init__(self,
1154+
list_to_array=None,
1155+
validate_len=None,
1156+
zerodim=None,
1157+
inst_from_senior_cls=None):
1158+
self.list_to_array = list_to_array
1159+
self.validate_len = validate_len
1160+
self.zerodim = zerodim
1161+
self.inst_from_senior_cls = inst_from_senior_cls
1162+
1163+
def _list_to_array(self, comp):
1164+
@wraps(comp)
1165+
def wrapper(comp_self, comp_other):
1166+
if is_list_like(comp_other):
1167+
comp_other = np.asarray(comp_other)
1168+
return comp(comp_self, comp_other)
1169+
return wrapper
1170+
1171+
def _validate_len(self, comp):
1172+
@wraps(comp)
1173+
def wrapper(comp_self, comp_other):
1174+
if is_list_like(comp_other) and len(comp_other) != len(comp_self):
1175+
raise ValueError("Lengths must match to compare")
1176+
return comp(comp_self, comp_other)
1177+
return wrapper
1178+
1179+
def _zerodim(self, comp):
1180+
@wraps(comp)
1181+
def wrapper(comp_self, comp_other):
1182+
from pandas._libs import lib
1183+
comp_other = lib.item_from_zerodim(comp_other)
1184+
return comp(comp_self, comp_other)
1185+
return wrapper
1186+
1187+
def _inst_from_senior_cls(self, comp):
1188+
@wraps(comp)
1189+
def wrapper(comp_self, comp_other):
1190+
if isinstance(comp_other, (ABCDataFrame,
1191+
ABCSeries, ABCIndexClass)):
1192+
# Rely on pandas to unbox and dispatch to us.
1193+
return NotImplemented
1194+
return comp(comp_self, comp_other)
1195+
return wrapper
1196+
1197+
def __call__(self, comp):
1198+
for key in CompWrapper.__key__:
1199+
if getattr(self, key) is True:
1200+
comp = getattr(self, '_' + key)(comp)
1201+
return comp

pandas/core/arrays/datetimes.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
is_string_dtype, is_timedelta64_dtype, pandas_dtype)
2121
from pandas.core.dtypes.dtypes import DatetimeTZDtype
2222
from pandas.core.dtypes.generic import (
23-
ABCDataFrame, ABCIndexClass, ABCPandasArray, ABCSeries)
23+
ABCIndexClass, ABCPandasArray, ABCSeries)
2424
from pandas.core.dtypes.missing import isna
2525

2626
from pandas.core import ops
2727
from pandas.core.algorithms import checked_add_with_arr
28-
from pandas.core.arrays import datetimelike as dtl
28+
from pandas.core.arrays import datetimelike as dtl, CompWrapper
2929
from pandas.core.arrays._ranges import generate_regular_range
3030
import pandas.core.common as com
3131

@@ -129,12 +129,8 @@ def _dt_array_cmp(cls, op):
129129
opname = '__{name}__'.format(name=op.__name__)
130130
nat_result = True if opname == '__ne__' else False
131131

132+
@CompWrapper(inst_from_senior_cls=True, validate_len=True, zerodim=True)
132133
def wrapper(self, other):
133-
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
134-
return NotImplemented
135-
136-
other = lib.item_from_zerodim(other)
137-
138134
if isinstance(other, (datetime, np.datetime64, compat.string_types)):
139135
if isinstance(other, (datetime, np.datetime64)):
140136
# GH#18435 strings get a pass from tzawareness compat
@@ -151,8 +147,8 @@ def wrapper(self, other):
151147
result.fill(nat_result)
152148
elif lib.is_scalar(other) or np.ndim(other) == 0:
153149
return ops.invalid_comparison(self, other, op)
154-
elif len(other) != len(self):
155-
raise ValueError("Lengths must match")
150+
#elif len(other) != len(self):
151+
# raise ValueError("Lengths must match")
156152
else:
157153
if isinstance(other, list):
158154
try:

pandas/core/arrays/integer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pandas.core.dtypes.missing import isna, notna
1919

2020
from pandas.core import nanops
21-
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin
21+
from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin, CompWrapper
2222
from pandas.core.tools.numeric import to_numeric
2323

2424

@@ -496,22 +496,17 @@ def _values_for_argsort(self):
496496

497497
@classmethod
498498
def _create_comparison_method(cls, op):
499+
@CompWrapper(validate_len=True, inst_from_senior_cls=True)
499500
def cmp_method(self, other):
500501

501502
op_name = op.__name__
502503
mask = None
503504

504-
if isinstance(other, (ABCSeries, ABCIndexClass)):
505-
# Rely on pandas to unbox and dispatch to us.
506-
return NotImplemented
507-
508505
if isinstance(other, IntegerArray):
509506
other, mask = other._data, other._mask
510507

511508
elif is_list_like(other):
512509
other = np.asarray(other)
513-
if other.ndim > 0 and len(self) != len(other):
514-
raise ValueError('Lengths must match to compare')
515510

516511
other = lib.item_from_zerodim(other)
517512

pandas/core/arrays/period.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515

1616
from pandas.core.dtypes.common import (
1717
_TD_DTYPE, ensure_object, is_datetime64_dtype, is_float_dtype,
18-
is_list_like, is_period_dtype, pandas_dtype)
18+
is_period_dtype, pandas_dtype)
1919
from pandas.core.dtypes.dtypes import PeriodDtype
2020
from pandas.core.dtypes.generic import (
21-
ABCDataFrame, ABCIndexClass, ABCPeriodIndex, ABCSeries)
21+
ABCIndexClass, ABCPeriodIndex, ABCSeries)
2222
from pandas.core.dtypes.missing import isna, notna
2323

2424
import pandas.core.algorithms as algos
2525
from pandas.core.arrays import datetimelike as dtl
26+
from pandas.core.arrays.base import CompWrapper
2627
import pandas.core.common as com
2728

2829
from pandas.tseries import frequencies
@@ -47,15 +48,10 @@ def _period_array_cmp(cls, op):
4748
opname = '__{name}__'.format(name=op.__name__)
4849
nat_result = True if opname == '__ne__' else False
4950

51+
@CompWrapper(validate_len=True, inst_from_senior_cls=True)
5052
def wrapper(self, other):
5153
op = getattr(self.asi8, opname)
5254

53-
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
54-
return NotImplemented
55-
56-
if is_list_like(other) and len(other) != len(self):
57-
raise ValueError("Lengths must match")
58-
5955
if isinstance(other, Period):
6056
self._check_compatible_with(other)
6157

pandas/core/arrays/timedeltas.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pandas.compat as compat
1515
from pandas.util._decorators import Appender
1616

17+
from pandas.core.arrays import CompWrapper
1718
from pandas.core.dtypes.common import (
1819
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_float_dtype,
1920
is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
@@ -63,10 +64,8 @@ def _td_array_cmp(cls, op):
6364
opname = '__{name}__'.format(name=op.__name__)
6465
nat_result = True if opname == '__ne__' else False
6566

67+
@CompWrapper(validate_len=True, inst_from_senior_cls=True)
6668
def wrapper(self, other):
67-
if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
68-
return NotImplemented
69-
7069
if _is_convertible_to_td(other) or other is NaT:
7170
try:
7271
other = Timedelta(other)
@@ -81,9 +80,6 @@ def wrapper(self, other):
8180
elif not is_list_like(other):
8281
return ops.invalid_comparison(self, other, op)
8382

84-
elif len(other) != len(self):
85-
raise ValueError("Lengths must match")
86-
8783
else:
8884
try:
8985
other = type(self)._from_sequence(other)._data

0 commit comments

Comments
 (0)