Skip to content

Commit 6ea8122

Browse files
stduhpfleejet
andauthored
feat: add flux 1 lite 8B (freepik) support (#474)
* Flux Lite (Freepik) support * format code --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent 9b1d90b commit 6ea8122

File tree

7 files changed

+38
-24
lines changed

7 files changed

+38
-24
lines changed

clip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ class CLIPTextModel : public GGMLBlock {
712712
auto text_projection = params["text_projection"];
713713
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
714714
if (text_projection != NULL) {
715-
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
715+
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
716716
} else {
717717
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
718718
}

conditioner.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ struct SD3CLIPEmbedder : public Conditioner {
798798
}
799799

800800
if (chunk_idx == 0) {
801-
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
801+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
802802
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
803803
clip_l->compute(n_threads,
804804
input_ids,
@@ -808,7 +808,6 @@ struct SD3CLIPEmbedder : public Conditioner {
808808
true,
809809
&pooled_l,
810810
work_ctx);
811-
812811
}
813812
}
814813

@@ -848,7 +847,7 @@ struct SD3CLIPEmbedder : public Conditioner {
848847
}
849848

850849
if (chunk_idx == 0) {
851-
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
850+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
852851
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
853852
clip_g->compute(n_threads,
854853
input_ids,
@@ -858,7 +857,6 @@ struct SD3CLIPEmbedder : public Conditioner {
858857
true,
859858
&pooled_g,
860859
work_ctx);
861-
862860
}
863861
}
864862

@@ -1096,9 +1094,9 @@ struct FluxCLIPEmbedder : public Conditioner {
10961094
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
10971095
size_t max_token_idx = 0;
10981096

1099-
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1097+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
11001098
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1101-
1099+
11021100
clip_l->compute(n_threads,
11031101
input_ids,
11041102
0,
@@ -1107,7 +1105,6 @@ struct FluxCLIPEmbedder : public Conditioner {
11071105
true,
11081106
&pooled,
11091107
work_ctx);
1110-
11111108
}
11121109

11131110
// t5

flux.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,9 @@ namespace Flux {
822822
if (version == VERSION_FLUX_SCHNELL) {
823823
flux_params.guidance_embed = false;
824824
}
825+
if (version == VERSION_FLUX_LITE) {
826+
flux_params.depth = 8;
827+
}
825828
flux = Flux(flux_params);
826829
flux.init(params_ctx, wtype);
827830
}

model.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,15 +1364,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
13641364

13651365
SDVersion ModelLoader::get_sd_version() {
13661366
TensorStorage token_embedding_weight;
1367-
bool is_flux = false;
1368-
bool is_sd3 = false;
1367+
bool is_flux = false;
1368+
bool is_schnell = true;
1369+
bool is_lite = true;
1370+
bool is_sd3 = false;
13691371
for (auto& tensor_storage : tensor_storages) {
13701372
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
1371-
return VERSION_FLUX_DEV;
1373+
is_schnell = false;
13721374
}
13731375
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
13741376
is_flux = true;
13751377
}
1378+
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
1379+
is_lite = false;
1380+
}
13761381
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
13771382
return VERSION_SD3_5_2B;
13781383
}
@@ -1403,7 +1408,14 @@ SDVersion ModelLoader::get_sd_version() {
14031408
}
14041409
}
14051410
if (is_flux) {
1406-
return VERSION_FLUX_SCHNELL;
1411+
if (is_schnell) {
1412+
GGML_ASSERT(!is_lite);
1413+
return VERSION_FLUX_SCHNELL;
1414+
} else if (is_lite) {
1415+
return VERSION_FLUX_LITE;
1416+
} else {
1417+
return VERSION_FLUX_DEV;
1418+
}
14071419
}
14081420
if (is_sd3) {
14091421
return VERSION_SD3_2B;

model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ enum SDVersion {
2727
VERSION_FLUX_SCHNELL,
2828
VERSION_SD3_5_8B,
2929
VERSION_SD3_5_2B,
30+
VERSION_FLUX_LITE,
3031
VERSION_COUNT,
3132
};
3233

stable-diffusion.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ const char* model_version_to_str[] = {
3333
"Flux Dev",
3434
"Flux Schnell",
3535
"SD3.5 8B",
36-
"SD3.5 2B"};
36+
"SD3.5 2B",
37+
"Flux Lite 8B"};
3738

3839
const char* sampling_methods_str[] = {
3940
"Euler A",
@@ -291,7 +292,7 @@ class StableDiffusionGGML {
291292
}
292293
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
293294
scale_factor = 1.5305f;
294-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
295+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
295296
scale_factor = 0.3611;
296297
// TODO: shift_factor
297298
}
@@ -312,7 +313,7 @@ class StableDiffusionGGML {
312313
} else {
313314
clip_backend = backend;
314315
bool use_t5xxl = false;
315-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
316+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
316317
use_t5xxl = true;
317318
}
318319
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
@@ -326,7 +327,7 @@ class StableDiffusionGGML {
326327
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
327328
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
328329
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
329-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
330+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
330331
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
331332
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
332333
} else {
@@ -524,7 +525,7 @@ class StableDiffusionGGML {
524525
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
525526
LOG_INFO("running in FLOW mode");
526527
denoiser = std::make_shared<DiscreteFlowDenoiser>();
527-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
528+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
528529
LOG_INFO("running in Flux FLOW mode");
529530
float shift = 1.15f;
530531
if (version == VERSION_FLUX_SCHNELL) {
@@ -991,7 +992,7 @@ class StableDiffusionGGML {
991992
} else {
992993
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
993994
C = 32;
994-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
995+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
995996
C = 32;
996997
}
997998
}
@@ -1328,7 +1329,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13281329
int C = 4;
13291330
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
13301331
C = 16;
1331-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1332+
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
13321333
C = 16;
13331334
}
13341335
int W = width / 8;
@@ -1450,7 +1451,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14501451
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
14511452
params.mem_size *= 3;
14521453
}
1453-
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1454+
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
14541455
params.mem_size *= 4;
14551456
}
14561457
if (sd_ctx->sd->stacked_id) {
@@ -1475,15 +1476,15 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14751476
int C = 4;
14761477
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
14771478
C = 16;
1478-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1479+
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
14791480
C = 16;
14801481
}
14811482
int W = width / 8;
14821483
int H = height / 8;
14831484
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
14841485
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
14851486
ggml_set_f32(init_latent, 0.0609f);
1486-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1487+
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
14871488
ggml_set_f32(init_latent, 0.1159f);
14881489
} else {
14891490
ggml_set_f32(init_latent, 0.f);
@@ -1553,7 +1554,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
15531554
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
15541555
params.mem_size *= 2;
15551556
}
1556-
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
1557+
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
15571558
params.mem_size *= 3;
15581559
}
15591560
if (sd_ctx->sd->stacked_id) {

vae.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock {
457457
bool use_video_decoder = false,
458458
SDVersion version = VERSION_SD1)
459459
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
460-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
460+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
461461
dd_config.z_channels = 16;
462462
use_quant = false;
463463
}

0 commit comments

Comments
 (0)