Skip to content

Commit 04bba38

Browse files
linoytsabanhlky
andauthored
[Flux Redux] add prompt & multiple image input (#10056)
* add multiple prompts to flux redux --------- Co-authored-by: hlky <hlky@hlky.ac>
1 parent a2d424e commit 04bba38

File tree

1 file changed

+92
-5
lines changed

1 file changed

+92
-5
lines changed

src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,45 @@ def __init__(
142142
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
143143
)
144144

145+
def check_inputs(
146+
self,
147+
image,
148+
prompt,
149+
prompt_2,
150+
prompt_embeds=None,
151+
pooled_prompt_embeds=None,
152+
prompt_embeds_scale=1.0,
153+
pooled_prompt_embeds_scale=1.0,
154+
):
155+
if prompt is not None and prompt_embeds is not None:
156+
raise ValueError(
157+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
158+
" only forward one of the two."
159+
)
160+
elif prompt_2 is not None and prompt_embeds is not None:
161+
raise ValueError(
162+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
163+
" only forward one of the two."
164+
)
165+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
166+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
167+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
168+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
169+
if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)):
170+
raise ValueError(
171+
f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images"
172+
)
173+
if prompt_embeds is not None and pooled_prompt_embeds is None:
174+
raise ValueError(
175+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
176+
)
177+
if isinstance(prompt_embeds_scale, list) and (
178+
isinstance(image, list) and len(prompt_embeds_scale) != len(image)
179+
):
180+
raise ValueError(
181+
f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images"
182+
)
183+
145184
def encode_image(self, image, device, num_images_per_prompt):
146185
dtype = next(self.image_encoder.parameters()).dtype
147186
image = self.feature_extractor.preprocess(
@@ -334,6 +373,12 @@ def encode_prompt(
334373
def __call__(
335374
self,
336375
image: PipelineImageInput,
376+
prompt: Union[str, List[str]] = None,
377+
prompt_2: Optional[Union[str, List[str]]] = None,
378+
prompt_embeds: Optional[torch.FloatTensor] = None,
379+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
380+
prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
381+
pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
337382
return_dict: bool = True,
338383
):
339384
r"""
@@ -345,6 +390,16 @@ def __call__(
345390
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
346391
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
347392
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
393+
prompt (`str` or `List[str]`, *optional*):
394+
The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
395+
make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
396+
are not loaded.
397+
prompt_2 (`str` or `List[str]`, *optional*):
398+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
399+
prompt_embeds (`torch.FloatTensor`, *optional*):
400+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
401+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
402+
Pre-generated pooled text embeddings.
348403
return_dict (`bool`, *optional*, defaults to `True`):
349404
Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
350405
@@ -356,13 +411,31 @@ def __call__(
356411
returning a tuple, the first element is a list with the generated images.
357412
"""
358413

414+
# 1. Check inputs. Raise error if not correct
415+
self.check_inputs(
416+
image,
417+
prompt,
418+
prompt_2,
419+
prompt_embeds=prompt_embeds,
420+
pooled_prompt_embeds=pooled_prompt_embeds,
421+
prompt_embeds_scale=prompt_embeds_scale,
422+
pooled_prompt_embeds_scale=pooled_prompt_embeds_scale,
423+
)
424+
359425
# 2. Define call parameters
360426
if image is not None and isinstance(image, Image.Image):
361427
batch_size = 1
362428
elif image is not None and isinstance(image, list):
363429
batch_size = len(image)
364430
else:
365431
batch_size = image.shape[0]
432+
if prompt is not None and isinstance(prompt, str):
433+
prompt = batch_size * [prompt]
434+
if isinstance(prompt_embeds_scale, float):
435+
prompt_embeds_scale = batch_size * [prompt_embeds_scale]
436+
if isinstance(pooled_prompt_embeds_scale, float):
437+
pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]
438+
366439
device = self._execution_device
367440

368441
# 3. Prepare image embeddings
@@ -378,24 +451,38 @@ def __call__(
378451
pooled_prompt_embeds,
379452
_,
380453
) = self.encode_prompt(
381-
prompt=[""] * batch_size,
382-
prompt_2=None,
383-
prompt_embeds=None,
384-
pooled_prompt_embeds=None,
454+
prompt=prompt,
455+
prompt_2=prompt_2,
456+
prompt_embeds=prompt_embeds,
457+
pooled_prompt_embeds=pooled_prompt_embeds,
385458
device=device,
386459
num_images_per_prompt=1,
387460
max_sequence_length=512,
388461
lora_scale=None,
389462
)
390463
else:
464+
if prompt is not None:
465+
logger.warning(
466+
"prompt input is ignored when text encoders are not loaded to the pipeline. "
467+
"Make sure to explicitly load the text encoders to enable prompt input. "
468+
)
391469
# max_sequence_length is 512, t5 encoder hidden size is 4096
392470
prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
393471
# pooled_prompt_embeds is 768, clip text encoder hidden size
394472
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
395473

396-
# Concatenate image and text embeddings
474+
# scale & concatenate image and text embeddings
397475
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
398476

477+
prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
478+
pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[
479+
:, None
480+
]
481+
482+
# weighted sum
483+
prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
484+
pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)
485+
399486
# Offload all models
400487
self.maybe_free_model_hooks()
401488

0 commit comments

Comments
 (0)