From b6c9b539c9d60ab0d6d747ed9f3cf0db6bbdc3a8 Mon Sep 17 00:00:00 2001 From: HanClinto Date: Tue, 23 Jul 2024 12:57:09 -0400 Subject: [PATCH] Updated Swift and Android bindings to use the new llama_sampling_* refactor from #8643 --- examples/batched.swift/Sources/main.swift | 14 ++++++++------ .../llama/src/main/cpp/llama-android.cpp | 3 ++- .../llama.swiftui/llama.cpp.swift/LibLlama.swift | 5 ++++- include/llama.h | 2 +- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 616494d2d841d..9a62c5bb7ebfe 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -44,6 +44,8 @@ context_params.n_threads = 8 context_params.n_threads_batch = 8 let context = llama_new_context_with_model(model, context_params) +let smpl = llama_get_sampling(context) + guard context != nil else { print("Failed to initialize context") exit(1) @@ -144,13 +146,13 @@ while n_cur <= n_len { let top_p: Float = 0.9 let temp: Float = 0.4 - llama_sample_top_k(context, &candidates_p, top_k, 1) - llama_sample_top_p(context, &candidates_p, top_p, 1) - llama_sample_temp(context, &candidates_p, temp) + llama_sampling_top_k(smpl, &candidates_p, top_k, 1) + llama_sampling_top_p(smpl, &candidates_p, top_p, 1) + llama_sampling_temp(smpl, &candidates_p, temp) - let new_token_id = llama_sample_token(context, &candidates_p) + let new_token_id = llama_sampling_sample(smpl, &candidates_p) - // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + // const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { @@ -212,7 +214,7 @@ let t_main_end = ggml_time_us() print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n") -llama_print_timings(context) +llama_print_timings(context, smpl, nil) private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2aafe23167557..af3b356bc04b7 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -385,6 +385,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); + const auto sampling = reinterpret_cast(llama_get_sampling(context)); const auto batch = reinterpret_cast(batch_pointer); const auto model = llama_get_model(context); @@ -405,7 +406,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // sample the most likely token - const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); + const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 58c32ca533bb1..bfd273072e627 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama actor LlamaContext { private var model: OpaquePointer private var context: OpaquePointer + private var sampling: OpaquePointer private var batch: llama_batch private var tokens_list: [llama_token] var is_done: Bool = false @@ -42,12 +43,14 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] + self.sampling = llama_get_sampling(context) } deinit { llama_batch_free(batch) llama_free(context) llama_free_model(model) + llama_sampling_free(sampling) llama_backend_free() } @@ -156,7 +159,7 @@ actor LlamaContext { candidates.withUnsafeMutableBufferPointer() { buffer in var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false) - new_token_id = llama_sample_token_greedy(context, &candidates_p) + new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p) } if llama_token_is_eog(model, new_token_id) || n_cur == n_len { diff --git a/include/llama.h b/include/llama.h index 57937ac104148..4ecb1aa0b4c8e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1144,7 +1144,7 @@ extern "C" { float * mu); /// @details Selects the token with the highest probability. - /// Does not compute the token probabilities. Use llama_sample_softmax() instead. + /// Does not compute the token probabilities. Use llama_sampling_softmax() instead. LLAMA_API llama_token llama_sampling_sample_greedy( struct llama_sampling * smpl, llama_token_data_array * candidates);