diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 4e283a15..52b9a985 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -657,12 +657,32 @@ void parse_args(int argc, const char** argv, SDParams& params) { std::stringstream ss(sigmas_str); std::string item; + unsigned char op; + float last_sigma, parsed_sigma; while(std::getline(ss, item, ',')) { item.erase(0, item.find_first_not_of(" \t\n\r\f\v")); item.erase(item.find_last_not_of(" \t\n\r\f\v") + 1); if (!item.empty()) { try { - params.custom_sigmas.push_back(std::stof(item)); + op = item[0]; // basic math handling + switch (op) { + case '+': case '-': case '*': case '/': case '^': + item.erase(0, 1); // "* 1.5" => "1.5" + item.erase(0, item.find_first_not_of(" \t\n\r\f\v")); + break; + default: + op = 0; // simply a value, or ignore unknown ops + } + parsed_sigma = (!item.empty()) ? std::stof(item) : 0.f; + switch (op) { + case '+': last_sigma += parsed_sigma; break; + case '-': last_sigma -= parsed_sigma; break; + case '*': last_sigma *= parsed_sigma; break; + case '/': last_sigma /= parsed_sigma; break; + case '^': last_sigma = std::pow(last_sigma, parsed_sigma); break; + default: last_sigma = parsed_sigma; // set as is + } + params.custom_sigmas.push_back(last_sigma); } catch (const std::invalid_argument& e) { fprintf(stderr, "error: invalid float value '%s' in --sigmas\n", item.c_str()); invalid_arg = true; @@ -680,6 +700,9 @@ void parse_args(int argc, const char** argv, SDParams& params) { invalid_arg = true; break; } + // use last 2 values as last op holder + params.custom_sigmas.push_back((float)op); + params.custom_sigmas.push_back(parsed_sigma); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index ee4006b0..0df0dc63 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1199,11 +1199,27 @@ static std::vector prepare_sigmas( std::vector sigmas_for_generation; if (custom_sigmas_count > 0 && custom_sigmas_arr != nullptr) { LOG_INFO("Using custom sigmas provided by user for %s.", mode_name); - sigmas_for_generation.assign(custom_sigmas_arr, custom_sigmas_arr + custom_sigmas_count); + sigmas_for_generation.assign(custom_sigmas_arr, custom_sigmas_arr + custom_sigmas_count - 2); // last 2 values hold operation + operand + unsigned u = sigmas_for_generation.size(); size_t target_len = static_cast(sample_steps) + 1; - if (sigmas_for_generation.size() < target_len) { - LOG_DEBUG("Custom sigmas count (%zu) is less than target steps + 1 (%zu). Padding with 0.0f.", sigmas_for_generation.size(), target_len); + if (u < target_len) { + unsigned char op = custom_sigmas_arr[u]; + float last_parsed_sigma = custom_sigmas_arr[u + 1]; + LOG_DEBUG("Custom sigmas count (%zu) is less than target steps + 1 (%zu). Propagating last operation (%c%f).", u, target_len, op, last_parsed_sigma); sigmas_for_generation.resize(target_len, 0.0f); + for (; u < target_len - 1; u++) { // last sigma will be zeroed anyway + switch(op) { + case '+': sigmas_for_generation[u] = sigmas_for_generation[u - 1] + last_parsed_sigma; break; + case '-': sigmas_for_generation[u] = sigmas_for_generation[u - 1] - last_parsed_sigma; break; + case '*': sigmas_for_generation[u] = sigmas_for_generation[u - 1] * last_parsed_sigma; break; + case '/': sigmas_for_generation[u] = sigmas_for_generation[u - 1] / last_parsed_sigma; break; + case '^': sigmas_for_generation[u] = std::pow(sigmas_for_generation[u - 1], last_parsed_sigma); break; + default: sigmas_for_generation[u] = last_parsed_sigma; + } + } + for (u = 0; u < target_len; u++) { + LOG_DEBUG("sigmas_for_generation[%u] = '%f'.", u, sigmas_for_generation[u]); + } } else if (sigmas_for_generation.size() > target_len) { LOG_DEBUG("Custom sigmas count (%zu) is greater than target steps + 1 (%zu). Truncating.", sigmas_for_generation.size(), target_len); sigmas_for_generation.resize(target_len);