Skip to content

Flax: Fix img2img and align with other pipeline #1824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 29, 2022
Merged
2 changes: 1 addition & 1 deletion src/diffusers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
```"""
return self._cast_floating_to(params, jnp.float16, mask)

def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0

def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/vae_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def setup(self):
dtype=self.dtype,
)

def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
47 changes: 46 additions & 1 deletion src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import importlib
import inspect
import os
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -475,6 +475,51 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model = pipeline_class(**init_kwargs, dtype=dtype)
return model, params

@staticmethod
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"])
return expected_modules, optional_parameters

@property
def components(self) -> Dict[str, Any]:
r"""

The `self.components` property can be useful to run different pipelines with the same weights and
configurations to not have to re-allocate memory.

Examples:

```py
>>> from diffusers import (
... FlaxStableDiffusionPipeline,
... FlaxStableDiffusionImg2ImgPipeline,
... )

>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
... )
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
```

Returns:
A dictionary containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}

if set(components.keys()) != expected_modules:
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components} are defined."
)

return components

@staticmethod
def numpy_to_pil(images):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def components(self) -> Dict[str, Any]:
```

Returns:
A dictionaly containing all the modules needed to initialize the pipeline.
A dictionary containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,14 @@ def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: float = 7.5,
prng_seed: jax.random.KeyArray,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: jnp.array = None,
neg_prompt_ids: Optional[jnp.array] = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Expand Down Expand Up @@ -281,15 +277,15 @@ def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
prng_seed: jax.random.KeyArray,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
neg_prompt_ids: jnp.array = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down
Loading