Skip to content

Commit ae83c5a

Browse files
committed
Fix Flex 2 inpaint
1 parent 6f5f9d6 commit ae83c5a

File tree

2 files changed

+35
-15
lines changed

2 files changed

+35
-15
lines changed

ggml_extend.hpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,18 +375,31 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
375375

376376
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
377377
struct ggml_tensor* mask,
378-
struct ggml_tensor* output) {
378+
struct ggml_tensor* output,
379+
float masked_value = 0.5f) {
379380
int64_t width = output->ne[0];
380381
int64_t height = output->ne[1];
381382
int64_t channels = output->ne[2];
383+
for (int ix = 0; ix < mask->ne[0]; ix++) {
384+
for (int iy = 0; iy < mask->ne[1]; iy++) {
385+
float m = ggml_tensor_get_f32(mask, ix, iy);
386+
m = round(m); // inpaint models need binary masks
387+
ggml_tensor_set_f32(mask, m, ix, iy);
388+
}
389+
}
390+
float rescale_mx = mask->ne[0]/output->ne[0];
391+
float rescale_my = mask->ne[1]/output->ne[1];
382392
GGML_ASSERT(output->type == GGML_TYPE_F32);
383393
for (int ix = 0; ix < width; ix++) {
384394
for (int iy = 0; iy < height; iy++) {
385-
float m = ggml_tensor_get_f32(mask, ix, iy);
395+
int mx = (int)(ix * rescale_mx);
396+
int my = (int)(iy * rescale_my);
397+
float m = ggml_tensor_get_f32(mask, mx, my);
386398
m = round(m); // inpaint models need binary masks
387-
ggml_tensor_set_f32(mask, m, ix, iy);
399+
ggml_tensor_set_f32(mask, m, mx, my);
388400
for (int k = 0; k < channels; k++) {
389-
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
401+
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
402+
value = (1 - m) * (value - masked_value) + masked_value;
390403
ggml_tensor_set_f32(output, value, ix, iy, k);
391404
}
392405
}

stable-diffusion.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,10 +1477,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14771477
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
14781478

14791479
struct ggml_tensor* control_latent = NULL;
1480-
if(sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL){
1480+
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
14811481
if (!sd_ctx->sd->use_tiny_autoencoder) {
14821482
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1483-
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1483+
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
14841484
} else {
14851485
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
14861486
}
@@ -1560,7 +1560,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15601560
}
15611561
LOG_INFO("HERE");
15621562

1563-
cond.c_concat = concat_latent;
1563+
cond.c_concat = concat_latent;
15641564
}
15651565

15661566
for (int b = 0; b < batch_count; b++) {
@@ -1827,16 +1827,23 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18271827
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
18281828
mask_channels = 1 + init_latent->ne[2];
18291829
}
1830-
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
1831-
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1832-
sd_image_to_tensor(init_image.data, init_img);
1833-
sd_apply_mask(init_img, mask_img, masked_img);
18341830
ggml_tensor* masked_latent_0 = NULL;
1835-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1836-
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1837-
masked_latent_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1831+
if (sd_ctx->sd->version != VERSION_FLEX_2) {
1832+
// most inpaint models mask before vae
1833+
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
1834+
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1835+
sd_image_to_tensor(init_image.data, init_img);
1836+
sd_apply_mask(init_img, mask_img, masked_img);
1837+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1838+
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1839+
masked_latent_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1840+
} else {
1841+
masked_latent_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1842+
}
18381843
} else {
1839-
masked_latent_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1844+
// mask after vae
1845+
masked_latent_0 = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
1846+
sd_apply_mask(init_latent, mask_img, masked_latent_0, 0.);
18401847
}
18411848
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent_0->ne[0], masked_latent_0->ne[1], mask_channels + masked_latent_0->ne[2], 1);
18421849
for (int ix = 0; ix < masked_latent_0->ne[0]; ix++) {

0 commit comments

Comments
 (0)