Skip to content

Commit 2e1c355

Browse files
committed
(poc) track connection state in server
1 parent 1788077 commit 2e1c355

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

examples/server/server.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,19 @@ struct server_response {
531531
std::mutex mutex_results;
532532
std::condition_variable condition_results;
533533

534+
// mapping task_id --> sink
535+
// for tracking HTTP connection state
536+
std::unordered_map<int, httplib::DataSink *> map_id_to_sink;
537+
538+
bool can_send(int id_task) {
539+
std::unique_lock<std::mutex> lock(mutex_results);
540+
auto it = map_id_to_sink.find(id_task);
541+
if (it != map_id_to_sink.end()) {
542+
return it->second->is_writable();
543+
}
544+
return false;
545+
}
546+
534547
// add the id_task to the list of tasks waiting for response
535548
void add_waiting_task_id(int id_task) {
536549
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
@@ -548,12 +561,20 @@ struct server_response {
548561
}
549562
}
550563

564+
void add_waiting_sink(const std::unordered_set<int> & task_ids, httplib::DataSink * sink) {
565+
std::unique_lock<std::mutex> lock(mutex_results);
566+
for (auto id : task_ids) {
567+
map_id_to_sink[id] = sink;
568+
}
569+
}
570+
551571
// when the request is finished, we can remove task associated with it
552572
void remove_waiting_task_id(int id_task) {
553573
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
554574

555575
std::unique_lock<std::mutex> lock(mutex_results);
556576
waiting_task_ids.erase(id_task);
577+
map_id_to_sink.erase(id_task);
557578
}
558579

559580
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
@@ -562,6 +583,7 @@ struct server_response {
562583
for (const auto & id_task : id_tasks) {
563584
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
564585
waiting_task_ids.erase(id_task);
586+
map_id_to_sink.erase(id_task);
565587
}
566588
}
567589

@@ -1117,6 +1139,12 @@ struct server_context {
11171139
}
11181140

11191141
bool process_token(completion_token_output & result, server_slot & slot) {
1142+
// check if connection is still alive
1143+
if (!queue_results.can_send(slot.id_task)) {
1144+
slot.release();
1145+
return false;
1146+
}
1147+
11201148
// remember which tokens were sampled - used for repetition penalties during sampling
11211149
const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
11221150
slot.sampled = result.tok;
@@ -2920,6 +2948,7 @@ int main(int argc, char ** argv) {
29202948
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
29212949
} else {
29222950
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
2951+
ctx_server.queue_results.add_waiting_sink(task_ids, &sink);
29232952
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
29242953
return server_sent_event(sink, "data", result.data);
29252954
}, [&](const json & error_data) {
@@ -2930,7 +2959,7 @@ int main(int argc, char ** argv) {
29302959
};
29312960

29322961
auto on_complete = [task_ids, &ctx_server] (bool) {
2933-
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2962+
ctx_server.queue_results.remove_waiting_task_ids(task_ids); // will also remove the sink
29342963
};
29352964

29362965
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);

0 commit comments

Comments
 (0)