@@ -135,7 +135,7 @@ namespace gcpp
135
135
136
136
while (abs_pos < args.max_tokens )
137
137
{
138
- std::string prompt_string;
138
+ std::string prompt_string;
139
139
std::vector<int > prompt;
140
140
current_pos = 0 ;
141
141
{
@@ -255,6 +255,72 @@ namespace gcpp
255
255
{ return true ; });
256
256
}
257
257
258
+ std::string decode (gcpp::Gemma &model, hwy::ThreadPool &pool,
259
+ hwy::ThreadPool &inner_pool, const InferenceArgs &args,
260
+ int verbosity, const gcpp::AcceptFunc &accept_token, std::string &prompt_string)
261
+ {
262
+ std::string generated_text;
263
+ // Seed the random number generator
264
+ int current_pos = 0 ;
265
+ std::random_device rd;
266
+ std::mt19937 gen (rd ());
267
+ int prompt_size{};
268
+ if (model.model_training == ModelTraining::GEMMA_IT)
269
+ {
270
+ // For instruction-tuned models: add control tokens.
271
+ prompt_string = " <start_of_turn>user\n " + prompt_string +
272
+ " <end_of_turn>\n <start_of_turn>model\n " ;
273
+ }
274
+ // Encode the prompt string into tokens
275
+ std::vector<int > prompt;
276
+ HWY_ASSERT (model.Tokenizer ().Encode (prompt_string, &prompt).ok ());
277
+ // Placeholder for generated token IDs
278
+ std::vector<int > generated_tokens;
279
+ // Define lambda for token decoding
280
+ StreamFunc stream_token = [&generated_tokens,¤t_pos](int token, float /* probability */ ) -> bool {
281
+ ++current_pos;
282
+ generated_tokens.push_back (token);
283
+ return true ; // Continue generating
284
+ };
285
+ // Decode each token and concatenate
286
+ prompt_size = prompt.size ();
287
+ GenerateGemma (model, args, prompt, /* start_pos=*/ 0 , pool, inner_pool, stream_token, accept_token, gen, verbosity);
288
+ // for (int token : generated_tokens) {
289
+ // std::string token_text;
290
+ // if (model.Tokenizer().Decode(std::vector<int>{token}, &token_text).ok()) {
291
+ // generated_text += token_text; // Appending a space for readability
292
+ // }
293
+ HWY_ASSERT (model.Tokenizer ().Decode (generated_tokens, &generated_text).ok ());
294
+ // for (int i = prompt_size; i < generated_tokens.size(); ++i) {
295
+ // std::string token_text;
296
+ // if (model.Tokenizer().Decode(std::vector<int>{generated_tokens[i]}, &token_text).ok()) {
297
+ // generated_text += token_text; // Appending a space for readability
298
+ // }
299
+ // }
300
+ // remove promp from generated text
301
+ generated_text = generated_text.substr (prompt_string.size ());
302
+
303
+ return generated_text;
304
+ }
305
+
306
+ std::string completion (LoaderArgs &loader, InferenceArgs &inference, AppArgs &app, std::string &prompt_string)
307
+ {
308
+ hwy::ThreadPool inner_pool (0 );
309
+ hwy::ThreadPool pool (app.num_threads );
310
+ if (app.num_threads > 10 )
311
+ {
312
+ PinThreadToCore (app.num_threads - 1 ); // Main thread
313
+
314
+ pool.Run (0 , pool.NumThreads (),
315
+ [](uint64_t /* task*/ , size_t thread)
316
+ { PinThreadToCore (thread); });
317
+ }
318
+ gcpp::Gemma model (loader, pool);
319
+ return decode (model, pool, inner_pool, inference, app.verbosity , /* accept_token=*/ [](int )
320
+ { return true ; }, prompt_string);
321
+
322
+ }
323
+
258
324
} // namespace gcpp
259
325
260
326
void chat_base (int argc, char **argv)
@@ -283,7 +349,32 @@ void chat_base(int argc, char **argv)
283
349
PROFILER_PRINT_RESULTS (); // Must call outside the zone above.
284
350
// return 1;
285
351
}
352
+ std::string completion_base (int argc, char **argv)
353
+ {
354
+ gcpp::LoaderArgs loader (argc, argv);
355
+ gcpp::InferenceArgs inference (argc, argv);
356
+ gcpp::AppArgs app (argc, argv);
357
+ std::string prompt_string = argv[argc-1 ];
358
+ std::string output_text = gcpp::completion (loader, inference, app, prompt_string);
359
+ return output_text;
360
+ }
361
+ std::string completion_base_wrapper (const std::vector<std::string> &args,std::string &prompt_string)
362
+ {
363
+ int argc = args.size () + 2 ; // +1 for the program name
364
+ std::vector<char *> argv_vec;
365
+ argv_vec.reserve (argc);
286
366
367
+ argv_vec.push_back (const_cast <char *>(" pygemma" ));
368
+
369
+ for (const auto &arg : args)
370
+ {
371
+ argv_vec.push_back (const_cast <char *>(arg.c_str ()));
372
+ }
373
+ argv_vec.push_back (const_cast <char *>(prompt_string.c_str ()));
374
+ char **argv = argv_vec.data ();
375
+ std::string output = completion_base (argc, argv);
376
+ return output;
377
+ }
287
378
void show_help_wrapper ()
288
379
{
289
380
// Assuming ShowHelp does not critically depend on argv content
@@ -294,12 +385,11 @@ void show_help_wrapper()
294
385
ShowHelp (loader, inference, app);
295
386
}
296
387
297
- void chat_base_wrapper (const std::vector<std::string> &args)
388
+ std::string chat_base_wrapper (const std::vector<std::string> &args)
298
389
{
299
390
int argc = args.size () + 1 ; // +1 for the program name
300
391
std::vector<char *> argv_vec;
301
392
argv_vec.reserve (argc);
302
-
303
393
argv_vec.push_back (const_cast <char *>(" pygemma" ));
304
394
305
395
for (const auto &arg : args)
@@ -308,12 +398,15 @@ void chat_base_wrapper(const std::vector<std::string> &args)
308
398
}
309
399
310
400
char **argv = argv_vec.data ();
401
+
311
402
chat_base (argc, argv);
312
403
}
313
404
405
+
314
406
PYBIND11_MODULE (pygemma, m)
315
407
{
316
408
m.doc () = " Pybind11 integration for chat_base function" ;
317
409
m.def (" chat_base" , &chat_base_wrapper, " A wrapper for the chat_base function accepting Python list of strings as arguments" );
318
410
m.def (" show_help" , &show_help_wrapper, " A wrapper for show_help function" );
411
+ m.def (" completion" , &completion_base_wrapper, " A wrapper for inference function" );
319
412
}
0 commit comments