Skip to content

Need for dealing with strides when writing performant code against multiple array libraries #641

Closed
@rgommers

Description

@rgommers

Early on when putting the array API standard together, we made a decision that strides must not be exposed in the Python API, because multiple array libraries don't support that and it's memory layout which is more implementation detail than an API that should be exposed to Python users. We made that decision so early that there isn't even a dedicated issue for it I believe; the closest one is gh-571 about C vs. Fortran memory order. It's also related to gh-24 (strides are mentioned explicitly there).

The strides topic came up a couple of times recently though, and we should think about this more:

  1. In the SciPy proceedings paper one of the main review comments was about this: Paper: Python Array API Standard: Toward Array Interoperability in the Scientific Python Ecosystem scipy-conference/scipy_proceedings#822
  2. In the SciPy RFC for array API standard adoption it quickly came up because one of the first functions for which a conversion was tried used as_strided: RFC: SciPy array types & libraries support scipy/scipy#18286 (comment)

Previous comments

I'll copy some of the key comments below to have them in a single place:

From @jpivarski: The thing I was most surprised about in the read-through was that "strides" are not considered part of the array object. I recognize that some libraries (e.g. JAX) are able to do great things by not allowing strides to be free. But then, you show in the SciPy benchmark (Fig. 2c, 2d) that not having this handle implies a severe performance penalty, which is likely to be a major take-away from this paper by most readers. In the discussion, you say that it illustrates that users will need to apply library-specific optimizations after all. (Implementing a backend or runtime switching system is non-goal number 2.) It seems to me like strides-awareness would be a good candidate for one of the optional extensions, if extensions are not strictly sets of functions. Perhaps that's a question or suggestion I should take up with the Consortium's normal channels, not a review of the paper, but it's a question the paper leads me to have.

From @leofang: Regarding strides: I have conflicting views and will continue to self-debate. On the one hand, when over half of the array libraries (Jax, TF, Dask, cuNumeric) do not offer strides for various reasons (ex: it makes no sense for distributed arrays), it's my opinion that even making it optional is not the right way to go, just like many other design considerations that we've made. Also, for those supporting the stride semantics, their strides are accessible via DLPack, it's just that currently it's hard to access them from within pure Python. On the other hand, as a low-level C++ library developer (in short, we write backends for the array libraries), I do see the pain point that it's hard for us to even imagine how our library would be adopted and consumed by the array libraries that do not have strides. It could be possible that our library simply cannot serve as the battery behind them due to different design philosophies, idk. Though, to be fair, the standard aims at the array libraries, not their backends, so if there's a standardization effort targeting array backends, it might be a better place for standardizing strides.

From @tylerjereddy: scipy.signal.welch seemed to be a well-behaved target for prototyping based on that so I tried doing that from scratch on a feature branch locally, progressively allowing more welch() tests to accept CuPy arrays when an env variable is set.
Even for this "well-behaved"/reasonable target, I ran into problems pretty early on. For example, both in the original feature branch and in my branch, there doesn't seem to be an elegant solution made for handling numpy.lib.stride_tricks.as_strided. The docs for that function don't even recommend using it, and yet CuPy (and apparently torch from the Quansight example) do provide it, outside of the array API standard proper.
So, I guess my first real-world experience makes me wonder what our policy on special casing in these scenarios will be--ideally, I'd like to just remove the usage of as_strided() and substitute with some other approach that doesn't require conditionals on the exact array type/namespace. While this is a rather specific blocker, if I encounter something like this even for a "well behaved" case, I can imagine quite a few headaches for the cases that are not well behaved.

From @rgommers: Let's have a look at this case - this is the code in question:

    # Created strided array of data segments
    if nperseg == 1 and noverlap == 0:
        result = x[..., np.newaxis]
    else:
        # https://stackoverflow.com/a/5568169
        step = nperseg - noverlap
        shape = x.shape[:-1]+((x.shape[-1]-noverlap)//step, nperseg)
        strides = x.strides[:-1]+(step*x.strides[-1], x.strides[-1])
        result = np.lib.stride_tricks.as_strided(x, shape=shape,
                                                 strides=strides)

The as_strides usage is there to save memory. It's actually pretty difficult to understand what the code does exactly though, so it's good for performance but not great for maintainability. <.... example with more details cut, see here> At that point, we can say we're happy with more readability at the cost of some performance (TBD if the for-loop matters, my guess is it will). Or, we just keep the special-casing for numpy using as_strided. At that point, we do have two code paths, but no performance loss and also easier to understand code.

To discuss

It's clear that strides cannot be universally supported, and equally clear that users who write algorithmic code may need to access strides in some cases to avoid significant performance loss.

I think essentially, dealing with strides by hand for libraries that allow that is doing the work that in a more full-featured array library will/should be done by a compiler (e.g. the JIT/AOT compilers in JAX or PyTorch).

We need to at least suggest and document some strategy for cases like the as_strided usage above. In the similar but more specifc case of order=, which mattered quite a bit performance-wise for scikit-learn, they ended up writing a library-specific utility function to use the order keyword when present, and avoid it otherwise: https://github.com/scikit-learn/scikit-learn/blob/21312644df0a6b4c6f3c27a74ac9d26cf49c2304/sklearn/utils/_array_api.py#L360

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions