Skip to content

Fix panorama to support all schedulers #3546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
Expand All @@ -21,7 +22,7 @@
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
Expand Down Expand Up @@ -96,9 +97,6 @@ def __init__(
):
super().__init__()

if isinstance(scheduler, PNDMScheduler):
logger.error("PNDMScheduler for this pipeline is currently not supported.")

if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
Expand Down Expand Up @@ -612,7 +610,7 @@ def __call__(

# 6. Define panorama grid and initialize views for synthesis.
views = self.get_views(height, width)
blocks_model_outputs = [None] * len(views)
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views)
count = torch.zeros_like(latents)
value = torch.zeros_like(latents)

Expand All @@ -637,6 +635,9 @@ def __call__(
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]

# rematch block's scheduler status
self.scheduler.__dict__.update(views_scheduler_status[j])

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
Expand All @@ -657,21 +658,13 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if hasattr(self.scheduler, "model_outputs"):
# rematch model_outputs in each block
if i >= 1:
self.scheduler.model_outputs = blocks_model_outputs[j]
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
# collect model_outputs
blocks_model_outputs[j] = [
output if output is not None else None for output in self.scheduler.model_outputs
]
else:
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample
latents_view_denoised = self.scheduler.step(
noise_pred, t, latents_for_view, **extra_step_kwargs
).prev_sample

# save views scheduler status after sample
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)

value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1

Expand Down
15 changes: 11 additions & 4 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,22 @@ def test_stable_diffusion_panorama_euler(self):
def test_stable_diffusion_panorama_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler()
components["scheduler"] = PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
# the pipeline does not expect pndm so test if it raises error.
with self.assertRaises(ValueError):
_ = sd_pipe(**inputs).images
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)

expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2


@slow
Expand Down