Description
This RFC proposes to add an API to the standard for returning library defaults (e.g., dtypes and devices).
Currently, the standard requires that conforming array libraries explicitly state their default dtypes in their documentation; no guidance is provided regarding devices. The standard provides no APIs for querying library defaults.
The inability to query defaults requires manual workarounds, such as allocating a fresh array and checking the dtype
or device
attribute. This is less than ideal, especially for third-party array libraries wanting to generically extend an array library's namespace and adhere to the same default behavior (e.g., wrap library x
to expose additional array creation functions for specialized matrices).
Prior art
PyTorch
torch.get_default_dtype() → torch.dtype
Returns the current default floating-point dtype.
JAX
jax.default_backend() -> str
Returns the platform name of the default XLA backend.
Proposal
defaults() -> dict[str, any]
defaults(device: device) -> dict[str, any]
defaults(name: str) -> any
defaults(name: str, device: device) -> any
If not provided an argument, the function would return a dictionary with the following keys:
- device: default device.
- dtypes.real_floating_point: default real floating-point dtype.
- dtypes.complex_floating_point: default complex floating-point dtype.
- dtypes.integral: default integral dtype.
- dtypes.indexing: default index dtype.
More keys could be added in the future, depending on evolution of the standard.
If provided a device argument, the function would return a dictionary as described above, but specific to the specified device.
If provided a name argument, the function would return the default for the specific setting. E.g.,
>>> d = xp.defaults("device")
If provided both a name and device argument, the function would return the default for the specified setting for the specified device.
Notes
- If we wanted to allow supporting a standardized means for configuring defaults (e.g., setting the default real floating-point dtype to
float32
, instead offloat64
), may want to rename the API to something likeget_defaults()
and then setting could beset_default()
. - When invoked without a device argument, the function would return default dtypes based on the current device context, as array libraries may have differing default dtypes, depending on the device. Accordingly, users should be advised to not assume that defaults are static.