diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 48ef8ff2a237b..f4fb0ac1005d2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -29,6 +29,8 @@ #include #include #include +#include +#include using json = nlohmann::ordered_json; @@ -201,7 +203,7 @@ struct server_slot { double t_prompt_processing; // ms double t_token_generation; // ms - + void reset() { n_prompt_tokens = 0; generated_text = ""; @@ -463,6 +465,50 @@ struct server_queue { condition_tasks.notify_all(); } + //adding server health checking + std::string hostname_health = "127.0.0.1"; + std::string port_health = "8080"; + + bool check_server_health(const std::string& server, const std::string& port) { + using namespace boost::asio; + io_service svc; + ip::tcp::socket socket(svc); + ip::tcp::resolver resolver(svc); + boost::system::error_code ec; + + // Try to connect + connect(socket, resolver.resolve({server, port}), ec); + if (ec) { + std::cout << "Connection failed: " << ec.message() << std::endl; + return false; + } + + // Send HTTP GET request to /health endpoint + std::string request = "GET /health HTTP/1.1\r\nHost: " + server + "\r\n\r\n"; + write(socket, buffer(request), ec); + if (ec) { + std::cout << "Write failed: " << ec.message() << std::endl; + return false; + } + + // Read the response + boost::asio::streambuf response; + read_until(socket, response, "\r\n", ec); + std::istream response_stream(&response); + std::string http_version; + unsigned int status_code; + response_stream >> http_version >> status_code; + + bool server_status_ok = false; + + // Check HTTP response status code + if (status_code == 200 || status_code == 500 || status_code == 503) { + server_status_ok = true; + } + + return server_status_ok + } + /** * Main loop consists of these steps: * - Wait until a new task arrives @@ -474,6 +520,13 @@ struct server_queue { running = true; while (true) { + bool health_check = check_server_health(hostname_health, port_health); + if (health_check == false) { + while(!queue_tasks.empty()) { + queue_tasks.erase(queue_tasks.begin()); + } + break; + } LOG_VERBOSE("new task may arrive", {}); while (true) { diff --git a/examples/server/tests/req-cancel-testing.py b/examples/server/tests/req-cancel-testing.py new file mode 100644 index 0000000000000..9b08192de40e2 --- /dev/null +++ b/examples/server/tests/req-cancel-testing.py @@ -0,0 +1,87 @@ +import threading +import requests + +# Stats +total_requests = 0 +requests_executed = 0 +requests_cancelled = 0 +requests_remaining = 0 + +class StoppableThread(threading.Thread): + def __init__(self, *args, **kwargs): + super(StoppableThread, self).__init__(*args, **kwargs) + self.stop_event = threading.Event() + + def stop(self): + self.stop_event.set() + + def stopped(self): + return self.stop_event.is_set() + +def send_request(stop_event): + try: + url = 'http://127.0.0.1:8080/completion' + data = { + 'prompt': 'Hello llama', + 'n_predict': 2 + } + if not stop_event.is_set(): + response = requests.post(url, json=data, timeout=60) # Reduced timeout for testing + print('Response:', response.text) + global requests_executed + requests_executed += 1 + except requests.exceptions.Timeout: + print('Request timed out') + except Exception as e: + print('An error occurred:', str(e)) + +def get_health(): + try: + url = 'http://127.0.0.1:8080/health' + response = requests.get(url, timeout=10) + return response.status_code + except requests.exceptions.Timeout: + print('Health check timed out') + return + except Exception as e: + print('An error occurred during health check:', str(e)) + return + + +# User input for the number of requests +num_requests = int(input("How many requests would you like to post?\n")) + +total_requests = num_requests + +# Launching multiple requests +for i in range(num_requests): + health = get_health() + ok_status = False ##our server status + + if health == 503 or health == 500 or health == 200: + ok_status = True + + if ok_status == False: + print(f"Server is not running. Status:{health}. Exiting now...\n") + requests_cancelled = total_requests - i + break + + stop_event = threading.Event() + req_thread = StoppableThread(target=send_request, args=(stop_event,)) + req_thread.start() + + input("Press Enter when request is complete or you would like to stop the request!\n") + if not stop_event.is_set(): + stop_event.set() + + req_thread.join() # Ensure the thread finishes + +requests_remaining = total_requests - requests_executed - requests_cancelled + +print("\nSummary:") +print(f"Total requests: {total_requests}") +print(f"Requests executed: {requests_executed}") +print(f"Requests cancelled: {requests_cancelled}") +print(f"Requests remaining: {requests_remaining}") + +