Skip to content

Commit 9dd1ff8

Browse files
committed
support for flux controls
1 parent 42b8fe8 commit 9dd1ff8

File tree

4 files changed

+60
-29
lines changed

4 files changed

+60
-29
lines changed

flux.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,14 @@ namespace Flux {
845845
control = patchify(ctx, control, patch_size);
846846

847847
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
848+
} else if (version == VERSION_FLUX_CONTROLS) {
849+
GGML_ASSERT(c_concat != NULL);
850+
851+
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
852+
853+
control = patchify(ctx, control, patch_size);
854+
855+
img = ggml_concat(ctx, img, control, 0);
848856
}
849857

850858
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
@@ -877,6 +885,8 @@ namespace Flux {
877885
flux_params.depth_single_blocks = 0;
878886
if (version == VERSION_FLUX_FILL) {
879887
flux_params.in_channels = 384;
888+
} else if (version == VERSION_FLUX_CONTROLS) {
889+
flux_params.in_channels = 128;
880890
} else if (version == VERSION_FLEX_2) {
881891
flux_params.in_channels = 196;
882892
}

model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,10 +1551,12 @@ SDVersion ModelLoader::get_sd_version() {
15511551
}
15521552

15531553
if (is_flux) {
1554-
is_inpaint = input_block_weight.ne[0] == 384;
1555-
if (is_inpaint) {
1554+
if (input_block_weight.ne[0] == 384) {
15561555
return VERSION_FLUX_FILL;
15571556
}
1557+
if (input_block_weight.ne[0] == 128) {
1558+
return VERSION_FLUX_CONTROLS;
1559+
}
15581560
if(input_block_weight.ne[0] == 196){
15591561
return VERSION_FLEX_2;
15601562
}

model.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLUX_CONTROLS,
3435
VERSION_FLEX_2,
3536
VERSION_COUNT,
3637
};
3738

3839
static inline bool sd_version_is_flux(SDVersion version) {
39-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
40+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) {
4041
return true;
4142
}
4243
return false;
@@ -70,15 +71,16 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
7071
return false;
7172
}
7273

73-
static inline bool sd_version_is_inpaint(SDVersion version) {
74-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
74+
75+
static inline bool sd_version_is_dit(SDVersion version) {
76+
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
7577
return true;
7678
}
7779
return false;
7880
}
7981

80-
static inline bool sd_version_is_dit(SDVersion version) {
81-
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
82+
static inline bool sd_version_is_inpaint(SDVersion version) {
83+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
8284
return true;
8385
}
8486
return false;
@@ -88,8 +90,12 @@ static inline bool sd_version_is_edit(SDVersion version) {
8890
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
8991
}
9092

93+
static inline bool sd_version_is_control(SDVersion version) {
94+
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
95+
}
96+
9197
static bool sd_version_use_concat(SDVersion version) {
92-
return sd_version_is_edit(version) || sd_version_is_inpaint(version);
98+
return sd_version_is_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version);
9399
}
94100

95101
enum PMVersion {

stable-diffusion.cpp

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class StableDiffusionGGML {
301301
// TODO: shift_factor
302302
}
303303

304-
if(version == VERSION_FLEX_2){
304+
if (sd_version_is_control(version)) {
305305
// Might need vae encode for control cond
306306
vae_decode_only = false;
307307
}
@@ -1476,6 +1476,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14761476
int H = height / 8;
14771477
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
14781478
ggml_tensor* noise_mask = nullptr;
1479+
1480+
struct ggml_tensor* control_latent = NULL;
1481+
if(sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL){
1482+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1483+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1484+
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1485+
} else {
1486+
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1487+
}
1488+
}
1489+
14791490
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
14801491
int64_t mask_channels = 1;
14811492
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
@@ -1508,46 +1519,48 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15081519
}
15091520
}
15101521
}
1511-
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1522+
1523+
if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
15121524
bool no_inpaint = concat_latent == NULL;
15131525
if (no_inpaint) {
15141526
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
15151527
}
15161528
// fill in the control image here
1517-
struct ggml_tensor* control_latents = NULL;
1518-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1519-
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1520-
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1521-
} else {
1522-
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1523-
}
1524-
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1525-
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1529+
for (int64_t x = 0; x < control_latent->ne[0]; x++) {
1530+
for (int64_t y = 0; y < control_latent->ne[1]; y++) {
15261531
if (no_inpaint) {
1527-
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1532+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
15281533
// 0x16,1x1,0x16
15291534
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
15301535
}
15311536
}
1532-
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1533-
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1534-
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1537+
for (int64_t c = 0; c < control_latent->ne[2]; c++) {
1538+
float v = ggml_tensor_get_f32(control_latent, x, y, c);
1539+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
15351540
}
15361541
}
15371542
}
1538-
// Disable controlnet
1539-
image_hint = NULL;
15401543
} else if (concat_latent == NULL) {
15411544
concat_latent = empty_latent;
15421545
}
15431546
cond.c_concat = concat_latent;
15441547
uncond.c_concat = empty_latent;
1545-
// noise_mask = concat_latent;
1546-
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
1547-
cond.c_concat = concat_latent;
1548-
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, concat_latent->ne[0], concat_latent->ne[1], concat_latent->ne[2], concat_latent->ne[3]);
1548+
noise_mask = NULL;
1549+
} else if (sd_version_is_edit(sd_ctx->sd->version) || sd_version_is_control(sd_ctx->sd->version)) {
1550+
LOG_INFO("HERE");
1551+
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], init_latent->ne[3]);
1552+
LOG_INFO("HERE");
15491553
ggml_set_f32(empty_latent, 0);
15501554
uncond.c_concat = empty_latent;
1555+
if (sd_version_is_control(sd_ctx->sd->version) && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
1556+
concat_latent = control_latent;
1557+
}
1558+
if (concat_latent == NULL) {
1559+
concat_latent = empty_latent;
1560+
}
1561+
LOG_INFO("HERE");
1562+
1563+
cond.c_concat = concat_latent;
15511564
} else {
15521565
noise_mask = concat_latent;
15531566
}

0 commit comments

Comments
 (0)