Skip to content

Commit a5892d7

Browse files
committed
fix imports in community pipeline for semantic guidance for flux
1 parent 7829f3d commit a5892d7

File tree

1 file changed

+97
-85
lines changed

1 file changed

+97
-85
lines changed

examples/community/pipeline_flux_semantic_guidance.py

Lines changed: 97 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,25 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
16+
from typing import Any, Callable, Dict, List, Optional, Union
17+
1718
import numpy as np
1819
import torch
1920
from transformers import (
21+
CLIPImageProcessor,
2022
CLIPTextModel,
2123
CLIPTokenizer,
24+
CLIPVisionModelWithProjection,
2225
T5EncoderModel,
2326
T5TokenizerFast,
24-
CLIPVisionModelWithProjection,
25-
CLIPImageProcessor
2627
)
27-
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
28-
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
29-
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin
28+
29+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30+
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
3031
from diffusers.models.autoencoders import AutoencoderKL
3132
from diffusers.models.transformers import FluxTransformer2DModel
33+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
34+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
3235
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
3336
from diffusers.utils import (
3437
USE_PEFT_BACKEND,
@@ -39,7 +42,7 @@
3942
unscale_lora_layers,
4043
)
4144
from diffusers.utils.torch_utils import randn_tensor
42-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
45+
4346

