Skip to content

Commit 9cb050b

Browse files
committed
make style && make quality;
1 parent acefec8 commit 9cb050b

File tree

2 files changed

+39
-63
lines changed

2 files changed

+39
-63
lines changed

examples/research_projects/sana/train_sana_sprint_diffusers.py

Lines changed: 37 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,72 +15,48 @@
1515
# See the License for the specific language governing permissions and
1616

1717
import argparse
18-
import copy
19-
import itertools
20-
import json
18+
import io
2119
import logging
2220
import math
2321
import os
24-
import random
2522
import shutil
26-
import warnings
2723
from pathlib import Path
24+
from typing import Callable
2825

2926
import accelerate
30-
import io
3127
import numpy as np
3228
import torch
3329
import torch.nn as nn
3430
import torch.nn.functional as F
3531
import torch.utils.checkpoint
3632
import torchvision.transforms as T
37-
import torchvision.transforms.functional as TF
3833
import transformers
39-
import webdataset as wds
4034
from accelerate import Accelerator
4135
from accelerate.logging import get_logger
4236
from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
43-
from braceexpand import braceexpand
4437
from datasets import load_dataset
4538
from huggingface_hub import create_repo, upload_folder
46-
from huggingface_hub.utils import insecure_hashlib
4739
from packaging import version
48-
from peft.utils import get_peft_model_state_dict
4940
from PIL import Image
50-
from PIL.ImageOps import exif_transpose
5141
from safetensors.torch import load_file
5242
from torch.nn.utils.spectral_norm import SpectralNorm
53-
from torch.utils.data import default_collate, Dataset, DataLoader
54-
from torchvision.transforms.functional import crop
43+
from torch.utils.data import DataLoader, Dataset
5544
from tqdm.auto import tqdm
5645
from transformers import AutoTokenizer, Gemma2Model
57-
from typing import Callable, List, Union
58-
from webdataset.tariterators import (
59-
base_plus_ext,
60-
tar_file_expander,
61-
url_opener,
62-
valid_sample,
63-
)
6446

6547
import diffusers
6648
from diffusers import (
6749
AutoencoderDC,
68-
FlowMatchEulerDiscreteScheduler,
6950
SanaPipeline,
7051
SanaSprintPipeline,
7152
SanaTransformer2DModel,
72-
SCMScheduler,
7353
)
7454
from diffusers.optimization import get_scheduler
7555
from diffusers.training_utils import (
76-
cast_training_params,
77-
compute_density_for_timestep_sampling,
78-
compute_loss_weighting_for_sd3,
7956
free_memory,
8057
)
8158
from diffusers.utils import (
8259
check_min_version,
83-
convert_unet_state_dict_to_peft,
8460
is_wandb_available,
8561
)
8662
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
@@ -98,7 +74,7 @@
9874

9975
if is_torch_npu_available():
10076
torch.npu.config.allow_internal_format = False
101-
77+
10278
COMPLEX_HUMAN_INSTRUCTION = [
10379
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
10480
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
@@ -109,7 +85,7 @@
10985
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
11086
"User Prompt: ",
11187
]
112-
88+
11389

11490

11591
class Text2ImageDataset(Dataset):
@@ -140,17 +116,17 @@ def __init__(self, hf_dataset, resolution=1024):
140116

141117
def __len__(self):
142118
return len(self.dataset)
143-
119+
144120
def __getitem__(self, idx):
145121
item = self.dataset[idx]
146122
text = item['llava']
147123
image_bytes = item['image']
148-
124+
149125
# Convert bytes to PIL Image
150126
image = Image.open(io.BytesIO(image_bytes))
151-
127+
152128
image_tensor = self.transform(image)
153-
129+
154130
return {
155131
'text': text,
156132
'image': image_tensor
@@ -768,7 +744,7 @@ def state_dict(self):
768744

769745
def __getattr__(self, name):
770746
return getattr(self.disc, name)
771-
747+
772748
class SanaTrigFlow(SanaTransformer2DModel):
773749
def __init__(self, original_model, guidance=False):
774750
self.__dict__ = original_model.__dict__
@@ -779,7 +755,7 @@ def __init__(self, original_model, guidance=False):
779755
self.logvar_linear = torch.nn.Linear(hidden_size, 1)
780756
torch.nn.init.xavier_uniform_(self.logvar_linear.weight)
781757
torch.nn.init.constant_(self.logvar_linear.bias, 0)
782-
758+
783759
def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None, jvp=False, return_logvar=False, **kwargs):
784760
batch_size = hidden_states.shape[0]
785761
latents = hidden_states
@@ -812,8 +788,8 @@ def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None,
812788
trigflow_model_out = ((1 - 2 * flow_timestep_expanded) * latent_model_input + (1 - 2 * flow_timestep_expanded + 2 * flow_timestep_expanded**2) * model_out) / torch.sqrt(
813789
flow_timestep_expanded**2 + (1 - flow_timestep_expanded) ** 2
814790
)
815-
816-
791+
792+
817793
if self.guidance and guidance is not None:
818794
timestep, embedded_timestep = self.time_embed(
819795
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
@@ -822,15 +798,15 @@ def forward(self, hidden_states, encoder_hidden_states, timestep, guidance=None,
822798
timestep, embedded_timestep = self.time_embed(
823799
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
824800
)
825-
801+
826802
if return_logvar:
827803
logvar = self.logvar_linear(embedded_timestep)
828804
return trigflow_model_out, logvar
829-
805+
830806

831807
return (trigflow_model_out,)
832808

833-
809+
834810

835811
def compute_density_for_timestep_sampling_scm(
836812
batch_size: int, logit_mean: float = None, logit_std: float = None
@@ -925,19 +901,19 @@ def main(args):
925901
revision=args.revision,
926902
variant=args.variant,
927903
)
928-
904+
929905
ori_transformer = SanaTransformer2DModel.from_pretrained(
930906
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant,
931907
guidance_embeds=True, cross_attention_type='vanilla'
932908
)
933-
909+
934910
ori_transformer_no_guide = SanaTransformer2DModel.from_pretrained(
935911
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant,
936912
guidance_embeds=False
937913
)
938-
914+
939915
original_state_dict = load_file(f"{args.pretrained_model_name_or_path}/transformer/diffusion_pytorch_model.safetensors")
940-
916+
941917
param_mapping = {
942918
'time_embed.emb.timestep_embedder.linear_1.weight': 'time_embed.timestep_embedder.linear_1.weight',
943919
'time_embed.emb.timestep_embedder.linear_1.bias': 'time_embed.timestep_embedder.linear_1.bias',
@@ -968,7 +944,7 @@ def main(args):
968944

969945
transformer = SanaTrigFlow(ori_transformer, guidance=True).train()
970946
pretrained_model = SanaTrigFlow(ori_transformer_no_guide, guidance=False).eval()
971-
947+
972948
disc = SanaMSCMDiscriminator(
973949
pretrained_model,
974950
is_multiscale=args.ladd_multi_scale,
@@ -1134,7 +1110,7 @@ def load_model_hook(models, input_dir):
11341110
data_files=args.file_path,
11351111
split='train',
11361112
)
1137-
1113+
11381114
train_dataset = Text2ImageDataset(
11391115
hf_dataset=hf_dataset,
11401116
resolution=args.resolution,
@@ -1282,8 +1258,8 @@ def load_model_hook(models, input_dir):
12821258
# Add noise according to TrigFlow.
12831259
# zt = cos(t) * x + sin(t) * noise
12841260
t = u.view(-1, 1, 1, 1)
1285-
noisy_model_input = torch.cos(t) * model_input + torch.sin(t) * noise
1286-
1261+
noisy_model_input = torch.cos(t) * model_input + torch.sin(t) * noise
1262+
12871263

12881264
scm_cfg_scale = torch.tensor(
12891265
np.random.choice(args.scm_cfg_scale, size=bsz, replace=True),
@@ -1295,7 +1271,7 @@ def model_wrapper(scaled_x_t, t):
12951271
hidden_states=scaled_x_t, timestep=t.flatten(), encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale), jvp=True, return_logvar=True
12961272
)
12971273
return pred, logvar
1298-
1274+
12991275
if phase == "G":
13001276
transformer.train()
13011277
disc.eval()
@@ -1322,8 +1298,8 @@ def model_wrapper(scaled_x_t, t):
13221298

13231299
v_x = torch.cos(t) * torch.sin(t) * dxt_dt / sigma_data
13241300
v_t = torch.cos(t) * torch.sin(t)
1325-
1326-
1301+
1302+
13271303
# Adapt from https://github.com/xandergos/sCM-mnist/blob/master/train_consistency.py
13281304
with torch.no_grad():
13291305
F_theta, F_theta_grad, logvar = torch.func.jvp(
@@ -1371,8 +1347,8 @@ def model_wrapper(scaled_x_t, t):
13711347
loss_no_logvar = loss_no_logvar.mean()
13721348
loss_no_weight = l2_loss.mean()
13731349
g_norm = g_norm.mean()
1374-
1375-
1350+
1351+
13761352
pred_x_0 = torch.cos(t) * noisy_model_input - torch.sin(t) * F_theta * sigma_data
13771353

13781354
if args.train_largest_timestep:
@@ -1414,7 +1390,7 @@ def model_wrapper(scaled_x_t, t):
14141390
# Add noise to predicted x0
14151391
z_D = torch.randn_like(model_input) * sigma_data
14161392
noised_predicted_x0 = torch.cos(t_D) * pred_x_0 + torch.sin(t_D) * z_D
1417-
1393+
14181394

14191395
# Calculate adversarial loss
14201396
pred_fake = disc(hidden_states=(noised_predicted_x0 / sigma_data), timestep=t_D.flatten(), encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask)
@@ -1445,7 +1421,7 @@ def model_wrapper(scaled_x_t, t):
14451421
optimizer_G.step()
14461422
lr_scheduler.step()
14471423
optimizer_G.zero_grad(set_to_none=True)
1448-
1424+
14491425
elif phase == "D":
14501426
transformer.eval()
14511427
disc.train()
@@ -1515,7 +1491,7 @@ def model_wrapper(scaled_x_t, t):
15151491

15161492

15171493
# Calculate D loss
1518-
1494+
15191495
pred_fake = disc(hidden_states=(noised_predicted_x0 / sigma_data), timestep=t_D_fake.flatten(), encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask)
15201496
pred_true = disc(hidden_states=(noised_latents / sigma_data), timestep=t_D_real.flatten(), encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask)
15211497

@@ -1542,7 +1518,7 @@ def model_wrapper(scaled_x_t, t):
15421518

15431519
optimizer_D.step()
15441520
optimizer_D.zero_grad(set_to_none=True)
1545-
1521+
15461522

15471523
# Checks if the accelerator has performed an optimization step behind the scenes
15481524
if accelerator.sync_gradients:
@@ -1616,14 +1592,14 @@ def model_wrapper(scaled_x_t, t):
16161592
transformer.to(torch.float32)
16171593
else:
16181594
transformer = transformer.to(weight_dtype)
1619-
1595+
16201596
# Save discriminator heads
16211597
disc = unwrap_model(disc)
16221598
disc_heads_state_dict = disc.heads.state_dict()
1623-
1599+
16241600
# Save transformer model
16251601
transformer.save_pretrained(os.path.join(args.output_dir, "transformer"))
1626-
1602+
16271603
# Save discriminator heads
16281604
torch.save(disc_heads_state_dict, os.path.join(args.output_dir, "disc_heads.pt"))
16291605

@@ -1677,4 +1653,4 @@ def model_wrapper(scaled_x_t, t):
16771653

16781654
if __name__ == "__main__":
16791655
args = parse_args()
1680-
main(args)
1656+
main(args)

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
1516
from typing import Any, Dict, Optional, Tuple, Union
1617

17-
import math
1818
import torch
1919
import torch.nn.functional as F
2020
from torch import nn
@@ -185,7 +185,7 @@ def __call__(
185185

186186
return hidden_states
187187

188-
188+
189189
class SanaAttnProcessor3_0:
190190
r"""
191191
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).

0 commit comments

Comments
 (0)