Skip to content

Commit ce45656

Browse files
authored
Fix: make unsupported comparisons return NotImplemented (#1199)
Make all comparator magic methods return `NotImplemented` instead of `False` (or raising `TypeError` in some instances) if the other operand is not of a supported type. This means that when comparing a driver type with another type is doesn't support, the other type get the chance to handle the comparison. Affected types: * `neo4j.Record` * `neo4j.graph.Node`, `neo4j.graph.Relationship`, `neo4j.graph.Path` * `neo4j.time.Date`, `neo4j.time.Time`, `neo4j.time.DateTime` * `neo4j.spatial.Point` (and subclasses)
1 parent a32e403 commit ce45656

File tree

10 files changed

+188
-110
lines changed

10 files changed

+188
-110
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog.
156156
should be treated as immutable.
157157
- Graph type sets (`neo4j.graph.EntitySetView`) can no longer by indexed by legacy `id` (`int`, e.g., `graph.nodes[0]`).
158158
Use the `element_id` instead (`str`, e.g., `graph.nodes["..."]`).
159+
- Make all comparator magic methods return `NotImplemented` instead of `False` (or raising `TypeError` in some
160+
instances) if the other operand is not of a supported type.
161+
This means that when comparing a driver type with another type is doesn't support, the other type get the chance to
162+
handle the comparison.
163+
Affected types:
164+
- `neo4j.Record`
165+
- `neo4j.graph.Node`, `neo4j.graph.Relationship`, `neo4j.graph.Path`
166+
- `neo4j.time.Date`, `neo4j.time.Time`, `neo4j.time.DateTime`
167+
- `neo4j.spatial.Point` (and subclasses)
159168

160169

161170
## Version 5.28

src/neo4j/_codec/packstream/_python/_common.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ def __eq__(self, other):
2828
try:
2929
return self.tag == other.tag and self.fields == other.fields
3030
except AttributeError:
31-
return False
32-
33-
def __ne__(self, other):
34-
return not self.__eq__(other)
31+
return NotImplementedError
3532

3633
def __len__(self):
3734
return len(self.fields)

src/neo4j/_data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,7 @@ def __eq__(self, other: object) -> bool:
110110
other = t.cast(t.Mapping, other)
111111
return dict(self) == dict(other)
112112
else:
113-
return False
114-
115-
def __ne__(self, other: object) -> bool:
116-
return not self.__eq__(other)
113+
return NotImplemented
117114

118115
def __hash__(self):
119116
return reduce(xor_operator, map(hash, self.items()))

src/neo4j/_io/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,28 +61,28 @@ def __ne__(self, other: object) -> bool:
6161
return self.version != other
6262
return NotImplemented
6363

64-
def __lt__(self, other: object) -> bool:
64+
def __lt__(self, other: BoltProtocolVersion | tuple) -> bool:
6565
if isinstance(other, BoltProtocolVersion):
6666
return self.version < other.version
6767
if isinstance(other, tuple):
6868
return self.version < other
6969
return NotImplemented
7070

71-
def __le__(self, other: object) -> bool:
71+
def __le__(self, other: BoltProtocolVersion | tuple) -> bool:
7272
if isinstance(other, BoltProtocolVersion):
7373
return self.version <= other.version
7474
if isinstance(other, tuple):
7575
return self.version <= other
7676
return NotImplemented
7777

78-
def __gt__(self, other: object) -> bool:
78+
def __gt__(self, other: BoltProtocolVersion | tuple) -> bool:
7979
if isinstance(other, BoltProtocolVersion):
8080
return self.version > other.version
8181
if isinstance(other, tuple):
8282
return self.version > other
8383
return NotImplemented
8484

85-
def __ge__(self, other: object) -> bool:
85+
def __ge__(self, other: BoltProtocolVersion | tuple) -> bool:
8686
if isinstance(other, BoltProtocolVersion):
8787
return self.version >= other.version
8888
if isinstance(other, tuple):

src/neo4j/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
if parameters:
114114
self.parameters = parameters
115115

116-
def __eq__(self, other: _t.Any) -> bool:
116+
def __eq__(self, other: object) -> bool:
117117
if not isinstance(other, Auth):
118118
return NotImplemented
119119
return vars(self) == vars(other)

src/neo4j/graph/__init__.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,14 @@ def __init__(
117117
}
118118

119119
def __eq__(self, other: _t.Any) -> bool:
120-
# TODO: 6.0 - return NotImplemented on type mismatch instead of False
121120
try:
122121
return (
123122
type(self) is type(other)
124123
and self.graph == other.graph
125124
and self.element_id == other.element_id
126125
)
127126
except AttributeError:
128-
return False
129-
130-
def __ne__(self, other: object) -> bool:
131-
return not self.__eq__(other)
127+
return NotImplemented
132128

133129
def __hash__(self):
134130
return hash(self._element_id)
@@ -325,17 +321,13 @@ def __repr__(self) -> str:
325321
)
326322

327323
def __eq__(self, other: _t.Any) -> bool:
328-
# TODO: 6.0 - return NotImplemented on type mismatch instead of False
329324
try:
330325
return (
331326
self.start_node == other.start_node
332327
and self.relationships == other.relationships
333328
)
334329
except AttributeError:
335-
return False
336-
337-
def __ne__(self, other: object) -> bool:
338-
return not self.__eq__(other)
330+
return NotImplemented
339331

340332
def __hash__(self):
341333
value = hash(self._nodes[0])

src/neo4j/spatial/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ def __eq__(self, other: object) -> bool:
7777
_t.cast(Point, other)
7878
)
7979
except (AttributeError, TypeError):
80-
return False
81-
82-
def __ne__(self, other: object) -> bool:
83-
return not self.__eq__(other)
80+
return NotImplemented
8481

8582
def __hash__(self):
8683
return hash(type(self)) ^ hash(tuple(self))

src/neo4j/time/__init__.py

Lines changed: 57 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,60 +1217,31 @@ def __hash__(self):
12171217
def __eq__(self, other: object) -> bool:
12181218
"""``==`` comparison with :class:`.Date` or :class:`datetime.date`."""
12191219
if not isinstance(other, (Date, _date)):
1220-
# TODO: 6.0 - return NotImplemented for non-Date objects
1221-
# return NotImplemented
1222-
return False
1220+
return NotImplemented
12231221
return self.toordinal() == other.toordinal()
12241222

1225-
def __ne__(self, other: object) -> bool:
1226-
"""``!=`` comparison with :class:`.Date` or :class:`datetime.date`."""
1227-
# TODO: 6.0 - return NotImplemented for non-Date objects
1228-
# if not isinstance(other, (Date, date)):
1229-
# return NotImplemented
1230-
return not self.__eq__(other)
1231-
12321223
def __lt__(self, other: Date | _date) -> bool:
12331224
"""``<`` comparison with :class:`.Date` or :class:`datetime.date`."""
12341225
if not isinstance(other, (Date, _date)):
1235-
# TODO: 6.0 - return NotImplemented for non-Date objects
1236-
# return NotImplemented
1237-
raise TypeError(
1238-
"'<' not supported between instances of 'Date' and "
1239-
f"{type(other).__name__!r}"
1240-
)
1226+
return NotImplemented
12411227
return self.toordinal() < other.toordinal()
12421228

12431229
def __le__(self, other: Date | _date) -> bool:
12441230
"""``<=`` comparison with :class:`.Date` or :class:`datetime.date`."""
12451231
if not isinstance(other, (Date, _date)):
1246-
# TODO: 6.0 - return NotImplemented for non-Date objects
1247-
# return NotImplemented
1248-
raise TypeError(
1249-
"'<=' not supported between instances of 'Date' and "
1250-
f"{type(other).__name__!r}"
1251-
)
1232+
return NotImplemented
12521233
return self.toordinal() <= other.toordinal()
12531234

12541235
def __ge__(self, other: Date | _date) -> bool:
12551236
"""``>=`` comparison with :class:`.Date` or :class:`datetime.date`."""
12561237
if not isinstance(other, (Date, _date)):
1257-
# TODO: 6.0 - return NotImplemented for non-Date objects
1258-
# return NotImplemented
1259-
raise TypeError(
1260-
"'>=' not supported between instances of 'Date' and "
1261-
f"{type(other).__name__!r}"
1262-
)
1238+
return NotImplemented
12631239
return self.toordinal() >= other.toordinal()
12641240

12651241
def __gt__(self, other: Date | _date) -> bool:
12661242
"""``>`` comparison with :class:`.Date` or :class:`datetime.date`."""
12671243
if not isinstance(other, (Date, _date)):
1268-
# TODO: 6.0 - return NotImplemented for non-Date objects
1269-
# return NotImplemented
1270-
raise TypeError(
1271-
"'>' not supported between instances of 'Date' and "
1272-
f"{type(other).__name__!r}"
1273-
)
1244+
return NotImplemented
12741245
return self.toordinal() > other.toordinal()
12751246

12761247
def __add__(self, other: Duration) -> Date: # type: ignore[override]
@@ -1917,29 +1888,37 @@ def tzinfo(self) -> _tzinfo | None:
19171888

19181889
# OPERATIONS #
19191890

1920-
def _get_both_normalized_ticks(self, other: object, strict=True):
1921-
if isinstance(other, (_time, Time)) and (
1922-
(self.utc_offset() is None) ^ (other.utcoffset() is None)
1923-
):
1891+
@_t.overload
1892+
def _get_both_normalized_ticks(
1893+
self, other: Time | _time, strict: _t.Literal[True] = True
1894+
) -> tuple[int, int]: ...
1895+
1896+
@_t.overload
1897+
def _get_both_normalized_ticks(
1898+
self, other: Time | _time, strict: _t.Literal[False]
1899+
) -> tuple[int, int] | None: ...
1900+
1901+
def _get_both_normalized_ticks(
1902+
self, other: Time | _time, strict: bool = True
1903+
) -> tuple[int, int] | None:
1904+
if (self.utc_offset() is None) ^ (other.utcoffset() is None):
19241905
if strict:
19251906
raise TypeError(
19261907
"can't compare offset-naive and offset-aware times"
19271908
)
19281909
else:
1929-
return None, None
1910+
return None
19301911
other_ticks: int
19311912
if isinstance(other, Time):
19321913
other_ticks = other.__ticks
1933-
elif isinstance(other, _time):
1914+
else:
1915+
assert isinstance(other, _time)
19341916
other_ticks = int(
19351917
3600000000000 * other.hour
19361918
+ 60000000000 * other.minute
19371919
+ _NANO_SECONDS * other.second
19381920
+ 1000 * other.microsecond
19391921
)
1940-
else:
1941-
return None, None
1942-
assert isinstance(other, (Time, _time))
19431922
utc_offset: _timedelta | None = other.utcoffset()
19441923
if utc_offset is not None:
19451924
other_ticks -= int(utc_offset.total_seconds() * _NANO_SECONDS)
@@ -1959,43 +1938,40 @@ def __hash__(self):
19591938

19601939
def __eq__(self, other: object) -> bool:
19611940
"""`==` comparison with :class:`.Time` or :class:`datetime.time`."""
1962-
self_ticks, other_ticks = self._get_both_normalized_ticks(
1963-
other, strict=False
1964-
)
1965-
if self_ticks is None:
1941+
if not isinstance(other, (Time, _time)):
1942+
return NotImplemented
1943+
ticks = self._get_both_normalized_ticks(other, strict=False)
1944+
if ticks is None:
19661945
return False
1946+
self_ticks, other_ticks = ticks
19671947
return self_ticks == other_ticks
19681948

1969-
def __ne__(self, other: object) -> bool:
1970-
"""`!=` comparison with :class:`.Time` or :class:`datetime.time`."""
1971-
return not self.__eq__(other)
1972-
19731949
def __lt__(self, other: Time | _time) -> bool:
19741950
"""`<` comparison with :class:`.Time` or :class:`datetime.time`."""
1975-
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
1976-
if self_ticks is None:
1951+
if not isinstance(other, (Time, _time)):
19771952
return NotImplemented
1953+
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
19781954
return self_ticks < other_ticks
19791955

19801956
def __le__(self, other: Time | _time) -> bool:
19811957
"""`<=` comparison with :class:`.Time` or :class:`datetime.time`."""
1982-
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
1983-
if self_ticks is None:
1958+
if not isinstance(other, (Time, _time)):
19841959
return NotImplemented
1960+
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
19851961
return self_ticks <= other_ticks
19861962

19871963
def __ge__(self, other: Time | _time) -> bool:
19881964
"""`>=` comparison with :class:`.Time` or :class:`datetime.time`."""
1989-
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
1990-
if self_ticks is None:
1965+
if not isinstance(other, (Time, _time)):
19911966
return NotImplemented
1967+
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
19921968
return self_ticks >= other_ticks
19931969

19941970
def __gt__(self, other: Time | _time) -> bool:
19951971
"""`>` comparison with :class:`.Time` or :class:`datetime.time`."""
1996-
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
1997-
if self_ticks is None:
1972+
if not isinstance(other, (Time, _time)):
19981973
return NotImplemented
1974+
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
19991975
return self_ticks > other_ticks
20001976

20011977
# INSTANCE METHODS #
@@ -2603,29 +2579,36 @@ def hour_minute_second_nanosecond(self) -> tuple[int, int, int, int]:
26032579

26042580
# OPERATIONS #
26052581

2606-
def _get_both_normalized(self, other, strict=True):
2607-
if isinstance(other, (_datetime, DateTime)) and (
2608-
(self.utc_offset() is None) ^ (other.utcoffset() is None)
2609-
):
2582+
@_t.overload
2583+
def _get_both_normalized(
2584+
self, other: _datetime | DateTime, strict: _t.Literal[True] = True
2585+
) -> tuple[DateTime, DateTime | _datetime]: ...
2586+
2587+
@_t.overload
2588+
def _get_both_normalized(
2589+
self, other: _datetime | DateTime, strict: _t.Literal[False]
2590+
) -> tuple[DateTime, DateTime | _datetime] | None: ...
2591+
2592+
def _get_both_normalized(
2593+
self, other: _datetime | DateTime, strict: bool = True
2594+
) -> tuple[DateTime, DateTime | _datetime] | None:
2595+
if (self.utc_offset() is None) ^ (other.utcoffset() is None):
26102596
if strict:
26112597
raise TypeError(
26122598
"can't compare offset-naive and offset-aware datetimes"
26132599
)
26142600
else:
2615-
return None, None
2601+
return None
26162602
self_norm = self
26172603
utc_offset = self.utc_offset()
26182604
if utc_offset is not None:
26192605
self_norm -= utc_offset
26202606
self_norm = self_norm.replace(tzinfo=None)
26212607
other_norm = other
2622-
if isinstance(other, (_datetime, DateTime)):
2623-
utc_offset = other.utcoffset()
2624-
if utc_offset is not None:
2625-
other_norm -= utc_offset
2626-
other_norm = other_norm.replace(tzinfo=None)
2627-
else:
2628-
return None, None
2608+
utc_offset = other.utcoffset()
2609+
if utc_offset is not None:
2610+
other_norm -= utc_offset
2611+
other_norm = other_norm.replace(tzinfo=None)
26292612
return self_norm, other_norm
26302613

26312614
def __hash__(self):
@@ -2647,21 +2630,12 @@ def __eq__(self, other: object) -> bool:
26472630
return NotImplemented
26482631
if self.utc_offset() == other.utcoffset():
26492632
return self.date() == other.date() and self.time() == other.time()
2650-
self_norm, other_norm = self._get_both_normalized(other, strict=False)
2651-
if self_norm is None:
2633+
normalized = self._get_both_normalized(other, strict=False)
2634+
if normalized is None:
26522635
return False
2636+
self_norm, other_norm = normalized
26532637
return self_norm == other_norm
26542638

2655-
def __ne__(self, other: object) -> bool:
2656-
"""
2657-
``!=`` comparison with another datetime.
2658-
2659-
Accepts :class:`.DateTime` and :class:`datetime.datetime`.
2660-
"""
2661-
if not isinstance(other, (DateTime, _datetime)):
2662-
return NotImplemented
2663-
return not self.__eq__(other)
2664-
26652639
def __lt__( # type: ignore[override]
26662640
self, other: _datetime | DateTime
26672641
) -> bool:

0 commit comments

Comments
 (0)