diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index ceae27b83..51afba9e2 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -71,6 +71,7 @@ struct SDParams { SDMode mode = TXT2IMG; std::string model_path; + std::string clip_g_path; std::string clip_l_path; std::string t5xxl_path; std::string diffusion_model_path; @@ -127,6 +128,7 @@ void print_params(SDParams params) { printf(" mode: %s\n", modes_str[params.mode]); printf(" model_path: %s\n", params.model_path.c_str()); printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); + printf(" clip_g_path: %s\n", params.clip_g_path.c_str()); printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); @@ -175,6 +177,7 @@ void print_usage(int argc, const char* argv[]) { printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" -m, --model [MODEL] path to full model\n"); printf(" --diffusion-model path to the standalone diffusion model\n"); + printf(" --clip_g path to the clip-g text encoder\n"); printf(" --clip_l path to the clip-l text encoder\n"); printf(" --t5xxl path to the the t5xxl text encoder.\n"); printf(" --vae [VAE] path to vae\n"); @@ -256,6 +259,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.model_path = argv[i]; + } else if (arg == "--clip_g") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.clip_g_path = argv[i]; } else if (arg == "--clip_l") { if (++i >= argc) { invalid_arg = true; @@ -764,6 +773,7 @@ int main(int argc, const char* argv[]) { } sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), + params.clip_g_path.c_str(), params.clip_l_path.c_str(), params.t5xxl_path.c_str(), params.diffusion_model_path.c_str(), diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 07b59bb8a..87b8e3f3e 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -138,6 +138,7 @@ class StableDiffusionGGML { } bool load_from_file(const std::string& model_path, + const std::string& clip_g_path, const std::string& clip_l_path, const std::string& t5xxl_path, const std::string& diffusion_model_path, @@ -167,7 +168,7 @@ class StableDiffusionGGML { for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { backend = ggml_backend_vk_init(device); } - if(!backend) { + if (!backend) { LOG_WARN("Failed to initialize Vulkan backend"); } #endif @@ -181,7 +182,7 @@ class StableDiffusionGGML { backend = ggml_backend_cpu_init(); } #ifdef SD_USE_FLASH_ATTENTION -#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN) +#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN) LOG_WARN("Flash Attention not supported with GPU Backend"); #else LOG_INFO("Flash Attention enabled"); @@ -198,24 +199,44 @@ class StableDiffusionGGML { } } + if (diffusion_model_path.size() > 0) { + LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); + if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { + LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); + } + } + version = model_loader.get_sd_version(); + + if (clip_g_path.size() > 0) { + LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str()); + std::string prefix = "text_encoders.clip_g."; + if (version == VERSION_SD3_2B ) { + prefix = "text_encoders.clip_g.transformer."; + } + if (!model_loader.init_from_file(clip_g_path, prefix)) { + LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str()); + } + } + if (clip_l_path.size() > 0) { LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); - if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.")) { + std::string prefix = "text_encoders.clip_l."; + if (version == VERSION_SD3_2B ) { + prefix = "text_encoders.clip_l.transformer."; + } + if (!model_loader.init_from_file(clip_l_path, prefix)) { LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); } } if (t5xxl_path.size() > 0) { LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str()); - if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.")) { - LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str()); + std::string prefix = "text_encoders.t5xxl."; + if (version == VERSION_SD3_2B ) { + prefix = "text_encoders.t5xxl.transformer."; } - } - - if (diffusion_model_path.size() > 0) { - LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); - if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { - LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); + if (!model_loader.init_from_file(t5xxl_path, prefix)) { + LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str()); } } @@ -226,7 +247,6 @@ class StableDiffusionGGML { } } - version = model_loader.get_sd_version(); if (version == VERSION_COUNT) { LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); return false; @@ -1007,6 +1027,7 @@ struct sd_ctx_t { }; sd_ctx_t* new_sd_ctx(const char* model_path_c_str, + const char* clip_g_path_c_str, const char* clip_l_path_c_str, const char* t5xxl_path_c_str, const char* diffusion_model_path_c_str, @@ -1031,6 +1052,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, return NULL; } std::string model_path(model_path_c_str); + std::string clip_g_path(clip_g_path_c_str); std::string clip_l_path(clip_l_path_c_str); std::string t5xxl_path(t5xxl_path_c_str); std::string diffusion_model_path(diffusion_model_path_c_str); @@ -1051,6 +1073,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, } if (!sd_ctx->sd->load_from_file(model_path, + clip_g_path, clip_l_path, t5xxl_path_c_str, diffusion_model_path, diff --git a/stable-diffusion.h b/stable-diffusion.h index 0d4cc1fda..d6b5a855c 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -123,6 +123,7 @@ typedef struct { typedef struct sd_ctx_t sd_ctx_t; SD_API sd_ctx_t* new_sd_ctx(const char* model_path, + const char* clip_g_path, const char* clip_l_path, const char* t5xxl_path, const char* diffusion_model_path,