Skip to content

HiDream Image #11231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 11, 2025
Merged

HiDream Image #11231

merged 25 commits into from
Apr 11, 2025

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Apr 8, 2025

What does this PR do?

Original code

Weights

Code

import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import (
    UniPCMultistepScheduler,
    HiDreamImagePipeline,
    HiDreamImageTransformer2DModel,
)

scheduler = UniPCMultistepScheduler(
    flow_shift=3.0,
    prediction_type="flow_prediction",
    use_flow_sigmas=True,
)

tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct"
)

text_encoder_4 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)

transformer = HiDreamImageTransformer2DModel.from_pretrained(
    "HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16
)

pipe = HiDreamImagePipeline.from_pretrained(
    "HiDream-ai/HiDream-I1-Full",
    scheduler=scheduler,
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()

image = pipe(
    'A cat holding a sign that says "Hi-Dreams.ai".',
    height=1024,
    width=1024,
    guidance_scale=5.0,
    num_inference_steps=50,
    generator=torch.Generator("cuda").manual_seed(0),
).images[0]

image.save("hidream.png")

Output

hidream

NOTES

  • Scheduler changes are not necessarily required, above test of HiDream-ai/HiDream-I1-Full is using existing UniPCMultistepScheduler with prediction_type and use_flow_sigmas

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hlky
Copy link
Contributor Author

hlky commented Apr 9, 2025

@kebe7jun What is your transformers version? Can you do pip install -U transformers and try with latest?

@kebe7jun
Copy link

kebe7jun commented Apr 9, 2025

@kebe7jun What is your transformers version? Can you do pip install -U transformers and try with latest?

thanks, this can work.

@hlky hlky marked this pull request as ready for review April 10, 2025 13:31
)


class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FromOriginalModelMixin shouldn't be needed here I think since the weights are diffusers format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, thanks, spotted a couple other things in 9d43a32

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good 👍🏽 . Could we add fast tests for Pipeline and Model.

@hlky
Copy link
Contributor Author

hlky commented Apr 10, 2025

@bot /style

Copy link
Contributor

Style fixes have been applied. View the workflow run here.

_, seq_len, _ = prompt_embeds.shape

# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just nits, but we discussed on doing this repeat/expand parts in encode_prompt. Not blocker to merge to main atm, so feel free to take up in followup PR. Same comment for other similar repeats

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree here, should just return a single prompt_embeds

@hlky
Copy link
Contributor Author

hlky commented Apr 10, 2025

@a-r-r-o-w Removing _init_weights changes the slice in local tests, need to test if real output is perceptually different, there are a few other cases where we have these functions.

@a-r-r-o-w
Copy link
Member

I think the test slices differences is expected, no? It just changes the random initialization of the matrices, so if we're loading pretrained weights, it wouldn't cause perceptual difference

Co-authored-by: Aryan <contact.aryanvs@gmail.com>
@ShuyUSTC
Copy link

ShuyUSTC commented Apr 11, 2025

Hi @hlky,

Thank you for your contribution and effort in integrating our HiDream-I1 into the diffusers library! We’re the official team behind this model (HiDream-I1), and we’re currently working on its official integration.

We’d love to collaborate with you to refine this PR—whether by reviewing the implementation, adding missing components (e.g., docs, tests), or assisting with upstream merging. Let us know how you’d prefer to proceed (e.g., we can co-author this PR or build upon your work).

Again, we appreciate your initiative! Looking forward to your thoughts.

Best,
HiDream.ai

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 11, 2025

@ShuyUSTC
thanks for the message. and congrats on such great work!

feel free to give the PR a review and help test it:)

Copy link

@ShuyUSTC ShuyUSTC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the operation noise_pred = -noise_pred in pipeline_hidream_image to transformer_hidream_image

hidden_states = ff_output_i + hidden_states
encoder_hidden_states = ff_output_t + encoder_hidden_states
return hidden_states, encoder_hidden_states

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class NegateLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return -x

Add a NegateLayer to convert the input x to -x

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShuyUSTC The comments about returning the negative value are not really in diffusers coding style for how we write the modeling code. So, we will be unable to add those changes here, and will have to keep the negation in the pipeline.

Feel free to let us know if there's anything else that you'd like us to change

caption_projection = []
for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
self.caption_projection = nn.ModuleList(caption_projection)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.caption_projection = nn.ModuleList(caption_projection)
self.caption_projection = nn.ModuleList(caption_projection)
self.negate_layer = NegateLayer()

Initialize a negate_layer


hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
output = self.final_layer(hidden_states, adaln_input)
output = self.unpatchify(output, img_sizes, self.training)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output = self.unpatchify(output, img_sizes, self.training)
output = self.unpatchify(output, img_sizes, self.training)
output = self.negate_layer(output)

Convert the output to -output

img_ids=img_ids,
return_dict=False,
)[0]
noise_pred = -noise_pred

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
noise_pred = -noise_pred

Remove this operation and add a negate_layer to HiDreamImageTransformer2DModel to convert the output

@hlky hlky dismissed ShuyUSTC’s stale review April 11, 2025 08:47

NegateLayer does not fit with coding style, there are other cases of -noise_pred in the codebase.

@nitinmukesh
Copy link

nitinmukesh commented Apr 11, 2025

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu Addressed some of the review comments. LMK if any further changes are required (we can do follow up PRs too). I've only tested with the Full model for now, and will do the other two tomorrow unless someone else can finish up the PR. All models seem to be working but the outputs of "Fast" feel a bit off. It might be scheduler related - looking into it

Also tested that num_images_per_prompt > 1 works with the changes to encode_prompt. PR LGTM to merge for first pass 👍

@vladmandic
Copy link
Contributor

vladmandic commented Apr 11, 2025

fyi, i've tested hidream-i1-fast and it works fine with this pr, but there is one issue...
using low_cpu_mem_usage=True sometimes/often breaks offloading as model fails with:

ValueError: weight is on the meta device, we need a value to put in on 0.

@a-r-r-o-w
Copy link
Member

@vladmandic Did you try with the latest changes? I encountered the issue too, but now I don't get it any more after the refactor

@vladmandic
Copy link
Contributor

just updated the codebase, i cant reproduce at the moment as well anymore - will run few more tests as it was pretty random to start with.

@tin2tin
Copy link

tin2tin commented Apr 11, 2025

Apparently, here are a couple of improvements:
https://github.com/lum3on/comfyui_HiDream-Sampler

@vladmandic
Copy link
Contributor

update: no issues with offloading using latest codebase.
also works fine with both bnb and optimum.quanto quantization_config.
regarding llama replacement, yes, thats totally ok, but imo that's not really up to diffusers to provide other than one-liner "here is how you load te4" which is already in docs.

@yiyixuxu yiyixuxu merged commit 0ef2935 into huggingface:main Apr 11, 2025
12 checks passed
@yiyixuxu
Copy link
Collaborator

merged PR - we can add any follow-up changes in a new one

@Skquark
Copy link

Skquark commented Apr 11, 2025

Would this work for the NF4 Quantized 4-bit models? I had implemented already using this fork https://github.com/hykilpikonna/HiDream-I1-nf4 and these models azaneko/HiDream-I1-Dev-nf4 because it didn't run on less than 24gb otherwise. Transitioning to this implementation from the Github and just want to make sure I can keep the code mostly the same.. Wouldn't mind seeing example for memory optimized code that runs <=16GB...

@vladmandic
Copy link
Contributor

@Skquark on-the-fly quantization using bitsandbytes and/or optimum.quanto together with diffusers implementation works just fine, you dont need random unofficial fixed quants. with bnb-nf4, it works with 16gb vram and with quanto-int4 it works even with 12gb.

@Skquark
Copy link

Skquark commented Apr 16, 2025

@vladmandic It'd still be nice to have working example of quantization in the docs of this one since it takes more than 24gb and a bit more complicated. Do we just run BitsAndBytesConfig 4bit quant on the Transformer or the Tokenizer, and can we optimize Llama encoder with nf4 too? Could it also use group offloading? Thanks.

@nitinmukesh
Copy link

@Skquark

See if this helps
#11337

@Skquark
Copy link

Skquark commented Apr 16, 2025

@nitinmukesh Interesting, but not what I was expecting to load those models. That's similar to what I was originally doing, but he was saying it's better on-the-fly quantization instead of using the modded int4 models. Since there seems to be like 4 different ways to optimize this, I'll just have an Optimization Mode option of which to try in my app and figure it out from there. Any better ways?

@nitinmukesh
Copy link

nitinmukesh commented Apr 16, 2025

On-the fly is very time consuming. Each launch will quantize again.
You can create your own repo using whatever settings you prefer and then save_pretrained . Put on HF (locally also works) and use.

Also I have added GGUF version if you know how to use. (same topic)

@vladmandic
Copy link
Contributor

vladmandic commented Apr 16, 2025

@vladmandic It'd still be nice to have working example of quantization in the docs of this one since it takes more than 24gb and a bit more complicated. Do we just run BitsAndBytesConfig 4bit quant on the Transformer or the Tokenizer, and can we optimize Llama encoder with nf4 too? Could it also use group offloading? Thanks.

you just pass quantization_config when loading transformer, text_encoder_3 (t5), text_encoder_4 (llama)
you can quantize any of them or all 3.
(not tokenizer and not te1/te2)

and quantization_config can be any valid bitsandbytes or optimium.quanto config.
it should also work with torchao and layerwise methods, but i didnt test those.
you can also mix&match, e.g. you can run transformer in nf4 and te4 in fp8

and when you load individual components, you assemble the pipeline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.