13
13
# limitations under the License.
14
14
15
15
from contextlib import contextmanager , nullcontext
16
- from typing import Dict , List , Optional , Set , Tuple
16
+ from typing import Dict , List , Optional , Set , Tuple , Union
17
17
18
18
import torch
19
19
@@ -55,7 +55,7 @@ def __init__(
55
55
parameters : Optional [List [torch .nn .Parameter ]] = None ,
56
56
buffers : Optional [List [torch .Tensor ]] = None ,
57
57
non_blocking : bool = False ,
58
- stream : Optional [torch .cuda .Stream ] = None ,
58
+ stream : Union [torch .cuda .Stream , torch . Stream , None ] = None ,
59
59
record_stream : Optional [bool ] = False ,
60
60
low_cpu_mem_usage : bool = False ,
61
61
onload_self : bool = True ,
@@ -115,8 +115,13 @@ def _pinned_memory_tensors(self):
115
115
116
116
def onload_ (self ):
117
117
r"""Onloads the group of modules to the onload_device."""
118
- context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
119
- current_stream = torch .cuda .current_stream () if self .record_stream else None
118
+ torch_accelerator_module = (
119
+ getattr (torch , torch .accelerator .current_accelerator ().type )
120
+ if hasattr (torch , "accelerator" )
121
+ else torch .cuda
122
+ )
123
+ context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
124
+ current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
120
125
121
126
if self .stream is not None :
122
127
# Wait for previous Host->Device transfer to complete
@@ -162,9 +167,15 @@ def onload_(self):
162
167
163
168
def offload_ (self ):
164
169
r"""Offloads the group of modules to the offload_device."""
170
+
171
+ torch_accelerator_module = (
172
+ getattr (torch , torch .accelerator .current_accelerator ().type )
173
+ if hasattr (torch , "accelerator" )
174
+ else torch .cuda
175
+ )
165
176
if self .stream is not None :
166
177
if not self .record_stream :
167
- torch . cuda .current_stream ().synchronize ()
178
+ torch_accelerator_module .current_stream ().synchronize ()
168
179
for group_module in self .modules :
169
180
for param in group_module .parameters ():
170
181
param .data = self .cpu_param_dict [param ]
@@ -429,8 +440,10 @@ def apply_group_offloading(
429
440
if use_stream :
430
441
if torch .cuda .is_available ():
431
442
stream = torch .cuda .Stream ()
443
+ elif hasattr (torch , "xpu" ) and torch .xpu .is_available ():
444
+ stream = torch .Stream ()
432
445
else :
433
- raise ValueError ("Using streams for data transfer requires a CUDA device." )
446
+ raise ValueError ("Using streams for data transfer requires a CUDA device, or an Intel XPU device ." )
434
447
435
448
_raise_error_if_accelerate_model_or_sequential_hook_present (module )
436
449
@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level(
468
481
offload_device : torch .device ,
469
482
onload_device : torch .device ,
470
483
non_blocking : bool ,
471
- stream : Optional [torch .cuda .Stream ] = None ,
484
+ stream : Union [torch .cuda .Stream , torch . Stream , None ] = None ,
472
485
record_stream : Optional [bool ] = False ,
473
486
low_cpu_mem_usage : bool = False ,
474
487
) -> None :
@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level(
486
499
non_blocking (`bool`):
487
500
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
488
501
and data transfer.
489
- stream (`torch.cuda.Stream`, *optional*):
502
+ stream (`torch.cuda.Stream`or `torch.Stream` , *optional*):
490
503
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
491
504
for overlapping computation and data transfer.
492
505
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
@@ -572,7 +585,7 @@ def _apply_group_offloading_leaf_level(
572
585
offload_device : torch .device ,
573
586
onload_device : torch .device ,
574
587
non_blocking : bool ,
575
- stream : Optional [torch .cuda .Stream ] = None ,
588
+ stream : Union [torch .cuda .Stream , torch . Stream , None ] = None ,
576
589
record_stream : Optional [bool ] = False ,
577
590
low_cpu_mem_usage : bool = False ,
578
591
) -> None :
@@ -592,7 +605,7 @@ def _apply_group_offloading_leaf_level(
592
605
non_blocking (`bool`):
593
606
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
594
607
and data transfer.
595
- stream (`torch.cuda.Stream`, *optional*):
608
+ stream (`torch.cuda.Stream` or `torch.Stream` , *optional*):
596
609
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
597
610
for overlapping computation and data transfer.
598
611
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
0 commit comments