@@ -95,7 +95,7 @@ class StableDiffusionGGML {
95
95
std::shared_ptr<DiffusionModel> diffusion_model;
96
96
std::shared_ptr<AutoEncoderKL> first_stage_model;
97
97
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
98
- std::shared_ptr<ControlNet> control_net;
98
+ std::shared_ptr<ControlNet> control_net = NULL ;
99
99
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
100
100
std::shared_ptr<LoraModel> pmid_lora;
101
101
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -301,6 +301,11 @@ class StableDiffusionGGML {
301
301
// TODO: shift_factor
302
302
}
303
303
304
+ if (version == VERSION_FLEX_2){
305
+ // Might need vae encode for control cond
306
+ vae_decode_only = false ;
307
+ }
308
+
304
309
if (version == VERSION_SVD) {
305
310
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types );
306
311
clip_vision->alloc_params_buffer ();
@@ -898,7 +903,7 @@ class StableDiffusionGGML {
898
903
899
904
std::vector<struct ggml_tensor *> controls;
900
905
901
- if (control_hint != NULL ) {
906
+ if (control_hint != NULL && control_net != NULL ) {
902
907
control_net->compute (n_threads, noised_input, control_hint, timesteps, cond.c_crossattn , cond.c_vector );
903
908
controls = control_net->controls ;
904
909
// print_ggml_tensor(controls[12]);
@@ -935,7 +940,7 @@ class StableDiffusionGGML {
935
940
float * negative_data = NULL ;
936
941
if (has_unconditioned) {
937
942
// uncond
938
- if (control_hint != NULL ) {
943
+ if (control_hint != NULL && control_net != NULL ) {
939
944
control_net->compute (n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn , uncond.c_vector );
940
945
controls = control_net->controls ;
941
946
}
@@ -1283,7 +1288,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1283
1288
float style_ratio,
1284
1289
bool normalize_input,
1285
1290
std::string input_id_images_path,
1286
- ggml_tensor* masked_latent = NULL ) {
1291
+ ggml_tensor* concat_latent = NULL ) {
1287
1292
if (seed < 0 ) {
1288
1293
// Generally, when using the provided command line, the seed is always >0.
1289
1294
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1475,6 +1480,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1475
1480
int64_t mask_channels = 1 ;
1476
1481
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1477
1482
mask_channels = 8 * 8 ; // flatten the whole mask
1483
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1484
+ mask_channels = 1 + init_latent->ne [2 ];
1478
1485
}
1479
1486
auto empty_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1480
1487
// no mask, set the whole image as masked
@@ -1488,6 +1495,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1488
1495
for (int64_t c = init_latent->ne [2 ]; c < empty_latent->ne [2 ]; c++) {
1489
1496
ggml_tensor_set_f32 (empty_latent, 1 , x, y, c);
1490
1497
}
1498
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1499
+ for (int64_t c = 0 ; c < empty_latent->ne [2 ]; c++) {
1500
+ // 0x16,1x1,0x16
1501
+ ggml_tensor_set_f32 (empty_latent, c == init_latent->ne [2 ], x, y, c);
1502
+ }
1491
1503
} else {
1492
1504
ggml_tensor_set_f32 (empty_latent, 1 , x, y, 0 );
1493
1505
for (int64_t c = 1 ; c < empty_latent->ne [2 ]; c++) {
@@ -1496,19 +1508,48 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1496
1508
}
1497
1509
}
1498
1510
}
1499
- if (masked_latent == NULL ) {
1500
- masked_latent = empty_latent;
1511
+ if (sd_ctx->sd ->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd ->control_net == NULL ) {
1512
+ bool no_inpaint = concat_latent == NULL ;
1513
+ if (no_inpaint) {
1514
+ concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1515
+ }
1516
+ // fill in the control image here
1517
+ struct ggml_tensor * control_latents = NULL ;
1518
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1519
+ struct ggml_tensor * control_moments = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1520
+ control_latents = sd_ctx->sd ->get_first_stage_encoding (work_ctx, control_moments);
1521
+ } else {
1522
+ control_latents = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1523
+ }
1524
+ for (int64_t x = 0 ; x < concat_latent->ne [0 ]; x++) {
1525
+ for (int64_t y = 0 ; y < concat_latent->ne [1 ]; y++) {
1526
+ if (no_inpaint) {
1527
+ for (int64_t c = 0 ; c < concat_latent->ne [2 ] - control_latents->ne [2 ]; c++) {
1528
+ // 0x16,1x1,0x16
1529
+ ggml_tensor_set_f32 (concat_latent, c == init_latent->ne [2 ], x, y, c);
1530
+ }
1531
+ }
1532
+ for (int64_t c = 0 ; c < control_latents->ne [2 ]; c++) {
1533
+ float v = ggml_tensor_get_f32 (control_latents, x, y, c);
1534
+ ggml_tensor_set_f32 (concat_latent, v, x, y, concat_latent->ne [2 ] - control_latents->ne [2 ] + c);
1535
+ }
1536
+ }
1537
+ }
1538
+ // Disable controlnet
1539
+ image_hint = NULL ;
1540
+ } else if (concat_latent == NULL ) {
1541
+ concat_latent = empty_latent;
1501
1542
}
1502
- cond.c_concat = masked_latent ;
1543
+ cond.c_concat = concat_latent ;
1503
1544
uncond.c_concat = empty_latent;
1504
- // noise_mask = masked_latent ;
1545
+ // noise_mask = concat_latent ;
1505
1546
} else if (sd_version_is_edit (sd_ctx->sd ->version )) {
1506
- cond.c_concat = masked_latent ;
1507
- auto empty_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, masked_latent ->ne [0 ], masked_latent ->ne [1 ], masked_latent ->ne [2 ], masked_latent ->ne [3 ]);
1547
+ cond.c_concat = concat_latent ;
1548
+ auto empty_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, concat_latent ->ne [0 ], concat_latent ->ne [1 ], concat_latent ->ne [2 ], concat_latent ->ne [3 ]);
1508
1549
ggml_set_f32 (empty_latent, 0 );
1509
1550
uncond.c_concat = empty_latent;
1510
1551
} else {
1511
- noise_mask = masked_latent ;
1552
+ noise_mask = concat_latent ;
1512
1553
}
1513
1554
1514
1555
for (int b = 0 ; b < batch_count; b++) {
@@ -1756,7 +1797,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1756
1797
1757
1798
sd_image_to_tensor (init_image.data , init_img);
1758
1799
1759
- ggml_tensor* masked_latent ;
1800
+ ggml_tensor* concat_latent ;
1760
1801
1761
1802
ggml_tensor* init_latent = NULL ;
1762
1803
ggml_tensor* init_moments = NULL ;
@@ -1771,6 +1812,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1771
1812
int64_t mask_channels = 1 ;
1772
1813
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1773
1814
mask_channels = 8 * 8 ; // flatten the whole mask
1815
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1816
+ mask_channels = 1 + init_latent->ne [2 ];
1774
1817
}
1775
1818
ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1776
1819
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
@@ -1783,56 +1826,82 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1783
1826
} else {
1784
1827
masked_latent_0 = sd_ctx->sd ->encode_first_stage (work_ctx, masked_img);
1785
1828
}
1786
- masked_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, masked_latent_0->ne [0 ], masked_latent_0->ne [1 ], mask_channels + masked_latent_0->ne [2 ], 1 );
1829
+ concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, masked_latent_0->ne [0 ], masked_latent_0->ne [1 ], mask_channels + masked_latent_0->ne [2 ], 1 );
1787
1830
for (int ix = 0 ; ix < masked_latent_0->ne [0 ]; ix++) {
1788
1831
for (int iy = 0 ; iy < masked_latent_0->ne [1 ]; iy++) {
1789
1832
int mx = ix * 8 ;
1790
1833
int my = iy * 8 ;
1791
1834
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1792
1835
for (int k = 0 ; k < masked_latent_0->ne [2 ]; k++) {
1793
1836
float v = ggml_tensor_get_f32 (masked_latent_0, ix, iy, k);
1794
- ggml_tensor_set_f32 (masked_latent , v, ix, iy, k);
1837
+ ggml_tensor_set_f32 (concat_latent , v, ix, iy, k);
1795
1838
}
1796
1839
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
1797
1840
for (int x = 0 ; x < 8 ; x++) {
1798
1841
for (int y = 0 ; y < 8 ; y++) {
1799
1842
float m = ggml_tensor_get_f32 (mask_img, mx + x, my + y);
1800
1843
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
1801
1844
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
1802
- ggml_tensor_set_f32 (masked_latent , m, ix, iy, masked_latent_0->ne [2 ] + x * 8 + y);
1845
+ ggml_tensor_set_f32 (concat_latent , m, ix, iy, masked_latent_0->ne [2 ] + x * 8 + y);
1803
1846
}
1804
1847
}
1848
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1849
+ float m = ggml_tensor_get_f32 (mask_img, mx, my);
1850
+ // masked image
1851
+ for (int k = 0 ; k < masked_latent_0->ne [2 ]; k++) {
1852
+ float v = ggml_tensor_get_f32 (masked_latent_0, ix, iy, k);
1853
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
1854
+ }
1855
+ // downsampled mask
1856
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent_0->ne [2 ]);
1857
+ // control (todo: support this)
1858
+ for (int k = 0 ; k < masked_latent_0->ne [2 ]; k++) {
1859
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent_0->ne [2 ] + 1 + k);
1860
+ }
1861
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1862
+ float m = ggml_tensor_get_f32 (mask_img, mx, my);
1863
+ // masked image
1864
+ for (int k = 0 ; k < masked_latent_0->ne [2 ]; k++) {
1865
+ float v = ggml_tensor_get_f32 (masked_latent_0, ix, iy, k);
1866
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
1867
+ }
1868
+ // downsampled mask
1869
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent_0->ne [2 ]);
1870
+ // control (todo: support this)
1871
+ for (int k = 0 ; k < masked_latent_0->ne [2 ]; k++) {
1872
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent_0->ne [2 ] + 1 + k);
1873
+ }
1805
1874
} else {
1806
1875
float m = ggml_tensor_get_f32 (mask_img, mx, my);
1807
- ggml_tensor_set_f32 (masked_latent , m, ix, iy, 0 );
1876
+ ggml_tensor_set_f32 (concat_latent , m, ix, iy, 0 );
1808
1877
for (int k = 0 ; k < masked_latent_0->ne [2 ]; k++) {
1809
1878
float v = ggml_tensor_get_f32 (masked_latent_0, ix, iy, k);
1810
- ggml_tensor_set_f32 (masked_latent , v, ix, iy, k + mask_channels);
1879
+ ggml_tensor_set_f32 (concat_latent , v, ix, iy, k + mask_channels);
1811
1880
}
1812
1881
}
1813
1882
}
1814
1883
}
1815
1884
} else if (sd_version_is_edit (sd_ctx->sd ->version )) {
1816
- // Not actually masked, we're just highjacking the masked_latent variable since it will be used the same way
1885
+ // Not actually masked, we're just highjacking the concat_latent variable since it will be used the same way
1817
1886
if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1818
1887
if (sd_ctx->sd ->is_using_edm_v_parameterization ) {
1819
1888
// for CosXL edit
1820
- masked_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, init_moments);
1889
+ concat_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, init_moments);
1821
1890
} else {
1822
- masked_latent = sd_ctx->sd ->get_first_stage_encoding_mode (work_ctx, init_moments);
1891
+ concat_latent = sd_ctx->sd ->get_first_stage_encoding_mode (work_ctx, init_moments);
1823
1892
}
1824
1893
} else {
1825
- masked_latent = init_latent;
1894
+ concat_latent = init_latent;
1826
1895
}
1827
1896
} else {
1828
1897
// LOG_WARN("Inpainting with a base model is not great");
1829
- masked_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width / 8 , height / 8 , 1 , 1 );
1830
- for (int ix = 0 ; ix < masked_latent ->ne [0 ]; ix++) {
1831
- for (int iy = 0 ; iy < masked_latent ->ne [1 ]; iy++) {
1898
+ concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width / 8 , height / 8 , 1 , 1 );
1899
+ for (int ix = 0 ; ix < concat_latent ->ne [0 ]; ix++) {
1900
+ for (int iy = 0 ; iy < concat_latent ->ne [1 ]; iy++) {
1832
1901
int mx = ix * 8 ;
1833
1902
int my = iy * 8 ;
1834
1903
float m = ggml_tensor_get_f32 (mask_img, mx, my);
1835
- ggml_tensor_set_f32 (masked_latent , m, ix, iy);
1904
+ ggml_tensor_set_f32 (concat_latent , m, ix, iy);
1836
1905
}
1837
1906
}
1838
1907
}
@@ -1868,7 +1937,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1868
1937
style_ratio,
1869
1938
normalize_input,
1870
1939
input_id_images_path_c_str,
1871
- masked_latent );
1940
+ concat_latent );
1872
1941
1873
1942
size_t t2 = ggml_time_ms ();
1874
1943
0 commit comments