Description
This RFC proposes to add an API to the standard for returning the list of devices. Currently, the only standardized means for an array API consumer to access a device object is by creating an array and accessing the .device
attribute.
>>> x = xp.zeros((2, 2))
>>> d = x.device
More generally, the standard currently lacks APIs for device inspection, and device objects are unspecified, apart from required support for equality comparison. To aid in introspection and scripting/REPL envs, having a means for getting the list of devices would be useful, especially when probing the capabilities of an unknown platform.
Prior art
JAX
jax.devices(backend=None) -> List[Device]
Returns the list of devices from the default backend (e.g., gpu
, tpu
, or cpu
).
jax.local_devices(process_index=None, backend=None, host_id=None)
Returns a list of devices local to a given process.
Proposal
devices() -> List[device]
The proposed API would be a nullary function which returns a list of device objects.
Notes
- Currently, when not provided a
backend
kwarg, JAX returns the list of devices for the default backend. I am not sure whether this is preferred compared to returning a list of devices across all backends.