Closed
Description
Jax has canonicalize_dtype and PyTorch also has a notion of default types.
Can we provide canonicalize_dtype
for all libraries?
Something like:
def canonicalize_dtype(xp: Namespace, dtype: DType | type[complex]) -> DType:
if is_jax_namespace(xp):
from jax.dtypes import canonicalize_dtype
return canonicalize_dtype(dtype) # Suppresses warning.
return xp.empty((), dtype=dtype).dtype