Skip to content

Fix: make unsupported comparisons return NotImplemented #1199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog.
should be treated as immutable.
- Graph type sets (`neo4j.graph.EntitySetView`) can no longer by indexed by legacy `id` (`int`, e.g., `graph.nodes[0]`).
Use the `element_id` instead (`str`, e.g., `graph.nodes["..."]`).
- Make all comparator magic methods return `NotImplemented` instead of `False` (or raising `TypeError` in some
instances) if the other operand is not of a supported type.
This means that when comparing a driver type with another type is doesn't support, the other type get the chance to
handle the comparison.
Affected types:
- `neo4j.Record`
- `neo4j.graph.Node`, `neo4j.graph.Relationship`, `neo4j.graph.Path`
- `neo4j.time.Date`, `neo4j.time.Time`, `neo4j.time.DateTime`
- `neo4j.spatial.Point` (and subclasses)


## Version 5.28
Expand Down
5 changes: 1 addition & 4 deletions src/neo4j/_codec/packstream/_python/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ def __eq__(self, other):
try:
return self.tag == other.tag and self.fields == other.fields
except AttributeError:
return False

def __ne__(self, other):
return not self.__eq__(other)
return NotImplementedError

def __len__(self):
return len(self.fields)
Expand Down
5 changes: 1 addition & 4 deletions src/neo4j/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ def __eq__(self, other: object) -> bool:
other = t.cast(t.Mapping, other)
return dict(self) == dict(other)
else:
return False

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
return NotImplemented

def __hash__(self):
return reduce(xor_operator, map(hash, self.items()))
Expand Down
8 changes: 4 additions & 4 deletions src/neo4j/_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,28 @@ def __ne__(self, other: object) -> bool:
return self.version != other
return NotImplemented

def __lt__(self, other: object) -> bool:
def __lt__(self, other: BoltProtocolVersion | tuple) -> bool:
if isinstance(other, BoltProtocolVersion):
return self.version < other.version
if isinstance(other, tuple):
return self.version < other
return NotImplemented

def __le__(self, other: object) -> bool:
def __le__(self, other: BoltProtocolVersion | tuple) -> bool:
if isinstance(other, BoltProtocolVersion):
return self.version <= other.version
if isinstance(other, tuple):
return self.version <= other
return NotImplemented

def __gt__(self, other: object) -> bool:
def __gt__(self, other: BoltProtocolVersion | tuple) -> bool:
if isinstance(other, BoltProtocolVersion):
return self.version > other.version
if isinstance(other, tuple):
return self.version > other
return NotImplemented

def __ge__(self, other: object) -> bool:
def __ge__(self, other: BoltProtocolVersion | tuple) -> bool:
if isinstance(other, BoltProtocolVersion):
return self.version >= other.version
if isinstance(other, tuple):
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
if parameters:
self.parameters = parameters

def __eq__(self, other: _t.Any) -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, Auth):
return NotImplemented
return vars(self) == vars(other)
Expand Down
12 changes: 2 additions & 10 deletions src/neo4j/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,14 @@ def __init__(
}

def __eq__(self, other: _t.Any) -> bool:
# TODO: 6.0 - return NotImplemented on type mismatch instead of False
try:
return (
type(self) is type(other)
and self.graph == other.graph
and self.element_id == other.element_id
)
except AttributeError:
return False

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
return NotImplemented

def __hash__(self):
return hash(self._element_id)
Expand Down Expand Up @@ -325,17 +321,13 @@ def __repr__(self) -> str:
)

def __eq__(self, other: _t.Any) -> bool:
# TODO: 6.0 - return NotImplemented on type mismatch instead of False
try:
return (
self.start_node == other.start_node
and self.relationships == other.relationships
)
except AttributeError:
return False

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
return NotImplemented

def __hash__(self):
value = hash(self._nodes[0])
Expand Down
5 changes: 1 addition & 4 deletions src/neo4j/spatial/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ def __eq__(self, other: object) -> bool:
_t.cast(Point, other)
)
except (AttributeError, TypeError):
return False

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
return NotImplemented

def __hash__(self):
return hash(type(self)) ^ hash(tuple(self))
Expand Down
140 changes: 57 additions & 83 deletions src/neo4j/time/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,60 +1217,31 @@ def __hash__(self):
def __eq__(self, other: object) -> bool:
"""``==`` comparison with :class:`.Date` or :class:`datetime.date`."""
if not isinstance(other, (Date, _date)):
# TODO: 6.0 - return NotImplemented for non-Date objects
# return NotImplemented
return False
return NotImplemented
return self.toordinal() == other.toordinal()

def __ne__(self, other: object) -> bool:
"""``!=`` comparison with :class:`.Date` or :class:`datetime.date`."""
# TODO: 6.0 - return NotImplemented for non-Date objects
# if not isinstance(other, (Date, date)):
# return NotImplemented
return not self.__eq__(other)

def __lt__(self, other: Date | _date) -> bool:
"""``<`` comparison with :class:`.Date` or :class:`datetime.date`."""
if not isinstance(other, (Date, _date)):
# TODO: 6.0 - return NotImplemented for non-Date objects
# return NotImplemented
raise TypeError(
"'<' not supported between instances of 'Date' and "
f"{type(other).__name__!r}"
)
return NotImplemented
return self.toordinal() < other.toordinal()

def __le__(self, other: Date | _date) -> bool:
"""``<=`` comparison with :class:`.Date` or :class:`datetime.date`."""
if not isinstance(other, (Date, _date)):
# TODO: 6.0 - return NotImplemented for non-Date objects
# return NotImplemented
raise TypeError(
"'<=' not supported between instances of 'Date' and "
f"{type(other).__name__!r}"
)
return NotImplemented
return self.toordinal() <= other.toordinal()

def __ge__(self, other: Date | _date) -> bool:
"""``>=`` comparison with :class:`.Date` or :class:`datetime.date`."""
if not isinstance(other, (Date, _date)):
# TODO: 6.0 - return NotImplemented for non-Date objects
# return NotImplemented
raise TypeError(
"'>=' not supported between instances of 'Date' and "
f"{type(other).__name__!r}"
)
return NotImplemented
return self.toordinal() >= other.toordinal()

def __gt__(self, other: Date | _date) -> bool:
"""``>`` comparison with :class:`.Date` or :class:`datetime.date`."""
if not isinstance(other, (Date, _date)):
# TODO: 6.0 - return NotImplemented for non-Date objects
# return NotImplemented
raise TypeError(
"'>' not supported between instances of 'Date' and "
f"{type(other).__name__!r}"
)
return NotImplemented
return self.toordinal() > other.toordinal()

def __add__(self, other: Duration) -> Date: # type: ignore[override]
Expand Down Expand Up @@ -1917,29 +1888,37 @@ def tzinfo(self) -> _tzinfo | None:

# OPERATIONS #

def _get_both_normalized_ticks(self, other: object, strict=True):
if isinstance(other, (_time, Time)) and (
(self.utc_offset() is None) ^ (other.utcoffset() is None)
):
@_t.overload
def _get_both_normalized_ticks(
self, other: Time | _time, strict: _t.Literal[True] = True
) -> tuple[int, int]: ...

@_t.overload
def _get_both_normalized_ticks(
self, other: Time | _time, strict: _t.Literal[False]
) -> tuple[int, int] | None: ...

def _get_both_normalized_ticks(
self, other: Time | _time, strict: bool = True
) -> tuple[int, int] | None:
if (self.utc_offset() is None) ^ (other.utcoffset() is None):
if strict:
raise TypeError(
"can't compare offset-naive and offset-aware times"
)
else:
return None, None
return None
other_ticks: int
if isinstance(other, Time):
other_ticks = other.__ticks
elif isinstance(other, _time):
else:
assert isinstance(other, _time)
other_ticks = int(
3600000000000 * other.hour
+ 60000000000 * other.minute
+ _NANO_SECONDS * other.second
+ 1000 * other.microsecond
)
else:
return None, None
assert isinstance(other, (Time, _time))
utc_offset: _timedelta | None = other.utcoffset()
if utc_offset is not None:
other_ticks -= int(utc_offset.total_seconds() * _NANO_SECONDS)
Expand All @@ -1959,43 +1938,40 @@ def __hash__(self):

