@@ -373,6 +373,19 @@ def decode_latents(self, latents):
373
373
image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
374
374
return image
375
375
376
+ def decode_latents_with_padding (self , latents , padding = 8 ):
377
+ # Add padding to latents for circular inference
378
+ # padding is the number of latents to add on each side
379
+ # it would slightly increase the memory usage, but remove the boundary artifacts
380
+ latents = 1 / self .vae .config .scaling_factor * latents
381
+ latents_left = latents [..., :padding ]
382
+ latents_right = latents [..., - padding :]
383
+ latents = torch .cat ((latents_right , latents , latents_left ), axis = - 1 )
384
+ image = self .vae .decode (latents , return_dict = False )[0 ]
385
+ padding_pix = self .vae_scale_factor * padding
386
+ image = image [..., padding_pix :- padding_pix ]
387
+ return image
388
+
376
389
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
377
390
def prepare_extra_step_kwargs (self , generator , eta ):
378
391
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -457,13 +470,16 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
457
470
latents = latents * self .scheduler .init_noise_sigma
458
471
return latents
459
472
460
- def get_views (self , panorama_height , panorama_width , window_size = 64 , stride = 8 ):
473
+ def get_views (self , panorama_height , panorama_width , window_size = 64 , stride = 8 , circular_padding = False ):
461
474
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
462
475
# if panorama's height/width < window_size, num_blocks of height/width should return 1
463
476
panorama_height /= 8
464
477
panorama_width /= 8
465
478
num_blocks_height = (panorama_height - window_size ) // stride + 1 if panorama_height > window_size else 1
466
- num_blocks_width = (panorama_width - window_size ) // stride + 1 if panorama_width > window_size else 1
479
+ if circular_padding :
480
+ num_blocks_width = panorama_width // stride if panorama_width > window_size else 1
481
+ else :
482
+ num_blocks_width = (panorama_width - window_size ) // stride + 1 if panorama_width > window_size else 1
467
483
total_num_blocks = int (num_blocks_height * num_blocks_width )
468
484
views = []
469
485
for i in range (total_num_blocks ):
@@ -496,6 +512,7 @@ def __call__(
496
512
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
497
513
callback_steps : Optional [int ] = 1 ,
498
514
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
515
+ circular_padding : bool = False ,
499
516
):
500
517
r"""
501
518
Function invoked when calling the pipeline for generation.
@@ -560,6 +577,10 @@ def __call__(
560
577
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
561
578
`self.processor` in
562
579
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
580
+ circular_padding (`bool`, *optional*, defaults to `False`):
581
+ If set to True, circular padding is applied to ensure there are no stitching artifacts. Circular
582
+ padding allows the model to seamlessly generate a transition from the rightmost part of the image to
583
+ the leftmost part, maintaining consistency in a 360-degree sense.
563
584
564
585
Examples:
565
586
@@ -627,10 +648,9 @@ def __call__(
627
648
628
649
# 6. Define panorama grid and initialize views for synthesis.
629
650
# prepare batch grid
630
- views = self .get_views (height , width )
651
+ views = self .get_views (height , width , circular_padding = circular_padding )
631
652
views_batch = [views [i : i + view_batch_size ] for i in range (0 , len (views ), view_batch_size )]
632
653
views_scheduler_status = [copy .deepcopy (self .scheduler .__dict__ )] * len (views_batch )
633
-
634
654
count = torch .zeros_like (latents )
635
655
value = torch .zeros_like (latents )
636
656
@@ -655,9 +675,29 @@ def __call__(
655
675
for j , batch_view in enumerate (views_batch ):
656
676
vb_size = len (batch_view )
657
677
# get the latents corresponding to the current view coordinates
658
- latents_for_view = torch .cat (
659
- [latents [:, :, h_start :h_end , w_start :w_end ] for h_start , h_end , w_start , w_end in batch_view ]
660
- )
678
+ if circular_padding :
679
+ latents_for_view = []
680
+ for h_start , h_end , w_start , w_end in batch_view :
681
+ if w_end > latents .shape [3 ]:
682
+ # Add circular horizontal padding
683
+ latent_view = torch .cat (
684
+ (
685
+ latents [:, :, h_start :h_end , w_start :],
686
+ latents [:, :, h_start :h_end , : w_end - latents .shape [3 ]],
687
+ ),
688
+ axis = - 1 ,
689
+ )
690
+ else :
691
+ latent_view = latents [:, :, h_start :h_end , w_start :w_end ]
692
+ latents_for_view .append (latent_view )
693
+ latents_for_view = torch .cat (latents_for_view )
694
+ else :
695
+ latents_for_view = torch .cat (
696
+ [
697
+ latents [:, :, h_start :h_end , w_start :w_end ]
698
+ for h_start , h_end , w_start , w_end in batch_view
699
+ ]
700
+ )
661
701
662
702
# rematch block's scheduler status
663
703
self .scheduler .__dict__ .update (views_scheduler_status [j ])
@@ -698,8 +738,19 @@ def __call__(
698
738
for latents_view_denoised , (h_start , h_end , w_start , w_end ) in zip (
699
739
latents_denoised_batch .chunk (vb_size ), batch_view
700
740
):
701
- value [:, :, h_start :h_end , w_start :w_end ] += latents_view_denoised
702
- count [:, :, h_start :h_end , w_start :w_end ] += 1
741
+ if circular_padding and w_end > latents .shape [3 ]:
742
+ # Case for circular padding
743
+ value [:, :, h_start :h_end , w_start :] += latents_view_denoised [
744
+ :, :, h_start :h_end , : latents .shape [3 ] - w_start
745
+ ]
746
+ value [:, :, h_start :h_end , : w_end - latents .shape [3 ]] += latents_view_denoised [
747
+ :, :, h_start :h_end , latents .shape [3 ] - w_start :
748
+ ]
749
+ count [:, :, h_start :h_end , w_start :] += 1
750
+ count [:, :, h_start :h_end , : w_end - latents .shape [3 ]] += 1
751
+ else :
752
+ value [:, :, h_start :h_end , w_start :w_end ] += latents_view_denoised
753
+ count [:, :, h_start :h_end , w_start :w_end ] += 1
703
754
704
755
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
705
756
latents = torch .where (count > 0 , value / count , value )
@@ -711,7 +762,10 @@ def __call__(
711
762
callback (i , t , latents )
712
763
713
764
if not output_type == "latent" :
714
- image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
765
+ if circular_padding :
766
+ image = self .decode_latents_with_padding (latents )
767
+ else :
768
+ image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
715
769
image , has_nsfw_concept = self .run_safety_checker (image , device , prompt_embeds .dtype )
716
770
else :
717
771
image = latents
0 commit comments