Skip to content

Commit edb8c1b

Browse files
authored
[Flux] Improve true cfg condition (#10539)
* improve flux true cfg condition * add test
1 parent 0785dba commit edb8c1b

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,10 @@ def __call__(
790790
lora_scale = (
791791
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
792792
)
793-
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
793+
has_neg_prompt = negative_prompt is not None or (
794+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
795+
)
796+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
794797
(
795798
prompt_embeds,
796799
pooled_prompt_embeds,

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,17 @@ def test_flux_image_output_shape(self):
209209
output_height, output_width, _ = image.shape
210210
assert (output_height, output_width) == (expected_height, expected_width)
211211

212+
def test_flux_true_cfg(self):
213+
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
214+
inputs = self.get_dummy_inputs(torch_device)
215+
inputs.pop("generator")
216+
217+
no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
218+
inputs["negative_prompt"] = "bad quality"
219+
inputs["true_cfg_scale"] = 2.0
220+
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
221+
assert not np.allclose(no_true_cfg_out, true_cfg_out)
222+
212223

213224
@nightly
214225
@require_big_gpu_with_torch_cuda

0 commit comments

Comments
 (0)