Skip to content

Commit af48bf2

Browse files
authored
Add circular padding for artifact-free StableDiffusionPanoramaPipeline (#4025)
* Add circular padding option * Fix style with black * Fix corner case with small image size * Add circular padding test cases * Fix docstring * Improve docstring for circular padding, remove slow test case * Update docs for circular padding argument * Add images comparison for circular padding
1 parent 4b50ecc commit af48bf2

File tree

3 files changed

+117
-10
lines changed

3 files changed

+117
-10
lines changed

docs/source/en/api/pipelines/panorama.mdx

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ and increase the VRAM usage.
6060

6161
</Tip>
6262

63+
<Tip>
64+
65+
Circular padding is applied to ensure there are no stitching artifacts when working with
66+
panoramas that needs to seamlessly transition from the rightmost part to the leftmost part.
67+
By enabling circular padding (set `circular_padding=True`), the operation applies additional
68+
crops after the rightmost point of the image, allowing the model to "see” the transition
69+
from the rightmost part to the leftmost part. This helps maintain visual consistency in
70+
a 360-degree sense and creates a proper “panorama” that can be viewed using 360-degree
71+
panorama viewers. When decoding latents in StableDiffusion, circular padding is applied
72+
to ensure that the decoded latents match in the RGB space.
73+
74+
Without circular padding, there is a stitching artifact (default):
75+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20no_circular_padding.png)
76+
77+
With circular padding, the right and the left parts are matching (`circular_padding=True`):
78+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20circular_padding.png)
79+
80+
</Tip>
81+
6382
## StableDiffusionPanoramaPipeline
6483
[[autodoc]] StableDiffusionPanoramaPipeline
6584
- __call__

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,19 @@ def decode_latents(self, latents):
373373
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
374374
return image
375375

376+
def decode_latents_with_padding(self, latents, padding=8):
377+
# Add padding to latents for circular inference
378+
# padding is the number of latents to add on each side
379+
# it would slightly increase the memory usage, but remove the boundary artifacts
380+
latents = 1 / self.vae.config.scaling_factor * latents
381+
latents_left = latents[..., :padding]
382+
latents_right = latents[..., -padding:]
383+
latents = torch.cat((latents_right, latents, latents_left), axis=-1)
384+
image = self.vae.decode(latents, return_dict=False)[0]
385+
padding_pix = self.vae_scale_factor * padding
386+
image = image[..., padding_pix:-padding_pix]
387+
return image
388+
376389
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
377390
def prepare_extra_step_kwargs(self, generator, eta):
378391
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -457,13 +470,16 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
457470
latents = latents * self.scheduler.init_noise_sigma
458471
return latents
459472

