-
Notifications
You must be signed in to change notification settings - Fork 6k
[training] feat: enable quantization for hidream lora training. #11494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -179,7 +179,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin): | |||
This is a wrapper class about all possible attributes and features that you can play with a model that has been | |||
loaded using `bitsandbytes`. | |||
|
|||
This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. | |||
This replaces `load_in_8bit` or `load_in_4bit` therefore both options are mutually exclusive. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Harmless change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! 🤏🤏🤏
btw @sayakpaul - do we usually do similar quantization support with a json config? |
Not sure what you mean 👀 |
What does this PR do?
This PR adds support to apply quantization from
bitsandbytes
to the base model before we attach LoRA params to it and train them. This helps reduce the memory consumption quite a bit:With
--offload
, we can reduce further. The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first.Quick test:
The bnb config json:
Results:

WandB: https://wandb.ai/sayakpaul/dreambooth-hidream-lora/runs/01l8vy12
TODO