Skip to content

Update ptxla training #9864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions examples/research_projects/pytorch_xla/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
where we shard the input batches over the TPU devices.

As of 9-11-2024, these are some expected step times.
As of 10-31-2024, these are some expected step times.

| accelerator | global batch size | step time (seconds) |
| ----------- | ----------------- | --------- |
| v5p-128 | 1024 | 0.245 |
| v5p-256 | 2048 | 0.234 |
| v5p-512 | 4096 | 0.2498 |
| v5p-512 | 16384 | 1.01 |
| v5p-256 | 8192 | 1.01 |
| v5p-128 | 4096 | 1.0 |
| v5p-64 | 2048 | 1.01 |

## Create TPU

Expand Down Expand Up @@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
'
```

Expand Down Expand Up @@ -88,17 +90,18 @@ are fixed.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='
export XLA_DISABLE_FUNCTIONALIZATION=1
export XLA_DISABLE_FUNCTIONALIZATION=0
export PROFILE_DIR=/tmp/
export CACHE_DIR=/tmp/
export DATASET_NAME=lambdalabs/naruto-blip-captions
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
export TRAIN_STEPS=50
export OUTPUT_DIR=/tmp/trained-model/
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'

python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
```

Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.

### Environment Envs Explained

* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
Expand Down
119 changes: 60 additions & 59 deletions examples/research_projects/pytorch_xla/train_text_to_image_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,33 +140,43 @@ def run_optimizer(self):
self.optimizer.step()

def start_training(self):
times = []
last_time = time.time()
step = 0
while True:
if self.global_step >= self.args.max_train_steps:
xm.mark_step()
break
if step == 4 and PROFILE_DIR is not None:
xm.wait_device_ops()
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
dataloader_exception = False
measure_start_step = args.measure_start_step
assert measure_start_step < self.args.max_train_steps
total_time = 0
for step in range(0, self.args.max_train_steps):
try:
batch = next(self.dataloader)
except Exception as e:
dataloader_exception = True
print(e)
break
if step == measure_start_step and PROFILE_DIR is not None:
xm.wait_device_ops()
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
last_time = time.time()
loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
step_time = time.time() - last_time
if step >= 10:
times.append(step_time)
print(f"step: {step}, step_time: {step_time}")
if step % 5 == 0:
print(f"step: {step}, loss: {loss}")
last_time = time.time()
self.global_step += 1
step += 1
# print(f"Average step time: {sum(times)/len(times)}")
xm.wait_device_ops()

def print_loss_closure(step, loss):
print(f"Step: {step}, Loss: {loss}")

if args.print_loss:
xm.add_step_closure(
print_loss_closure,
args=(
self.global_step,
loss,
),
)
xm.mark_step()
if not dataloader_exception:
xm.wait_device_ops()
total_time = time.time() - last_time
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
else:
print("dataloader exception happen, skip result")
return

def step_fn(
self,
Expand All @@ -180,7 +190,10 @@ def step_fn(
noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)
bsz = latents.shape[0]
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
0,
self.noise_scheduler.config.num_train_timesteps,
(bsz,),
device=latents.device,
)
timesteps = timesteps.long()

Expand Down Expand Up @@ -224,9 +237,6 @@ def step_fn(

def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
)
parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
parser.add_argument(
"--pretrained_model_name_or_path",
Expand Down Expand Up @@ -258,12 +268,6 @@ def parse_args():
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
Expand All @@ -283,15 +287,6 @@ def parse_args():
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--output_dir",
type=str,
Expand All @@ -304,7 +299,6 @@ def parse_args():
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
Expand Down Expand Up @@ -374,12 +368,19 @@ def parse_args():
default=1,
help=("Number of subprocesses to use for data loading to cpu."),
)
parser.add_argument(
"--loader_prefetch_factor",
type=int,
default=2,
help=("Number of batches loaded in advance by each worker."),
)
parser.add_argument(
"--device_prefetch_size",
type=int,
default=1,
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
)
parser.add_argument("--measure_start_step", type=int, default=10, help="Step to start profiling.")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
Expand All @@ -394,12 +395,8 @@ def parse_args():
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
choices=["no", "bf16"],
help=("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"),
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
Expand All @@ -409,6 +406,12 @@ def parse_args():
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--print_loss",
default=False,
action="store_true",
help=("Print loss at every step."),
)

args = parser.parse_args()

Expand Down Expand Up @@ -436,7 +439,6 @@ def load_dataset(args):
# Downloading and loading a dataset from the hub.
dataset = datasets.load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
)
Expand Down Expand Up @@ -481,9 +483,7 @@ def main(args):
_ = xp.start_server(PORT)

num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
mesh_shape = (num_devices, 1)
mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
mesh = xs.get_1d_mesh("data")
xs.set_global_mesh(mesh)

text_encoder = CLIPTextModel.from_pretrained(
Expand Down Expand Up @@ -520,6 +520,7 @@ def main(args):
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear

unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))

vae.requires_grad_(False)
text_encoder.requires_grad_(False)
Expand All @@ -530,15 +531,12 @@ def main(args):
# as these weights are only used for inference, keeping weights in full
# precision is not required.
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
if args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

device = xm.xla_device()
print("device: ", device)
print("weight_dtype: ", weight_dtype)

# Move text_encode and vae to device and cast to weight_dtype
text_encoder = text_encoder.to(device, dtype=weight_dtype)
vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=weight_dtype)
Expand Down Expand Up @@ -606,24 +604,27 @@ def collate_fn(examples):
collate_fn=collate_fn,
num_workers=args.dataloader_num_workers,
batch_size=args.train_batch_size,
prefetch_factor=args.loader_prefetch_factor,
)

train_dataloader = pl.MpDeviceLoader(
train_dataloader,
device,
input_sharding={
"pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True),
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
"pixel_values": xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True),
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
},
loader_prefetch_size=args.loader_prefetch_size,
device_prefetch_size=args.device_prefetch_size,
)

num_hosts = xr.process_count()
num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal():
print("***** Running training *****")
print(f"Instantaneous batch size per device = {args.train_batch_size}")
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}"
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
)
print(f" Total optimization steps = {args.max_train_steps}")

Expand Down
Loading
Loading