From 8b729adc7f6ab118b5d74a88331095071cd51e18 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 23 May 2023 20:54:52 +0200 Subject: [PATCH 01/13] Add default to inpaint --- .../pipeline_stable_diffusion_inpaint.py | 82 ++++++++++--------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index f09db016d956..25e6ac5f4664 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -266,14 +266,10 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 if unet.config.in_channels != 9: - logger.warning( - f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," - f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`," - " 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify" - " this behavior, please check whether you have loaded the right checkpoint." - ) + logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( vae=vae, @@ -642,16 +638,7 @@ def prepare_latents( else: # otherwise initialise latents as init image + noise image = image.to(device=device, dtype=dtype) - if isinstance(generator, list): - image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) - - image_latents = self.vae.config.scaling_factor * image_latents - + image_latents = self._encode_image(image=image, generator=generator) latents = self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) @@ -661,6 +648,19 @@ def prepare_latents( return latents + def _encode_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): @@ -673,17 +673,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + masked_image_latents = self._encode_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -929,6 +919,7 @@ def __call__( timestep=latent_timestep, is_strength_max=is_strength_max, ) + noise = latents # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( @@ -944,16 +935,21 @@ def __call__( ) # 8. Check that sizes of mask, masked image and latents match - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) + num_channels_unet = self.unet.config.in_channels + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError(f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.") # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -967,7 +963,9 @@ def __call__( # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet( @@ -986,6 +984,14 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if num_channels_unet == 4: + if i == len(timesteps) - 1: + init_latents_proper = masked_image_latents + else: + init_latents_proper = self.scheduler.add_noise(masked_image_latents, noise, torch.tensor([t])) + + latents = (1 - mask) * init_latents_proper + mask * latents + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() From 0ca6f74ad03f58de25fe0fc5e87a47040712c907 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 23 May 2023 20:58:28 +0200 Subject: [PATCH 02/13] Make sure controlnet also works with normal sd for inpaint --- .../controlnet/pipeline_controlnet_inpaint.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 27475dc5ef8b..2a92239c9eab 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -1161,6 +1161,7 @@ def __call__( generator, latents, ) + noise = latents # 7. Prepare mask latent variables mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) @@ -1180,6 +1181,7 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop + num_channels_unet = self.unet.config.in_channels num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1213,7 +1215,9 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + noise_pred = self.unet( latent_model_input, t, @@ -1232,6 +1236,14 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if num_channels_unet == 4: + if i == len(timesteps) - 1: + init_latents_proper = masked_image_latents + else: + init_latents_proper = self.scheduler.add_noise(masked_image_latents, noise, torch.tensor([t])) + + latents = (1 - mask) * init_latents_proper + mask * latents + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() From 0ecac31fb746af9bafdfa48aae9eac220e3c4847 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 23 May 2023 21:06:49 +0200 Subject: [PATCH 03/13] Add tests --- .../controlnet/test_controlnet_inpaint.py | 117 ++++++++++++++++++ .../test_stable_diffusion_inpaint.py | 70 +++++++++++ 2 files changed, 187 insertions(+) diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 155286630c04..fefd94fb0219 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -163,6 +163,78 @@ def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) +class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests): + pipeline_class = StableDiffusionControlNetInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionControlNetInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS @@ -376,3 +448,48 @@ def test_canny(self): ) assert np.abs(expected_image - image).max() < 9e-2 + + def test_inpaint(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-inpaint") + + pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + image = load_image( + "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" + ).resize((512, 512)) + + mask_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" + "/stable_diffusion_inpaint/input_bench_mask.png" + ).resize((512, 512)) + + prompt = "pitch black hole" + + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + + output = pipe( + prompt, + image=image, + mask_image=mask_image, + control_image=control_image, + generator=generator, + output_type="np", + num_inference_steps=3, + ) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/inpaint.npy" + ) + + assert np.abs(expected_image - image).max() < 9e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 44de277ead07..fd55bbb6b539 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -192,6 +192,60 @@ def test_stable_diffusion_inpaint_lora(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) +class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests): + pipeline_class = StableDiffusionInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = PNDMScheduler(skip_prk_steps=True) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components @slow @require_torch_gpu @@ -378,6 +432,22 @@ def test_stable_diffusion_inpaint_strength_test(self): expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943]) assert np.abs(expected_slice - image_slice).max() < 3e-3 + def test_stable_diffusion_simple_inpaint_ddim(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + image = pipe(**inputs).images + image_slice = image[0, 253:256, 253:256, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0427, 0.0460, 0.0483, 0.0460, 0.0584, 0.0521, 0.1549, 0.1695, 0.1794]) + + assert np.abs(expected_slice - image_slice).max() < 6e-4 @nightly @require_torch_gpu From 4ed7e89e480e3af21901be38f1ff5c1957d6bd80 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 23 May 2023 21:36:33 +0200 Subject: [PATCH 04/13] improve --- .../controlnet/pipeline_controlnet_inpaint.py | 20 +++++++++++--- .../pipeline_stable_diffusion_inpaint.py | 27 ++++++++++++++----- .../test_stable_diffusion_inpaint.py | 23 +++++++++++++--- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 2a92239c9eab..299eb14ceba0 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -226,6 +226,17 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + + This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as + [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) + as well as default text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). + Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on + those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -1237,12 +1248,13 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: - if i == len(timesteps) - 1: - init_latents_proper = masked_image_latents - else: + init_latents_proper = masked_image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: init_latents_proper = self.scheduler.add_noise(masked_image_latents, noise, torch.tensor([t])) - latents = (1 - mask) * init_latents_proper + mask * latents + latents = (1 - init_mask) * init_latents_proper + init_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 25e6ac5f4664..287b0c258898 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -155,7 +155,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + Pipeline for text-guided image inpainting using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) @@ -167,6 +167,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + + It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such + as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default + text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible with + this pipeline, but might be less performant. + + + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -949,7 +959,9 @@ def __call__( " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: - raise ValueError(f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.") + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -985,12 +997,13 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: - if i == len(timesteps) - 1: - init_latents_proper = masked_image_latents - else: - init_latents_proper = self.scheduler.add_noise(masked_image_latents, noise, torch.tensor([t])) + init_latents_proper = masked_image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t])) - latents = (1 - mask) * init_latents_proper + mask * latents + latents = (1 - init_mask) * init_latents_proper + init_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index fd55bbb6b539..e3f8ea250f44 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -192,6 +192,7 @@ def test_stable_diffusion_inpaint_lora(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS @@ -247,6 +248,23 @@ def get_dummy_components(self): } return components + def test_stable_diffusion_inpaint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @slow @require_torch_gpu class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase): @@ -433,9 +451,7 @@ def test_stable_diffusion_inpaint_strength_test(self): assert np.abs(expected_slice - image_slice).max() < 3e-3 def test_stable_diffusion_simple_inpaint_ddim(self): - pipe = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", safety_checker=None - ) + pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() @@ -449,6 +465,7 @@ def test_stable_diffusion_simple_inpaint_ddim(self): assert np.abs(expected_slice - image_slice).max() < 6e-4 + @nightly @require_torch_gpu class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase): From c9564bebd1a6f242899bf8d74fca3fa16aa0e639 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 23 May 2023 21:42:06 +0200 Subject: [PATCH 05/13] Correct encode images function --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 287b0c258898..c4a4fb11aa39 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -664,6 +664,7 @@ def _encode_image(self, image: torch.Tensor, generator: torch.Generator): self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) for i in range(image.shape[0]) ] + image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) From 1c03c6063cd81c93d0e4826cca30e7f5883da1ae Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 23 May 2023 22:02:26 +0200 Subject: [PATCH 06/13] Correct inpaint controlnet --- .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 299eb14ceba0..ba949e0ee952 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -1252,7 +1252,7 @@ def __call__( init_mask = mask[:1] if i < len(timesteps) - 1: - init_latents_proper = self.scheduler.add_noise(masked_image_latents, noise, torch.tensor([t])) + init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t])) latents = (1 - init_mask) * init_latents_proper + init_mask * latents From fd83e1df688e3b537f8c9286c8bad2802a8ac2fb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 24 May 2023 14:01:41 +0000 Subject: [PATCH 07/13] Improve text2img inpanit --- .../pipeline_stable_diffusion_inpaint.py | 43 +++++++++++++------ .../test_stable_diffusion_inpaint.py | 3 +- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index c4a4fb11aa39..b904a56d75d7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -626,6 +626,8 @@ def prepare_latents( image=None, timestep=None, is_strength_max=True, + return_noise=False, + return_image_latents=False, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -640,23 +642,28 @@ def prepare_latents( "However, either the image or the noise timestep has not been provided." ) + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_image(image=image, generator=generator) + if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - if is_strength_max: - # if strength is 100% then simply initialise the latents to noise - latents = noise - else: - # otherwise initialise latents as init image + noise - image = image.to(device=device, dtype=dtype) - image_latents = self._encode_image(image=image, generator=generator) - latents = self.scheduler.add_noise(image_latents, noise, timestep) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - return latents + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs def _encode_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -917,7 +924,10 @@ def __call__( # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -929,8 +939,14 @@ def __call__( image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, ) - noise = latents + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( @@ -944,9 +960,10 @@ def __call__( generator, do_classifier_free_guidance, ) + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image = self._encode_image(init_image, generator=generator) # 8. Check that sizes of mask, masked image and latents match - num_channels_unet = self.unet.config.in_channels if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] @@ -998,7 +1015,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: - init_latents_proper = masked_image_latents[:1] + init_latents_proper = image_latents[:1] init_mask = mask[:1] if i < len(timesteps) - 1: diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index e3f8ea250f44..f8bd20dd6c13 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -458,10 +458,11 @@ def test_stable_diffusion_simple_inpaint_ddim(self): inputs = self.get_inputs(torch_device) image = pipe(**inputs).images + image_slice = image[0, 253:256, 253:256, -1].flatten() assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.0427, 0.0460, 0.0483, 0.0460, 0.0584, 0.0521, 0.1549, 0.1695, 0.1794]) + expected_slice = np.array([0.5157, 0.6858, 0.6873, 0.4619, 0.6416, 0.6898, 0.3702, 0.5960, 0.6935]) assert np.abs(expected_slice - image_slice).max() < 6e-4 From 5ca6b79cec0a3812b9e6a12731bdf48fcf5bb227 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 24 May 2023 16:21:56 +0000 Subject: [PATCH 08/13] make style --- .../controlnet/pipeline_controlnet_inpaint.py | 110 ++++++++++++------ .../pipeline_paint_by_example.py | 12 +- .../controlnet/test_controlnet_inpaint.py | 40 +++++-- 3 files changed, 104 insertions(+), 58 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index ba949e0ee952..cb6857b7b808 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -608,6 +608,16 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + def check_inputs( self, prompt, @@ -823,6 +833,8 @@ def prepare_latents( image=None, timestep=None, is_strength_max=True, + return_noise=False, + return_image_latents=False, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -837,32 +849,28 @@ def prepare_latents( "However, either the image or the noise timestep has not been provided." ) + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_image(image=image, generator=generator) + if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - if is_strength_max: - # if strength is 100% then simply initialise the latents to noise - latents = noise - else: - # otherwise initialise latents as init image + noise - image = image.to(device=device, dtype=dtype) - if isinstance(generator, list): - image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) - - image_latents = self.vae.config.scaling_factor * image_latents - - latents = self.scheduler.add_noise(image_latents, noise, timestep) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - return latents + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different @@ -902,17 +910,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + masked_image_latents = self._encode_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -941,6 +939,21 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_image + def _encode_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + # override DiffusionPipeline def save_pretrained( self, @@ -965,6 +978,7 @@ def __call__( ] = None, height: Optional[int] = None, width: Optional[int] = None, + strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -1001,6 +1015,13 @@ def __call__( The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -1156,13 +1177,25 @@ def __call__( assert False # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -1171,11 +1204,19 @@ def __call__( device, generator, latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, ) - noise = latents + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs # 7. Prepare mask latent variables - mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, @@ -1192,7 +1233,6 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - num_channels_unet = self.unet.config.in_channels num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1248,7 +1288,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: - init_latents_proper = masked_image_latents[:1] + init_latents_proper = image_latents[:1] init_mask = mask[:1] if i < len(timesteps) - 1: diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 24b05f36f913..c5dd08d74ddd 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -328,17 +328,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + masked_image_latents = self._encode_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index fefd94fb0219..d639461ac325 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -450,40 +450,56 @@ def test_canny(self): assert np.abs(expected_image - image).max() < 9e-2 def test_inpaint(self): - controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-inpaint") + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint") pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) - generator = torch.Generator(device="cpu").manual_seed(0) - image = load_image( - "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" - ).resize((512, 512)) + generator = torch.Generator(device="cpu").manual_seed(33) + init_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" "/stable_diffusion_inpaint/boy.png" + ).resize((512, 512)) mask_image = load_image( "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" - "/stable_diffusion_inpaint/input_bench_mask.png" + "/stable_diffusion_inpaint/boy_mask.png" ).resize((512, 512)) - prompt = "pitch black hole" + prompt = "a handsome man with ray-ban sunglasses" - control_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ).resize((512, 512)) + def make_inpaint_condition(image, image_mask): + image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + image[image_mask > 0.5] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image + + control_image = make_inpaint_condition(init_image, mask_image) output = pipe( prompt, - image=image, + image=init_image, mask_image=mask_image, control_image=control_image, + guidance_scale=9.0, + eta=1.0, generator=generator, + num_inference_steps=20, output_type="np", - num_inference_steps=3, ) + np.save("/home/patrick/diffusers-images/sd_controlnet/boy_ray_ban.npy", output.images[0]) + + output_pil = pipe.numpy_to_pil(output.images[0])[0] + output_pil.save("/home/patrick/diffusers-images/sd_controlnet/boy_ray_ban.png") + image = output.images[0] assert image.shape == (512, 512, 3) From 2ae19120a2f7a3e9af181a73c4c66f4228124c25 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 24 May 2023 16:47:58 +0000 Subject: [PATCH 09/13] up --- .../controlnet/pipeline_controlnet_inpaint.py | 53 +++++++++++-------- .../controlnet/test_controlnet_inpaint.py | 12 +++-- .../test_stable_diffusion_inpaint.py | 6 ++- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index cb6857b7b808..507b8c571f39 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -50,49 +50,58 @@ EXAMPLE_DOC_STRING = """ Examples: ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler >>> from diffusers.utils import load_image - >>> import numpy as np >>> import torch - >>> import cv2 - >>> from PIL import Image + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(33) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + - >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" - >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> def make_inpaint_condition(image, image_mask): + ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 - >>> init_image = load_image(img_url).resize((512, 512)) - >>> mask_image = load_image(mask_url).resize((512, 512)) + ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + ... image[image_mask > 0.5] = -1.0 # set as masked pixel + ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + ... image = torch.from_numpy(image) + ... return image - >>> image = np.array(init_image) - >>> # get canny image - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) + >>> control_image = make_inpaint_condition(init_image, mask_image) - >>> # load control net and stable diffusion inpainting - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( - ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image - >>> generator = torch.manual_seed(0) >>> image = pipe( ... "spiderman", - ... num_inference_steps=30, + ... num_inference_steps=20, ... generator=generator, + ... eta=1.0, ... image=init_image, ... mask_image=mask_image, - ... control_image=canny_image, + ... control_image=control_image, ... ).images[0] ``` """ diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index d639461ac325..9127a4cb2cb6 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -462,12 +462,14 @@ def test_inpaint(self): generator = torch.Generator(device="cpu").manual_seed(33) init_image = load_image( - "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" "/stable_diffusion_inpaint/boy.png" - ).resize((512, 512)) + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ) + init_image = init_image.resize((512, 512)) + mask_image = load_image( - "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" - "/stable_diffusion_inpaint/boy_mask.png" - ).resize((512, 512)) + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ) + mask_image = mask_image.resize((512, 512)) prompt = "a handsome man with ray-ban sunglasses" diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index f8bd20dd6c13..a9337417289a 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -260,10 +260,14 @@ def test_stable_diffusion_inpaint(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474]) + expected_slice = np.array([0.4925, 0.4967, 0.4100, 0.5234, 0.5322, 0.4532, 0.5805, 0.5877, 0.4151]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @unittest.skip("skipped here because area stays unchanged due to mask") + def test_stable_diffusion_inpaint_lora(self): + ... + @slow @require_torch_gpu From c0179d42e529d068d1826062ce32bff609460d85 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 May 2023 10:01:35 +0000 Subject: [PATCH 10/13] up --- .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 6 +++--- .../paint_by_example/pipeline_paint_by_example.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 507b8c571f39..545670baaa1b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -860,7 +860,7 @@ def prepare_latents( if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) - image_latents = self._encode_image(image=image, generator=generator) + image_latents = self._encode_vae_image(image=image, generator=generator) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -919,7 +919,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -948,7 +948,7 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_image + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image def _encode_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index c5dd08d74ddd..bcb9d03a26e3 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -328,7 +328,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b904a56d75d7..5dbac9295800 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -644,7 +644,7 @@ def prepare_latents( if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) - image_latents = self._encode_image(image=image, generator=generator) + image_latents = self._encode_vae_image(image=image, generator=generator) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -665,7 +665,7 @@ def prepare_latents( return outputs - def _encode_image(self, image: torch.Tensor, generator: torch.Generator): + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) @@ -691,7 +691,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - masked_image_latents = self._encode_image(masked_image, generator=generator) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -961,7 +961,7 @@ def __call__( do_classifier_free_guidance, ) init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) - init_image = self._encode_image(init_image, generator=generator) + init_image = self._encode_vae_image(init_image, generator=generator) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: From 25db6e6a73c4560736be1be4a5a29fa58cea7d0f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 May 2023 10:10:34 +0000 Subject: [PATCH 11/13] up --- .../controlnet/pipeline_controlnet_inpaint.py | 5 +++-- .../paint_by_example/pipeline_paint_by_example.py | 15 +++++++++++++++ .../pipeline_stable_diffusion_inpaint_legacy.py | 7 +++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 545670baaa1b..ee7540ba7895 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -53,6 +53,7 @@ >>> # !pip install transformers accelerate >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler >>> from diffusers.utils import load_image + >>> import numpy as np >>> import torch >>> init_image = load_image( @@ -95,7 +96,7 @@ >>> # generate image >>> image = pipe( - ... "spiderman", + ... "a beautiful man", ... num_inference_steps=20, ... generator=generator, ... eta=1.0, @@ -949,7 +950,7 @@ def prepare_mask_latents( return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image - def _encode_image(self, image: torch.Tensor, generator: torch.Generator): + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index bcb9d03a26e3..c8f3e8a9ee11 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -357,6 +357,21 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 5a2329a5c51f..bda4c8013d5d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -137,6 +137,13 @@ def __init__( ): super().__init__() + deprecation_message = ( + f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" + "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" + "for more information." + ) + deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" From 3f1889e381e4a3e842b7eee4fb2a5015a9547520 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 May 2023 15:37:12 +0000 Subject: [PATCH 12/13] up --- .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index ee7540ba7895..83ddd51c02f7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -61,7 +61,7 @@ ... ) >>> init_image = init_image.resize((512, 512)) - >>> generator = torch.Generator(device="cpu").manual_seed(33) + >>> generator = torch.Generator(device="cpu").manual_seed(1) >>> mask_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" From d9cd074d414b3f9e45b1519e7fc7518267c94d78 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 May 2023 19:12:19 +0000 Subject: [PATCH 13/13] fix more --- .../pipeline_stable_diffusion_inpaint_legacy.py | 1 - tests/pipelines/controlnet/test_controlnet_inpaint.py | 8 +------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index bda4c8013d5d..c549d869e685 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -123,7 +123,6 @@ class StableDiffusionInpaintPipelineLegacy( """ _optional_components = ["feature_extractor"] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( self, vae: AutoencoderKL, diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 9127a4cb2cb6..f8cc881e8650 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -496,18 +496,12 @@ def make_inpaint_condition(image, image_mask): num_inference_steps=20, output_type="np", ) - - np.save("/home/patrick/diffusers-images/sd_controlnet/boy_ray_ban.npy", output.images[0]) - - output_pil = pipe.numpy_to_pil(output.images[0])[0] - output_pil.save("/home/patrick/diffusers-images/sd_controlnet/boy_ray_ban.png") - image = output.images[0] assert image.shape == (512, 512, 3) expected_image = load_numpy( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/inpaint.npy" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/boy_ray_ban.npy" ) assert np.abs(expected_image - image).max() < 9e-2