Skip to content

Commit 41acd80

Browse files
codezjxNeoZhangJianyu
authored andcommitted
llama.android: add field formatChat to control whether to parse special tokens when send message (ggml-org#11270)
1 parent 38b364f commit 41acd80

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
347347
jlong context_pointer,
348348
jlong batch_pointer,
349349
jstring jtext,
350+
jboolean format_chat,
350351
jint n_len
351352
) {
352353

@@ -356,7 +357,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
356357
const auto context = reinterpret_cast<llama_context *>(context_pointer);
357358
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
358359

359-
const auto tokens_list = common_tokenize(context, text, 1);
360+
bool parse_special = (format_chat == JNI_TRUE);
361+
const auto tokens_list = common_tokenize(context, text, true, parse_special);
360362

361363
auto n_ctx = llama_n_ctx(context);
362364
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
@@ -368,7 +370,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
368370
}
369371

370372
for (auto id : tokens_list) {
371-
LOGi("%s", common_token_to_piece(context, id).c_str());
373+
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
372374
}
373375

374376
common_batch_clear(*batch);

examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class LLamaAndroid {
6565
context: Long,
6666
batch: Long,
6767
text: String,
68+
formatChat: Boolean,
6869
nLen: Int
6970
): Int
7071

@@ -115,10 +116,10 @@ class LLamaAndroid {
115116
}
116117
}
117118

118-
fun send(message: String): Flow<String> = flow {
119+
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
119120
when (val state = threadLocalState.get()) {
120121
is State.Loaded -> {
121-
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
122+
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
122123
while (ncur.value <= nlen) {
123124
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
124125
if (str == null) {

0 commit comments

Comments
 (0)