@@ -140,33 +140,43 @@ def run_optimizer(self):
140
140
self .optimizer .step ()
141
141
142
142
def start_training (self ):
143
- times = []
144
- last_time = time .time ()
145
- step = 0
146
- while True :
147
- if self .global_step >= self .args .max_train_steps :
148
- xm .mark_step ()
149
- break
150
- if step == 4 and PROFILE_DIR is not None :
151
- xm .wait_device_ops ()
152
- xp .trace_detached (f"localhost:{ PORT } " , PROFILE_DIR , duration_ms = args .profile_duration )
143
+ dataloader_exception = False
144
+ measure_start_step = args .measure_start_step
145
+ assert measure_start_step < self .args .max_train_steps
146
+ total_time = 0
147
+ for step in range (0 , self .args .max_train_steps ):
153
148
try :
154
149
batch = next (self .dataloader )
155
150
except Exception as e :
151
+ dataloader_exception = True
156
152
print (e )
157
153
break
154
+ if step == measure_start_step and PROFILE_DIR is not None :
155
+ xm .wait_device_ops ()
156
+ xp .trace_detached (f"localhost:{ PORT } " , PROFILE_DIR , duration_ms = args .profile_duration )
157
+ last_time = time .time ()
158
158
loss = self .step_fn (batch ["pixel_values" ], batch ["input_ids" ])
159
- step_time = time .time () - last_time
160
- if step >= 10 :
161
- times .append (step_time )
162
- print (f"step: { step } , step_time: { step_time } " )
163
- if step % 5 == 0 :
164
- print (f"step: { step } , loss: { loss } " )
165
- last_time = time .time ()
166
159
self .global_step += 1
167
- step += 1
168
- # print(f"Average step time: {sum(times)/len(times)}")
169
- xm .wait_device_ops ()
160
+
161
+ def print_loss_closure (step , loss ):
162
+ print (f"Step: { step } , Loss: { loss } " )
163
+
164
+ if args .print_loss :
165
+ xm .add_step_closure (
166
+ print_loss_closure ,
167
+ args = (
168
+ self .global_step ,
169
+ loss ,
170
+ ),
171
+ )
172
+ xm .mark_step ()
173
+ if not dataloader_exception :
174
+ xm .wait_device_ops ()
175
+ total_time = time .time () - last_time
176
+ print (f"Average step time: { total_time / (self .args .max_train_steps - measure_start_step )} " )
177
+ else :
178
+ print ("dataloader exception happen, skip result" )
179
+ return
170
180
171
181
def step_fn (
172
182
self ,
@@ -180,7 +190,10 @@ def step_fn(
180
190
noise = torch .randn_like (latents ).to (self .device , dtype = self .weight_dtype )
181
191
bsz = latents .shape [0 ]
182
192
timesteps = torch .randint (
183
- 0 , self .noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device
193
+ 0 ,
194
+ self .noise_scheduler .config .num_train_timesteps ,
195
+ (bsz ,),
196
+ device = latents .device ,
184
197
)
185
198
timesteps = timesteps .long ()
186
199
@@ -224,9 +237,6 @@ def step_fn(
224
237
225
238
def parse_args ():
226
239
parser = argparse .ArgumentParser (description = "Simple example of a training script." )
227
- parser .add_argument (
228
- "--input_perturbation" , type = float , default = 0 , help = "The scale of input perturbation. Recommended 0.1."
229
- )
230
240
parser .add_argument ("--profile_duration" , type = int , default = 10000 , help = "Profile duration in ms" )
231
241
parser .add_argument (
232
242
"--pretrained_model_name_or_path" ,
@@ -258,12 +268,6 @@ def parse_args():
258
268
" or to a folder containing files that 🤗 Datasets can understand."
259
269
),
260
270
)
261
- parser .add_argument (
262
- "--dataset_config_name" ,
263
- type = str ,
264
- default = None ,
265
- help = "The config of the Dataset, leave as None if there's only one config." ,
266
- )
267
271
parser .add_argument (
268
272
"--train_data_dir" ,
269
273
type = str ,
@@ -283,15 +287,6 @@ def parse_args():
283
287
default = "text" ,
284
288
help = "The column of the dataset containing a caption or a list of captions." ,
285
289
)
286
- parser .add_argument (
287
- "--max_train_samples" ,
288
- type = int ,
289
- default = None ,
290
- help = (
291
- "For debugging purposes or quicker training, truncate the number of training examples to this "
292
- "value if set."
293
- ),
294
- )
295
290
parser .add_argument (
296
291
"--output_dir" ,
297
292
type = str ,
@@ -304,7 +299,6 @@ def parse_args():
304
299
default = None ,
305
300
help = "The directory where the downloaded models and datasets will be stored." ,
306
301
)
307
- parser .add_argument ("--seed" , type = int , default = None , help = "A seed for reproducible training." )
308
302
parser .add_argument (
309
303
"--resolution" ,
310
304
type = int ,
@@ -374,12 +368,19 @@ def parse_args():
374
368
default = 1 ,
375
369
help = ("Number of subprocesses to use for data loading to cpu." ),
376
370
)
371
+ parser .add_argument (
372
+ "--loader_prefetch_factor" ,
373
+ type = int ,
374
+ default = 2 ,
375
+ help = ("Number of batches loaded in advance by each worker." ),
376
+ )
377
377
parser .add_argument (
378
378
"--device_prefetch_size" ,
379
379
type = int ,
380
380
default = 1 ,
381
381
help = ("Number of subprocesses to use for data loading to tpu from cpu. " ),
382
382
)
383
+ parser .add_argument ("--measure_start_step" , type = int , default = 10 , help = "Step to start profiling." )
383
384
parser .add_argument ("--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam optimizer." )
384
385
parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "The beta2 parameter for the Adam optimizer." )
385
386
parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-2 , help = "Weight decay to use." )
@@ -394,12 +395,8 @@ def parse_args():
394
395
"--mixed_precision" ,
395
396
type = str ,
396
397
default = None ,
397
- choices = ["no" , "fp16" , "bf16" ],
398
- help = (
399
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
400
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
401
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
402
- ),
398
+ choices = ["no" , "bf16" ],
399
+ help = ("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10" ),
403
400
)
404
401
parser .add_argument ("--push_to_hub" , action = "store_true" , help = "Whether or not to push the model to the Hub." )
405
402
parser .add_argument ("--hub_token" , type = str , default = None , help = "The token to use to push to the Model Hub." )
@@ -409,6 +406,12 @@ def parse_args():
409
406
default = None ,
410
407
help = "The name of the repository to keep in sync with the local `output_dir`." ,
411
408
)
409
+ parser .add_argument (
410
+ "--print_loss" ,
411
+ default = False ,
412
+ action = "store_true" ,
413
+ help = ("Print loss at every step." ),
414
+ )
412
415
413
416
args = parser .parse_args ()
414
417
@@ -436,7 +439,6 @@ def load_dataset(args):
436
439
# Downloading and loading a dataset from the hub.
437
440
dataset = datasets .load_dataset (
438
441
args .dataset_name ,
439
- args .dataset_config_name ,
440
442
cache_dir = args .cache_dir ,
441
443
data_dir = args .train_data_dir ,
442
444
)
@@ -481,9 +483,7 @@ def main(args):
481
483
_ = xp .start_server (PORT )
482
484
483
485
num_devices = xr .global_runtime_device_count ()
484
- device_ids = np .arange (num_devices )
485
- mesh_shape = (num_devices , 1 )
486
- mesh = xs .Mesh (device_ids , mesh_shape , ("x" , "y" ))
486
+ mesh = xs .get_1d_mesh ("data" )
487
487
xs .set_global_mesh (mesh )
488
488
489
489
text_encoder = CLIPTextModel .from_pretrained (
@@ -520,6 +520,7 @@ def main(args):
520
520
from torch_xla .distributed .fsdp .utils import apply_xla_patch_to_nn_linear
521
521
522
522
unet = apply_xla_patch_to_nn_linear (unet , xs .xla_patched_nn_linear_forward )
523
+ unet .enable_xla_flash_attention (partition_spec = ("data" , None , None , None ))
523
524
524
525
vae .requires_grad_ (False )
525
526
text_encoder .requires_grad_ (False )
@@ -530,15 +531,12 @@ def main(args):
530
531
# as these weights are only used for inference, keeping weights in full
531
532
# precision is not required.
532
533
weight_dtype = torch .float32
533
- if args .mixed_precision == "fp16" :
534
- weight_dtype = torch .float16
535
- elif args .mixed_precision == "bf16" :
534
+ if args .mixed_precision == "bf16" :
536
535
weight_dtype = torch .bfloat16
537
536
538
537
device = xm .xla_device ()
539
- print ("device: " , device )
540
- print ("weight_dtype: " , weight_dtype )
541
538
539
+ # Move text_encode and vae to device and cast to weight_dtype
542
540
text_encoder = text_encoder .to (device , dtype = weight_dtype )
543
541
vae = vae .to (device , dtype = weight_dtype )
544
542
unet = unet .to (device , dtype = weight_dtype )
@@ -606,24 +604,27 @@ def collate_fn(examples):
606
604
collate_fn = collate_fn ,
607
605
num_workers = args .dataloader_num_workers ,
608
606
batch_size = args .train_batch_size ,
607
+ prefetch_factor = args .loader_prefetch_factor ,
609
608
)
610
609
611
610
train_dataloader = pl .MpDeviceLoader (
612
611
train_dataloader ,
613
612
device ,
614
613
input_sharding = {
615
- "pixel_values" : xs .ShardingSpec (mesh , ("x " , None , None , None ), minibatch = True ),
616
- "input_ids" : xs .ShardingSpec (mesh , ("x " , None ), minibatch = True ),
614
+ "pixel_values" : xs .ShardingSpec (mesh , ("data " , None , None , None ), minibatch = True ),
615
+ "input_ids" : xs .ShardingSpec (mesh , ("data " , None ), minibatch = True ),
617
616
},
618
617
loader_prefetch_size = args .loader_prefetch_size ,
619
618
device_prefetch_size = args .device_prefetch_size ,
620
619
)
621
620
621
+ num_hosts = xr .process_count ()
622
+ num_devices_per_host = num_devices // num_hosts
622
623
if xm .is_master_ordinal ():
623
624
print ("***** Running training *****" )
624
- print (f"Instantaneous batch size per device = { args .train_batch_size } " )
625
+ print (f"Instantaneous batch size per device = { args .train_batch_size // num_devices_per_host } " )
625
626
print (
626
- f"Total train batch size (w. parallel, distributed & accumulation) = { args .train_batch_size * num_devices } "
627
+ f"Total train batch size (w. parallel, distributed & accumulation) = { args .train_batch_size * num_hosts } "
627
628
)
628
629
print (f" Total optimization steps = { args .max_train_steps } " )
629
630
0 commit comments