From 8cef82c074c2c38b82f039f100f0021eddc29741 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 5 May 2025 18:00:29 +0530 Subject: [PATCH 1/4] feat: enable quantization for hidream lora training. --- .../train_dreambooth_lora_hidream.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 39de32091408..4ce4ef1d7e80 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -16,6 +16,7 @@ import argparse import copy import itertools +import json import logging import math import os @@ -27,14 +28,13 @@ import numpy as np import torch -import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib -from peft import LoraConfig, set_peft_model_state_dict +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose @@ -47,6 +47,7 @@ import diffusers from diffusers import ( AutoencoderKL, + BitsAndBytesConfig, FlowMatchEulerDiscreteScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel, @@ -282,6 +283,12 @@ def parse_args(input_args=None): default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) parser.add_argument( "--revision", type=str, @@ -1056,6 +1063,14 @@ def main(args): args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" ) + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0 @@ -1064,20 +1079,30 @@ def main(args): text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders( text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) - vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, ) + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + transformer = HiDreamImageTransformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, force_inference_output=True, ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) # We only train the additional adapter LoRA layers transformer.requires_grad_(False) @@ -1087,14 +1112,6 @@ def main(args): text_encoder_three.requires_grad_(False) text_encoder_four.requires_grad_(False) - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: # due to pytorch#99272, MPS does not yet support bfloat16. raise ValueError( @@ -1109,7 +1126,7 @@ def main(args): text_encoder_three.to(**to_kwargs) text_encoder_four.to(**to_kwargs) # we never offload the transformer to CPU, so we can just use the accelerator device - transformer.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device) # Initialize a text encoding pipeline and keep it to CPU for now. text_encoding_pipeline = HiDreamImagePipeline.from_pretrained( From 5d5e80a2ff44b9c8a5fcf2a8ac251d8409e265f7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 5 May 2025 18:09:43 +0530 Subject: [PATCH 2/4] better handle compute dtype. --- examples/dreambooth/train_dreambooth_lora_hidream.py | 3 ++- src/diffusers/quantizers/quantization_config.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 4ce4ef1d7e80..576330cebadf 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1089,7 +1089,8 @@ def main(args): if args.bnb_quantization_config_path is not None: with open(args.bnb_quantization_config_path, "r") as f: config_kwargs = json.load(f) - config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype quantization_config = BitsAndBytesConfig(**config_kwargs) transformer = HiDreamImageTransformer2DModel.from_pretrained( diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 0bc433be0ff3..cc4d4fc1b017 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -179,7 +179,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin): This is a wrapper class about all possible attributes and features that you can play with a model that has been loaded using `bitsandbytes`. - This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. + This replaces `load_in_8bit` or `load_in_4bit` therefore both options are mutually exclusive. Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`, then more arguments will be added to this class. From 6fd71bddcb76e1a5b6c53d596448ccaa475fa7ac Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 5 May 2025 20:04:26 +0530 Subject: [PATCH 3/4] finalize. --- examples/dreambooth/README_hidream.md | 27 +++++++++++++++++++ .../train_dreambooth_lora_hidream.py | 9 ++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md index 63b19a7f70cc..6e4fd2510ffd 100644 --- a/examples/dreambooth/README_hidream.md +++ b/examples/dreambooth/README_hidream.md @@ -117,3 +117,30 @@ We provide several options for optimizing memory optimization: * `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model. + +## Using quantization + +You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file: + +```json +{ + "load_in_4bit": true, + "bnb_4bit_quant_type": "nf4" +} +``` + +Below, we provide some numbers with and without the use of NF4 quantization when training: + +``` +(with quantization) +Memory (before device placement): 9.085089683532715 GB. +Memory (after device placement): 34.59585428237915 GB. +Memory (after backward): 36.90267467498779 GB. + +(without quantization) +Memory (before device placement): 0.0 GB. +Memory (after device placement): 57.6400408744812 GB. +Memory (after backward): 59.932212829589844 GB. +``` + +The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first. \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 576330cebadf..946ffaa65c52 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1713,10 +1713,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) HiDreamImagePipeline.save_lora_weights( From 881024decbf1e1195bb8cf0a93ee0ca18a60d507 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 5 May 2025 20:06:57 +0530 Subject: [PATCH 4/4] fix dtype. --- examples/dreambooth/train_dreambooth_lora_hidream.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 946ffaa65c52..aa3ffb2483e1 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -1127,7 +1127,12 @@ def main(args): text_encoder_three.to(**to_kwargs) text_encoder_four.to(**to_kwargs) # we never offload the transformer to CPU, so we can just use the accelerator device - transformer.to(accelerator.device) + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + transformer.to(**transformer_to_kwargs) # Initialize a text encoding pipeline and keep it to CPU for now. text_encoding_pipeline = HiDreamImagePipeline.from_pretrained(