Skip to content

Commit 8ead643

Browse files
andjoerAndreas Jörgsayakpaul
authored
[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)
Fix: dtype mismatch of prompt embeddings in sd3 controlnet training Co-authored-by: Andreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 124ac3e commit 8ead643

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/controlnet/train_controlnet_sd3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,8 +1283,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12831283
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
12841284

12851285
# Get the text embedding for conditioning
1286-
prompt_embeds = batch["prompt_embeds"]
1287-
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
1286+
prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
1287+
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
12881288

12891289
# controlnet(s) inference
12901290
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)

0 commit comments

Comments
 (0)