Skip to content

Commit 8f47982

Browse files
authored
BUG: array_equivalent_object with mismatched shapes (#49363)
* BUG: array_equivalent_object with mismatched shapes * test with mismatched dtypes
1 parent d95bf9a commit 8f47982

File tree

4 files changed

+50
-18
lines changed

4 files changed

+50
-18
lines changed

pandas/_libs/lib.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def generate_bins_dt64(
230230
hasnans: bool = ...,
231231
) -> np.ndarray: ... # np.ndarray[np.int64, ndim=1]
232232
def array_equivalent_object(
233-
left: np.ndarray, # object[:]
234-
right: np.ndarray, # object[:]
233+
left: npt.NDArray[np.object_],
234+
right: npt.NDArray[np.object_],
235235
) -> bool: ...
236236
def has_infs(arr: np.ndarray) -> bool: ... # const floating[:]
237237
def get_reverse_indexer(

pandas/_libs/lib.pyx

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ from cpython.iterator cimport PyIter_Check
1515
from cpython.number cimport PyNumber_Check
1616
from cpython.object cimport (
1717
Py_EQ,
18+
PyObject,
1819
PyObject_RichCompareBool,
1920
PyTypeObject,
2021
)
@@ -571,25 +572,42 @@ def maybe_booleans_to_slice(ndarray[uint8_t, ndim=1] mask):
571572

572573
@cython.wraparound(False)
573574
@cython.boundscheck(False)
574-
def array_equivalent_object(left: object[:], right: object[:]) -> bool:
575+
def array_equivalent_object(ndarray left, ndarray right) -> bool:
575576
"""
576-
Perform an element by element comparison on 1-d object arrays
577+
Perform an element by element comparison on N-d object arrays
577578
taking into account nan positions.
578579
"""
580+
# left and right both have object dtype, but we cannot annotate that
581+
# without limiting ndim.
579582
cdef:
580-
Py_ssize_t i, n = left.shape[0]
583+
Py_ssize_t i, n = left.size
581584
object x, y
585+
cnp.broadcast mi = cnp.PyArray_MultiIterNew2(left, right)
586+
587+
# Caller is responsible for checking left.shape == right.shape
582588

583589
for i in range(n):
584-
x = left[i]
585-
y = right[i]
590+
# Analogous to: x = left[i]
591+
x = <object>(<PyObject**>cnp.PyArray_MultiIter_DATA(mi, 0))[0]
592+
y = <object>(<PyObject**>cnp.PyArray_MultiIter_DATA(mi, 1))[0]
586593

587594
# we are either not equal or both nan
588595
# I think None == None will be true here
589596
try:
590597
if PyArray_Check(x) and PyArray_Check(y):
591-
if not array_equivalent_object(x, y):
598+
if x.shape != y.shape:
592599
return False
600+
if x.dtype == y.dtype == object:
601+
if not array_equivalent_object(x, y):
602+
return False
603+
else:
604+
# Circular import isn't great, but so it goes.
605+
# TODO: could use np.array_equal?
606+
from pandas.core.dtypes.missing import array_equivalent
607+
608+
if not array_equivalent(x, y):
609+
return False
610+
593611
elif (x is C_NA) ^ (y is C_NA):
594612
return False
595613
elif not (
@@ -612,6 +630,8 @@ def array_equivalent_object(left: object[:], right: object[:]) -> bool:
612630
return False
613631
raise
614632

633+
cnp.PyArray_MultiIter_NEXT(mi)
634+
615635
return True
616636

617637

pandas/core/dtypes/missing.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -565,16 +565,7 @@ def _array_equivalent_object(left: np.ndarray, right: np.ndarray, strict_nan: bo
565565
if not strict_nan:
566566
# isna considers NaN and None to be equivalent.
567567

568-
if left.flags["F_CONTIGUOUS"] and right.flags["F_CONTIGUOUS"]:
569-
# we can improve performance by doing a copy-free ravel
570-
# e.g. in frame_methods.Equals.time_frame_nonunique_equal
571-
# if we transposed the frames
572-
left = left.ravel("K")
573-
right = right.ravel("K")
574-
575-
return lib.array_equivalent_object(
576-
ensure_object(left.ravel()), ensure_object(right.ravel())
577-
)
568+
return lib.array_equivalent_object(ensure_object(left), ensure_object(right))
578569

579570
for left_value, right_value in zip(left, right):
580571
if left_value is NaT and right_value is not NaT:

pandas/tests/dtypes/test_missing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,27 @@ def test_array_equivalent_series(val):
466466
assert not array_equivalent(Series([arr, arr]), Series([arr, val]))
467467

468468

469+
def test_array_equivalent_array_mismatched_shape():
470+
# to trigger the motivating bug, the first N elements of the arrays need
471+
# to match
472+
first = np.array([1, 2, 3])
473+
second = np.array([1, 2])
474+
475+
left = Series([first, "a"], dtype=object)
476+
right = Series([second, "a"], dtype=object)
477+
assert not array_equivalent(left, right)
478+
479+
480+
def test_array_equivalent_array_mismatched_dtype():
481+
# same shape, different dtype can still be equivalent
482+
first = np.array([1, 2], dtype=np.float64)
483+
second = np.array([1, 2])
484+
485+
left = Series([first, "a"], dtype=object)
486+
right = Series([second, "a"], dtype=object)
487+
assert array_equivalent(left, right)
488+
489+
469490
def test_array_equivalent_different_dtype_but_equal():
470491
# Unclear if this is exposed anywhere in the public-facing API
471492
assert array_equivalent(np.array([1, 2]), np.array([1.0, 2.0]))

0 commit comments

Comments
 (0)