460-
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
473+
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False):
461474
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
462475
# if panorama's height/width < window_size, num_blocks of height/width should return 1
463476
panorama_height /= 8
464477
panorama_width /= 8
465478
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
466-
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
479+
if circular_padding:
480+
num_blocks_width = panorama_width // stride if panorama_width > window_size else 1
481+
else:
482+
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
467483
total_num_blocks = int(num_blocks_height * num_blocks_width)
468484
views = []
469485
for i in range(total_num_blocks):
@@ -496,6 +512,7 @@ def __call__(
496512
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
497513
callback_steps: Optional[int] = 1,
498514
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
515+
circular_padding: bool = False,
499516
):
500517
r"""
501518
Function invoked when calling the pipeline for generation.
@@ -560,6 +577,10 @@ def __call__(
560577
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
561578
`self.processor` in
562579
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
580+
circular_padding (`bool`, *optional*, defaults to `False`):
581+
If set to True, circular padding is applied to ensure there are no stitching artifacts. Circular
582+
padding allows the model to seamlessly generate a transition from the rightmost part of the image to
583+
the leftmost part, maintaining consistency in a 360-degree sense.
563584
564585
Examples:
565586
@@ -627,10 +648,9 @@ def __call__(
627648

628649
# 6. Define panorama grid and initialize views for synthesis.
629650
# prepare batch grid
630-
views = self.get_views(height, width)
651+
views = self.get_views(height, width, circular_padding=circular_padding)
631652
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
632653
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)
633-
634654
count = torch.zeros_like(latents)
635655
value = torch.zeros_like(latents)
636656

@@ -655,9 +675,29 @@ def __call__(
655675
for j, batch_view in enumerate(views_batch):
656676
vb_size = len(batch_view)
657677
# get the latents corresponding to the current view coordinates
658-
latents_for_view = torch.cat(
659-
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
660-
)
678+
if circular_padding:
679+
latents_for_view = []
680+
for h_start, h_end, w_start, w_end in batch_view:
681+
if w_end > latents.shape[3]:
682+
# Add circular horizontal padding
683+
latent_view = torch.cat(
684+
(
685+
latents[:, :, h_start:h_end, w_start:],
686+
latents[:, :, h_start:h_end, : w_end - latents.shape[3]],
687+
),
688+
axis=-1,
689+
)
690+
else:
691+
latent_view = latents[:, :, h_start:h_end, w_start:w_end]
692+
latents_for_view.append(latent_view)
693+
latents_for_view = torch.cat(latents_for_view)
694+
else:
695+
latents_for_view = torch.cat(
696+
[
697+
latents[:, :, h_start:h_end, w_start:w_end]
698+
for h_start, h_end, w_start, w_end in batch_view
699+
]
700+
)
661701

662702
# rematch block's scheduler status
663703
self.scheduler.__dict__.update(views_scheduler_status[j])
@@ -698,8 +738,19 @@ def __call__(
698738
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
699739
latents_denoised_batch.chunk(vb_size), batch_view
700740
):
701-
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
702-
count[:, :, h_start:h_end, w_start:w_end] += 1
741+
if circular_padding and w_end > latents.shape[3]:
742+
# Case for circular padding
743+
value[:, :, h_start:h_end, w_start:] += latents_view_denoised[
744+
:, :, h_start:h_end, : latents.shape[3] - w_start
745+
]
746+
value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[
747+
:, :, h_start:h_end, latents.shape[3] - w_start :
748+
]
749+
count[:, :, h_start:h_end, w_start:] += 1
750+
count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1
751+
else:
752+
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
753+
count[:, :, h_start:h_end, w_start:w_end] += 1
703754

704755
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
705756
latents = torch.where(count > 0, value / count, value)
@@ -711,7 +762,10 @@ def __call__(
711762
callback(i, t, latents)
712763

713764
if not output_type == "latent":
714-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
765+
if circular_padding:
766+
image = self.decode_latents_with_padding(latents)
767+
else:
768+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
715769
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
716770
else:
717771
image = latents

tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,22 @@ def test_stable_diffusion_panorama_default_case(self):
125125

126126
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
127127

128+
def test_stable_diffusion_panorama_circular_padding_case(self):
129+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
130+
components = self.get_dummy_components()
131+
sd_pipe = StableDiffusionPanoramaPipeline(**components)
132+
sd_pipe = sd_pipe.to(device)
133+
sd_pipe.set_progress_bar_config(disable=None)
134+
135+
inputs = self.get_dummy_inputs(device)
136+
image = sd_pipe(**inputs, circular_padding=True).images
137+
image_slice = image[0, -3:, -3:, -1]
138+
assert image.shape == (1, 64, 64, 3)
139+
140+
expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
141+
142+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
143+
128144
# override to speed the overall test timing up.
129145
def test_inference_batch_consistent(self):
130146
super().test_inference_batch_consistent(batch_sizes=[1, 2])
@@ -170,6 +186,24 @@ def test_stable_diffusion_panorama_views_batch(self):
170186

171187
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
172188

189+
def test_stable_diffusion_panorama_views_batch_circular_padding(self):
190+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
191+
components = self.get_dummy_components()
192+
sd_pipe = StableDiffusionPanoramaPipeline(**components)
193+
sd_pipe = sd_pipe.to(device)
194+
sd_pipe.set_progress_bar_config(disable=None)
195+
196+
inputs = self.get_dummy_inputs(device)
197+
output = sd_pipe(**inputs, circular_padding=True, view_batch_size=2)
198+
image = output.images
199+
image_slice = image[0, -3:, -3:, -1]
200+
201+
assert image.shape == (1, 64, 64, 3)
202+
203+
expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
204+
205+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
206+
173207
def test_stable_diffusion_panorama_euler(self):
174208
device = "cpu" # ensure determinism for the device-dependent torch.Generator
175209
components = self.get_dummy_components()

0 commit comments

Comments
 (0)