Description
Describe the bug
I'm encountering an issue when trying to use xformers with textual inversion. Using textual inversion works as expected with the following setup:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("/content/8387", token="charturnerv2")
But when I enable xformers with:
pipe.enable_xformers_memory_efficient_attention()
The script fails with the following error:
TypeError: XFormersAttnProcessor.call() got an unexpected keyword argument 'temb'
Reproduction
from diffusers import StableDiffusionPipeline
import torch
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
pipe.load_textual_inversion("/content/8387", token="charturnerv2")
prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
generator = torch.Generator("cuda").manual_seed(1)
image = pipe(prompt, num_inference_steps=50, generator = generator).images[0]
display(image)
Colab: https://colab.research.google.com/drive/19_7eSlHqPc-I78JOYTaECQ0FTOQjH9zs?usp=sharing
Logs
in <cell line: 13>:13 │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_dif │
│ fusion.py:719 in __call__ │
│ │
│ 716 │ │ │ │ │ │ callback(i, t, latents) │
│ 717 │ │ │
│ 718 │ │ if not output_type == "latent": │
│ ❱ 719 │ │ │ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dic │
│ 720 │ │ │ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embe │
│ 721 │ │ else: │
│ 722 │ │ │ image = latents │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py:46 in wrapper │
│ │
│ 43 │ def wrapper(self, *args, **kwargs): │
│ 44 │ │ if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): │
│ 45 │ │ │ self._hf_hook.pre_forward(self) │
│ ❱ 46 │ │ return method(self, *args, **kwargs) │
│ 47 │ │
│ 48 │ return wrapper │
│ 49 │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoder_kl.py:191 in decode │
│ │
│ 188 │ │ │ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] │
│ 189 │ │ │ decoded = torch.cat(decoded_slices) │
│ 190 │ │ else: │
│ ❱ 191 │ │ │ decoded = self._decode(z).sample │
│ 192 │ │ │
│ 193 │ │ if not return_dict: │
│ 194 │ │ │ return (decoded,) │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoder_kl.py:178 in _decode │
│ │
│ 175 │ │ │ return self.tiled_decode(z, return_dict=return_dict) │
│ 176 │ │ │
│ 177 │ │ z = self.post_quant_conv(z) │
│ ❱ 178 │ │ dec = self.decoder(z) │
│ 179 │ │ │
│ 180 │ │ if not return_dict: │
│ 181 │ │ │ return (dec,) │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/vae.py:265 in forward │
│ │
│ 262 │ │ │ │ │ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_ │
│ 263 │ │ else: │
│ 264 │ │ │ # middle │
│ ❱ 265 │ │ │ sample = self.mid_block(sample, latent_embeds) │
│ 266 │ │ │ sample = sample.to(upscale_dtype) │
│ 267 │ │ │ │
│ 268 │ │ │ # up │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_blocks.py:472 in forward │
│ │
│ 469 │ │ hidden_states = self.resnets[0](hidden_states, temb) │
│ 470 │ │ for attn, resnet in zip(self.attentions, self.resnets[1:]): │
│ 471 │ │ │ if attn is not None: │
│ ❱ 472 │ │ │ │ hidden_states = attn(hidden_states, temb=temb) │
│ 473 │ │ │ hidden_states = resnet(hidden_states, temb) │
│ 474 │ │ │
│ 475 │ │ return hidden_states │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:312 in forward │
│ │
│ 309 │ │ # The `Attention` class can call different attention processors / attention func │
│ 310 │ │ # here we simply pass along all tensors to the selected processor class │
│ 311 │ │ # For standard processors that are defined here, `**cross_attention_kwargs` is e │
│ ❱ 312 │ │ return self.processor( │
│ 313 │ │ │ self, │
│ 314 │ │ │ hidden_states, │
│ 315 │ │ │ encoder_hidden_states=encoder_hidden_states, │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: XFormersAttnProcessor.__call__() got an unexpected keyword argument 'temb'
System Info
diffusers 0.17.0.dev0
xformers 0.0.20
python 3.10