Skip to content

[training] feat: enable quantization for hidream lora training. #11494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 5, 2025

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented May 5, 2025

What does this PR do?

This PR adds support to apply quantization from bitsandbytes to the base model before we attach LoRA params to it and train them. This helps reduce the memory consumption quite a bit:

(with quantization)
Memory (before device placement): 9.085089683532715 GB.
Memory (after device placement): 34.59585428237915 GB.
Memory (after backward): 36.90267467498779 GB.

(without quantization)
Memory (before device placement): 0.0 GB.
Memory (after device placement): 57.6400408744812 GB.
Memory (after backward): 59.932212829589844 GB.

With --offload, we can reduce further. The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first.

Quick test:

export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
export INSTANCE_DIR="linoyts/3d_icon"
export OUTPUT_DIR="trained-hidream-lora"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_dreambooth_lora_hidream.py \
	--pretrained_model_name_or_path=$MODEL_NAME \
	--dataset_name=$INSTANCE_DIR  \
	--output_dir=$OUTPUT_DIR  \
	--bnb_quantization_config_path="bnb_config.json"  \
	--mixed_precision="bf16" \
	--instance_prompt="3d icon"  \
	--caption_column="prompt"  \
	--resolution=1024   --train_batch_size=1   \
	--gradient_accumulation_steps=4   \
	--use_8bit_adam   --rank=8   \
	--learning_rate=2e-4   --report_to="wandb" \
	--lr_scheduler="constant_with_warmup"   --lr_warmup_steps=100 \
	--max_train_steps=1000 \
	--cache_latents  --gradient_checkpointing  \
	--validation_epochs=25   --seed="0"   \
	--final_validation_prompt="a 3dicon, a llama eating ramen"

The bnb config json:

{
    "load_in_4bit": true,
    "bnb_4bit_quant_type": "nf4"
}

Results:
image

WandB: https://wandb.ai/sayakpaul/dreambooth-hidream-lora/runs/01l8vy12

TODO

  • Docs
  • Complete a full reasonable run

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul requested a review from linoytsaban May 5, 2025 14:34
@sayakpaul sayakpaul marked this pull request as ready for review May 5, 2025 14:34
@@ -179,7 +179,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `bitsandbytes`.

This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
This replaces `load_in_8bit` or `load_in_4bit` therefore both options are mutually exclusive.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Harmless change.

Copy link
Collaborator

@linoytsaban linoytsaban left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! 🤏🤏🤏

@linoytsaban
Copy link
Collaborator

btw @sayakpaul - do we usually do similar quantization support with a json config?

@sayakpaul
Copy link
Member Author

Not sure what you mean 👀

@sayakpaul sayakpaul merged commit 071807c into main May 5, 2025
16 checks passed
@sayakpaul sayakpaul deleted the quantized-hidream-training branch May 5, 2025 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants