Skip to content

Commit 7d63182

Browse files
Make Dreambooth SD Training Script torch.compile compatible (#6532)
* support compile * make style * move unwrap_model inside function * change unwrap call * run make style * Update examples/dreambooth/train_dreambooth.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Revert "Update examples/dreambooth/train_dreambooth.py" This reverts commit 70ab097. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 33d2b5b commit 7d63182

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from diffusers.training_utils import compute_snr
5656
from diffusers.utils import check_min_version, is_wandb_available
5757
from diffusers.utils.import_utils import is_xformers_available
58+
from diffusers.utils.torch_utils import is_compiled_module
5859

5960

6061
if is_wandb_available():
@@ -129,15 +130,12 @@ def log_validation(
129130
if vae is not None:
130131
pipeline_args["vae"] = vae
131132

132-
if text_encoder is not None:
133-
text_encoder = accelerator.unwrap_model(text_encoder)
134-
135133
# create pipeline (note: unet and vae are loaded again in float32)
136134
pipeline = DiffusionPipeline.from_pretrained(
137135
args.pretrained_model_name_or_path,
138136
tokenizer=tokenizer,
139137
text_encoder=text_encoder,
140-
unet=accelerator.unwrap_model(unet),
138+
unet=unet,
141139
revision=args.revision,
142140
variant=args.variant,
143141
torch_dtype=weight_dtype,
@@ -794,6 +792,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
794792
prompt_embeds = text_encoder(
795793
text_input_ids,
796794
attention_mask=attention_mask,
795+
return_dict=False,
797796
)
798797
prompt_embeds = prompt_embeds[0]
799798

@@ -931,11 +930,16 @@ def main(args):
931930
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
932931
)
933932

933+
def unwrap_model(model):
934+
model = accelerator.unwrap_model(model)
935+
model = model._orig_mod if is_compiled_module(model) else model
936+
return model
937+
934938
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
935939
def save_model_hook(models, weights, output_dir):
936940
if accelerator.is_main_process:
937941
for model in models:
938-
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
942+
sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
939943
model.save_pretrained(os.path.join(output_dir, sub_dir))
940944

941945
# make sure to pop weight so that corresponding model is not saved again
@@ -946,7 +950,7 @@ def load_model_hook(models, input_dir):
946950
# pop models so that they are not loaded again
947951
model = models.pop()
948952

949-
if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
953+
if isinstance(model, type(unwrap_model(text_encoder))):
950954
# load transformers style into model
951955
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
952956
model.config = load_model.config
@@ -991,15 +995,12 @@ def load_model_hook(models, input_dir):
991995
" doing mixed precision training. copy of the weights should still be float32."
992996
)
993997

994-
if accelerator.unwrap_model(unet).dtype != torch.float32:
995-
raise ValueError(
996-
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
997-
)
998+
if unwrap_model(unet).dtype != torch.float32:
999+
raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
9981000

999-
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
1001+
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
10001002
raise ValueError(
1001-
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
1002-
f" {low_precision_error_string}"
1003+
f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
10031004
)
10041005

10051006
# Enable TF32 for faster training on Ampere GPUs,
@@ -1246,7 +1247,7 @@ def compute_text_embeddings(prompt):
12461247
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
12471248
)
12481249

1249-
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
1250+
if unwrap_model(unet).config.in_channels == channels * 2:
12501251
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
12511252

12521253
if args.class_labels_conditioning == "timesteps":
@@ -1256,8 +1257,8 @@ def compute_text_embeddings(prompt):
12561257

12571258
# Predict the noise residual
12581259
model_pred = unet(
1259-
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1260-
).sample
1260+
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
1261+
)[0]
12611262

12621263
if model_pred.shape[1] == 6:
12631264
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
@@ -1350,9 +1351,9 @@ def compute_text_embeddings(prompt):
13501351

13511352
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
13521353
images = log_validation(
1353-
text_encoder,
1354+
unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
13541355
tokenizer,
1355-
unet,
1356+
unwrap_model(unet),
13561357
vae,
13571358
args,
13581359
accelerator,
@@ -1375,14 +1376,14 @@ def compute_text_embeddings(prompt):
13751376
pipeline_args = {}
13761377

13771378
if text_encoder is not None:
1378-
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
1379+
pipeline_args["text_encoder"] = unwrap_model(text_encoder)
13791380

13801381
if args.skip_save_text_encoder:
13811382
pipeline_args["text_encoder"] = None
13821383

13831384
pipeline = DiffusionPipeline.from_pretrained(
13841385
args.pretrained_model_name_or_path,
1385-
unet=accelerator.unwrap_model(unet),
1386+
unet=unwrap_model(unet),
13861387
revision=args.revision,
13871388
variant=args.variant,
13881389
**pipeline_args,

0 commit comments

Comments
 (0)