diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/README.md index a6901d5ada9d..06013b8a61e0 100644 --- a/examples/research_projects/pytorch_xla/README.md +++ b/examples/research_projects/pytorch_xla/README.md @@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on This script implements Distributed Data Parallel using GSPMD feature in XLA compiler where we shard the input batches over the TPU devices. -As of 9-11-2024, these are some expected step times. +As of 10-31-2024, these are some expected step times. | accelerator | global batch size | step time (seconds) | | ----------- | ----------------- | --------- | -| v5p-128 | 1024 | 0.245 | -| v5p-256 | 2048 | 0.234 | -| v5p-512 | 4096 | 0.2498 | +| v5p-512 | 16384 | 1.01 | +| v5p-256 | 8192 | 1.01 | +| v5p-128 | 4096 | 1.0 | +| v5p-64 | 2048 | 1.01 | ## Create TPU @@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions: gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} --zone=${ZONE} --worker=all \ --command=' -pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu -pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html +pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html +pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html ' ``` @@ -88,17 +90,18 @@ are fixed. gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} --zone=${ZONE} --worker=all \ --command=' -export XLA_DISABLE_FUNCTIONALIZATION=1 +export XLA_DISABLE_FUNCTIONALIZATION=0 export PROFILE_DIR=/tmp/ export CACHE_DIR=/tmp/ export DATASET_NAME=lambdalabs/naruto-blip-captions export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p export TRAIN_STEPS=50 export OUTPUT_DIR=/tmp/trained-model/ -python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4' - +python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4' ``` +Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer. + ### Environment Envs Explained * `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer. diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 5d9d8c540f11..9719585d3dfb 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -140,33 +140,43 @@ def run_optimizer(self): self.optimizer.step() def start_training(self): - times = [] - last_time = time.time() - step = 0 - while True: - if self.global_step >= self.args.max_train_steps: - xm.mark_step() - break - if step == 4 and PROFILE_DIR is not None: - xm.wait_device_ops() - xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration) + dataloader_exception = False + measure_start_step = args.measure_start_step + assert measure_start_step < self.args.max_train_steps + total_time = 0 + for step in range(0, self.args.max_train_steps): try: batch = next(self.dataloader) except Exception as e: + dataloader_exception = True print(e) break + if step == measure_start_step and PROFILE_DIR is not None: + xm.wait_device_ops() + xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration) + last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) - step_time = time.time() - last_time - if step >= 10: - times.append(step_time) - print(f"step: {step}, step_time: {step_time}") - if step % 5 == 0: - print(f"step: {step}, loss: {loss}") - last_time = time.time() self.global_step += 1 - step += 1 - # print(f"Average step time: {sum(times)/len(times)}") - xm.wait_device_ops() + + def print_loss_closure(step, loss): + print(f"Step: {step}, Loss: {loss}") + + if args.print_loss: + xm.add_step_closure( + print_loss_closure, + args=( + self.global_step, + loss, + ), + ) + xm.mark_step() + if not dataloader_exception: + xm.wait_device_ops() + total_time = time.time() - last_time + print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") + else: + print("dataloader exception happen, skip result") + return def step_fn( self, @@ -180,7 +190,10 @@ def step_fn( noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype) bsz = latents.shape[0] timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, ) timesteps = timesteps.long() @@ -224,9 +237,6 @@ def step_fn( def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." - ) parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms") parser.add_argument( "--pretrained_model_name_or_path", @@ -258,12 +268,6 @@ def parse_args(): " or to a folder containing files that 🤗 Datasets can understand." ), ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The config of the Dataset, leave as None if there's only one config.", - ) parser.add_argument( "--train_data_dir", type=str, @@ -283,15 +287,6 @@ def parse_args(): default="text", help="The column of the dataset containing a caption or a list of captions.", ) - parser.add_argument( - "--max_train_samples", - type=int, - default=None, - help=( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ), - ) parser.add_argument( "--output_dir", type=str, @@ -304,7 +299,6 @@ def parse_args(): default=None, help="The directory where the downloaded models and datasets will be stored.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -374,12 +368,19 @@ def parse_args(): default=1, help=("Number of subprocesses to use for data loading to cpu."), ) + parser.add_argument( + "--loader_prefetch_factor", + type=int, + default=2, + help=("Number of batches loaded in advance by each worker."), + ) parser.add_argument( "--device_prefetch_size", type=int, default=1, help=("Number of subprocesses to use for data loading to tpu from cpu. "), ) + parser.add_argument("--measure_start_step", type=int, default=10, help="Step to start profiling.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -394,12 +395,8 @@ def parse_args(): "--mixed_precision", type=str, default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), + choices=["no", "bf16"], + help=("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"), ) parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") @@ -409,6 +406,12 @@ def parse_args(): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) + parser.add_argument( + "--print_loss", + default=False, + action="store_true", + help=("Print loss at every step."), + ) args = parser.parse_args() @@ -436,7 +439,6 @@ def load_dataset(args): # Downloading and loading a dataset from the hub. dataset = datasets.load_dataset( args.dataset_name, - args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir, ) @@ -481,9 +483,7 @@ def main(args): _ = xp.start_server(PORT) num_devices = xr.global_runtime_device_count() - device_ids = np.arange(num_devices) - mesh_shape = (num_devices, 1) - mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y")) + mesh = xs.get_1d_mesh("data") xs.set_global_mesh(mesh) text_encoder = CLIPTextModel.from_pretrained( @@ -520,6 +520,7 @@ def main(args): from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward) + unet.enable_xla_flash_attention(partition_spec=("data", None, None, None)) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -530,15 +531,12 @@ def main(args): # as these weights are only used for inference, keeping weights in full # precision is not required. weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": + if args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 device = xm.xla_device() - print("device: ", device) - print("weight_dtype: ", weight_dtype) + # Move text_encode and vae to device and cast to weight_dtype text_encoder = text_encoder.to(device, dtype=weight_dtype) vae = vae.to(device, dtype=weight_dtype) unet = unet.to(device, dtype=weight_dtype) @@ -606,24 +604,27 @@ def collate_fn(examples): collate_fn=collate_fn, num_workers=args.dataloader_num_workers, batch_size=args.train_batch_size, + prefetch_factor=args.loader_prefetch_factor, ) train_dataloader = pl.MpDeviceLoader( train_dataloader, device, input_sharding={ - "pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True), - "input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True), + "pixel_values": xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True), + "input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True), }, loader_prefetch_size=args.loader_prefetch_size, device_prefetch_size=args.device_prefetch_size, ) + num_hosts = xr.process_count() + num_devices_per_host = num_devices // num_hosts if xm.is_master_ordinal(): print("***** Running training *****") - print(f"Instantaneous batch size per device = {args.train_batch_size}") + print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }") print( - f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}" + f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}" ) print(f" Total optimization steps = {args.max_train_steps}") diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 13d910db6135..444f201f6376 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -20,8 +20,8 @@ from torch import nn from ..image_processor import IPAdapterMaskProcessor -from ..utils import deprecate, logging -from ..utils.import_utils import is_torch_npu_available, is_xformers_available +from ..utils import deprecate, is_torch_xla_available, logging +from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph @@ -36,6 +36,15 @@ else: xformers = None +if is_torch_xla_available(): + # flash attention pallas kernel is introduced in the torch_xla 2.3 release. + if is_torch_xla_version(">", "2.2"): + from torch_xla.experimental.custom_kernel import flash_attention + from torch_xla.runtime import is_spmd + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + @maybe_allow_in_graph class Attention(nn.Module): @@ -275,6 +284,33 @@ def __init__( ) self.set_processor(processor) + def set_use_xla_flash_attention( + self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None + ) -> None: + r""" + Set whether to use xla flash attention from `torch_xla` or not. + + Args: + use_xla_flash_attention (`bool`): + Whether to use pallas flash attention kernel from `torch_xla` or not. + partition_spec (`Tuple[]`, *optional*): + Specify the partition specification if using SPMD. Otherwise None. + """ + if use_xla_flash_attention: + if not is_torch_xla_available: + raise "torch_xla is not available" + elif is_torch_xla_version("<", "2.3"): + raise "flash attention pallas kernel is supported from torch_xla version 2.3" + elif is_spmd() and is_torch_xla_version("<", "2.4"): + raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" + else: + processor = XLAFlashAttnProcessor2_0(partition_spec) + else: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: r""" Set whether to use npu flash attention from `torch_npu` or not. @@ -2753,6 +2789,122 @@ def __call__( return hidden_states +class XLAFlashAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. + """ + + def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + if is_torch_xla_version("<", "2.3"): + raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") + if is_spmd() and is_torch_xla_version("<", "2.4"): + raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") + self.partition_spec = partition_spec + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]): + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) + # Convert mask to float and replace 0s with -inf and 1s with 0 + attention_mask = ( + attention_mask.float() + .masked_fill(attention_mask == 0, float("-inf")) + .masked_fill(attention_mask == 1, float(0.0)) + ) + + # Apply attention mask to key + key = key + attention_mask + query /= math.sqrt(query.shape[3]) + partition_spec = self.partition_spec if is_spmd() else None + hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec) + else: + logger.warning( + "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096." + ) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class MochiVaeAttnProcessor2_0: r""" Attention processor used in Mochi VAE. @@ -5074,6 +5226,7 @@ def __init__(self): FusedCogVideoXAttnProcessor2_0, XFormersAttnAddedKVProcessor, XFormersAttnProcessor, + XLAFlashAttnProcessor2_0, AttnProcessorNPU, AttnProcessor2_0, MochiVaeAttnProcessor2_0, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 7b2022798d41..4fe457706473 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -208,6 +208,35 @@ def disable_npu_flash_attention(self) -> None: """ self.set_use_npu_flash_attention(False) + def set_use_xla_flash_attention( + self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_xla_flash_attention method + # gets the message + def fn_recursive_set_flash_attention(module: torch.nn.Module): + if hasattr(module, "set_use_xla_flash_attention"): + module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec) + + for child in module.children(): + fn_recursive_set_flash_attention(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_flash_attention(module) + + def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None): + r""" + Enable the flash attention pallals kernel for torch_xla. + """ + self.set_use_xla_flash_attention(True, partition_spec) + + def disable_xla_flash_attention(self): + r""" + Disable the flash attention pallals kernel for torch_xla. + """ + self.set_use_xla_flash_attention(False) + def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c8f64adf3e8a..f91cee8113f2 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -86,6 +86,7 @@ is_torch_npu_available, is_torch_version, is_torch_xla_available, + is_torch_xla_version, is_torchsde_available, is_torchvision_available, is_transformers_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f1323bf00ea4..e3b7655737a8 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -700,6 +700,21 @@ def is_torch_version(operation: str, version: str): return compare_versions(parse(_torch_version), operation, version) +def is_torch_xla_version(operation: str, version: str): + """ + Compares the current torch_xla version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of torch_xla + """ + if not is_torch_xla_available: + return False + return compare_versions(parse(_torch_xla_version), operation, version) + + def is_transformers_version(operation: str, version: str): """ Compares the current Transformers version to a given reference with an operation.