@@ -142,6 +142,45 @@ def __init__(
142
142
self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
143
143
)
144
144
145
+ def check_inputs (
146
+ self ,
147
+ image ,
148
+ prompt ,
149
+ prompt_2 ,
150
+ prompt_embeds = None ,
151
+ pooled_prompt_embeds = None ,
152
+ prompt_embeds_scale = 1.0 ,
153
+ pooled_prompt_embeds_scale = 1.0 ,
154
+ ):
155
+ if prompt is not None and prompt_embeds is not None :
156
+ raise ValueError (
157
+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
158
+ " only forward one of the two."
159
+ )
160
+ elif prompt_2 is not None and prompt_embeds is not None :
161
+ raise ValueError (
162
+ f"Cannot forward both `prompt_2`: { prompt_2 } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
163
+ " only forward one of the two."
164
+ )
165
+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
166
+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
167
+ elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
168
+ raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
169
+ if prompt is not None and (isinstance (prompt , list ) and isinstance (image , list ) and len (prompt ) != len (image )):
170
+ raise ValueError (
171
+ f"number of prompts must be equal to number of images, but { len (prompt )} prompts were provided and { len (image )} images"
172
+ )
173
+ if prompt_embeds is not None and pooled_prompt_embeds is None :
174
+ raise ValueError (
175
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
176
+ )
177
+ if isinstance (prompt_embeds_scale , list ) and (
178
+ isinstance (image , list ) and len (prompt_embeds_scale ) != len (image )
179
+ ):
180
+ raise ValueError (
181
+ f"number of weights must be equal to number of images, but { len (prompt_embeds_scale )} weights were provided and { len (image )} images"
182
+ )
183
+
145
184
def encode_image (self , image , device , num_images_per_prompt ):
146
185
dtype = next (self .image_encoder .parameters ()).dtype
147
186
image = self .feature_extractor .preprocess (
@@ -334,6 +373,12 @@ def encode_prompt(
334
373
def __call__ (
335
374
self ,
336
375
image : PipelineImageInput ,
376
+ prompt : Union [str , List [str ]] = None ,
377
+ prompt_2 : Optional [Union [str , List [str ]]] = None ,
378
+ prompt_embeds : Optional [torch .FloatTensor ] = None ,
379
+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
380
+ prompt_embeds_scale : Optional [Union [float , List [float ]]] = 1.0 ,
381
+ pooled_prompt_embeds_scale : Optional [Union [float , List [float ]]] = 1.0 ,
337
382
return_dict : bool = True ,
338
383
):
339
384
r"""
@@ -345,6 +390,16 @@ def __call__(
345
390
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
346
391
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
347
392
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
393
+ prompt (`str` or `List[str]`, *optional*):
394
+ The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
395
+ make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
396
+ are not loaded.
397
+ prompt_2 (`str` or `List[str]`, *optional*):
398
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
399
+ prompt_embeds (`torch.FloatTensor`, *optional*):
400
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
401
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
402
+ Pre-generated pooled text embeddings.
348
403
return_dict (`bool`, *optional*, defaults to `True`):
349
404
Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
350
405
@@ -356,13 +411,31 @@ def __call__(
356
411
returning a tuple, the first element is a list with the generated images.
357
412
"""
358
413
414
+ # 1. Check inputs. Raise error if not correct
415
+ self .check_inputs (
416
+ image ,
417
+ prompt ,
418
+ prompt_2 ,
419
+ prompt_embeds = prompt_embeds ,
420
+ pooled_prompt_embeds = pooled_prompt_embeds ,
421
+ prompt_embeds_scale = prompt_embeds_scale ,
422
+ pooled_prompt_embeds_scale = pooled_prompt_embeds_scale ,
423
+ )
424
+
359
425
# 2. Define call parameters
360
426
if image is not None and isinstance (image , Image .Image ):
361
427
batch_size = 1
362
428
elif image is not None and isinstance (image , list ):
363
429
batch_size = len (image )
364
430
else :
365
431
batch_size = image .shape [0 ]
432
+ if prompt is not None and isinstance (prompt , str ):
433
+ prompt = batch_size * [prompt ]
434
+ if isinstance (prompt_embeds_scale , float ):
435
+ prompt_embeds_scale = batch_size * [prompt_embeds_scale ]
436
+ if isinstance (pooled_prompt_embeds_scale , float ):
437
+ pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale ]
438
+
366
439
device = self ._execution_device
367
440
368
441
# 3. Prepare image embeddings
@@ -378,24 +451,38 @@ def __call__(
378
451
pooled_prompt_embeds ,
379
452
_ ,
380
453
) = self .encode_prompt (
381
- prompt = [ "" ] * batch_size ,
382
- prompt_2 = None ,
383
- prompt_embeds = None ,
384
- pooled_prompt_embeds = None ,
454
+ prompt = prompt ,
455
+ prompt_2 = prompt_2 ,
456
+ prompt_embeds = prompt_embeds ,
457
+ pooled_prompt_embeds = pooled_prompt_embeds ,
385
458
device = device ,
386
459
num_images_per_prompt = 1 ,
387
460
max_sequence_length = 512 ,
388
461
lora_scale = None ,
389
462
)
390
463
else :
464
+ if prompt is not None :
465
+ logger .warning (
466
+ "prompt input is ignored when text encoders are not loaded to the pipeline. "
467
+ "Make sure to explicitly load the text encoders to enable prompt input. "
468
+ )
391
469
# max_sequence_length is 512, t5 encoder hidden size is 4096
392
470
prompt_embeds = torch .zeros ((batch_size , 512 , 4096 ), device = device , dtype = image_embeds .dtype )
393
471
# pooled_prompt_embeds is 768, clip text encoder hidden size
394
472
pooled_prompt_embeds = torch .zeros ((batch_size , 768 ), device = device , dtype = image_embeds .dtype )
395
473
396
- # Concatenate image and text embeddings
474
+ # scale & concatenate image and text embeddings
397
475
prompt_embeds = torch .cat ([prompt_embeds , image_embeds ], dim = 1 )
398
476
477
+ prompt_embeds *= torch .tensor (prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[:, None , None ]
478
+ pooled_prompt_embeds *= torch .tensor (pooled_prompt_embeds_scale , device = device , dtype = image_embeds .dtype )[
479
+ :, None
480
+ ]
481
+
482
+ # weighted sum
483
+ prompt_embeds = torch .sum (prompt_embeds , dim = 0 , keepdim = True )
484
+ pooled_prompt_embeds = torch .sum (pooled_prompt_embeds , dim = 0 , keepdim = True )
485
+
399
486
# Offload all models
400
487
self .maybe_free_model_hooks ()
401
488
0 commit comments