|
11 | 11 | # include <curl/curl.h>
|
12 | 12 | #endif
|
13 | 13 |
|
| 14 | +#include <signal.h> |
| 15 | + |
14 | 16 | #include <climits>
|
15 | 17 | #include <cstdarg>
|
16 | 18 | #include <cstdio>
|
|
25 | 27 | #include "json.hpp"
|
26 | 28 | #include "llama-cpp.h"
|
27 | 29 |
|
| 30 | +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) |
| 31 | +[[noreturn]] static void sigint_handler(int) { |
| 32 | + printf("\n"); |
| 33 | + exit(0); // not ideal, but it's the only way to guarantee exit in all cases |
| 34 | +} |
| 35 | +#endif |
| 36 | + |
28 | 37 | GGML_ATTRIBUTE_FORMAT(1, 2)
|
29 | 38 | static std::string fmt(const char * fmt, ...) {
|
30 | 39 | va_list ap;
|
@@ -801,7 +810,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
801 | 810 |
|
802 | 811 | static int read_user_input(std::string & user) {
|
803 | 812 | std::getline(std::cin, user);
|
804 |
| - return user.empty(); // Should have data in happy path |
| 813 | + if (std::cin.eof()) { |
| 814 | + printf("\n"); |
| 815 | + return 1; |
| 816 | + } |
| 817 | + |
| 818 | + if (user == "/bye") { |
| 819 | + return 1; |
| 820 | + } |
| 821 | + |
| 822 | + if (user.empty()) { |
| 823 | + return 2; |
| 824 | + } |
| 825 | + |
| 826 | + return 0; // Should have data in happy path |
805 | 827 | }
|
806 | 828 |
|
807 | 829 | // Function to generate a response based on the prompt
|
@@ -868,15 +890,34 @@ static bool is_stdout_a_terminal() {
|
868 | 890 | #endif
|
869 | 891 | }
|
870 | 892 |
|
871 |
| -// Function to tokenize the prompt |
| 893 | +// Function to handle user input |
| 894 | +static int get_user_input(std::string & user_input, const std::string & user) { |
| 895 | + while (true) { |
| 896 | + const int ret = handle_user_input(user_input, user); |
| 897 | + if (ret == 1) { |
| 898 | + return 1; |
| 899 | + } |
| 900 | + |
| 901 | + if (ret == 2) { |
| 902 | + continue; |
| 903 | + } |
| 904 | + |
| 905 | + break; |
| 906 | + } |
| 907 | + |
| 908 | + return 0; |
| 909 | +} |
| 910 | + |
| 911 | +// Main chat loop function |
872 | 912 | static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
873 | 913 | int prev_len = 0;
|
874 | 914 | llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
875 | 915 | static const bool stdout_a_terminal = is_stdout_a_terminal();
|
876 | 916 | while (true) {
|
877 | 917 | // Get user input
|
878 | 918 | std::string user_input;
|
879 |
| - while (handle_user_input(user_input, user)) { |
| 919 | + if (get_user_input(user_input, user) == 1) { |
| 920 | + return 0; |
880 | 921 | }
|
881 | 922 |
|
882 | 923 | add_message("user", user.empty() ? user_input : user, llama_data);
|
@@ -917,7 +958,23 @@ static std::string read_pipe_data() {
|
917 | 958 | return result.str();
|
918 | 959 | }
|
919 | 960 |
|
| 961 | +static void ctrl_c_handling() { |
| 962 | +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) |
| 963 | + struct sigaction sigint_action; |
| 964 | + sigint_action.sa_handler = sigint_handler; |
| 965 | + sigemptyset(&sigint_action.sa_mask); |
| 966 | + sigint_action.sa_flags = 0; |
| 967 | + sigaction(SIGINT, &sigint_action, NULL); |
| 968 | +#elif defined(_WIN32) |
| 969 | + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { |
| 970 | + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; |
| 971 | + }; |
| 972 | + SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); |
| 973 | +#endif |
| 974 | +} |
| 975 | + |
920 | 976 | int main(int argc, const char ** argv) {
|
| 977 | + ctrl_c_handling(); |
921 | 978 | Opt opt;
|
922 | 979 | const int ret = opt.init(argc, argv);
|
923 | 980 | if (ret == 2) {
|
|
0 commit comments