Skip to content

ENH: new canonicalize DType function? #151

Closed
Closed
@NeilGirdhar

Description

@NeilGirdhar

Jax has canonicalize_dtype and PyTorch also has a notion of default types.

Can we provide canonicalize_dtype for all libraries?

Something like:

def canonicalize_dtype(xp: Namespace, dtype: DType | type[complex]) -> DType:
    if is_jax_namespace(xp):
        from jax.dtypes import canonicalize_dtype
        return canonicalize_dtype(dtype)  # Suppresses warning.
    return xp.empty((), dtype=dtype).dtype

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions