@@ -126,9 +126,13 @@ struct SDParams {
126
126
int upscale_repeats = 1 ;
127
127
128
128
std::vector<int > skip_layers = {7 , 8 , 9 };
129
- float slg_scale = 0 .f ;
129
+ float slg_scale = 0 .0f ;
130
130
float skip_layer_start = 0 .01f ;
131
131
float skip_layer_end = 0 .2f ;
132
+
133
+ float apg_eta = 1 .0f ;
134
+ float apg_momentum = 0 .0f ;
135
+ float apg_norm_treshold = 0 .0f ;
132
136
};
133
137
134
138
void print_params (SDParams params) {
@@ -213,6 +217,9 @@ void print_usage(int argc, const char* argv[]) {
213
217
printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
214
218
printf (" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n " );
215
219
printf (" --guidance SCALE guidance scale for img2img (default: 3.5)\n " );
220
+ printf (" --apg-eta VALUE parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)\n " );
221
+ printf (" --apg-momentum VALUE CFG update direction momentum for APG (default: 0, recommended: around -0.5)\n " );
222
+ printf (" --apg-nt, --apg-rescale VALUE CFG update direction norm threshold for APG (default: 0 = disabled, recommended: 4-15)\n " );
216
223
printf (" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n " );
217
224
printf (" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n " );
218
225
printf (" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n " );
@@ -629,6 +636,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629
636
break ;
630
637
}
631
638
params.skip_layer_end = std::stof (argv[i]);
639
+ } else if (arg == " --apg-eta" ) {
640
+ if (++i >= argc) {
641
+ invalid_arg = true ;
642
+ break ;
643
+ }
644
+ params.apg_eta = std::stof (argv[i]);
645
+ } else if (arg == " --apg-momentum" ) {
646
+ if (++i >= argc) {
647
+ invalid_arg = true ;
648
+ break ;
649
+ }
650
+ params.apg_momentum = std::stof (argv[i]);
651
+ } else if (arg == " --apg-nt" || arg == " --apg-rescale" ) {
652
+ if (++i >= argc) {
653
+ invalid_arg = true ;
654
+ break ;
655
+ }
656
+ params.apg_norm_treshold = std::stof (argv[i]);
632
657
} else {
633
658
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
634
659
print_usage (argc, argv);
@@ -968,7 +993,9 @@ int main(int argc, const char* argv[]) {
968
993
params.slg_scale ,
969
994
params.skip_layer_start ,
970
995
params.skip_layer_end },
971
- sd_apg_params_t {1 , 0 , 0 });
996
+ sd_apg_params_t {params.apg_eta ,
997
+ params.apg_momentum ,
998
+ params.apg_norm_treshold });
972
999
} else {
973
1000
sd_image_t input_image = {(uint32_t )params.width ,
974
1001
(uint32_t )params.height ,
@@ -1038,7 +1065,9 @@ int main(int argc, const char* argv[]) {
1038
1065
params.slg_scale ,
1039
1066
params.skip_layer_start ,
1040
1067
params.skip_layer_end },
1041
- sd_apg_params_t {1 , 0 , 0 });
1068
+ sd_apg_params_t {params.apg_eta ,
1069
+ params.apg_momentum ,
1070
+ params.apg_norm_treshold });
1042
1071
}
1043
1072
}
1044
1073
0 commit comments