Skip to content

Commit c0f867a

Browse files
Fix temb attention (#3607)
* Fix temb attention * Apply suggestions from code review * make style * Add tests and fix docker * Apply suggestions from code review
1 parent c6ae883 commit c0f867a

File tree

4 files changed

+83
-5
lines changed

4 files changed

+83
-5
lines changed

docker/diffusers-pytorch-cuda/Dockerfile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
3838
scipy \
3939
tensorboard \
4040
transformers \
41-
omegaconf
41+
omegaconf \
42+
pytorch-lightning \
43+
xformers
4244

4345
CMD ["/bin/bash"]

src/diffusers/models/attention_processor.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,14 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
540540
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
541541
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
542542

543-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
543+
def __call__(
544+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
545+
):
544546
residual = hidden_states
545547

548+
if attn.spatial_norm is not None:
549+
hidden_states = attn.spatial_norm(hidden_states, temb)
550+
546551
input_ndim = hidden_states.ndim
547552

548553
if input_ndim == 4:
@@ -905,9 +910,13 @@ def __call__(
905910
hidden_states: torch.FloatTensor,
906911
encoder_hidden_states: Optional[torch.FloatTensor] = None,
907912
attention_mask: Optional[torch.FloatTensor] = None,
913+
temb: Optional[torch.FloatTensor] = None,
908914
):
909915
residual = hidden_states
910916

917+
if attn.spatial_norm is not None:
918+
hidden_states = attn.spatial_norm(hidden_states, temb)
919+
911920
input_ndim = hidden_states.ndim
912921

913922
if input_ndim == 4:
@@ -1081,9 +1090,14 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio
10811090
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
10821091
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
10831092

1084-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1093+
def __call__(
1094+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1095+
):
10851096
residual = hidden_states
10861097

1098+
if attn.spatial_norm is not None:
1099+
hidden_states = attn.spatial_norm(hidden_states, temb)
1100+
10871101
input_ndim = hidden_states.ndim
10881102

10891103
if input_ndim == 4:
@@ -1334,8 +1348,12 @@ class SlicedAttnAddedKVProcessor:
13341348
def __init__(self, slice_size):
13351349
self.slice_size = slice_size
13361350

1337-
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
1351+
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
13381352
residual = hidden_states
1353+
1354+
if attn.spatial_norm is not None:
1355+
hidden_states = attn.spatial_norm(hidden_states, temb)
1356+
13391357
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
13401358

13411359
batch_size, sequence_length, _ = hidden_states.shape

src/diffusers/utils/testing_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,9 @@ def enable_full_determinism():
577577
torch.backends.cudnn.deterministic = True
578578
torch.backends.cudnn.benchmark = False
579579
torch.backends.cuda.matmul.allow_tf32 = False
580+
581+
582+
def disable_full_determinism():
583+
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
584+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
585+
torch.use_deterministic_algorithms(False)

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,18 @@
3737
UNet2DConditionModel,
3838
logging,
3939
)
40-
from diffusers.models.attention_processor import AttnProcessor
40+
from diffusers.models.attention_processor import AttnProcessor, LoRAXFormersAttnProcessor
4141
from diffusers.utils import load_numpy, nightly, slow, torch_device
4242
from diffusers.utils.testing_utils import (
4343
CaptureLogger,
44+
disable_full_determinism,
4445
enable_full_determinism,
4546
require_torch_2,
4647
require_torch_gpu,
4748
run_test_in_subprocess,
4849
)
4950

51+
from ...models.test_lora_layers import create_unet_lora_layers
5052
from ...models.test_models_unet_2d_condition import create_lora_layers
5153
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
5254
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
@@ -366,6 +368,56 @@ def test_stable_diffusion_pndm(self):
366368

367369
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
368370

371+
@unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda")
372+
def test_stable_diffusion_attn_processors(self):
373+
disable_full_determinism()
374+
device = "cuda" # ensure determinism for the device-dependent torch.Generator
375+
components = self.get_dummy_components()
376+
sd_pipe = StableDiffusionPipeline(**components)
377+
sd_pipe = sd_pipe.to(device)
378+
sd_pipe.set_progress_bar_config(disable=None)
379+
380+
inputs = self.get_dummy_inputs(device)
381+
382+
# run normal sd pipe
383+
image = sd_pipe(**inputs).images
384+
assert image.shape == (1, 64, 64, 3)
385+
386+
# run xformers attention
387+
sd_pipe.enable_xformers_memory_efficient_attention()
388+
image = sd_pipe(**inputs).images
389+
assert image.shape == (1, 64, 64, 3)
390+
391+
# run attention slicing
392+
sd_pipe.enable_attention_slicing()
393+
image = sd_pipe(**inputs).images
394+
assert image.shape == (1, 64, 64, 3)
395+
396+
# run vae attention slicing
397+
sd_pipe.enable_vae_slicing()
398+
image = sd_pipe(**inputs).images
399+
assert image.shape == (1, 64, 64, 3)
400+
401+
# run lora attention
402+
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
403+
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
404+
sd_pipe.unet.set_attn_processor(attn_processors)
405+
image = sd_pipe(**inputs).images
406+
assert image.shape == (1, 64, 64, 3)
407+
408+
# run lora xformers attention
409+
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
410+
attn_processors = {
411+
k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim)
412+
for k, v in attn_processors.items()
413+
}
414+
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
415+
sd_pipe.unet.set_attn_processor(attn_processors)
416+
image = sd_pipe(**inputs).images
417+
assert image.shape == (1, 64, 64, 3)
418+
419+
enable_full_determinism()
420+
369421
def test_stable_diffusion_no_safety_checker(self):
370422
pipe = StableDiffusionPipeline.from_pretrained(
371423
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None

0 commit comments

Comments
 (0)