@@ -301,7 +301,7 @@ class StableDiffusionGGML {
301
301
// TODO: shift_factor
302
302
}
303
303
304
- if ( version == VERSION_FLEX_2) {
304
+ if ( sd_version_is_control ( version)) {
305
305
// Might need vae encode for control cond
306
306
vae_decode_only = false ;
307
307
}
@@ -1476,6 +1476,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1476
1476
int H = height / 8 ;
1477
1477
LOG_INFO (" sampling using %s method" , sampling_methods_str[sample_method]);
1478
1478
ggml_tensor* noise_mask = nullptr ;
1479
+
1480
+ struct ggml_tensor * control_latent = NULL ;
1481
+ if (sd_version_is_control (sd_ctx->sd ->version ) && image_hint != NULL ){
1482
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1483
+ struct ggml_tensor * control_moments = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1484
+ control_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, control_moments);
1485
+ } else {
1486
+ control_latent = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1487
+ }
1488
+ }
1489
+
1479
1490
if (sd_version_is_inpaint (sd_ctx->sd ->version )) {
1480
1491
int64_t mask_channels = 1 ;
1481
1492
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
@@ -1508,46 +1519,48 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1508
1519
}
1509
1520
}
1510
1521
}
1511
- if (sd_ctx->sd ->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd ->control_net == NULL ) {
1522
+
1523
+ if (sd_ctx->sd ->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd ->control_net == NULL ) {
1512
1524
bool no_inpaint = concat_latent == NULL ;
1513
1525
if (no_inpaint) {
1514
1526
concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1515
1527
}
1516
1528
// fill in the control image here
1517
- struct ggml_tensor * control_latents = NULL ;
1518
- if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1519
- struct ggml_tensor * control_moments = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1520
- control_latents = sd_ctx->sd ->get_first_stage_encoding (work_ctx, control_moments);
1521
- } else {
1522
- control_latents = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1523
- }
1524
- for (int64_t x = 0 ; x < concat_latent->ne [0 ]; x++) {
1525
- for (int64_t y = 0 ; y < concat_latent->ne [1 ]; y++) {
1529
+ for (int64_t x = 0 ; x < control_latent->ne [0 ]; x++) {
1530
+ for (int64_t y = 0 ; y < control_latent->ne [1 ]; y++) {
1526
1531
if (no_inpaint) {
1527
- for (int64_t c = 0 ; c < concat_latent->ne [2 ] - control_latents ->ne [2 ]; c++) {
1532
+ for (int64_t c = 0 ; c < concat_latent->ne [2 ] - control_latent ->ne [2 ]; c++) {
1528
1533
// 0x16,1x1,0x16
1529
1534
ggml_tensor_set_f32 (concat_latent, c == init_latent->ne [2 ], x, y, c);
1530
1535
}
1531
1536
}
1532
- for (int64_t c = 0 ; c < control_latents ->ne [2 ]; c++) {
1533
- float v = ggml_tensor_get_f32 (control_latents , x, y, c);
1534
- ggml_tensor_set_f32 (concat_latent, v, x, y, concat_latent->ne [2 ] - control_latents ->ne [2 ] + c);
1537
+ for (int64_t c = 0 ; c < control_latent ->ne [2 ]; c++) {
1538
+ float v = ggml_tensor_get_f32 (control_latent , x, y, c);
1539
+ ggml_tensor_set_f32 (concat_latent, v, x, y, concat_latent->ne [2 ] - control_latent ->ne [2 ] + c);
1535
1540
}
1536
1541
}
1537
1542
}
1538
- // Disable controlnet
1539
- image_hint = NULL ;
1540
1543
} else if (concat_latent == NULL ) {
1541
1544
concat_latent = empty_latent;
1542
1545
}
1543
1546
cond.c_concat = concat_latent;
1544
1547
uncond.c_concat = empty_latent;
1545
- // noise_mask = concat_latent;
1546
- } else if (sd_version_is_edit (sd_ctx->sd ->version )) {
1547
- cond.c_concat = concat_latent;
1548
- auto empty_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, concat_latent->ne [0 ], concat_latent->ne [1 ], concat_latent->ne [2 ], concat_latent->ne [3 ]);
1548
+ noise_mask = NULL ;
1549
+ } else if (sd_version_is_edit (sd_ctx->sd ->version ) || sd_version_is_control (sd_ctx->sd ->version )) {
1550
+ LOG_INFO (" HERE" );
1551
+ auto empty_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], init_latent->ne [2 ], init_latent->ne [3 ]);
1552
+ LOG_INFO (" HERE" );
1549
1553
ggml_set_f32 (empty_latent, 0 );
1550
1554
uncond.c_concat = empty_latent;
1555
+ if (sd_version_is_control (sd_ctx->sd ->version ) && control_latent != NULL && sd_ctx->sd ->control_net == NULL ) {
1556
+ concat_latent = control_latent;
1557
+ }
1558
+ if (concat_latent == NULL ) {
1559
+ concat_latent = empty_latent;
1560
+ }
1561
+ LOG_INFO (" HERE" );
1562
+
1563
+ cond.c_concat = concat_latent;
1551
1564
} else {
1552
1565
noise_mask = concat_latent;
1553
1566
}
0 commit comments