Skip to content

TypeError when using XFormers with Textual Inversion #3576

Closed
@realimposter

Description

@realimposter

Describe the bug

I'm encountering an issue when trying to use xformers with textual inversion. Using textual inversion works as expected with the following setup:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("/content/8387", token="charturnerv2")

But when I enable xformers with:
pipe.enable_xformers_memory_efficient_attention()

The script fails with the following error:
TypeError: XFormersAttnProcessor.call() got an unexpected keyword argument 'temb'

Reproduction

from diffusers import StableDiffusionPipeline
import torch

model_id = "runwayml/stable-diffusion-v1-5"

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
pipe.load_textual_inversion("/content/8387", token="charturnerv2")

prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
generator = torch.Generator("cuda").manual_seed(1)

image = pipe(prompt, num_inference_steps=50, generator = generator).images[0]
display(image)

Colab: https://colab.research.google.com/drive/19_7eSlHqPc-I78JOYTaECQ0FTOQjH9zs?usp=sharing

Logs

in <cell line: 13>:13                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115 in decorate_context       │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_dif │
│ fusion.py:719 in __call__                                                                        │
│                                                                                                  │
│   716 │   │   │   │   │   │   callback(i, t, latents)                                            │
│   717 │   │                                                                                      │
│   718 │   │   if not output_type == "latent":                                                    │
│ ❱ 719 │   │   │   image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dic   │
│   720 │   │   │   image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embe   │
│   721 │   │   else:                                                                              │
│   722 │   │   │   image = latents                                                                │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py:46 in wrapper        │
│                                                                                                  │
│   43 │   def wrapper(self, *args, **kwargs):                                                     │
│   44 │   │   if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"):             │
│   45 │   │   │   self._hf_hook.pre_forward(self)                                                 │
│ ❱ 46 │   │   return method(self, *args, **kwargs)                                                │
│   47 │                                                                                           │
│   48 │   return wrapper                                                                          │
│   49                                                                                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoder_kl.py:191 in decode         │
│                                                                                                  │
│   188 │   │   │   decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]      │
│   189 │   │   │   decoded = torch.cat(decoded_slices)                                            │
│   190 │   │   else:                                                                              │
│ ❱ 191 │   │   │   decoded = self._decode(z).sample                                               │
│   192 │   │                                                                                      │
│   193 │   │   if not return_dict:                                                                │
│   194 │   │   │   return (decoded,)                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/autoencoder_kl.py:178 in _decode        │
│                                                                                                  │
│   175 │   │   │   return self.tiled_decode(z, return_dict=return_dict)                           │
│   176 │   │                                                                                      │
│   177 │   │   z = self.post_quant_conv(z)                                                        │
│ ❱ 178 │   │   dec = self.decoder(z)                                                              │
│   179 │   │                                                                                      │
│   180 │   │   if not return_dict:                                                                │
│   181 │   │   │   return (dec,)                                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/vae.py:265 in forward                   │
│                                                                                                  │
│   262 │   │   │   │   │   sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_   │
│   263 │   │   else:                                                                              │
│   264 │   │   │   # middle                                                                       │
│ ❱ 265 │   │   │   sample = self.mid_block(sample, latent_embeds)                                 │
│   266 │   │   │   sample = sample.to(upscale_dtype)                                              │
│   267 │   │   │                                                                                  │
│   268 │   │   │   # up                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_blocks.py:472 in forward        │
│                                                                                                  │
│    469 │   │   hidden_states = self.resnets[0](hidden_states, temb)                              │
│    470 │   │   for attn, resnet in zip(self.attentions, self.resnets[1:]):                       │
│    471 │   │   │   if attn is not None:                                                          │
│ ❱  472 │   │   │   │   hidden_states = attn(hidden_states, temb=temb)                            │
│    473 │   │   │   hidden_states = resnet(hidden_states, temb)                                   │
│    474 │   │                                                                                     │
│    475 │   │   return hidden_states                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:312 in forward   │
│                                                                                                  │
│    309 │   │   # The `Attention` class can call different attention processors / attention func  │
│    310 │   │   # here we simply pass along all tensors to the selected processor class           │
│    311 │   │   # For standard processors that are defined here, `**cross_attention_kwargs` is e  │
│ ❱  312 │   │   return self.processor(                                                            │
│    313 │   │   │   self,                                                                         │
│    314 │   │   │   hidden_states,                                                                │
│    315 │   │   │   encoder_hidden_states=encoder_hidden_states,                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: XFormersAttnProcessor.__call__() got an unexpected keyword argument 'temb'

System Info

diffusers 0.17.0.dev0
xformers 0.0.20
python 3.10

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions