Skip to content

Commit 7e0ba71

Browse files
committed
add completion function
1 parent ede8eba commit 7e0ba71

File tree

2 files changed

+125
-18
lines changed

2 files changed

+125
-18
lines changed

src/gemma_binding.cpp

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ namespace gcpp
135135

136136
while (abs_pos < args.max_tokens)
137137
{
138-
std::string prompt_string;
138+
std::string prompt_string;
139139
std::vector<int> prompt;
140140
current_pos = 0;
141141
{
@@ -255,6 +255,72 @@ namespace gcpp
255255
{ return true; });
256256
}
257257

258+
std::string decode(gcpp::Gemma &model, hwy::ThreadPool &pool,
259+
hwy::ThreadPool &inner_pool, const InferenceArgs &args,
260+
int verbosity, const gcpp::AcceptFunc &accept_token, std::string &prompt_string)
261+
{
262+
std::string generated_text;
263+
// Seed the random number generator
264+
int current_pos = 0;
265+
std::random_device rd;
266+
std::mt19937 gen(rd());
267+
int prompt_size{};
268+
if (model.model_training == ModelTraining::GEMMA_IT)
269+
{
270+
// For instruction-tuned models: add control tokens.
271+
prompt_string = "<start_of_turn>user\n" + prompt_string +
272+
"<end_of_turn>\n<start_of_turn>model\n";
273+
}
274+
// Encode the prompt string into tokens
275+
std::vector<int> prompt;
276+
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok());
277+
// Placeholder for generated token IDs
278+
std::vector<int> generated_tokens;
279+
// Define lambda for token decoding
280+
StreamFunc stream_token = [&generated_tokens,&current_pos](int token, float /* probability */) -> bool {
281+
++current_pos;
282+
generated_tokens.push_back(token);
283+
return true; // Continue generating
284+
};
285+
// Decode each token and concatenate
286+
prompt_size = prompt.size();
287+
GenerateGemma(model, args, prompt, /*start_pos=*/0, pool, inner_pool, stream_token, accept_token, gen, verbosity);
288+
// for (int token : generated_tokens) {
289+
// std::string token_text;
290+
// if (model.Tokenizer().Decode(std::vector<int>{token}, &token_text).ok()) {
291+
// generated_text += token_text; // Appending a space for readability
292+
// }
293+
HWY_ASSERT(model.Tokenizer().Decode(generated_tokens, &generated_text).ok());
294+
// for (int i = prompt_size; i < generated_tokens.size(); ++i) {
295+
// std::string token_text;
296+
// if (model.Tokenizer().Decode(std::vector<int>{generated_tokens[i]}, &token_text).ok()) {
297+
// generated_text += token_text; // Appending a space for readability
298+
// }
299+
// }
300+
// remove promp from generated text
301+
generated_text = generated_text.substr(prompt_string.size());
302+
303+
return generated_text;
304+
}
305+
306+
std::string completion(LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, std::string &prompt_string)
307+
{
308+
hwy::ThreadPool inner_pool(0);
309+
hwy::ThreadPool pool(app.num_threads);
310+
if (app.num_threads > 10)
311+
{
312+
PinThreadToCore(app.num_threads - 1); // Main thread
313+
314+
pool.Run(0, pool.NumThreads(),
315+
[](uint64_t /*task*/, size_t thread)
316+
{ PinThreadToCore(thread); });
317+
}
318+
gcpp::Gemma model(loader, pool);
319+
return decode(model, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int)
320+
{ return true; }, prompt_string);
321+
322+
}
323+
258324
} // namespace gcpp
259325

