Skip to content

Commit 835180f

Browse files
committed
Support Flex-2
1 parent a697518 commit 835180f

File tree

6 files changed

+128
-33
lines changed

6 files changed

+128
-33
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ int main(int argc, const char* argv[]) {
933933
}
934934

935935
sd_image_t* control_image = NULL;
936-
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
936+
if (params.control_image_path.size() > 0) {
937937
int c = 0;
938938
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
939939
if (control_image_buffer == NULL) {

flux.hpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,8 @@ namespace Flux {
793793
struct ggml_tensor* y,
794794
struct ggml_tensor* guidance,
795795
struct ggml_tensor* pe,
796-
std::vector<int> skip_layers = std::vector<int>()) {
796+
std::vector<int> skip_layers = std::vector<int>(),
797+
SDVersion version = VERSION_FLUX) {
797798
// Forward pass of DiT.
798799
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
799800
// timestep: (N,) tensor of diffusion timesteps
@@ -817,7 +818,8 @@ namespace Flux {
817818
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
818819
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
819820

820-
if (c_concat != NULL) {
821+
if (version == VERSION_FLUX_FILL) {
822+
GGML_ASSERT(c_concat != NULL);
821823
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
822824
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
823825

@@ -828,6 +830,21 @@ namespace Flux {
828830
mask = patchify(ctx, mask, patch_size);
829831

830832
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
833+
} else if (version == VERSION_FLEX_2) {
834+
GGML_ASSERT(c_concat != NULL);
835+
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
836+
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
837+
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
838+
839+
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
840+
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
841+
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
842+
843+
masked = patchify(ctx, masked, patch_size);
844+
mask = patchify(ctx, mask, patch_size);
845+
control = patchify(ctx, control, patch_size);
846+
847+
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
831848
}
832849

833850
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
@@ -846,19 +863,22 @@ namespace Flux {
846863
FluxParams flux_params;
847864
Flux flux;
848865
std::vector<float> pe_vec; // for cache
866+
SDVersion version;
849867

850868
FluxRunner(ggml_backend_t backend,
851869
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
852870
const std::string prefix = "",
853871
SDVersion version = VERSION_FLUX,
854872
bool flash_attn = false)
855-
: GGMLRunner(backend) {
873+
: GGMLRunner(backend), version(version) {
856874
flux_params.flash_attn = flash_attn;
857875
flux_params.guidance_embed = false;
858876
flux_params.depth = 0;
859877
flux_params.depth_single_blocks = 0;
860878
if (version == VERSION_FLUX_FILL) {
861879
flux_params.in_channels = 384;
880+
} else if (version == VERSION_FLEX_2) {
881+
flux_params.in_channels = 196;
862882
}
863883
for (auto pair : tensor_types) {
864884
std::string tensor_name = pair.first;
@@ -941,7 +961,8 @@ namespace Flux {
941961
y,
942962
guidance,
943963
pe,
944-
skip_layers);
964+
skip_layers,
965+
version);
945966

946967
ggml_build_forward_expand(gf, out);
947968

model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,9 @@ SDVersion ModelLoader::get_sd_version() {
15551555
if (is_inpaint) {
15561556
return VERSION_FLUX_FILL;
15571557
}
1558+
if(input_block_weight.ne[0] == 196){
1559+
return VERSION_FLEX_2;
1560+
}
15581561
return VERSION_FLUX;
15591562
}
15601563

model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLEX_2,
3435
VERSION_COUNT,
3536
};
3637

3738
static inline bool sd_version_is_flux(SDVersion version) {
38-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
39+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
3940
return true;
4041
}
4142
return false;
@@ -70,7 +71,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
7071
}
7172

7273
static inline bool sd_version_is_inpaint(SDVersion version) {
73-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
74+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
7475
return true;
7576
}
7677
return false;

stable-diffusion.cpp

Lines changed: 95 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class StableDiffusionGGML {
9595
std::shared_ptr<DiffusionModel> diffusion_model;
9696
std::shared_ptr<AutoEncoderKL> first_stage_model;
9797
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
98-
std::shared_ptr<ControlNet> control_net;
98+
std::shared_ptr<ControlNet> control_net = NULL;
9999
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
100100
std::shared_ptr<LoraModel> pmid_lora;
101101
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -301,6 +301,11 @@ class StableDiffusionGGML {
301301
// TODO: shift_factor
302302
}
303303

304+
if(version == VERSION_FLEX_2){
305+
// Might need vae encode for control cond
306+
vae_decode_only = false;
307+
}
308+
304309
if (version == VERSION_SVD) {
305310
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types);
306311
clip_vision->alloc_params_buffer();
@@ -898,7 +903,7 @@ class StableDiffusionGGML {
898903

899904
std::vector<struct ggml_tensor*> controls;
900905

901-
if (control_hint != NULL) {
906+
if (control_hint != NULL && control_net != NULL) {
902907
control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector);
903908
controls = control_net->controls;
904909
// print_ggml_tensor(controls[12]);
@@ -935,7 +940,7 @@ class StableDiffusionGGML {
935940
float* negative_data = NULL;
936941
if (has_unconditioned) {
937942
// uncond
938-
if (control_hint != NULL) {
943+
if (control_hint != NULL && control_net != NULL) {
939944
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
940945
controls = control_net->controls;
941946
}
@@ -1283,7 +1288,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
12831288
float style_ratio,
12841289
bool normalize_input,
12851290
std::string input_id_images_path,
1286-
ggml_tensor* masked_latent = NULL) {
1291+
ggml_tensor* concat_latent = NULL) {
12871292
if (seed < 0) {
12881293
// Generally, when using the provided command line, the seed is always >0.
12891294
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1475,6 +1480,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14751480
int64_t mask_channels = 1;
14761481
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
14771482
mask_channels = 8 * 8; // flatten the whole mask
1483+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1484+
mask_channels = 1 + init_latent->ne[2];
14781485
}
14791486
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
14801487
// no mask, set the whole image as masked
@@ -1488,6 +1495,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14881495
for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) {
14891496
ggml_tensor_set_f32(empty_latent, 1, x, y, c);
14901497
}
1498+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1499+
for (int64_t c = 0; c < empty_latent->ne[2]; c++) {
1500+
// 0x16,1x1,0x16
1501+
ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c);
1502+
}
14911503
} else {
14921504
ggml_tensor_set_f32(empty_latent, 1, x, y, 0);
14931505
for (int64_t c = 1; c < empty_latent->ne[2]; c++) {
@@ -1496,19 +1508,48 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14961508
}
14971509
}
14981510
}
1499-
if (masked_latent == NULL) {
1500-
masked_latent = empty_latent;
1511+
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1512+
bool no_inpaint = concat_latent == NULL;
1513+
if (no_inpaint) {
1514+
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
1515+
}
1516+
// fill in the control image here
1517+
struct ggml_tensor* control_latents = NULL;
1518+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1519+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1520+
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1521+
} else {
1522+
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1523+
}
1524+
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1525+
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1526+
if (no_inpaint) {
1527+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1528+
// 0x16,1x1,0x16
1529+
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
1530+
}
1531+
}
1532+
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1533+
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1534+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1535+
}
1536+
}
1537+
}
1538+
// Disable controlnet
1539+
image_hint = NULL;
1540+
} else if (concat_latent == NULL) {
1541+
concat_latent = empty_latent;
15011542
}
1502-
cond.c_concat = masked_latent;
1543+
cond.c_concat = concat_latent;
15031544
uncond.c_concat = empty_latent;
1504-
// noise_mask = masked_latent;
1545+
// noise_mask = concat_latent;
15051546
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
1506-
cond.c_concat = masked_latent;
1507-
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], masked_latent->ne[1], masked_latent->ne[2], masked_latent->ne[3]);
1547+
cond.c_concat = concat_latent;
1548+
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, concat_latent->ne[0], concat_latent->ne[1], concat_latent->ne[2], concat_latent->ne[3]);
15081549
ggml_set_f32(empty_latent, 0);
15091550
uncond.c_concat = empty_latent;
15101551
} else {
1511-
noise_mask = masked_latent;
1552+
noise_mask = concat_latent;
15121553
}
15131554

15141555
for (int b = 0; b < batch_count; b++) {
@@ -1756,7 +1797,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17561797

17571798
sd_image_to_tensor(init_image.data, init_img);
17581799

1759-
ggml_tensor* masked_latent;
1800+
ggml_tensor* concat_latent;
17601801

17611802
ggml_tensor* init_latent = NULL;
17621803
ggml_tensor* init_moments = NULL;
@@ -1771,6 +1812,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17711812
int64_t mask_channels = 1;
17721813
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
17731814
mask_channels = 8 * 8; // flatten the whole mask
1815+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1816+
mask_channels = 1 + init_latent->ne[2];
17741817
}
17751818
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
17761819
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
@@ -1783,56 +1826,82 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17831826
} else {
17841827
masked_latent_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
17851828
}
1786-
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent_0->ne[0], masked_latent_0->ne[1], mask_channels + masked_latent_0->ne[2], 1);
1829+
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent_0->ne[0], masked_latent_0->ne[1], mask_channels + masked_latent_0->ne[2], 1);
17871830
for (int ix = 0; ix < masked_latent_0->ne[0]; ix++) {
17881831
for (int iy = 0; iy < masked_latent_0->ne[1]; iy++) {
17891832
int mx = ix * 8;
17901833
int my = iy * 8;
17911834
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
17921835
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
17931836
float v = ggml_tensor_get_f32(masked_latent_0, ix, iy, k);
1794-
ggml_tensor_set_f32(masked_latent, v, ix, iy, k);
1837+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
17951838
}
17961839
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
17971840
for (int x = 0; x < 8; x++) {
17981841
for (int y = 0; y < 8; y++) {
17991842
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
18001843
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
18011844
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
1802-
ggml_tensor_set_f32(masked_latent, m, ix, iy, masked_latent_0->ne[2] + x * 8 + y);
1845+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent_0->ne[2] + x * 8 + y);
18031846
}
18041847
}
1848+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1849+
float m = ggml_tensor_get_f32(mask_img, mx, my);
1850+
// masked image
1851+
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
1852+
float v = ggml_tensor_get_f32(masked_latent_0, ix, iy, k);
1853+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
1854+
}
1855+
// downsampled mask
1856+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent_0->ne[2]);
1857+
// control (todo: support this)
1858+
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
1859+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent_0->ne[2] + 1 + k);
1860+
}
1861+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1862+
float m = ggml_tensor_get_f32(mask_img, mx, my);
1863+
// masked image
1864+
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
1865+
float v = ggml_tensor_get_f32(masked_latent_0, ix, iy, k);
1866+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
1867+
}
1868+
// downsampled mask
1869+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent_0->ne[2]);
1870+
// control (todo: support this)
1871+
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
1872+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent_0->ne[2] + 1 + k);
1873+
}
18051874
} else {
18061875
float m = ggml_tensor_get_f32(mask_img, mx, my);
1807-
ggml_tensor_set_f32(masked_latent, m, ix, iy, 0);
1876+
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);
18081877
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
18091878
float v = ggml_tensor_get_f32(masked_latent_0, ix, iy, k);
1810-
ggml_tensor_set_f32(masked_latent, v, ix, iy, k + mask_channels);
1879+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels);
18111880
}
18121881
}
18131882
}
18141883
}
18151884
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
1816-
// Not actually masked, we're just highjacking the masked_latent variable since it will be used the same way
1885+
// Not actually masked, we're just highjacking the concat_latent variable since it will be used the same way
18171886
if (!sd_ctx->sd->use_tiny_autoencoder) {
18181887
if (sd_ctx->sd->is_using_edm_v_parameterization) {
18191888
// for CosXL edit
1820-
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
1889+
concat_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
18211890
} else {
1822-
masked_latent = sd_ctx->sd->get_first_stage_encoding_mode(work_ctx, init_moments);
1891+
concat_latent = sd_ctx->sd->get_first_stage_encoding_mode(work_ctx, init_moments);
18231892
}
18241893
} else {
1825-
masked_latent = init_latent;
1894+
concat_latent = init_latent;
18261895
}
18271896
} else {
18281897
// LOG_WARN("Inpainting with a base model is not great");
1829-
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
1830-
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {
1831-
for (int iy = 0; iy < masked_latent->ne[1]; iy++) {
1898+
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
1899+
for (int ix = 0; ix < concat_latent->ne[0]; ix++) {
1900+
for (int iy = 0; iy < concat_latent->ne[1]; iy++) {
18321901
int mx = ix * 8;
18331902
int my = iy * 8;
18341903
float m = ggml_tensor_get_f32(mask_img, mx, my);
1835-
ggml_tensor_set_f32(masked_latent, m, ix, iy);
1904+
ggml_tensor_set_f32(concat_latent, m, ix, iy);
18361905
}
18371906
}
18381907
}
@@ -1868,7 +1937,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18681937
style_ratio,
18691938
normalize_input,
18701939
input_id_images_path_c_str,
1871-
masked_latent);
1940+
concat_latent);
18721941

18731942
size_t t2 = ggml_time_ms();
18741943

vae.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ struct AutoEncoderKL : public GGMLRunner {
559559
bool decode_graph,
560560
struct ggml_tensor** output,
561561
struct ggml_context* output_ctx = NULL) {
562+
GGML_ASSERT(!decode_only || decode_graph);
562563
auto get_graph = [&]() -> struct ggml_cgraph* {
563564
return build_graph(z, decode_graph);
564565
};

0 commit comments

Comments
 (0)