Skip to content

Commit 19b3de4

Browse files
committed
Remove unnecessary optional types in _generate
1 parent 74128b2 commit 19b3de4

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,13 @@ def _generate(
185185
prompt_ids: jnp.array,
186186
params: Union[Dict, FrozenDict],
187187
prng_seed: jax.random.KeyArray,
188-
num_inference_steps: int = 50,
189-
height: Optional[int] = None,
190-
width: Optional[int] = None,
191-
guidance_scale: float = 7.5,
188+
num_inference_steps: int,
189+
height: int,
190+
width: int,
191+
guidance_scale: float,
192192
latents: Optional[jnp.array] = None,
193-
neg_prompt_ids: jnp.array = None,
193+
neg_prompt_ids: Optional[jnp.array] = None,
194194
):
195-
# 0. Default height and width to unet
196-
height = height or self.unet.config.sample_size * self.vae_scale_factor
197-
width = width or self.unet.config.sample_size * self.vae_scale_factor
198-
199195
if height % 8 != 0 or width % 8 != 0:
200196
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
201197

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,17 +181,13 @@ def _generate(
181181
params: Union[Dict, FrozenDict],
182182
prng_seed: jax.random.KeyArray,
183183
start_timestep: int,
184-
num_inference_steps: int = 50,
185-
height: Optional[int] = None,
186-
width: Optional[int] = None,
187-
guidance_scale: float = 7.5,
184+
num_inference_steps: int,
185+
height: int,
186+
width: int,
187+
guidance_scale: float,
188188
noise: Optional[jnp.array] = None,
189-
neg_prompt_ids: jnp.array = None,
189+
neg_prompt_ids: Optional[jnp.array] = None,
190190
):
191-
# 0. Default height and width to unet
192-
height = height or self.unet.config.sample_size * self.vae_scale_factor
193-
width = width or self.unet.config.sample_size * self.vae_scale_factor
194-
195191
if height % 8 != 0 or width % 8 != 0:
196192
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
197193

0 commit comments

Comments
 (0)