From 917302be4cba988ca97e7f3588f223b7b9488d5a Mon Sep 17 00:00:00 2001 From: Eliseu Silva Date: Sat, 8 Mar 2025 12:18:02 -0300 Subject: [PATCH] small fix on generating time_ids & embeddings --- examples/community/mixture_tiling_sdxl.py | 44 +++++++++++------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/community/mixture_tiling_sdxl.py b/examples/community/mixture_tiling_sdxl.py index f7b971bae841..bd56ddb3d61d 100644 --- a/examples/community/mixture_tiling_sdxl.py +++ b/examples/community/mixture_tiling_sdxl.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1070,32 +1070,32 @@ def __call__( text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left[row][col], - target_size, + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left[row][col], + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left[row][col], + negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left[row][col], - negative_target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - negative_add_time_ids = add_time_ids + else: + negative_add_time_ids = add_time_ids - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids)) embeddings_and_added_time.append(addition_embed_type_row)