def __eq__(self, other: object) -> bool:
"""`==` comparison with :class:`.Time` or :class:`datetime.time`."""
self_ticks, other_ticks = self._get_both_normalized_ticks(
other, strict=False
)
if self_ticks is None:
if not isinstance(other, (Time, _time)):
return NotImplemented
ticks = self._get_both_normalized_ticks(other, strict=False)
if ticks is None:
return False
self_ticks, other_ticks = ticks
return self_ticks == other_ticks

def __ne__(self, other: object) -> bool:
"""`!=` comparison with :class:`.Time` or :class:`datetime.time`."""
return not self.__eq__(other)

def __lt__(self, other: Time | _time) -> bool:
"""`<` comparison with :class:`.Time` or :class:`datetime.time`."""
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
if self_ticks is None:
if not isinstance(other, (Time, _time)):
return NotImplemented
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
return self_ticks < other_ticks

def __le__(self, other: Time | _time) -> bool:
"""`<=` comparison with :class:`.Time` or :class:`datetime.time`."""
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
if self_ticks is None:
if not isinstance(other, (Time, _time)):
return NotImplemented
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
return self_ticks <= other_ticks

def __ge__(self, other: Time | _time) -> bool:
"""`>=` comparison with :class:`.Time` or :class:`datetime.time`."""
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
if self_ticks is None:
if not isinstance(other, (Time, _time)):
return NotImplemented
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
return self_ticks >= other_ticks

def __gt__(self, other: Time | _time) -> bool:
"""`>` comparison with :class:`.Time` or :class:`datetime.time`."""
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
if self_ticks is None:
if not isinstance(other, (Time, _time)):
return NotImplemented
self_ticks, other_ticks = self._get_both_normalized_ticks(other)
return self_ticks > other_ticks

# INSTANCE METHODS #
Expand Down Expand Up @@ -2603,29 +2579,36 @@ def hour_minute_second_nanosecond(self) -> tuple[int, int, int, int]:

# OPERATIONS #

def _get_both_normalized(self, other, strict=True):
if isinstance(other, (_datetime, DateTime)) and (
(self.utc_offset() is None) ^ (other.utcoffset() is None)
):
@_t.overload
def _get_both_normalized(
self, other: _datetime | DateTime, strict: _t.Literal[True] = True
) -> tuple[DateTime, DateTime | _datetime]: ...

@_t.overload
def _get_both_normalized(
self, other: _datetime | DateTime, strict: _t.Literal[False]
) -> tuple[DateTime, DateTime | _datetime] | None: ...

def _get_both_normalized(
self, other: _datetime | DateTime, strict: bool = True
) -> tuple[DateTime, DateTime | _datetime] | None:
if (self.utc_offset() is None) ^ (other.utcoffset() is None):
if strict:
raise TypeError(
"can't compare offset-naive and offset-aware datetimes"
)
else:
return None, None
return None
self_norm = self
utc_offset = self.utc_offset()
if utc_offset is not None:
self_norm -= utc_offset
self_norm = self_norm.replace(tzinfo=None)
other_norm = other
if isinstance(other, (_datetime, DateTime)):
utc_offset = other.utcoffset()
if utc_offset is not None:
other_norm -= utc_offset
other_norm = other_norm.replace(tzinfo=None)
else:
return None, None
utc_offset = other.utcoffset()
if utc_offset is not None:
other_norm -= utc_offset
other_norm = other_norm.replace(tzinfo=None)
return self_norm, other_norm

def __hash__(self):
Expand All @@ -2647,21 +2630,12 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
if self.utc_offset() == other.utcoffset():
return self.date() == other.date() and self.time() == other.time()
self_norm, other_norm = self._get_both_normalized(other, strict=False)
if self_norm is None:
normalized = self._get_both_normalized(other, strict=False)
if normalized is None:
return False
self_norm, other_norm = normalized
return self_norm == other_norm

def __ne__(self, other: object) -> bool:
"""
``!=`` comparison with another datetime.

Accepts :class:`.DateTime` and :class:`datetime.datetime`.
"""
if not isinstance(other, (DateTime, _datetime)):
return NotImplemented
return not self.__eq__(other)

def __lt__( # type: ignore[override]
self, other: _datetime | DateTime
) -> bool:
Expand Down
Loading