diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp index 2407cfc..53607db 100644 --- a/src/gemma_binding.cpp +++ b/src/gemma_binding.cpp @@ -135,7 +135,7 @@ namespace gcpp while (abs_pos < args.max_tokens) { - std::string prompt_string; + std::string prompt_string; std::vector prompt; current_pos = 0; { @@ -255,6 +255,58 @@ namespace gcpp { return true; }); } + std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool, + hwy::ThreadPool &inner_pool, const InferenceArgs &args, + int verbosity, const gcpp::AcceptFunc &accept_token, std::string &prompt_string) + { + std::string generated_text; + // Seed the random number generator + std::random_device rd; + std::mt19937 gen(rd()); + int prompt_size{}; + if (model.model_training == ModelTraining::GEMMA_IT) + { + // For instruction-tuned models: add control tokens. + prompt_string = "user\n" + prompt_string + + "\nmodel\n"; + } + // Encode the prompt string into tokens + std::vector prompt; + HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + // Placeholder for generated token IDs + std::vector generated_tokens; + // Define lambda for token decoding + StreamFunc stream_token = [&generated_tokens](int token, float /* probability */) -> bool { + generated_tokens.push_back(token); + return true; // Continue generating + }; + // Decode tokens + prompt_size = prompt.size(); + GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, accept_token, gen, verbosity); + HWY_ASSERT(model.Tokenizer().Decode(generated_tokens, &generated_text).ok()); + generated_text = generated_text.substr(prompt_string.size()); + + return generated_text; + } + + std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, std::string &prompt_string) + { + hwy::ThreadPool inner_pool(0); + hwy::ThreadPool pool(app.num_threads); + if (app.num_threads > 10) + { + PinThreadToCore(app.num_threads - 1); // Main thread + + pool.Run(0, pool.NumThreads(), + [](uint64_t /*task*/, size_t thread) + { PinThreadToCore(thread); }); + } + gcpp::Gemma model(loader, pool); + return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) + { return true; }, prompt_string); + + } + } // namespace gcpp void chat_base(int argc, char **argv) @@ -283,7 +335,30 @@ void chat_base(int argc, char **argv) PROFILER_PRINT_RESULTS(); // Must call outside the zone above. // return 1; } +std::string completion_base(int argc, char **argv) +{ + gcpp::LoaderArgs loader(argc, argv); + gcpp::InferenceArgs inference(argc, argv); + gcpp::AppArgs app(argc, argv); + std::string prompt_string = argv[argc-1]; + return gcpp::completion(loader, inference, app, prompt_string); +} +std::string completion_base_wrapper(const std::vector &args,std::string &prompt_string) +{ + int argc = args.size() + 2; // +1 for the program name + std::vector argv_vec; + argv_vec.reserve(argc); + argv_vec.push_back(const_cast("pygemma")); + + for (const auto &arg : args) + { + argv_vec.push_back(const_cast(arg.c_str())); + } + argv_vec.push_back(const_cast(prompt_string.c_str())); + char **argv = argv_vec.data(); + return completion_base(argc, argv); +} void show_help_wrapper() { // Assuming ShowHelp does not critically depend on argv content @@ -294,12 +369,11 @@ void show_help_wrapper() ShowHelp(loader, inference, app); } -void chat_base_wrapper(const std::vector &args) +std::string chat_base_wrapper(const std::vector &args) { int argc = args.size() + 1; // +1 for the program name std::vector argv_vec; argv_vec.reserve(argc); - argv_vec.push_back(const_cast("pygemma")); for (const auto &arg : args) @@ -308,12 +382,15 @@ void chat_base_wrapper(const std::vector &args) } char **argv = argv_vec.data(); + chat_base(argc, argv); } + PYBIND11_MODULE(pygemma, m) { m.doc() = "Pybind11 integration for chat_base function"; m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments"); m.def("show_help", &show_help_wrapper, "A wrapper for show_help function"); + m.def("completion", &completion_base_wrapper, "A wrapper for inference function"); } diff --git a/tests/test_chat.py b/tests/test_chat.py index 0554771..43795c7 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -18,25 +18,39 @@ def main(): parser.add_argument( "--model", type=str, required=True, help="Model type identifier." ) - - args = parser.parse_args() - - # Now using the parsed arguments - pygemma.chat_base( - [ - "--tokenizer", - args.tokenizer, - "--compressed_weights", - args.compressed_weights, - "--model", - args.model, - ] + parser.add_argument( + "--input", type=str, required=False, help="Input text to chat with the model. If None, Switch to Chat mode.", + default="Hello." ) - + # Now using the parsed arguments + args = parser.parse_args() + if args.input is not None: + string = pygemma.completion( + [ + "--tokenizer", + args.tokenizer, + "--compressed_weights", + args.compressed_weights, + "--model", + args.model, + ], args.input + ) + print(string) + else: + return pygemma.chat_base( + [ + "--tokenizer", + args.tokenizer, + "--compressed_weights", + args.compressed_weights, + "--model", + args.model, + ] + ) # Optionally, show help if needed # pygemma.show_help() if __name__ == "__main__": main() - # python tests/test_chat.py --tokenizer /path/to/tokenizer.spm --compressed_weights /path/to/weights.sbs --model model_identifier + # python tests/test_chat.py --tokenizer ../Model_Weight/tokenizer.spm --compressed_weights ../Model_Weight/2b-it-sfp.sbs --model 2b-it