4447
if is_torch_xla_available():
4548
import torch_xla.core.xla_model as xm
@@ -57,7 +60,7 @@
5760
>>> from diffusers import DiffusionPipeline
5861
5962
>>> pipe = DiffusionPipeline.from_pretrained(
60-
>>> "black-forest-labs/FLUX.1-dev",
63+
>>> "black-forest-labs/FLUX.1-dev",
6164
>>> custom_pipeline="pipeline_flux_semantic_guidance",
6265
>>> torch_dtype=torch.bfloat16
6366
>>> )
@@ -319,7 +322,6 @@ def _get_clip_prompt_embeds(
319322

320323
return prompt_embeds
321324

322-
323325
def encode_prompt(
324326
self,
325327
prompt: Union[str, List[str]],
@@ -400,18 +402,18 @@ def encode_prompt(
400402
return prompt_embeds, pooled_prompt_embeds, text_ids
401403

402404
def encode_text_with_editing(
403-
self,
404-
prompt: Union[str, List[str]],
405-
prompt_2: Union[str, List[str]],
406-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
407-
editing_prompt: Optional[List[str]] = None,
408-
editing_prompt_2: Optional[List[str]] = None,
409-
editing_prompt_embeds: Optional[torch.FloatTensor] = None,
410-
pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
411-
device: Optional[torch.device] = None,
412-
num_images_per_prompt: int = 1,
413-
max_sequence_length: int = 512,
414-
lora_scale: Optional[float] = None,
405+
self,
406+
prompt: Union[str, List[str]],
407+
prompt_2: Union[str, List[str]],
408+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
409+
editing_prompt: Optional[List[str]] = None,
410+
editing_prompt_2: Optional[List[str]] = None,
411+
editing_prompt_embeds: Optional[torch.FloatTensor] = None,
412+
pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
413+
device: Optional[torch.device] = None,
414+
num_images_per_prompt: int = 1,
415+
max_sequence_length: int = 512,
416+
lora_scale: Optional[float] = None,
415417
):
416418
"""
417419
Encode text prompts with editing prompts and negative prompts for semantic guidance.
@@ -500,8 +502,15 @@ def encode_text_with_editing(
500502
editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0)
501503
pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0)
502504

503-
return (prompt_embeds, pooled_prompt_embeds, editing_prompt_embeds,
504-
pooled_editing_prompt_embeds, text_ids, edit_text_ids, enabled_editing_prompts)
505+
return (
506+
prompt_embeds,
507+
pooled_prompt_embeds,
508+
editing_prompt_embeds,
509+
pooled_editing_prompt_embeds,
510+
text_ids,
511+
edit_text_ids,
512+
enabled_editing_prompts,
513+
)
505514

506515
def encode_image(self, image, device, num_images_per_prompt):
507516
dtype = next(self.image_encoder.parameters()).dtype
@@ -546,27 +555,27 @@ def prepare_ip_adapter_image_embeds(
546555
return ip_adapter_image_embeds
547556

548557
def check_inputs(
549-
self,
550-
prompt,
551-
prompt_2,
552-
height,
553-
width,
554-
negative_prompt=None,
555-
negative_prompt_2=None,
556-
prompt_embeds=None,
557-
negative_prompt_embeds=None,
558-
pooled_prompt_embeds=None,
559-
negative_pooled_prompt_embeds=None,
560-
callback_on_step_end_tensor_inputs=None,
561-
max_sequence_length=None,
558+
self,
559+
prompt,
560+
prompt_2,
561+
height,
562+
width,
563+
negative_prompt=None,
564+
negative_prompt_2=None,
565+
prompt_embeds=None,
566+
negative_prompt_embeds=None,
567+
pooled_prompt_embeds=None,
568+
negative_pooled_prompt_embeds=None,
569+
callback_on_step_end_tensor_inputs=None,
570+
max_sequence_length=None,
562571
):
563572
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
564573
logger.warning(
565574
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
566575
)
567576

568577
if callback_on_step_end_tensor_inputs is not None and not all(
569-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
578+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
570579
):
571580
raise ValueError(
572581
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -743,47 +752,47 @@ def interrupt(self):
743752
@torch.no_grad()
744753
@replace_example_docstring(EXAMPLE_DOC_STRING)
745754
def __call__(
746-
self,
747-
prompt: Union[str, List[str]] = None,
748-
prompt_2: Optional[Union[str, List[str]]] = None,
749-
negative_prompt: Union[str, List[str]] = None,
750-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
751-
true_cfg_scale: float = 1.0,
752-
height: Optional[int] = None,
753-
width: Optional[int] = None,
754-
num_inference_steps: int = 28,
755-
sigmas: Optional[List[float]] = None,
756-
guidance_scale: float = 3.5,
757-
num_images_per_prompt: Optional[int] = 1,
758-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
759-
latents: Optional[torch.FloatTensor] = None,
760-
prompt_embeds: Optional[torch.FloatTensor] = None,
761-
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
762-
ip_adapter_image: Optional[PipelineImageInput] = None,
763-
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
764-
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
765-
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
766-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
767-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
768-
output_type: Optional[str] = "pil",
769-
return_dict: bool = True,
770-
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
771-
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
772-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
773-
max_sequence_length: int = 512,
774-
editing_prompt: Optional[Union[str, List[str]]] = None,
775-
editing_prompt_2: Optional[Union[str, List[str]]] = None,
776-
editing_prompt_embeds: Optional[torch.FloatTensor] = None,
777-
pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
778-
reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
779-
edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
780-
edit_warmup_steps: Optional[Union[int, List[int]]] = 8,
781-
edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
782-
edit_threshold: Optional[Union[float, List[float]]] = 0.9,
783-
edit_momentum_scale: Optional[float] = 0.1,
784-
edit_mom_beta: Optional[float] = 0.4,
785-
edit_weights: Optional[List[float]] = None,
786-
sem_guidance: Optional[List[torch.Tensor]] = None,
755+
self,
756+
prompt: Union[str, List[str]] = None,
757+
prompt_2: Optional[Union[str, List[str]]] = None,
758+
negative_prompt: Union[str, List[str]] = None,
759+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
760+
true_cfg_scale: float = 1.0,
761+
height: Optional[int] = None,
762+
width: Optional[int] = None,
763+
num_inference_steps: int = 28,
764+
sigmas: Optional[List[float]] = None,
765+
guidance_scale: float = 3.5,
766+
num_images_per_prompt: Optional[int] = 1,
767+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
768+
latents: Optional[torch.FloatTensor] = None,
769+
prompt_embeds: Optional[torch.FloatTensor] = None,
770+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
771+
ip_adapter_image: Optional[PipelineImageInput] = None,
772+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
773+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
774+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
775+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
776+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
777+
output_type: Optional[str] = "pil",
778+
return_dict: bool = True,
779+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
780+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
781+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
782+
max_sequence_length: int = 512,
783+
editing_prompt: Optional[Union[str, List[str]]] = None,
784+
editing_prompt_2: Optional[Union[str, List[str]]] = None,
785+
editing_prompt_embeds: Optional[torch.FloatTensor] = None,
786+
pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
787+
reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
788+
edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
789+
edit_warmup_steps: Optional[Union[int, List[int]]] = 8,
790+
edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
791+
edit_threshold: Optional[Union[float, List[float]]] = 0.9,
792+
edit_momentum_scale: Optional[float] = 0.1,
793+
edit_mom_beta: Optional[float] = 0.4,
794+
edit_weights: Optional[List[float]] = None,
795+
sem_guidance: Optional[List[torch.Tensor]] = None,
787796
):
788797
r"""
789798
Function invoked when calling the pipeline for generation.
@@ -1037,7 +1046,9 @@ def __call__(
10371046
min_edit_warmup_steps = 0
10381047

10391048
if edit_cooldown_steps:
1040-
tmp_e_cooldown_steps = edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps]
1049+
tmp_e_cooldown_steps = (
1050+
edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps]
1051+
)
10411052
max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps)
10421053
else:
10431054
max_edit_cooldown_steps = num_inference_steps
@@ -1110,7 +1121,9 @@ def __call__(
11101121

11111122
if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:
11121123
noise_pred_edit_concepts = []
1113-
for e_embed, pooled_e_embed, e_text_id in zip(editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids):
1124+
for e_embed, pooled_e_embed, e_text_id in zip(
1125+
editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids
1126+
):
11141127
noise_pred_edit = self.transformer(
11151128
hidden_states=latents,
11161129
timestep=timestep / 1000,
@@ -1160,7 +1173,6 @@ def __call__(
11601173
# noise_guidance_edit = torch.zeros_like(noise_guidance)
11611174
warmup_inds = []
11621175
for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
1163-
11641176
if isinstance(edit_guidance_scale, list):
11651177
edit_guidance_scale_c = edit_guidance_scale[c]
11661178
else:
@@ -1247,9 +1259,7 @@ def __call__(
12471259
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
12481260
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
12491261

1250-
noise_guidance_edit_tmp = torch.index_select(
1251-
noise_guidance_edit.to(device), 0, warmup_inds
1252-
)
1262+
noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
12531263
noise_guidance_edit_tmp = torch.einsum(
12541264
"cb,cbij->bij", concept_weights_tmp, noise_guidance_edit_tmp
12551265
)
@@ -1325,4 +1335,6 @@ def __call__(
13251335
if not return_dict:
13261336
return (image,)
13271337

1328-
return FluxPipelineOutput(image, )
1338+
return FluxPipelineOutput(
1339+
image,
1340+
)

0 commit comments

Comments
 (0)