Skip to content

Commit ab0e92f

Browse files
skirstenpcuenca
andauthored
Flax: Fix img2img and align with other pipeline (#1824)
* Flax: Add components function * Flax: Fix img2img and align with other pipeline * Flax: Fix PRNGKey type * Refactor strength to start_timestep * Fix preprocess images * Fix processed_images dimen * latents.shape -> latents_shape * Fix typo * Remove "static" comment * Remove unnecessary optional types in _generate * Apply doc-builder code style. Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent 9ea7052 commit ab0e92f

File tree

8 files changed

+179
-84
lines changed

8 files changed

+179
-84
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
189189
```"""
190190
return self._cast_floating_to(params, jnp.float16, mask)
191191

192-
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
192+
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
193193
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
194194

195195
@classmethod

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
112112
flip_sin_to_cos: bool = True
113113
freq_shift: int = 0
114114

115-
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
115+
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
116116
# init input tensors
117117
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
118118
sample = jnp.zeros(sample_shape, dtype=jnp.float32)

src/diffusers/models/vae_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ def setup(self):
806806
dtype=self.dtype,
807807
)
808808

809-
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
809+
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
810810
# init input tensors
811811
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
812812
sample = jnp.zeros(sample_shape, dtype=jnp.float32)

src/diffusers/pipeline_flax_utils.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import importlib
1818
import inspect
1919
import os
20-
from typing import Dict, List, Optional, Union
20+
from typing import Any, Dict, List, Optional, Union
2121

2222
import numpy as np
2323

@@ -475,6 +475,51 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
475475
model = pipeline_class(**init_kwargs, dtype=dtype)
476476
return model, params
477477

478+
@staticmethod
479+
def _get_signature_keys(obj):
480+
parameters = inspect.signature(obj.__init__).parameters
481+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
482+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
483+
expected_modules = set(required_parameters.keys()) - set(["self"])
484+
return expected_modules, optional_parameters
485+
486+
@property
487+
def components(self) -> Dict[str, Any]:
488+
r"""
489+
490+
The `self.components` property can be useful to run different pipelines with the same weights and
491+
configurations to not have to re-allocate memory.
492+
493+
Examples:
494+
495+
```py
496+
>>> from diffusers import (
497+
... FlaxStableDiffusionPipeline,
498+
... FlaxStableDiffusionImg2ImgPipeline,
499+
... )
500+
501+
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
502+
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
503+
... )
504+
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
505+
```
506+
507+
Returns:
508+
A dictionary containing all the modules needed to initialize the pipeline.
509+
"""
510+
expected_modules, optional_parameters = self._get_signature_keys(self)
511+
components = {
512+
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
513+
}
514+
515+
if set(components.keys()) != expected_modules:
516+
raise ValueError(
517+
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
518+
f" {expected_modules} to be defined, but {components} are defined."
519+
)
520+
521+
return components
522+
478523
@staticmethod
479524
def numpy_to_pil(images):
480525
"""

src/diffusers/pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def components(self) -> Dict[str, Any]:
764764
```
765765
766766
Returns:
767-
A dictionaly containing all the modules needed to initialize the pipeline.
767+
A dictionary containing all the modules needed to initialize the pipeline.
768768
"""
769769
expected_modules, optional_parameters = self._get_signature_keys(self)
770770
components = {

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,14 @@ def _generate(
184184
self,
185185
prompt_ids: jnp.array,
186186
params: Union[Dict, FrozenDict],
187-
prng_seed: jax.random.PRNGKey,
188-
num_inference_steps: int = 50,
189-
height: Optional[int] = None,
190-
width: Optional[int] = None,
191-
guidance_scale: float = 7.5,
187+
prng_seed: jax.random.KeyArray,
188+
num_inference_steps: int,
189+
height: int,
190+
width: int,
191+
guidance_scale: float,
192192
latents: Optional[jnp.array] = None,
193-
neg_prompt_ids: jnp.array = None,
193+
neg_prompt_ids: Optional[jnp.array] = None,
194194
):
195-
# 0. Default height and width to unet
196-
height = height or self.unet.config.sample_size * self.vae_scale_factor
197-
width = width or self.unet.config.sample_size * self.vae_scale_factor
198-
199195
if height % 8 != 0 or width % 8 != 0:
200196
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
201197

@@ -281,15 +277,15 @@ def __call__(
281277
self,
282278
prompt_ids: jnp.array,
283279
params: Union[Dict, FrozenDict],
284-
prng_seed: jax.random.PRNGKey,
280+
prng_seed: jax.random.KeyArray,
285281
num_inference_steps: int = 50,
286282
height: Optional[int] = None,
287283
width: Optional[int] = None,
288284
guidance_scale: Union[float, jnp.array] = 7.5,
289285
latents: jnp.array = None,
286+
neg_prompt_ids: jnp.array = None,
290287
return_dict: bool = True,
291288
jit: bool = False,
292-
neg_prompt_ids: jnp.array = None,
293289
):
294290
r"""
295291
Function invoked when calling the pipeline for generation.

0 commit comments

Comments
 (0)