Skip to content

Commit 87252d8

Browse files
Anonym0u3Benny079
andauthored
Add pipeline_stable_diffusion_xl_attentive_eraser (#10579)
* add pipeline_stable_diffusion_xl_attentive_eraser * add pipeline_stable_diffusion_xl_attentive_eraser_make_style * make style and add example output * update Docs Co-authored-by: Other Contributor <a457435687@126.com> * add Oral Co-authored-by: Other Contributor <a457435687@126.com> * update_review Co-authored-by: Other Contributor <a457435687@126.com> * update_review_ms Co-authored-by: Other Contributor <a457435687@126.com> --------- Co-authored-by: Other Contributor <a457435687@126.com>
1 parent 5897137 commit 87252d8

File tree

2 files changed

+2408
-2
lines changed

2 files changed

+2408
-2
lines changed

examples/community/README.md

100755100644
Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
7777
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
7878
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
7979
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
80+
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
8081

8182
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
8283

@@ -4585,8 +4586,8 @@ image = pipe(
45854586
```
45864587

45874588
| ![Gradient](https://github.com/user-attachments/assets/e38ce4d5-1ae6-4df0-ab43-adc1b45716b5) | ![Input](https://github.com/user-attachments/assets/9c95679c-e9d7-4f5a-90d6-560203acd6b3) | ![Output](https://github.com/user-attachments/assets/5313ff64-a0c4-418b-8b55-a38f1a5e7532) |
4588-
| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
4589-
| Gradient | Input | Output |
4589+
| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
4590+
| Gradient | Input | Output |
45904591

45914592
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
45924593

@@ -4634,6 +4635,93 @@ make_image_grid(image, rows=1, cols=len(image))
46344635
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
46354636
```
46364637

4638+
### Stable Diffusion XL Attentive Eraser Pipeline
4639+
<img src="https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/fenmian.png" width="600" />
4640+
4641+
**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
4642+
4643+
#### Key features
4644+
4645+
- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
4646+
- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
4647+
- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
4648+
4649+
#### Usage example
4650+
To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
4651+
```py
4652+
import torch
4653+
from diffusers import DDIMScheduler, DiffusionPipeline
4654+
from diffusers.utils import load_image
4655+
import torch.nn.functional as F
4656+
from torchvision.transforms.functional import to_tensor, gaussian_blur
4657+
4658+
dtype = torch.float16
4659+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
4660+
4661+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
4662+
pipeline = DiffusionPipeline.from_pretrained(
4663+
"stabilityai/stable-diffusion-xl-base-1.0",
4664+
custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
4665+
scheduler=scheduler,
4666+
variant="fp16",
4667+
use_safetensors=True,
4668+
torch_dtype=dtype,
4669+
).to(device)
4670+
4671+
4672+
def preprocess_image(image_path, device):
4673+
image = to_tensor((load_image(image_path)))
4674+
image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
4675+
if image.shape[1] != 3:
4676+
image = image.expand(-1, 3, -1, -1)
4677+
image = F.interpolate(image, (1024, 1024))
4678+
image = image.to(dtype).to(device)
4679+
return image
4680+
4681+
def preprocess_mask(mask_path, device):
4682+
mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
4683+
mask = mask.unsqueeze_(0).float() # 0 or 1
4684+
mask = F.interpolate(mask, (1024, 1024))
4685+
mask = gaussian_blur(mask, kernel_size=(77, 77))
4686+
mask[mask < 0.1] = 0
4687+
mask[mask >= 0.1] = 1
4688+
mask = mask.to(dtype).to(device)
4689+
return mask
4690+
4691+
prompt = "" # Set prompt to null
4692+
seed=123
4693+
generator = torch.Generator(device=device).manual_seed(seed)
4694+
source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
4695+
mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
4696+
source_image = preprocess_image(source_image_path, device)
4697+
mask = preprocess_mask(mask_path, device)
4698+
4699+
image = pipeline(
4700+
prompt=prompt,
4701+
image=source_image,
4702+
mask_image=mask,
4703+
height=1024,
4704+
width=1024,
4705+
AAS=True, # enable AAS
4706+
strength=0.8, # inpainting strength
4707+
rm_guidance_scale=9, # removal guidance scale
4708+
ss_steps = 9, # similarity suppression steps
4709+
ss_scale = 0.3, # similarity suppression scale
4710+
AAS_start_step=0, # AAS start step
4711+
AAS_start_layer=34, # AAS start layer
4712+
AAS_end_layer=70, # AAS end layer
4713+
num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
4714+
generator=generator,
4715+
guidance_scale=1,
4716+
).images[0]
4717+
image.save('./removed_img.png')
4718+
print("Object removal completed")
4719+
```
4720+
4721+
| Source Image | Mask | Output |
4722+
| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
4723+
| ![Source Image](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png) | ![Mask](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png) | ![Output](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/AE_step40_layer34.png) |
4724+
46374725
# Perturbed-Attention Guidance
46384726

46394727
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)

0 commit comments

Comments
 (0)