260326
void chat_base(int argc, char **argv)
@@ -283,7 +349,32 @@ void chat_base(int argc, char **argv)
283349
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
284350
// return 1;
285351
}
352+
std::string completion_base(int argc, char **argv)
353+
{
354+
gcpp::LoaderArgs loader(argc, argv);
355+
gcpp::InferenceArgs inference(argc, argv);
356+
gcpp::AppArgs app(argc, argv);
357+
std::string prompt_string = argv[argc-1];
358+
std::string output_text = gcpp::completion(loader, inference, app, prompt_string);
359+
return output_text;
360+
}
361+
std::string completion_base_wrapper(const std::vector<std::string> &args,std::string &prompt_string)
362+
{
363+
int argc = args.size() + 2; // +1 for the program name
364+
std::vector<char *> argv_vec;
365+
argv_vec.reserve(argc);
286366

367+
argv_vec.push_back(const_cast<char *>("pygemma"));
368+
369+
for (const auto &arg : args)
370+
{
371+
argv_vec.push_back(const_cast<char *>(arg.c_str()));
372+
}
373+
argv_vec.push_back(const_cast<char *>(prompt_string.c_str()));
374+
char **argv = argv_vec.data();
375+
std::string output = completion_base(argc, argv);
376+
return output;
377+
}
287378
void show_help_wrapper()
288379
{
289380
// Assuming ShowHelp does not critically depend on argv content
@@ -294,12 +385,11 @@ void show_help_wrapper()
294385
ShowHelp(loader, inference, app);
295386
}
296387

297-
void chat_base_wrapper(const std::vector<std::string> &args)
388+
std::string chat_base_wrapper(const std::vector<std::string> &args)
298389
{
299390
int argc = args.size() + 1; // +1 for the program name
300391
std::vector<char *> argv_vec;
301392
argv_vec.reserve(argc);
302-
303393
argv_vec.push_back(const_cast<char *>("pygemma"));
304394

305395
for (const auto &arg : args)
@@ -308,12 +398,15 @@ void chat_base_wrapper(const std::vector<std::string> &args)
308398
}
309399

310400
char **argv = argv_vec.data();
401+
311402
chat_base(argc, argv);
312403
}
313404

405+
314406
PYBIND11_MODULE(pygemma, m)
315407
{
316408
m.doc() = "Pybind11 integration for chat_base function";
317409
m.def("chat_base", &chat_base_wrapper, "A wrapper for the chat_base function accepting Python list of strings as arguments");
318410
m.def("show_help", &show_help_wrapper, "A wrapper for show_help function");
411+
m.def("completion", &completion_base_wrapper, "A wrapper for inference function");
319412
}

tests/test_chat.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,39 @@ def main():
1818
parser.add_argument(
1919
"--model", type=str, required=True, help="Model type identifier."
2020
)
21-
22-
args = parser.parse_args()
23-
24-
# Now using the parsed arguments
25-
pygemma.chat_base(
26-
[
27-
"--tokenizer",
28-
args.tokenizer,
29-
"--compressed_weights",
30-
args.compressed_weights,
31-
"--model",
32-
args.model,
33-
]
21+
parser.add_argument(
22+
"--input", type=str, required=False, help="Input text to chat with the model. If None, Switch to Chat mode.",
23+
default="Hello."
3424
)
35-
25+
# Now using the parsed arguments
26+
args = parser.parse_args()
27+
if args.input is not None:
28+
string = pygemma.completion(
29+
[
30+
"--tokenizer",
31+
args.tokenizer,
32+
"--compressed_weights",
33+
args.compressed_weights,
34+
"--model",
35+
args.model,
36+
], args.input
37+
)
38+
print(string)
39+
else:
40+
return pygemma.chat_base(
41+
[
42+
"--tokenizer",
43+
args.tokenizer,
44+
"--compressed_weights",
45+
args.compressed_weights,
46+
"--model",
47+
args.model,
48+
]
49+
)
3650
# Optionally, show help if needed
3751
# pygemma.show_help()
3852

3953

4054
if __name__ == "__main__":
4155
main()
42-
# python tests/test_chat.py --tokenizer /path/to/tokenizer.spm --compressed_weights /path/to/weights.sbs --model model_identifier
56+
# python tests/test_chat.py --tokenizer ../Model_Weight/tokenizer.spm --compressed_weights ../Model_Weight/2b-it-sfp.sbs --model 2b-it

0 commit comments

Comments
 (0)