From 518494565c23498f11e955268ce467aaf6e88ae1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 23 May 2023 18:14:12 +0800 Subject: [PATCH 1/8] refactor blocks init --- .../pipeline_stable_diffusion_panorama.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 223f8a236efa..f57d648319d1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -466,6 +466,13 @@ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): w_end = w_start + window_size views.append((h_start, h_end, w_start, w_end)) return views + + def init_views_output(self,views): + if hasattr(self.scheduler, "model_outputs"): + # init schedulers: dpmsolver, unipc + return [self.scheduler.model_outputs] * len(views) + else: + return None @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -612,7 +619,7 @@ def __call__( # 6. Define panorama grid and initialize views for synthesis. views = self.get_views(height, width) - blocks_model_outputs = [None] * len(views) + blocks_model_outputs = self.init_views_output(views) count = torch.zeros_like(latents) value = torch.zeros_like(latents) @@ -659,8 +666,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 if hasattr(self.scheduler, "model_outputs"): # rematch model_outputs in each block - if i >= 1: - self.scheduler.model_outputs = blocks_model_outputs[j] + self.scheduler.model_outputs = blocks_model_outputs[j] latents_view_denoised = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample From 9df07c099c9fdc261e271e771a5904dfe3e31c12 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 23 May 2023 21:19:44 +0800 Subject: [PATCH 2/8] refactor blocks loop --- .../pipeline_stable_diffusion_panorama.py | 29 +++++-------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index f57d648319d1..2dc6ece0dfa9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -466,13 +466,6 @@ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): w_end = w_start + window_size views.append((h_start, h_end, w_start, w_end)) return views - - def init_views_output(self,views): - if hasattr(self.scheduler, "model_outputs"): - # init schedulers: dpmsolver, unipc - return [self.scheduler.model_outputs] * len(views) - else: - return None @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -619,7 +612,7 @@ def __call__( # 6. Define panorama grid and initialize views for synthesis. views = self.get_views(height, width) - blocks_model_outputs = self.init_views_output(views) + views_scheduler_status = [self.scheduler.__dict__] * len(views) count = torch.zeros_like(latents) value = torch.zeros_like(latents) @@ -664,20 +657,12 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if hasattr(self.scheduler, "model_outputs"): - # rematch model_outputs in each block - self.scheduler.model_outputs = blocks_model_outputs[j] - latents_view_denoised = self.scheduler.step( - noise_pred, t, latents_for_view, **extra_step_kwargs - ).prev_sample - # collect model_outputs - blocks_model_outputs[j] = [ - output if output is not None else None for output in self.scheduler.model_outputs - ] - else: - latents_view_denoised = self.scheduler.step( - noise_pred, t, latents_for_view, **extra_step_kwargs - ).prev_sample + self.scheduler.__dict__.update(views_scheduler_status[j]) + latents_view_denoised = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs + ).prev_sample + views_scheduler_status[j] = self.scheduler.__dict__ + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 From 6614a97fea875747428269f68131e0aaa92af938 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 24 May 2023 11:14:31 +0800 Subject: [PATCH 3/8] remove unused function and warnings --- .../pipeline_stable_diffusion_panorama.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 2dc6ece0dfa9..d3049357a0be 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union @@ -96,9 +97,6 @@ def __init__( ): super().__init__() - if isinstance(scheduler, PNDMScheduler): - logger.error("PNDMScheduler for this pipeline is currently not supported.") - if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" @@ -612,7 +610,7 @@ def __call__( # 6. Define panorama grid and initialize views for synthesis. views = self.get_views(height, width) - views_scheduler_status = [self.scheduler.__dict__] * len(views) + views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views) count = torch.zeros_like(latents) value = torch.zeros_like(latents) @@ -657,11 +655,12 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 + # restore/save views scheduler status before/after sample self.scheduler.__dict__.update(views_scheduler_status[j]) latents_view_denoised = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample - views_scheduler_status[j] = self.scheduler.__dict__ + views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 From 25aa807164189fbecd5d60a1f2fc64bdc4ebe928 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 24 May 2023 12:38:30 +0800 Subject: [PATCH 4/8] fix scheduler update location --- .../stable_diffusion/pipeline_stable_diffusion_panorama.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index d3049357a0be..d4909eb87c5e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -635,6 +635,9 @@ def __call__( # get the latents corresponding to the current view coordinates latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] + # rematch block's scheduler status + self.scheduler.__dict__.update(views_scheduler_status[j]) + # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view @@ -655,11 +658,11 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - # restore/save views scheduler status before/after sample - self.scheduler.__dict__.update(views_scheduler_status[j]) latents_view_denoised = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample + + # save views scheduler status after sample views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised From 3bbb51ca889b35db13a064f91d9a1ebece7784dc Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 24 May 2023 17:53:21 +0800 Subject: [PATCH 5/8] reformat code --- .../stable_diffusion/pipeline_stable_diffusion_panorama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index d4909eb87c5e..f17dec97f22d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -662,7 +662,7 @@ def __call__( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample - # save views scheduler status after sample + # save views scheduler status after sample views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised From 2fd5d981732d0f2978cf620c51f77f37d0fea19c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 24 May 2023 17:56:31 +0800 Subject: [PATCH 6/8] reformat code again --- .../stable_diffusion/pipeline_stable_diffusion_panorama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index f17dec97f22d..66706c806a81 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -22,7 +22,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput From f8be630777c1f7ce8eff32aed9c82f0e6543bca2 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 24 May 2023 19:28:33 +0800 Subject: [PATCH 7/8] fix PNDM test case --- .../test_stable_diffusion_panorama.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 02a15b2a29dc..0a52e23e195e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -174,15 +174,22 @@ def test_stable_diffusion_panorama_euler(self): def test_stable_diffusion_panorama_pndm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - components["scheduler"] = PNDMScheduler() + components["scheduler"] = PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps = True + ) sd_pipe = StableDiffusionPanoramaPipeline(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - # the pipeline does not expect pndm so test if it raises error. - with self.assertRaises(ValueError): - _ = sd_pipe(**inputs).images + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow From 47447830bb52b5a6cc2947fee9aefb153c61419d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 24 May 2023 19:31:10 +0800 Subject: [PATCH 8/8] reformat pndm test case --- .../stable_diffusion/test_stable_diffusion_panorama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 0a52e23e195e..021065416838 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -175,7 +175,7 @@ def test_stable_diffusion_panorama_pndm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() components["scheduler"] = PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps = True + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ) sd_pipe = StableDiffusionPanoramaPipeline(**components) sd_pipe = sd_pipe.to(device) @@ -187,7 +187,7 @@ def test_stable_diffusion_panorama_pndm(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539]) + expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2