@@ -531,6 +531,19 @@ struct server_response {
531
531
std::mutex mutex_results;
532
532
std::condition_variable condition_results;
533
533
534
+ // mapping task_id --> sink
535
+ // for tracking HTTP connection state
536
+ std::unordered_map<int , httplib::DataSink *> map_id_to_sink;
537
+
538
+ bool can_send (int id_task) {
539
+ std::unique_lock<std::mutex> lock (mutex_results);
540
+ auto it = map_id_to_sink.find (id_task);
541
+ if (it != map_id_to_sink.end ()) {
542
+ return it->second ->is_writable ();
543
+ }
544
+ return false ;
545
+ }
546
+
534
547
// add the id_task to the list of tasks waiting for response
535
548
void add_waiting_task_id (int id_task) {
536
549
SRV_DBG (" add task %d to waiting list. current waiting = %d (before add)\n " , id_task, (int ) waiting_task_ids.size ());
@@ -548,12 +561,20 @@ struct server_response {
548
561
}
549
562
}
550
563
564
+ void add_waiting_sink (const std::unordered_set<int > & task_ids, httplib::DataSink * sink) {
565
+ std::unique_lock<std::mutex> lock (mutex_results);
566
+ for (auto id : task_ids) {
567
+ map_id_to_sink[id] = sink;
568
+ }
569
+ }
570
+
551
571
// when the request is finished, we can remove task associated with it
552
572
void remove_waiting_task_id (int id_task) {
553
573
SRV_DBG (" remove task %d from waiting list. current waiting = %d (before remove)\n " , id_task, (int ) waiting_task_ids.size ());
554
574
555
575
std::unique_lock<std::mutex> lock (mutex_results);
556
576
waiting_task_ids.erase (id_task);
577
+ map_id_to_sink.erase (id_task);
557
578
}
558
579
559
580
void remove_waiting_task_ids (const std::unordered_set<int > & id_tasks) {
@@ -562,6 +583,7 @@ struct server_response {
562
583
for (const auto & id_task : id_tasks) {
563
584
SRV_DBG (" remove task %d from waiting list. current waiting = %d (before remove)\n " , id_task, (int ) waiting_task_ids.size ());
564
585
waiting_task_ids.erase (id_task);
586
+ map_id_to_sink.erase (id_task);
565
587
}
566
588
}
567
589
@@ -1117,6 +1139,12 @@ struct server_context {
1117
1139
}
1118
1140
1119
1141
bool process_token (completion_token_output & result, server_slot & slot) {
1142
+ // check if connection is still alive
1143
+ if (!queue_results.can_send (slot.id_task )) {
1144
+ slot.release ();
1145
+ return false ;
1146
+ }
1147
+
1120
1148
// remember which tokens were sampled - used for repetition penalties during sampling
1121
1149
const std::string token_str = llama_token_to_piece (ctx, result.tok , params.special );
1122
1150
slot.sampled = result.tok ;
@@ -2920,6 +2948,7 @@ int main(int argc, char ** argv) {
2920
2948
ctx_server.queue_results .remove_waiting_task_ids (task_ids);
2921
2949
} else {
2922
2950
const auto chunked_content_provider = [task_ids, &ctx_server](size_t , httplib::DataSink & sink) {
2951
+ ctx_server.queue_results .add_waiting_sink (task_ids, &sink);
2923
2952
ctx_server.receive_cmpl_results_stream (task_ids, [&](const server_task_result & result) -> bool {
2924
2953
return server_sent_event (sink, " data" , result.data );
2925
2954
}, [&](const json & error_data) {
@@ -2930,7 +2959,7 @@ int main(int argc, char ** argv) {
2930
2959
};
2931
2960
2932
2961
auto on_complete = [task_ids, &ctx_server] (bool ) {
2933
- ctx_server.queue_results .remove_waiting_task_ids (task_ids);
2962
+ ctx_server.queue_results .remove_waiting_task_ids (task_ids); // will also remove the sink
2934
2963
};
2935
2964
2936
2965
res.set_chunked_content_provider (" text/event-stream" , chunked_content_provider, on_complete);
0 commit comments