|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import inspect
|
16 |
| -from typing import Any, Callable, Dict, List, Optional, Union, Tuple |
| 16 | +from typing import Any, Callable, Dict, List, Optional, Union |
| 17 | + |
17 | 18 | import numpy as np
|
18 | 19 | import torch
|
19 | 20 | from transformers import (
|
| 21 | + CLIPImageProcessor, |
20 | 22 | CLIPTextModel,
|
21 | 23 | CLIPTokenizer,
|
| 24 | + CLIPVisionModelWithProjection, |
22 | 25 | T5EncoderModel,
|
23 | 26 | T5TokenizerFast,
|
24 |
| - CLIPVisionModelWithProjection, |
25 |
| - CLIPImageProcessor |
26 | 27 | )
|
27 |
| -from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput |
28 |
| -from diffusers.image_processor import VaeImageProcessor, PipelineImageInput |
29 |
| -from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin |
| 28 | + |
| 29 | +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
| 30 | +from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin |
30 | 31 | from diffusers.models.autoencoders import AutoencoderKL
|
31 | 32 | from diffusers.models.transformers import FluxTransformer2DModel
|
| 33 | +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput |
| 34 | +from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
32 | 35 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
33 | 36 | from diffusers.utils import (
|
34 | 37 | USE_PEFT_BACKEND,
|
|
39 | 42 | unscale_lora_layers,
|
40 | 43 | )
|
41 | 44 | from diffusers.utils.torch_utils import randn_tensor
|
42 |
| -from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| 45 | + |
43 | 46 |
|
44 | 47 | if is_torch_xla_available():
|
45 | 48 | import torch_xla.core.xla_model as xm
|
|
57 | 60 | >>> from diffusers import DiffusionPipeline
|
58 | 61 |
|
59 | 62 | >>> pipe = DiffusionPipeline.from_pretrained(
|
60 |
| - >>> "black-forest-labs/FLUX.1-dev", |
| 63 | + >>> "black-forest-labs/FLUX.1-dev", |
61 | 64 | >>> custom_pipeline="pipeline_flux_semantic_guidance",
|
62 | 65 | >>> torch_dtype=torch.bfloat16
|
63 | 66 | >>> )
|
@@ -319,7 +322,6 @@ def _get_clip_prompt_embeds(
|
319 | 322 |
|
320 | 323 | return prompt_embeds
|
321 | 324 |
|
322 |
| - |
323 | 325 | def encode_prompt(
|
324 | 326 | self,
|
325 | 327 | prompt: Union[str, List[str]],
|
@@ -400,18 +402,18 @@ def encode_prompt(
|
400 | 402 | return prompt_embeds, pooled_prompt_embeds, text_ids
|
401 | 403 |
|
402 | 404 | def encode_text_with_editing(
|
403 |
| - self, |
404 |
| - prompt: Union[str, List[str]], |
405 |
| - prompt_2: Union[str, List[str]], |
406 |
| - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
407 |
| - editing_prompt: Optional[List[str]] = None, |
408 |
| - editing_prompt_2: Optional[List[str]] = None, |
409 |
| - editing_prompt_embeds: Optional[torch.FloatTensor] = None, |
410 |
| - pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, |
411 |
| - device: Optional[torch.device] = None, |
412 |
| - num_images_per_prompt: int = 1, |
413 |
| - max_sequence_length: int = 512, |
414 |
| - lora_scale: Optional[float] = None, |
| 405 | + self, |
| 406 | + prompt: Union[str, List[str]], |
| 407 | + prompt_2: Union[str, List[str]], |
| 408 | + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 409 | + editing_prompt: Optional[List[str]] = None, |
| 410 | + editing_prompt_2: Optional[List[str]] = None, |
| 411 | + editing_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 412 | + pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 413 | + device: Optional[torch.device] = None, |
| 414 | + num_images_per_prompt: int = 1, |
| 415 | + max_sequence_length: int = 512, |
| 416 | + lora_scale: Optional[float] = None, |
415 | 417 | ):
|
416 | 418 | """
|
417 | 419 | Encode text prompts with editing prompts and negative prompts for semantic guidance.
|
@@ -500,8 +502,15 @@ def encode_text_with_editing(
|
500 | 502 | editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0)
|
501 | 503 | pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0)
|
502 | 504 |
|
503 |
| - return (prompt_embeds, pooled_prompt_embeds, editing_prompt_embeds, |
504 |
| - pooled_editing_prompt_embeds, text_ids, edit_text_ids, enabled_editing_prompts) |
| 505 | + return ( |
| 506 | + prompt_embeds, |
| 507 | + pooled_prompt_embeds, |
| 508 | + editing_prompt_embeds, |
| 509 | + pooled_editing_prompt_embeds, |
| 510 | + text_ids, |
| 511 | + edit_text_ids, |
| 512 | + enabled_editing_prompts, |
| 513 | + ) |
505 | 514 |
|
506 | 515 | def encode_image(self, image, device, num_images_per_prompt):
|
507 | 516 | dtype = next(self.image_encoder.parameters()).dtype
|
@@ -546,27 +555,27 @@ def prepare_ip_adapter_image_embeds(
|
546 | 555 | return ip_adapter_image_embeds
|
547 | 556 |
|
548 | 557 | def check_inputs(
|
549 |
| - self, |
550 |
| - prompt, |
551 |
| - prompt_2, |
552 |
| - height, |
553 |
| - width, |
554 |
| - negative_prompt=None, |
555 |
| - negative_prompt_2=None, |
556 |
| - prompt_embeds=None, |
557 |
| - negative_prompt_embeds=None, |
558 |
| - pooled_prompt_embeds=None, |
559 |
| - negative_pooled_prompt_embeds=None, |
560 |
| - callback_on_step_end_tensor_inputs=None, |
561 |
| - max_sequence_length=None, |
| 558 | + self, |
| 559 | + prompt, |
| 560 | + prompt_2, |
| 561 | + height, |
| 562 | + width, |
| 563 | + negative_prompt=None, |
| 564 | + negative_prompt_2=None, |
| 565 | + prompt_embeds=None, |
| 566 | + negative_prompt_embeds=None, |
| 567 | + pooled_prompt_embeds=None, |
| 568 | + negative_pooled_prompt_embeds=None, |
| 569 | + callback_on_step_end_tensor_inputs=None, |
| 570 | + max_sequence_length=None, |
562 | 571 | ):
|
563 | 572 | if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
564 | 573 | logger.warning(
|
565 | 574 | f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
566 | 575 | )
|
567 | 576 |
|
568 | 577 | if callback_on_step_end_tensor_inputs is not None and not all(
|
569 |
| - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
| 578 | + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
570 | 579 | ):
|
571 | 580 | raise ValueError(
|
572 | 581 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
@@ -743,47 +752,47 @@ def interrupt(self):
|
743 | 752 | @torch.no_grad()
|
744 | 753 | @replace_example_docstring(EXAMPLE_DOC_STRING)
|
745 | 754 | def __call__(
|
746 |
| - self, |
747 |
| - prompt: Union[str, List[str]] = None, |
748 |
| - prompt_2: Optional[Union[str, List[str]]] = None, |
749 |
| - negative_prompt: Union[str, List[str]] = None, |
750 |
| - negative_prompt_2: Optional[Union[str, List[str]]] = None, |
751 |
| - true_cfg_scale: float = 1.0, |
752 |
| - height: Optional[int] = None, |
753 |
| - width: Optional[int] = None, |
754 |
| - num_inference_steps: int = 28, |
755 |
| - sigmas: Optional[List[float]] = None, |
756 |
| - guidance_scale: float = 3.5, |
757 |
| - num_images_per_prompt: Optional[int] = 1, |
758 |
| - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
759 |
| - latents: Optional[torch.FloatTensor] = None, |
760 |
| - prompt_embeds: Optional[torch.FloatTensor] = None, |
761 |
| - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
762 |
| - ip_adapter_image: Optional[PipelineImageInput] = None, |
763 |
| - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, |
764 |
| - negative_ip_adapter_image: Optional[PipelineImageInput] = None, |
765 |
| - negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, |
766 |
| - negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
767 |
| - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
768 |
| - output_type: Optional[str] = "pil", |
769 |
| - return_dict: bool = True, |
770 |
| - joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
771 |
| - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
772 |
| - callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
773 |
| - max_sequence_length: int = 512, |
774 |
| - editing_prompt: Optional[Union[str, List[str]]] = None, |
775 |
| - editing_prompt_2: Optional[Union[str, List[str]]] = None, |
776 |
| - editing_prompt_embeds: Optional[torch.FloatTensor] = None, |
777 |
| - pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, |
778 |
| - reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, |
779 |
| - edit_guidance_scale: Optional[Union[float, List[float]]] = 5, |
780 |
| - edit_warmup_steps: Optional[Union[int, List[int]]] = 8, |
781 |
| - edit_cooldown_steps: Optional[Union[int, List[int]]] = None, |
782 |
| - edit_threshold: Optional[Union[float, List[float]]] = 0.9, |
783 |
| - edit_momentum_scale: Optional[float] = 0.1, |
784 |
| - edit_mom_beta: Optional[float] = 0.4, |
785 |
| - edit_weights: Optional[List[float]] = None, |
786 |
| - sem_guidance: Optional[List[torch.Tensor]] = None, |
| 755 | + self, |
| 756 | + prompt: Union[str, List[str]] = None, |
| 757 | + prompt_2: Optional[Union[str, List[str]]] = None, |
| 758 | + negative_prompt: Union[str, List[str]] = None, |
| 759 | + negative_prompt_2: Optional[Union[str, List[str]]] = None, |
| 760 | + true_cfg_scale: float = 1.0, |
| 761 | + height: Optional[int] = None, |
| 762 | + width: Optional[int] = None, |
| 763 | + num_inference_steps: int = 28, |
| 764 | + sigmas: Optional[List[float]] = None, |
| 765 | + guidance_scale: float = 3.5, |
| 766 | + num_images_per_prompt: Optional[int] = 1, |
| 767 | + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 768 | + latents: Optional[torch.FloatTensor] = None, |
| 769 | + prompt_embeds: Optional[torch.FloatTensor] = None, |
| 770 | + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 771 | + ip_adapter_image: Optional[PipelineImageInput] = None, |
| 772 | + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, |
| 773 | + negative_ip_adapter_image: Optional[PipelineImageInput] = None, |
| 774 | + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, |
| 775 | + negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 776 | + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 777 | + output_type: Optional[str] = "pil", |
| 778 | + return_dict: bool = True, |
| 779 | + joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
| 780 | + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| 781 | + callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| 782 | + max_sequence_length: int = 512, |
| 783 | + editing_prompt: Optional[Union[str, List[str]]] = None, |
| 784 | + editing_prompt_2: Optional[Union[str, List[str]]] = None, |
| 785 | + editing_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 786 | + pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None, |
| 787 | + reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, |
| 788 | + edit_guidance_scale: Optional[Union[float, List[float]]] = 5, |
| 789 | + edit_warmup_steps: Optional[Union[int, List[int]]] = 8, |
| 790 | + edit_cooldown_steps: Optional[Union[int, List[int]]] = None, |
| 791 | + edit_threshold: Optional[Union[float, List[float]]] = 0.9, |
| 792 | + edit_momentum_scale: Optional[float] = 0.1, |
| 793 | + edit_mom_beta: Optional[float] = 0.4, |
| 794 | + edit_weights: Optional[List[float]] = None, |
| 795 | + sem_guidance: Optional[List[torch.Tensor]] = None, |
787 | 796 | ):
|
788 | 797 | r"""
|
789 | 798 | Function invoked when calling the pipeline for generation.
|
@@ -1037,7 +1046,9 @@ def __call__(
|
1037 | 1046 | min_edit_warmup_steps = 0
|
1038 | 1047 |
|
1039 | 1048 | if edit_cooldown_steps:
|
1040 |
| - tmp_e_cooldown_steps = edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps] |
| 1049 | + tmp_e_cooldown_steps = ( |
| 1050 | + edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps] |
| 1051 | + ) |
1041 | 1052 | max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps)
|
1042 | 1053 | else:
|
1043 | 1054 | max_edit_cooldown_steps = num_inference_steps
|
@@ -1110,7 +1121,9 @@ def __call__(
|
1110 | 1121 |
|
1111 | 1122 | if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:
|
1112 | 1123 | noise_pred_edit_concepts = []
|
1113 |
| - for e_embed, pooled_e_embed, e_text_id in zip(editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids): |
| 1124 | + for e_embed, pooled_e_embed, e_text_id in zip( |
| 1125 | + editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids |
| 1126 | + ): |
1114 | 1127 | noise_pred_edit = self.transformer(
|
1115 | 1128 | hidden_states=latents,
|
1116 | 1129 | timestep=timestep / 1000,
|
@@ -1160,7 +1173,6 @@ def __call__(
|
1160 | 1173 | # noise_guidance_edit = torch.zeros_like(noise_guidance)
|
1161 | 1174 | warmup_inds = []
|
1162 | 1175 | for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
|
1163 |
| - |
1164 | 1176 | if isinstance(edit_guidance_scale, list):
|
1165 | 1177 | edit_guidance_scale_c = edit_guidance_scale[c]
|
1166 | 1178 | else:
|
@@ -1247,9 +1259,7 @@ def __call__(
|
1247 | 1259 | concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
|
1248 | 1260 | # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
|
1249 | 1261 |
|
1250 |
| - noise_guidance_edit_tmp = torch.index_select( |
1251 |
| - noise_guidance_edit.to(device), 0, warmup_inds |
1252 |
| - ) |
| 1262 | + noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) |
1253 | 1263 | noise_guidance_edit_tmp = torch.einsum(
|
1254 | 1264 | "cb,cbij->bij", concept_weights_tmp, noise_guidance_edit_tmp
|
1255 | 1265 | )
|
@@ -1325,4 +1335,6 @@ def __call__(
|
1325 | 1335 | if not return_dict:
|
1326 | 1336 | return (image,)
|
1327 | 1337 |
|
1328 |
| - return FluxPipelineOutput(image, ) |
| 1338 | + return FluxPipelineOutput( |
| 1339 | + image, |
| 1340 | + ) |
0 commit comments