Skip to content

Flax: Fix img2img and align with other pipeline #1824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 29, 2022

Conversation

skirsten
Copy link
Contributor

  • Fixes for flax img2img
  • Other misc changes
  • Re-aligned the img2img pipe with the normal pipe (mostly copy paste)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 24, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! This clearly improves the existing pipeline, I just left comments with a few suggestions and ideas for a potential improvement (make strength parallelizable too).

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this! Left some comments

width // self.vae_scale_factor,
)
if noise is None:
noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any specific reason to hardcode the dtype to jnp.float32? think we should use self.dtype here as before for half-precision inference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a copy paste from the normal pipeline. There are some other places where the dtype is hardcoded to jnp.float32 related to latents. I just remember setting every occurrence to self.dtype and losing all details in the generated images.

I would prefer if all occurrences of jnp.float32 could be removed at the same time. But I can also just revert this one change. Let me know what you think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @pcuenca do we need noise in float32 in Jax?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I investigated this a bit more and found the "problem". I asked the JAX maintainers on their thoughts here: jax-ml/jax#13798

Basically, to prevent losing detail in the image, the noise has to be generated in float32 and then casted to self.dtype:

noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32).astype(self.dtype)

Still I would prefer not to do that in this PR and instead create a new PR that fixes all of these occurrences at once.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That experiment was very cool and instructive, thanks a lot for taking the time to clarify the behaviour! I agree to deal with this in another PR.

# run with python for loop
for i in range(t_start, len(scheduler_state.timesteps)):
for i in range(start_timestep, num_inference_steps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think it's safer to use len(scheduler_state.timesteps) that num_inference_steps because depending on the scheduler the could be some extra timesteps, for example if using PNDMScheduler with PRK stpes, it'll add some extra time timesteps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I was not aware of that. Thanks for letting me know. In that case it would also have to be changed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it looks to me like here it is making sure that the length is always equal to num_inference_steps by dropping some plms timesteps.

Or am I getting that wrong?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you are right, let's use num_inference_steps here. Because the timesteps are extended for 2nd order schedulers like Heun.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure about this. Because the timesteps are extended for some schedulers, shouldn't we loop through the timesteps instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those schedulers are not part of flax yet no? (specifically the HeunScheduler),

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are not.

Comment on lines 191 to 193
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we compute the height and width in __call__ as well, so here we could make it required.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks good now, thank you for addressing the comments! The quality checks are failing, run make style and make quality , push and then it should be good to merge after green light from @pcuenca

@skirsten
Copy link
Contributor Author

The failing CI seems to be unrelated to these changes (doc_builder)

@pcuenca
Copy link
Member

pcuenca commented Dec 29, 2022

@skirsten This is what make style did for me: a47d1f0, you probably didn't install doc-builder.

Feel free to cherry-pick in your branch :)

Edit: I created this PR to your repo.

@skirsten
Copy link
Contributor Author

@pcuenca Thanks, I added you commit. Yes, I was missing some packages (and make did not resolve python to python3) so I just assumed that the problem was not in this branch 🙈

@pcuenca
Copy link
Member

pcuenca commented Dec 29, 2022

No worries!

@pcuenca
Copy link
Member

pcuenca commented Dec 29, 2022

@skirsten @patil-suraj is this ready to merge then?

@pcuenca pcuenca merged commit ab0e92f into huggingface:main Dec 29, 2022
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Flax: Add components function

* Flax: Fix img2img and align with other pipeline

* Flax: Fix PRNGKey type

* Refactor strength to start_timestep

* Fix preprocess images

* Fix processed_images dimen

* latents.shape -> latents_shape

* Fix typo

* Remove "static" comment

* Remove unnecessary optional types in _generate

* Apply doc-builder code style.

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants