From 7e0ba7151cc3cf94fc4829c8b4467d44e2564c9c Mon Sep 17 00:00:00 2001 From: bachvudinh Date: Sun, 3 Mar 2024 06:04:44 +0700 Subject: [PATCH 1/2] add completion function --- src/gemma_binding.cpp | 99 +++++++++++++++++++++++++++++++++++++++++-- tests/test_chat.py | 44 ++++++++++++------- 2 files changed, 125 insertions(+), 18 deletions(-) diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp index 2407cfc..608563f 100644 --- a/src/gemma_binding.cpp +++ b/src/gemma_binding.cpp @@ -135,7 +135,7 @@ namespace gcpp while (abs_pos < args.max_tokens) { - std::string prompt_string; + std::string prompt_string; std::vector prompt; current_pos = 0; { @@ -255,6 +255,72 @@ namespace gcpp { return true; }); } + std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool, + hwy::ThreadPool &inner_pool, const InferenceArgs &args, + int verbosity, const gcpp::AcceptFunc &accept_token, std::string &prompt_string) + { + std::string generated_text; + // Seed the random number generator + int current_pos = 0; + std::random_device rd; + std::mt19937 gen(rd()); + int prompt_size{}; + if (model.model_training == ModelTraining::GEMMA_IT) + { + // For instruction-tuned models: add control tokens. + prompt_string = "user\n" + prompt_string + + "\nmodel\n"; + } + // Encode the prompt string into tokens + std::vector prompt; + HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + // Placeholder for generated token IDs + std::vector generated_tokens; + // Define lambda for token decoding + StreamFunc stream_token = [&generated_tokens,¤t_pos](int token, float /* probability */) -> bool { + ++current_pos; + generated_tokens.push_back(token); + return true; // Continue generating + }; + // Decode each token and concatenate + prompt_size = prompt.size(); + GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, accept_token, gen, verbosity); + // for (int token : generated_tokens) { + // std::string token_text; + // if (model.Tokenizer().Decode(std::vector{token}, &token_text).ok()) { + // generated_text += token_text; // Appending a space for readability + // } + HWY_ASSERT(model.Tokenizer().Decode(generated_tokens, &generated_text).ok()); + // for (int i = prompt_size; i < generated_tokens.size(); ++i) { + // std::string token_text; + // if (model.Tokenizer().Decode(std::vector{generated_tokens[i]}, &token_text).ok()) { + // generated_text += token_text; // Appending a space for readability + // } + // } + // remove promp from generated text + generated_text = generated_text.substr(prompt_string.size()); + + return generated_text; + } + + std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, std::string &prompt_string) + { + hwy::ThreadPool inner_pool(0); + hwy::ThreadPool pool(app.num_threads); + if (app.num_threads > 10) + { + PinThreadToCore(app.num_threads - 1); // Main thread + + pool.Run(0, pool.NumThreads(), + [](uint64_t /*task*/, size_t thread) + { PinThreadToCore(thread); }); + } + gcpp::Gemma model(loader, pool); + return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) + { return true; }, prompt_string); + + } + } // namespace gcpp void chat_base(int argc, char **argv) @@ -283,7 +349,32 @@ void chat_base(int argc, char **argv) PROFILER_PRINT_RESULTS(); // Must call outside the zone above. // return 1; } +std::string completion_base(int argc, char **argv) +{ + gcpp::LoaderArgs loader(argc, argv); + gcpp::InferenceArgs inference(argc, argv); + gcpp::AppArgs app(argc, argv); + std::string prompt_string = argv[argc-1]; + std::string output_text = gcpp::completion(loader, inference, app, prompt_string); + return output_text; +} +std::string completion_base_wrapper(const std::vector &args,std::string &prompt_string) +{ + int argc = args.size() + 2; // +1 for the program name + std::vector argv_vec; + argv_vec.reserve(argc); + argv_vec.push_back(const_cast("pygemma")); + + for (const auto &arg : args) + { + argv_vec.push_back(const_cast(arg.c_str())); + } + argv_vec.push_back(const_cast(prompt_string.c_str())); + char **argv = argv_vec.data(); + std::string output = completion_base(argc, argv); + return output; +} void show_help_wrapper() { // Assuming ShowHelp does not critically depend on argv content @@ -294,12 +385,11 @@ void show_help_wrapper() ShowHelp(loader, inference, app); } -void chat_base_wrapper(const std::vector &args) +std::string chat_base_wrapper(const std::vector &args) { int argc = args.size() + 1; // +1 for the program name std::vector argv_vec; argv_vec.reserve(argc); - argv_vec.push_back(const_cast("pygemma")); for (const auto &arg : args) @@ -308,12 +398,15 @@ void chat_base_wrapper(const std::vector &args) } char **argv = argv_vec.data(); + chat_base(argc, argv); } + PYBIND11_MODULE(pygemma, m) { m.doc() = "Pybind11 integration for chat_base function"; m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments"); m.def("show_help", &show_help_wrapper, "A wrapper for show_help function"); + m.def("completion", &completion_base_wrapper, "A wrapper for inference function"); } diff --git a/tests/test_chat.py b/tests/test_chat.py index 0554771..43795c7 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -18,25 +18,39 @@ def main(): parser.add_argument( "--model", type=str, required=True, help="Model type identifier." ) - - args = parser.parse_args() - - # Now using the parsed arguments - pygemma.chat_base( - [ - "--tokenizer", - args.tokenizer, - "--compressed_weights", - args.compressed_weights, - "--model", - args.model, - ] + parser.add_argument( + "--input", type=str, required=False, help="Input text to chat with the model. If None, Switch to Chat mode.", + default="Hello." ) - + # Now using the parsed arguments + args = parser.parse_args() + if args.input is not None: + string = pygemma.completion( + [ + "--tokenizer", + args.tokenizer, + "--compressed_weights", + args.compressed_weights, + "--model", + args.model, + ], args.input + ) + print(string) + else: + return pygemma.chat_base( + [ + "--tokenizer", + args.tokenizer, + "--compressed_weights", + args.compressed_weights, + "--model", + args.model, + ] + ) # Optionally, show help if needed # pygemma.show_help() if __name__ == "__main__": main() - # python tests/test_chat.py --tokenizer /path/to/tokenizer.spm --compressed_weights /path/to/weights.sbs --model model_identifier + # python tests/test_chat.py --tokenizer ../Model_Weight/tokenizer.spm --compressed_weights ../Model_Weight/2b-it-sfp.sbs --model 2b-it From df10530c03b7f2b248ecd2eb5811aa61ae859d3e Mon Sep 17 00:00:00 2001 From: bachvudinh Date: Sun, 3 Mar 2024 06:48:32 +0700 Subject: [PATCH 2/2] clean code --- src/gemma_binding.cpp | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp index 608563f..53607db 100644 --- a/src/gemma_binding.cpp +++ b/src/gemma_binding.cpp @@ -261,7 +261,6 @@ namespace gcpp { std::string generated_text; // Seed the random number generator - int current_pos = 0; std::random_device rd; std::mt19937 gen(rd()); int prompt_size{}; @@ -277,27 +276,14 @@ namespace gcpp // Placeholder for generated token IDs std::vector generated_tokens; // Define lambda for token decoding - StreamFunc stream_token = [&generated_tokens,¤t_pos](int token, float /* probability */) -> bool { - ++current_pos; + StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool { generated_tokens.push_back(token); return true; // Continue generating }; - // Decode each token and concatenate + // Decode tokens prompt_size = prompt.size(); GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, accept_token, gen, verbosity); - // for (int token : generated_tokens) { - // std::string token_text; - // if (model.Tokenizer().Decode(std::vector{token}, &token_text).ok()) { - // generated_text += token_text; // Appending a space for readability - // } HWY_ASSERT(model.Tokenizer().Decode(generated_tokens, &generated_text).ok()); - // for (int i = prompt_size; i < generated_tokens.size(); ++i) { - // std::string token_text; - // if (model.Tokenizer().Decode(std::vector{generated_tokens[i]}, &token_text).ok()) { - // generated_text += token_text; // Appending a space for readability - // } - // } - // remove promp from generated text generated_text = generated_text.substr(prompt_string.size()); return generated_text; @@ -355,8 +341,7 @@ std::string completion_base(int argc, char **argv) gcpp::InferenceArgs inference(argc, argv); gcpp::AppArgs app(argc, argv); std::string prompt_string = argv[argc-1]; - std::string output_text = gcpp::completion(loader, inference, app, prompt_string); - return output_text; + return gcpp::completion(loader, inference, app, prompt_string); } std::string completion_base_wrapper(const std::vector &args,std::string &prompt_string) { @@ -372,8 +357,7 @@ std::string completion_base_wrapper(const std::vector &args,std::st } argv_vec.push_back(const_cast(prompt_string.c_str())); char **argv = argv_vec.data(); - std::string output = completion_base(argc, argv); - return output; + return completion_base(argc, argv); } void show_help_wrapper() {