Skip to content

Commit 24a8d2d

Browse files
entrpnjfacevedo-googlezpcoresayakpaulPei Zhang
committed
Update ptxla training (#9864)
* update ptxla example --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com> Co-authored-by: Pei Zhang <zpcore@gmail.com> Co-authored-by: Pei Zhang <piz@google.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pei Zhang <pei@Peis-MacBook-Pro.local> Co-authored-by: hlky <hlky@hlky.ac>
1 parent e179be2 commit 24a8d2d

File tree

6 files changed

+272
-70
lines changed

6 files changed

+272
-70
lines changed

examples/research_projects/pytorch_xla/README.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on
77
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
88
where we shard the input batches over the TPU devices.
99

10-
As of 9-11-2024, these are some expected step times.
10+
As of 10-31-2024, these are some expected step times.
1111

1212
| accelerator | global batch size | step time (seconds) |
1313
| ----------- | ----------------- | --------- |
14-
| v5p-128 | 1024 | 0.245 |
15-
| v5p-256 | 2048 | 0.234 |
16-
| v5p-512 | 4096 | 0.2498 |
14+
| v5p-512 | 16384 | 1.01 |
15+
| v5p-256 | 8192 | 1.01 |
16+
| v5p-128 | 4096 | 1.0 |
17+
| v5p-64 | 2048 | 1.01 |
1718

1819
## Create TPU
1920

@@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions:
4344
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
4445
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
4546
--command='
46-
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
47-
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
47+
pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
48+
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
49+
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
4850
'
4951
```
5052

@@ -88,17 +90,18 @@ are fixed.
8890
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
8991
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
9092
--command='
91-
export XLA_DISABLE_FUNCTIONALIZATION=1
93+
export XLA_DISABLE_FUNCTIONALIZATION=0
9294
export PROFILE_DIR=/tmp/
9395
export CACHE_DIR=/tmp/
9496
export DATASET_NAME=lambdalabs/naruto-blip-captions
9597
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
9698
export TRAIN_STEPS=50
9799
export OUTPUT_DIR=/tmp/trained-model/
98-
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
99-
100+
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
100101
```
101102

103+
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
104+
102105
### Environment Envs Explained
103106

104107
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.

examples/research_projects/pytorch_xla/train_text_to_image_xla.py

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -140,33 +140,43 @@ def run_optimizer(self):
140140
self.optimizer.step()
141141

142142
def start_training(self):
143-
times = []
144-
last_time = time.time()
145-
step = 0
146-
while True:
147-
if self.global_step >= self.args.max_train_steps:
148-
xm.mark_step()
149-
break
150-
if step == 4 and PROFILE_DIR is not None:
151-
xm.wait_device_ops()
152-
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
143+
dataloader_exception = False
144+
measure_start_step = args.measure_start_step
145+
assert measure_start_step < self.args.max_train_steps
146+
total_time = 0
147+
for step in range(0, self.args.max_train_steps):
153148
try:
154149
batch = next(self.dataloader)
155150
except Exception as e:
151+
dataloader_exception = True
156152
print(e)
157153
break
154+
if step == measure_start_step and PROFILE_DIR is not None:
155+
xm.wait_device_ops()
156+
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
157+
last_time = time.time()
158158
loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
159-
step_time = time.time() - last_time
160-
if step >= 10:
161-
times.append(step_time)
162-
print(f"step: {step}, step_time: {step_time}")
163-
if step % 5 == 0:
164-
print(f"step: {step}, loss: {loss}")
165-
last_time = time.time()
166159
self.global_step += 1
167-
step += 1
168-
# print(f"Average step time: {sum(times)/len(times)}")
169-
xm.wait_device_ops()
160+
161+
def print_loss_closure(step, loss):
162+
print(f"Step: {step}, Loss: {loss}")
163+
164+
if args.print_loss:
165+
xm.add_step_closure(
166+
print_loss_closure,
167+
args=(
168+
self.global_step,
169+
loss,
170+
),
171+
)
172+
xm.mark_step()
173+
if not dataloader_exception:
174+
xm.wait_device_ops()
175+
total_time = time.time() - last_time
176+
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
177+
else:
178+
print("dataloader exception happen, skip result")
179+
return
170180

171181
def step_fn(
172182
self,
@@ -180,7 +190,10 @@ def step_fn(
180190
noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)
181191
bsz = latents.shape[0]
182192
timesteps = torch.randint(
183-
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
193+
0,
194+
self.noise_scheduler.config.num_train_timesteps,
195+
(bsz,),
196+
device=latents.device,
184197
)
185198
timesteps = timesteps.long()
186199

@@ -224,9 +237,6 @@ def step_fn(
224237

225238
def parse_args():
226239
parser = argparse.ArgumentParser(description="Simple example of a training script.")
227-
parser.add_argument(
228-
"--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
229-
)
230240
parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
231241
parser.add_argument(
232242
"--pretrained_model_name_or_path",
@@ -258,12 +268,6 @@ def parse_args():
258268
" or to a folder containing files that 🤗 Datasets can understand."
259269
),
260270
)
261-
parser.add_argument(
262-
"--dataset_config_name",
263-
type=str,
264-
default=None,
265-
help="The config of the Dataset, leave as None if there's only one config.",
266-
)
267271
parser.add_argument(
268272
"--train_data_dir",
269273
type=str,
@@ -283,15 +287,6 @@ def parse_args():
283287
default="text",
284288
help="The column of the dataset containing a caption or a list of captions.",
285289
)
286-
parser.add_argument(
287-
"--max_train_samples",
288-
type=int,
289-
default=None,
290-
help=(
291-
"For debugging purposes or quicker training, truncate the number of training examples to this "
292-
"value if set."
293-
),
294-
)
295290
parser.add_argument(
296291
"--output_dir",
297292
type=str,
@@ -304,7 +299,6 @@ def parse_args():
304299
default=None,
305300
help="The directory where the downloaded models and datasets will be stored.",
306301
)
307-
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
308302
parser.add_argument(
309303
"--resolution",
310304
type=int,
@@ -374,12 +368,19 @@ def parse_args():
374368
default=1,
375369
help=("Number of subprocesses to use for data loading to cpu."),
376370
)
371+
parser.add_argument(
372+
"--loader_prefetch_factor",
373+
type=int,
374+
default=2,
375+
help=("Number of batches loaded in advance by each worker."),
376+
)
377377
parser.add_argument(
378378
"--device_prefetch_size",
379379
type=int,
380380
default=1,
381381
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
382382
)
383+
parser.add_argument("--measure_start_step", type=int, default=10, help="Step to start profiling.")
383384
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
384385
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
385386
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
@@ -394,12 +395,8 @@ def parse_args():
394395
"--mixed_precision",
395396
type=str,
396397
default=None,
397-
choices=["no", "fp16", "bf16"],
398-
help=(
399-
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
400-
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
401-
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
402-
),
398+
choices=["no", "bf16"],
399+
help=("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"),
403400
)
404401
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
405402
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
@@ -409,6 +406,12 @@ def parse_args():
409406
default=None,
410407
help="The name of the repository to keep in sync with the local `output_dir`.",
411408
)
409+
parser.add_argument(
410+
"--print_loss",
411+
default=False,
412+
action="store_true",
413+
help=("Print loss at every step."),
414+
)
412415

413416
args = parser.parse_args()
414417

@@ -436,7 +439,6 @@ def load_dataset(args):
436439
# Downloading and loading a dataset from the hub.
437440
dataset = datasets.load_dataset(
438441
args.dataset_name,
439-
args.dataset_config_name,
440442
cache_dir=args.cache_dir,
441443
data_dir=args.train_data_dir,
442444
)
@@ -481,9 +483,7 @@ def main(args):
481483
_ = xp.start_server(PORT)
482484

483485
num_devices = xr.global_runtime_device_count()
484-
device_ids = np.arange(num_devices)
485-
mesh_shape = (num_devices, 1)
486-
mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
486+
mesh = xs.get_1d_mesh("data")
487487
xs.set_global_mesh(mesh)
488488

489489
text_encoder = CLIPTextModel.from_pretrained(
@@ -520,6 +520,7 @@ def main(args):
520520
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
521521

522522
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
523+
unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))
523524

524525
vae.requires_grad_(False)
525526
text_encoder.requires_grad_(False)
@@ -530,15 +531,12 @@ def main(args):
530531
# as these weights are only used for inference, keeping weights in full
531532
# precision is not required.
532533
weight_dtype = torch.float32
533-
if args.mixed_precision == "fp16":
534-
weight_dtype = torch.float16
535-
elif args.mixed_precision == "bf16":
534+
if args.mixed_precision == "bf16":
536535
weight_dtype = torch.bfloat16
537536

538537
device = xm.xla_device()
539-
print("device: ", device)
540-
print("weight_dtype: ", weight_dtype)
541538

539+
# Move text_encode and vae to device and cast to weight_dtype
542540
text_encoder = text_encoder.to(device, dtype=weight_dtype)
543541
vae = vae.to(device, dtype=weight_dtype)
544542
unet = unet.to(device, dtype=weight_dtype)
@@ -606,24 +604,27 @@ def collate_fn(examples):
606604
collate_fn=collate_fn,
607605
num_workers=args.dataloader_num_workers,
608606
batch_size=args.train_batch_size,
607+
prefetch_factor=args.loader_prefetch_factor,
609608
)
610609

611610
train_dataloader = pl.MpDeviceLoader(
612611
train_dataloader,
613612
device,
614613
input_sharding={
615-
"pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True),
616-
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
614+
"pixel_values": xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True),
615+
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
617616
},
618617
loader_prefetch_size=args.loader_prefetch_size,
619618
device_prefetch_size=args.device_prefetch_size,
620619
)
621620

621+
num_hosts = xr.process_count()
622+
num_devices_per_host = num_devices // num_hosts
622623
if xm.is_master_ordinal():
623624
print("***** Running training *****")
624-
print(f"Instantaneous batch size per device = {args.train_batch_size}")
625+
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
625626
print(
626-
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}"
627+
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
627628
)
628629
print(f" Total optimization steps = {args.max_train_steps}")
629630

0 commit comments

Comments
 (0)