Skip to content

PERF: short-circuit (left == right).all() comparisons #32339

Open
@jbrockmendel

Description

@jbrockmendel

In places like equals methods and array_equivalent, we do things like (left == right).all() or ((left == right) | (isna(left) & isna(right))).all(). For large arrays that are not equal, we can do much better with something like:

def all_match(left, right) -> bool:
    if left.dtype.kind != "i":
        # viewing as i8 will make NaNs be treated as equal
        return _all_match_i8(left.view("i8"), right.view("i8"))

    return _all_match_i8(left, right)

cdef bint _all_match_i8(const int64_t[:] left, const int64_t[:] right):
    cdef:
        Py_ssize_t i, n = len(left)

    for i in range(n):
        if left[i] != right[i]:
            return False

    return True

Some profiling results:

In [2]: arr = np.arange(10**6)                                                                                                                                                      
In [3]: arr2 = arr.copy()                                                                                                                                                           
In [4]: arr2[0] = -1                                                                                                                                                                

In [5]: %timeit np.array_equal(arr, arr2)                                                                                                                                           
831 µs ± 42.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [6]: %timeit all_match(arr, arr2)                                                                                                                                                
1.27 µs ± 58.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [7]: %timeit np.array_equal(arr, arr)                                                                                                                                            
416 µs ± 16.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [8]: %timeit all_match(arr, arr)                                                                                                                                                 
812 µs ± 5.84 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

So in cases that short circuit early, we can get massive speedups, but this implementation is actually 2x slower in cases that dont short-circuit (for reasons that are not clear to me).

Metadata

Metadata

Assignees

No one assigned

    Labels

    Numeric OperationsArithmetic, Comparison, and Logical operationsPerformanceMemory or execution speed performance

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions