Skip to content

[training] feat: enable quantization for hidream lora training. #11494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/dreambooth/README_hidream.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,30 @@ We provide several options for optimizing memory optimization:
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.

Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.

## Using quantization

You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:

```json
{
"load_in_4bit": true,
"bnb_4bit_quant_type": "nf4"
}
```

Below, we provide some numbers with and without the use of NF4 quantization when training:

```
(with quantization)
Memory (before device placement): 9.085089683532715 GB.
Memory (after device placement): 34.59585428237915 GB.
Memory (after backward): 36.90267467498779 GB.

(without quantization)
Memory (before device placement): 0.0 GB.
Memory (after device placement): 57.6400408744812 GB.
Memory (after backward): 59.932212829589844 GB.
```

The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first.
56 changes: 40 additions & 16 deletions examples/dreambooth/train_dreambooth_lora_hidream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import argparse
import copy
import itertools
import json
import logging
import math
import os
Expand All @@ -27,14 +28,13 @@

import numpy as np
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from peft import LoraConfig, set_peft_model_state_dict
from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
Expand All @@ -47,6 +47,7 @@
import diffusers
from diffusers import (
AutoencoderKL,
BitsAndBytesConfig,
FlowMatchEulerDiscreteScheduler,
HiDreamImagePipeline,
HiDreamImageTransformer2DModel,
Expand Down Expand Up @@ -282,6 +283,12 @@ def parse_args(input_args=None):
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--bnb_quantization_config_path",
type=str,
default=None,
help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
)
parser.add_argument(
"--revision",
type=str,
Expand Down Expand Up @@ -1056,6 +1063,14 @@ def main(args):
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
)

# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0
Expand All @@ -1064,20 +1079,31 @@ def main(args):
text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)

vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
quantization_config = None
if args.bnb_quantization_config_path is not None:
with open(args.bnb_quantization_config_path, "r") as f:
config_kwargs = json.load(f)
if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
quantization_config = BitsAndBytesConfig(**config_kwargs)

transformer = HiDreamImageTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
revision=args.revision,
variant=args.variant,
quantization_config=quantization_config,
torch_dtype=weight_dtype,
force_inference_output=True,
)
if args.bnb_quantization_config_path is not None:
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)

# We only train the additional adapter LoRA layers
transformer.requires_grad_(False)
Expand All @@ -1087,14 +1113,6 @@ def main(args):
text_encoder_three.requires_grad_(False)
text_encoder_four.requires_grad_(False)

# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
Expand All @@ -1109,7 +1127,12 @@ def main(args):
text_encoder_three.to(**to_kwargs)
text_encoder_four.to(**to_kwargs)
# we never offload the transformer to CPU, so we can just use the accelerator device
transformer.to(accelerator.device, dtype=weight_dtype)
transformer_to_kwargs = (
{"device": accelerator.device}
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)

# Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = HiDreamImagePipeline.from_pretrained(
Expand Down Expand Up @@ -1695,10 +1718,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

HiDreamImagePipeline.save_lora_weights(
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `bitsandbytes`.

This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
This replaces `load_in_8bit` or `load_in_4bit` therefore both options are mutually exclusive.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Harmless change.


Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
then more arguments will be added to this class.
Expand Down