Skip to content

Commit 6f827eb

Browse files
committed
main: add apg support
1 parent c5c231d commit 6f827eb

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

examples/cli/main.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,13 @@ struct SDParams {
126126
int upscale_repeats = 1;
127127

128128
std::vector<int> skip_layers = {7, 8, 9};
129-
float slg_scale = 0.;
130-
float skip_layer_start = 0.01;
131-
float skip_layer_end = 0.2;
129+
float slg_scale = 0.0f;
130+
float skip_layer_start = 0.01f;
131+
float skip_layer_end = 0.2f;
132+
133+
float apg_eta = 1.0f;
134+
float apg_momentum = 0.0f;
135+
float apg_norm_treshold = 0.0f;
132136
};
133137

134138
void print_params(SDParams params) {
@@ -213,6 +217,9 @@ void print_usage(int argc, const char* argv[]) {
213217
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
214218
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
215219
printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n");
220+
printf(" --apg-eta VALUE parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)\n");
221+
printf(" --apg-momentum VALUE CFG update direction momentum for APG (default: 0, recommended: around -0.5)\n");
222+
printf(" --apg-nt, --apg-rescale VALUE CFG update direction norm threshold for APG (default: 0 = disabled, recommended: 4-15)\n");
216223
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
217224
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
218225
printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n");
@@ -629,6 +636,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629636
break;
630637
}
631638
params.skip_layer_end = std::stof(argv[i]);
639+
} else if (arg == "--apg-eta") {
640+
if (++i >= argc) {
641+
invalid_arg = true;
642+
break;
643+
}
644+
params.apg_eta = std::stof(argv[i]);
645+
} else if (arg == "--apg-momentum") {
646+
if (++i >= argc) {
647+
invalid_arg = true;
648+
break;
649+
}
650+
params.apg_momentum = std::stof(argv[i]);
651+
} else if (arg == "--apg-nt" || arg == "--apg-rescale") {
652+
if (++i >= argc) {
653+
invalid_arg = true;
654+
break;
655+
}
656+
params.apg_norm_treshold = std::stof(argv[i]);
632657
} else {
633658
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
634659
print_usage(argc, argv);
@@ -968,7 +993,9 @@ int main(int argc, const char* argv[]) {
968993
params.slg_scale,
969994
params.skip_layer_start,
970995
params.skip_layer_end},
971-
sd_apg_params_t{1, 0, 0});
996+
sd_apg_params_t{params.apg_eta,
997+
params.apg_momentum,
998+
params.apg_norm_treshold});
972999
} else {
9731000
sd_image_t input_image = {(uint32_t)params.width,
9741001
(uint32_t)params.height,
@@ -1038,7 +1065,9 @@ int main(int argc, const char* argv[]) {
10381065
params.slg_scale,
10391066
params.skip_layer_start,
10401067
params.skip_layer_end},
1041-
sd_apg_params_t{1, 0, 0});
1068+
sd_apg_params_t{params.apg_eta,
1069+
params.apg_momentum,
1070+
params.apg_norm_treshold});
10421071
}
10431072
}
10441073

0 commit comments

Comments
 (0)