-
Notifications
You must be signed in to change notification settings - Fork 6k
Flax: Fix img2img and align with other pipeline #1824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
skirsten
commented
Dec 24, 2022
- Fixes for flax img2img
- Other misc changes
- Re-aligned the img2img pipe with the normal pipe (mostly copy paste)
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! This clearly improves the existing pipeline, I just left comments with a few suggestions and ideas for a potential improvement (make strength
parallelizable too).
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this! Left some comments
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Show resolved
Hide resolved
width // self.vae_scale_factor, | ||
) | ||
if noise is None: | ||
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any specific reason to hardcode the dtype
to jnp.float32
? think we should use self.dtype
here as before for half-precision inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a copy paste from the normal pipeline. There are some other places where the dtype
is hardcoded to jnp.float32
related to latents. I just remember setting every occurrence to self.dtype
and losing all details in the generated images.
I would prefer if all occurrences of jnp.float32
could be removed at the same time. But I can also just revert this one change. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @pcuenca do we need noise in float32 in Jax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I investigated this a bit more and found the "problem". I asked the JAX maintainers on their thoughts here: jax-ml/jax#13798
Basically, to prevent losing detail in the image, the noise has to be generated in float32
and then casted to self.dtype
:
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32).astype(self.dtype)
Still I would prefer not to do that in this PR and instead create a new PR that fixes all of these occurrences at once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That experiment was very cool and instructive, thanks a lot for taking the time to clarify the behaviour! I agree to deal with this in another PR.
# run with python for loop | ||
for i in range(t_start, len(scheduler_state.timesteps)): | ||
for i in range(start_timestep, num_inference_steps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think it's safer to use len(scheduler_state.timesteps)
that num_inference_steps
because depending on the scheduler the could be some extra timesteps, for example if using PNDMScheduler
with PRK stpes, it'll add some extra time timesteps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 I was not aware of that. Thanks for letting me know. In that case it would also have to be changed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, it looks to me like here it is making sure that the length is always equal to num_inference_steps
by dropping some plms timesteps.
Or am I getting that wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually you are right, let's use num_inference_steps
here. Because the timesteps are extended for 2nd order schedulers like Heun.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm not sure about this. Because the timesteps are extended for some schedulers, shouldn't we loop through the timesteps
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those schedulers are not part of flax yet no? (specifically the HeunScheduler
),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, they are not.
# 0. Default height and width to unet | ||
height = height or self.unet.config.sample_size * self.vae_scale_factor | ||
width = width or self.unet.config.sample_size * self.vae_scale_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we compute the height
and width
in __call__
as well, so here we could make it required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All looks good now, thank you for addressing the comments! The quality checks are failing, run make style
and make quality
, push and then it should be good to merge after green light from @pcuenca
7085ba1
to
19b3de4
Compare
The failing CI seems to be unrelated to these changes (doc_builder) |
@pcuenca Thanks, I added you commit. Yes, I was missing some packages (and make did not resolve python to python3) so I just assumed that the problem was not in this branch 🙈 |
No worries! |
@skirsten @patil-suraj is this ready to merge then? |
* Flax: Add components function * Flax: Fix img2img and align with other pipeline * Flax: Fix PRNGKey type * Refactor strength to start_timestep * Fix preprocess images * Fix processed_images dimen * latents.shape -> latents_shape * Fix typo * Remove "static" comment * Remove unnecessary optional types in _generate * Apply doc-builder code style. Co-authored-by: Pedro Cuenca <pedro@huggingface.co>