Skip to content

Commit f59601f

Browse files
S1ro1hanouticelina
andcommitted
Feat: support DTensor when saving (#3042)
* Feat: support DTensor for storage size and id * Feat: tests --------- Co-authored-by: célina <hanouticelina@gmail.com>
1 parent 7fde08d commit f59601f

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

src/huggingface_hub/serialization/_torch.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,15 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
706706
if the input is a wrapper tensor subclass Tensor
707707
"""
708708

709+
try:
710+
from torch.distributed.tensor import DTensor
711+
712+
if isinstance(tensor, DTensor):
713+
local_tensor = tensor.to_local()
714+
return local_tensor.storage().data_ptr()
715+
except ImportError:
716+
pass
717+
709718
try:
710719
# for torch 2.1 and above we can also handle tensor subclasses
711720
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
@@ -753,6 +762,15 @@ def get_torch_storage_size(tensor: "torch.Tensor") -> int:
753762
"""
754763
Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
755764
"""
765+
try:
766+
from torch.distributed.tensor import DTensor
767+
768+
if isinstance(tensor, DTensor):
769+
# this returns the size of the FULL tensor in bytes
770+
return tensor.nbytes
771+
except ImportError:
772+
pass
773+
756774
try:
757775
# for torch 2.1 and above we can also handle tensor subclasses
758776
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

tests/test_serialization.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def is_wrapper_tensor_subclass_available():
4646
return False
4747

4848

49+
def is_dtensor_available():
50+
try:
51+
from torch.distributed.device_mesh import init_device_mesh # type: ignore[import] # noqa: F401
52+
from torch.distributed.tensor import DTensor # type: ignore[import] # noqa: F401
53+
54+
return True
55+
except ImportError:
56+
return False
57+
58+
4959
@pytest.fixture
5060
def dummy_state_dict() -> Dict[str, List[int]]:
5161
return {
@@ -250,6 +260,33 @@ def test_get_torch_storage_size():
250260
assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2
251261

252262

263+
@requires("torch")
264+
@pytest.mark.skipif(not is_dtensor_available(), reason="requires torch with dtensor available")
265+
def test_get_torch_storage_size_dtensor():
266+
# testing distributed sharded tensors isn't very easy, would need to subprocess call torchrun, so this should be good enough
267+
import torch
268+
import torch.distributed as dist
269+
from torch.distributed.device_mesh import init_device_mesh
270+
from torch.distributed.tensor import DTensor, Replicate
271+
272+
if dist.is_available() and not dist.is_initialized():
273+
dist.init_process_group(
274+
backend="gloo",
275+
store=dist.HashStore(),
276+
rank=0,
277+
world_size=1,
278+
)
279+
280+
mesh = init_device_mesh("cpu", (1,))
281+
local = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
282+
dt = DTensor.from_local(local, mesh, [Replicate()])
283+
284+
assert get_torch_storage_size(dt) == 5 * 2
285+
286+
if dist.is_initialized():
287+
dist.destroy_process_group()
288+
289+
253290
@requires("torch")
254291
@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
255292
def test_get_torch_storage_size_wrapper_tensor_subclass():

0 commit comments

Comments
 (0)