Skip to content

Commit c2256c0

Browse files
committed
apg: first implementation
1 parent 30b3ac8 commit c2256c0

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

stable-diffusion.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,15 @@ class StableDiffusionGGML {
847847
}
848848
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
849849

850+
// TODO do not hardcode
851+
float apg_eta = .08f;
852+
float apg_momentum = -.5f;
853+
float apg_norm_treshold = 15.0f;
854+
855+
std::vector<float> apg_momentum_buffer;
856+
if (apg_momentum != 0)
857+
apg_momentum_buffer.resize((size_t)ggml_nelements(denoised));
858+
850859
auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
851860
if (step == 1) {
852861
pretty_progress(0, (int)steps, 0);
@@ -951,6 +960,50 @@ class StableDiffusionGGML {
951960
float* vec_input = (float*)input->data;
952961
float* positive_data = (float*)out_cond->data;
953962
int ne_elements = (int)ggml_nelements(denoised);
963+
964+
float* deltas = vec_denoised;
965+
966+
// https://arxiv.org/pdf/2410.02416
967+
float apg_scale_factor = 1.;
968+
float diff_norm = 0;
969+
float cond_norm_sq = 0;
970+
float dot = 0;
971+
for (int i = 0; i < ne_elements; i++) {
972+
float delta = positive_data[i] - negative_data[i];
973+
if (apg_momentum != 0) {
974+
delta += apg_momentum * apg_momentum_buffer[i];
975+
apg_momentum_buffer[i] = delta;
976+
}
977+
if (apg_norm_treshold > 0) {
978+
diff_norm += delta * delta;
979+
}
980+
if (apg_eta != 1.0f) {
981+
cond_norm_sq += positive_data[i] * positive_data[i];
982+
dot += positive_data[i] * delta;
983+
}
984+
deltas[i] = delta;
985+
}
986+
if (apg_norm_treshold > 0) {
987+
diff_norm = std::sqrtf(diff_norm);
988+
apg_scale_factor = std::min(1.0f, apg_norm_treshold / diff_norm);
989+
}
990+
if (apg_eta != 1.0f) {
991+
dot *= apg_scale_factor;
992+
// pre-normalize (avoids one square root and ne_elements extra divs)
993+
dot /= cond_norm_sq;
994+
}
995+
996+
for (int i = 0; i < ne_elements; i++) {
997+
deltas[i] *= apg_scale_factor;
998+
if (apg_eta != 1.0f) {
999+
float apg_parallel = dot * positive_data[i];
1000+
float apg_orthogonal = deltas[i] - apg_parallel;
1001+
1002+
// tweak deltas
1003+
deltas[i] = apg_orthogonal + apg_eta * apg_parallel;
1004+
}
1005+
}
1006+
9541007
for (int i = 0; i < ne_elements; i++) {
9551008
float latent_result = positive_data[i];
9561009
if (has_unconditioned) {
@@ -960,7 +1013,9 @@ class StableDiffusionGGML {
9601013
int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2];
9611014
float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3);
9621015
} else {
963-
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
1016+
float delta = deltas[i];
1017+
1018+
latent_result = positive_data[i] + (cfg_scale - 1) * delta;
9641019
}
9651020
}
9661021
if (is_skiplayer_step) {
@@ -1004,7 +1059,8 @@ class StableDiffusionGGML {
10041059
}
10051060

10061061
// ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding
1007-
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* moments) {
1062+
ggml_tensor*
1063+
get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* moments) {
10081064
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
10091065
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
10101066
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);

0 commit comments

Comments
 (0)