From 3997436a07932a8fc39d1bdbfdd13860efeeef14 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 11:31:32 +0530 Subject: [PATCH] fix audioldm2 for transformers main. --- .../pipelines/audioldm2/pipeline_audioldm2.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index 87d78646a966..eeabf0d248a0 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -40,6 +40,7 @@ logging, replace_example_docstring, ) +from ...utils.import_utils import is_transformers_version from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel @@ -312,8 +313,19 @@ def generate_language_model( `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): The sequence of generated hidden-states. """ + cache_position_kwargs = {} + if is_transformers_version("<", "4.52.0.dev0"): + cache_position_kwargs["input_ids"] = inputs_embeds + cache_position_kwargs["model_kwargs"] = model_kwargs + else: + cache_position_kwargs["seq_length"] = inputs_embeds.shape[0] + cache_position_kwargs["device"] = ( + self.language_model.device if getattr(self, "language_model", None) is not None else self.device + ) + cache_position_kwargs["model_kwargs"] = model_kwargs max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens - model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs) + model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs) + for _ in range(max_new_tokens): # prepare model inputs model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)