Skip to content

Commit a6744e4

Browse files
slarenngxson
andauthored
llama : add simple-chat example (#10124)
* llama : add simple-chat example --------- Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
1 parent e991e31 commit a6744e4

File tree

6 files changed

+220
-4
lines changed

6 files changed

+220
-4
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ BUILD_TARGETS = \
3434
llama-save-load-state \
3535
llama-server \
3636
llama-simple \
37+
llama-simple-chat \
3738
llama-speculative \
3839
llama-tokenize \
3940
llama-vdot \
@@ -1287,6 +1288,11 @@ llama-simple: examples/simple/simple.cpp \
12871288
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
12881289
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
12891290

1291+
llama-simple-chat: examples/simple-chat/simple-chat.cpp \
1292+
$(OBJ_ALL)
1293+
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
1294+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
1295+
12901296
llama-tokenize: examples/tokenize/tokenize.cpp \
12911297
$(OBJ_ALL)
12921298
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ else()
4949
endif()
5050
add_subdirectory(save-load-state)
5151
add_subdirectory(simple)
52+
add_subdirectory(simple-chat)
5253
add_subdirectory(speculative)
5354
add_subdirectory(tokenize)
5455
endif()

examples/simple-chat/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-simple-chat)
2+
add_executable(${TARGET} simple-chat.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/simple-chat/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# llama.cpp/example/simple-chat
2+
3+
The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file.
4+
5+
```bash
6+
./llama-simple-chat -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048
7+
...

examples/simple-chat/simple-chat.cpp

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#include "llama.h"
2+
#include <cstdio>
3+
#include <cstring>
4+
#include <iostream>
5+
#include <string>
6+
#include <vector>
7+
8+
static void print_usage(int, char ** argv) {
9+
printf("\nexample usage:\n");
10+
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
11+
printf("\n");
12+
}
13+
14+
int main(int argc, char ** argv) {
15+
std::string model_path;
16+
int ngl = 99;
17+
int n_ctx = 2048;
18+
19+
// parse command line arguments
20+
for (int i = 1; i < argc; i++) {
21+
try {
22+
if (strcmp(argv[i], "-m") == 0) {
23+
if (i + 1 < argc) {
24+
model_path = argv[++i];
25+
} else {
26+
print_usage(argc, argv);
27+
return 1;
28+
}
29+
} else if (strcmp(argv[i], "-c") == 0) {
30+
if (i + 1 < argc) {
31+
n_ctx = std::stoi(argv[++i]);
32+
} else {
33+
print_usage(argc, argv);
34+
return 1;
35+
}
36+
} else if (strcmp(argv[i], "-ngl") == 0) {
37+
if (i + 1 < argc) {
38+
ngl = std::stoi(argv[++i]);
39+
} else {
40+
print_usage(argc, argv);
41+
return 1;
42+
}
43+
} else {
44+
print_usage(argc, argv);
45+
return 1;
46+
}
47+
} catch (std::exception & e) {
48+
fprintf(stderr, "error: %s\n", e.what());
49+
print_usage(argc, argv);
50+
return 1;
51+
}
52+
}
53+
if (model_path.empty()) {
54+
print_usage(argc, argv);
55+
return 1;
56+
}
57+
58+
// only print errors
59+
llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) {
60+
if (level >= GGML_LOG_LEVEL_ERROR) {
61+
fprintf(stderr, "%s", text);
62+
}
63+
}, nullptr);
64+
65+
// initialize the model
66+
llama_model_params model_params = llama_model_default_params();
67+
model_params.n_gpu_layers = ngl;
68+
69+
llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params);
70+
if (!model) {
71+
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
72+
return 1;
73+
}
74+
75+
// initialize the context
76+
llama_context_params ctx_params = llama_context_default_params();
77+
ctx_params.n_ctx = n_ctx;
78+
ctx_params.n_batch = n_ctx;
79+
80+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
81+
if (!ctx) {
82+
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
83+
return 1;
84+
}
85+
86+
// initialize the sampler
87+
llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());
88+
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1));
89+
llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
90+
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
91+
92+
// helper function to evaluate a prompt and generate a response
93+
auto generate = [&](const std::string & prompt) {
94+
std::string response;
95+
96+
// tokenize the prompt
97+
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
98+
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
99+
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
100+
GGML_ABORT("failed to tokenize the prompt\n");
101+
}
102+
103+
// prepare a batch for the prompt
104+
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
105+
llama_token new_token_id;
106+
while (true) {
107+
// check if we have enough space in the context to evaluate this batch
108+
int n_ctx = llama_n_ctx(ctx);
109+
int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
110+
if (n_ctx_used + batch.n_tokens > n_ctx) {
111+
printf("\033[0m\n");
112+
fprintf(stderr, "context size exceeded\n");
113+
exit(0);
114+
}
115+
116+
if (llama_decode(ctx, batch)) {
117+
GGML_ABORT("failed to decode\n");
118+
}
119+
120+
// sample the next token
121+
new_token_id = llama_sampler_sample(smpl, ctx, -1);
122+
123+
// is it an end of generation?
124+
if (llama_token_is_eog(model, new_token_id)) {
125+
break;
126+
}
127+
128+
// convert the token to a string, print it and add it to the response
129+
char buf[256];
130+
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
131+
if (n < 0) {
132+
GGML_ABORT("failed to convert token to piece\n");
133+
}
134+
std::string piece(buf, n);
135+
printf("%s", piece.c_str());
136+
fflush(stdout);
137+
response += piece;
138+
139+
// prepare the next batch with the sampled token
140+
batch = llama_batch_get_one(&new_token_id, 1);
141+
}
142+
143+
return response;
144+
};
145+
146+
std::vector<llama_chat_message> messages;
147+
std::vector<char> formatted(llama_n_ctx(ctx));
148+
int prev_len = 0;
149+
while (true) {
150+
// get user input
151+
printf("\033[32m> \033[0m");
152+
std::string user;
153+
std::getline(std::cin, user);
154+
155+
if (user.empty()) {
156+
break;
157+
}
158+
159+
// add the user input to the message list and format it
160+
messages.push_back({"user", strdup(user.c_str())});
161+
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
162+
if (new_len > (int)formatted.size()) {
163+
formatted.resize(new_len);
164+
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
165+
}
166+
if (new_len < 0) {
167+
fprintf(stderr, "failed to apply the chat template\n");
168+
return 1;
169+
}
170+
171+
// remove previous messages to obtain the prompt to generate the response
172+
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
173+
174+
// generate a response
175+
printf("\033[33m");
176+
std::string response = generate(prompt);
177+
printf("\n\033[0m");
178+
179+
// add the response to the messages
180+
messages.push_back({"assistant", strdup(response.c_str())});
181+
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
182+
if (prev_len < 0) {
183+
fprintf(stderr, "failed to apply the chat template\n");
184+
return 1;
185+
}
186+
}
187+
188+
// free resources
189+
for (auto & msg : messages) {
190+
free(const_cast<char *>(msg.content));
191+
}
192+
llama_sampler_free(smpl);
193+
llama_free(ctx);
194+
llama_free_model(model);
195+
196+
return 0;
197+
}

ggml/include/ggml.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -558,10 +558,10 @@ extern "C" {
558558

559559
enum ggml_log_level {
560560
GGML_LOG_LEVEL_NONE = 0,
561-
GGML_LOG_LEVEL_INFO = 1,
562-
GGML_LOG_LEVEL_WARN = 2,
563-
GGML_LOG_LEVEL_ERROR = 3,
564-
GGML_LOG_LEVEL_DEBUG = 4,
561+
GGML_LOG_LEVEL_DEBUG = 1,
562+
GGML_LOG_LEVEL_INFO = 2,
563+
GGML_LOG_LEVEL_WARN = 3,
564+
GGML_LOG_LEVEL_ERROR = 4,
565565
GGML_LOG_LEVEL_CONT = 5, // continue previous log
566566
};
567567

0 commit comments

Comments
 (0)