@@ -261,7 +261,6 @@ namespace gcpp
261
261
{
262
262
std::string generated_text;
263
263
// Seed the random number generator
264
- int current_pos = 0 ;
265
264
std::random_device rd;
266
265
std::mt19937 gen (rd ());
267
266
int prompt_size{};
@@ -277,27 +276,14 @@ namespace gcpp
277
276
// Placeholder for generated token IDs
278
277
std::vector<int > generated_tokens;
279
278
// Define lambda for token decoding
280
- StreamFunc stream_token = [&generated_tokens,¤t_pos](int token, float /* probability */ ) -> bool {
281
- ++current_pos;
279
+ StreamFunc stream_token = [&generated_tokens](int token, float /* probability */ ) -> bool {
282
280
generated_tokens.push_back (token);
283
281
return true ; // Continue generating
284
282
};
285
- // Decode each token and concatenate
283
+ // Decode tokens
286
284
prompt_size = prompt.size ();
287
285
GenerateGemma (model, args, prompt, /* start_pos=*/ 0 , pool, inner_pool, stream_token, accept_token, gen, verbosity);
288
- // for (int token : generated_tokens) {
289
- // std::string token_text;
290
- // if (model.Tokenizer().Decode(std::vector<int>{token}, &token_text).ok()) {
291
- // generated_text += token_text; // Appending a space for readability
292
- // }
293
286
HWY_ASSERT (model.Tokenizer ().Decode (generated_tokens, &generated_text).ok ());
294
- // for (int i = prompt_size; i < generated_tokens.size(); ++i) {
295
- // std::string token_text;
296
- // if (model.Tokenizer().Decode(std::vector<int>{generated_tokens[i]}, &token_text).ok()) {
297
- // generated_text += token_text; // Appending a space for readability
298
- // }
299
- // }
300
- // remove promp from generated text
301
287
generated_text = generated_text.substr (prompt_string.size ());
302
288
303
289
return generated_text;
@@ -355,8 +341,7 @@ std::string completion_base(int argc, char **argv)
355
341
gcpp::InferenceArgs inference (argc, argv);
356
342
gcpp::AppArgs app (argc, argv);
357
343
std::string prompt_string = argv[argc-1 ];
358
- std::string output_text = gcpp::completion (loader, inference, app, prompt_string);
359
- return output_text;
344
+ return gcpp::completion (loader, inference, app, prompt_string);
360
345
}
361
346
std::string completion_base_wrapper (const std::vector<std::string> &args,std::string &prompt_string)
362
347
{
@@ -372,8 +357,7 @@ std::string completion_base_wrapper(const std::vector<std::string> &args,std::st
372
357
}
373
358
argv_vec.push_back (const_cast <char *>(prompt_string.c_str ()));
374
359
char **argv = argv_vec.data ();
375
- std::string output = completion_base (argc, argv);
376
- return output;
360
+ return completion_base (argc, argv);
377
361
}
378
362
void show_help_wrapper ()
379
363
{
0 commit comments