15
15
# See the License for the specific language governing permissions and
16
16
17
17
import argparse
18
- import copy
19
- import itertools
20
- import json
18
+ import io
21
19
import logging
22
20
import math
23
21
import os
24
- import random
25
22
import shutil
26
- import warnings
27
23
from pathlib import Path
24
+ from typing import Callable
28
25
29
26
import accelerate
30
- import io
31
27
import numpy as np
32
28
import torch
33
29
import torch .nn as nn
34
30
import torch .nn .functional as F
35
31
import torch .utils .checkpoint
36
32
import torchvision .transforms as T
37
- import torchvision .transforms .functional as TF
38
33
import transformers
39
- import webdataset as wds
40
34
from accelerate import Accelerator
41
35
from accelerate .logging import get_logger
42
36
from accelerate .utils import DistributedDataParallelKwargs , DistributedType , ProjectConfiguration , set_seed
43
- from braceexpand import braceexpand
44
37
from datasets import load_dataset
45
38
from huggingface_hub import create_repo , upload_folder
46
- from huggingface_hub .utils import insecure_hashlib
47
39
from packaging import version
48
- from peft .utils import get_peft_model_state_dict
49
40
from PIL import Image
50
- from PIL .ImageOps import exif_transpose
51
41
from safetensors .torch import load_file
52
42
from torch .nn .utils .spectral_norm import SpectralNorm
53
- from torch .utils .data import default_collate , Dataset , DataLoader
54
- from torchvision .transforms .functional import crop
43
+ from torch .utils .data import DataLoader , Dataset
55
44
from tqdm .auto import tqdm
56
45
from transformers import AutoTokenizer , Gemma2Model
57
- from typing import Callable , List , Union
58
- from webdataset .tariterators import (
59
- base_plus_ext ,
60
- tar_file_expander ,
61
- url_opener ,
62
- valid_sample ,
63
- )
64
46
65
47
import diffusers
66
48
from diffusers import (
67
49
AutoencoderDC ,
68
- FlowMatchEulerDiscreteScheduler ,
69
50
SanaPipeline ,
70
51
SanaSprintPipeline ,
71
52
SanaTransformer2DModel ,
72
- SCMScheduler ,
73
53
)
74
54
from diffusers .optimization import get_scheduler
75
55
from diffusers .training_utils import (
76
- cast_training_params ,
77
- compute_density_for_timestep_sampling ,
78
- compute_loss_weighting_for_sd3 ,
79
56
free_memory ,
80
57
)
81
58
from diffusers .utils import (
82
59
check_min_version ,
83
- convert_unet_state_dict_to_peft ,
84
60
is_wandb_available ,
85
61
)
86
62
from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
98
74
99
75
if is_torch_npu_available ():
100
76
torch .npu .config .allow_internal_format = False
101
-
77
+
102
78
COMPLEX_HUMAN_INSTRUCTION = [
103
79
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:" ,
104
80
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes." ,
109
85
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:" ,
110
86
"User Prompt: " ,
111
87
]
112
-
88
+
113
89
114
90
115
91
class Text2ImageDataset (Dataset ):
@@ -140,17 +116,17 @@ def __init__(self, hf_dataset, resolution=1024):
140
116
141
117
def __len__ (self ):
142
118
return len (self .dataset )
143
-
119
+
144
120
def __getitem__ (self , idx ):
145
121
item = self .dataset [idx ]
146
122
text = item ['llava' ]
147
123
image_bytes = item ['image' ]
148
-
124
+
149
125
# Convert bytes to PIL Image
150
126
image = Image .open (io .BytesIO (image_bytes ))
151
-
127
+
152
128
image_tensor = self .transform (image )
153
-
129
+
154
130
return {
155
131
'text' : text ,
156
132
'image' : image_tensor
@@ -768,7 +744,7 @@ def state_dict(self):
768
744
769
745
def __getattr__ (self , name ):
770
746
return getattr (self .disc , name )
771
-
747
+
772
748
class SanaTrigFlow (SanaTransformer2DModel ):
773
749
def __init__ (self , original_model , guidance = False ):
774
750
self .__dict__ = original_model .__dict__
@@ -779,7 +755,7 @@ def __init__(self, original_model, guidance=False):
779
755
self .logvar_linear = torch .nn .Linear (hidden_size , 1 )
780
756
torch .nn .init .xavier_uniform_ (self .logvar_linear .weight )
781
757
torch .nn .init .constant_ (self .logvar_linear .bias , 0 )
782
-
758
+
783
759
def forward (self , hidden_states , encoder_hidden_states , timestep , guidance = None , jvp = False , return_logvar = False , ** kwargs ):
784
760
batch_size = hidden_states .shape [0 ]
785
761
latents = hidden_states
@@ -812,8 +788,8 @@ def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None,
812
788
trigflow_model_out = ((1 - 2 * flow_timestep_expanded ) * latent_model_input + (1 - 2 * flow_timestep_expanded + 2 * flow_timestep_expanded ** 2 ) * model_out ) / torch .sqrt (
813
789
flow_timestep_expanded ** 2 + (1 - flow_timestep_expanded ) ** 2
814
790
)
815
-
816
-
791
+
792
+
817
793
if self .guidance and guidance is not None :
818
794
timestep , embedded_timestep = self .time_embed (
819
795
timestep , guidance = guidance , hidden_dtype = hidden_states .dtype
@@ -822,15 +798,15 @@ def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None,
822
798
timestep , embedded_timestep = self .time_embed (
823
799
timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype
824
800
)
825
-
801
+
826
802
if return_logvar :
827
803
logvar = self .logvar_linear (embedded_timestep )
828
804
return trigflow_model_out , logvar
829
-
805
+
830
806
831
807
return (trigflow_model_out ,)
832
808
833
-
809
+
834
810
835
811
def compute_density_for_timestep_sampling_scm (
836
812
batch_size : int , logit_mean : float = None , logit_std : float = None
@@ -925,19 +901,19 @@ def main(args):
925
901
revision = args .revision ,
926
902
variant = args .variant ,
927
903
)
928
-
904
+
929
905
ori_transformer = SanaTransformer2DModel .from_pretrained (
930
906
args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant ,
931
907
guidance_embeds = True , cross_attention_type = 'vanilla'
932
908
)
933
-
909
+
934
910
ori_transformer_no_guide = SanaTransformer2DModel .from_pretrained (
935
911
args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant ,
936
912
guidance_embeds = False
937
913
)
938
-
914
+
939
915
original_state_dict = load_file (f"{ args .pretrained_model_name_or_path } /transformer/diffusion_pytorch_model.safetensors" )
940
-
916
+
941
917
param_mapping = {
942
918
'time_embed.emb.timestep_embedder.linear_1.weight' : 'time_embed.timestep_embedder.linear_1.weight' ,
943
919
'time_embed.emb.timestep_embedder.linear_1.bias' : 'time_embed.timestep_embedder.linear_1.bias' ,
@@ -968,7 +944,7 @@ def main(args):
968
944
969
945
transformer = SanaTrigFlow (ori_transformer , guidance = True ).train ()
970
946
pretrained_model = SanaTrigFlow (ori_transformer_no_guide , guidance = False ).eval ()
971
-
947
+
972
948
disc = SanaMSCMDiscriminator (
973
949
pretrained_model ,
974
950
is_multiscale = args .ladd_multi_scale ,
@@ -1134,7 +1110,7 @@ def load_model_hook(models, input_dir):
1134
1110
data_files = args .file_path ,
1135
1111
split = 'train' ,
1136
1112
)
1137
-
1113
+
1138
1114
train_dataset = Text2ImageDataset (
1139
1115
hf_dataset = hf_dataset ,
1140
1116
resolution = args .resolution ,
@@ -1282,8 +1258,8 @@ def load_model_hook(models, input_dir):
1282
1258
# Add noise according to TrigFlow.
1283
1259
# zt = cos(t) * x + sin(t) * noise
1284
1260
t = u .view (- 1 , 1 , 1 , 1 )
1285
- noisy_model_input = torch .cos (t ) * model_input + torch .sin (t ) * noise
1286
-
1261
+ noisy_model_input = torch .cos (t ) * model_input + torch .sin (t ) * noise
1262
+
1287
1263
1288
1264
scm_cfg_scale = torch .tensor (
1289
1265
np .random .choice (args .scm_cfg_scale , size = bsz , replace = True ),
@@ -1295,7 +1271,7 @@ def model_wrapper(scaled_x_t, t):
1295
1271
hidden_states = scaled_x_t , timestep = t .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask , guidance = (scm_cfg_scale .flatten () * args .guidance_embeds_scale ), jvp = True , return_logvar = True
1296
1272
)
1297
1273
return pred , logvar
1298
-
1274
+
1299
1275
if phase == "G" :
1300
1276
transformer .train ()
1301
1277
disc .eval ()
@@ -1322,8 +1298,8 @@ def model_wrapper(scaled_x_t, t):
1322
1298
1323
1299
v_x = torch .cos (t ) * torch .sin (t ) * dxt_dt / sigma_data
1324
1300
v_t = torch .cos (t ) * torch .sin (t )
1325
-
1326
-
1301
+
1302
+
1327
1303
# Adapt from https://github.com/xandergos/sCM-mnist/blob/master/train_consistency.py
1328
1304
with torch .no_grad ():
1329
1305
F_theta , F_theta_grad , logvar = torch .func .jvp (
@@ -1371,8 +1347,8 @@ def model_wrapper(scaled_x_t, t):
1371
1347
loss_no_logvar = loss_no_logvar .mean ()
1372
1348
loss_no_weight = l2_loss .mean ()
1373
1349
g_norm = g_norm .mean ()
1374
-
1375
-
1350
+
1351
+
1376
1352
pred_x_0 = torch .cos (t ) * noisy_model_input - torch .sin (t ) * F_theta * sigma_data
1377
1353
1378
1354
if args .train_largest_timestep :
@@ -1414,7 +1390,7 @@ def model_wrapper(scaled_x_t, t):
1414
1390
# Add noise to predicted x0
1415
1391
z_D = torch .randn_like (model_input ) * sigma_data
1416
1392
noised_predicted_x0 = torch .cos (t_D ) * pred_x_0 + torch .sin (t_D ) * z_D
1417
-
1393
+
1418
1394
1419
1395
# Calculate adversarial loss
1420
1396
pred_fake = disc (hidden_states = (noised_predicted_x0 / sigma_data ), timestep = t_D .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask )
@@ -1445,7 +1421,7 @@ def model_wrapper(scaled_x_t, t):
1445
1421
optimizer_G .step ()
1446
1422
lr_scheduler .step ()
1447
1423
optimizer_G .zero_grad (set_to_none = True )
1448
-
1424
+
1449
1425
elif phase == "D" :
1450
1426
transformer .eval ()
1451
1427
disc .train ()
@@ -1515,7 +1491,7 @@ def model_wrapper(scaled_x_t, t):
1515
1491
1516
1492
1517
1493
# Calculate D loss
1518
-
1494
+
1519
1495
pred_fake = disc (hidden_states = (noised_predicted_x0 / sigma_data ), timestep = t_D_fake .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask )
1520
1496
pred_true = disc (hidden_states = (noised_latents / sigma_data ), timestep = t_D_real .flatten (), encoder_hidden_states = prompt_embeds , encoder_attention_mask = prompt_attention_mask )
1521
1497
@@ -1542,7 +1518,7 @@ def model_wrapper(scaled_x_t, t):
1542
1518
1543
1519
optimizer_D .step ()
1544
1520
optimizer_D .zero_grad (set_to_none = True )
1545
-
1521
+
1546
1522
1547
1523
# Checks if the accelerator has performed an optimization step behind the scenes
1548
1524
if accelerator .sync_gradients :
@@ -1616,14 +1592,14 @@ def model_wrapper(scaled_x_t, t):
1616
1592
transformer .to (torch .float32 )
1617
1593
else :
1618
1594
transformer = transformer .to (weight_dtype )
1619
-
1595
+
1620
1596
# Save discriminator heads
1621
1597
disc = unwrap_model (disc )
1622
1598
disc_heads_state_dict = disc .heads .state_dict ()
1623
-
1599
+
1624
1600
# Save transformer model
1625
1601
transformer .save_pretrained (os .path .join (args .output_dir , "transformer" ))
1626
-
1602
+
1627
1603
# Save discriminator heads
1628
1604
torch .save (disc_heads_state_dict , os .path .join (args .output_dir , "disc_heads.pt" ))
1629
1605
@@ -1677,4 +1653,4 @@ def model_wrapper(scaled_x_t, t):
1677
1653
1678
1654
if __name__ == "__main__" :
1679
1655
args = parse_args ()
1680
- main (args )
1656
+ main (args )
0 commit comments