Skip to content

Commit efc7225

Browse files
server : add "/chat/completions" alias for "/v1/...` (#5722)
* Add "/chat/completions" as alias for "/v1/chat/completions" * merge to upstream master * minor : fix trailing whitespace --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 7c4263d commit efc7225

File tree

3 files changed

+114
-67
lines changed

3 files changed

+114
-67
lines changed

examples/server/server.cpp

Lines changed: 66 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3211,87 +3211,88 @@ int main(int argc, char **argv)
32113211
res.set_content(models.dump(), "application/json; charset=utf-8");
32123212
});
32133213

3214+
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
3215+
{
3216+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3217+
if (!validate_api_key(req, res)) {
3218+
return;
3219+
}
3220+
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
32143221

3215-
// TODO: add mount point without "/v1" prefix -- how?
3216-
svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
3217-
{
3218-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3219-
if (!validate_api_key(req, res)) {
3220-
return;
3221-
}
3222-
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
3223-
3224-
const int task_id = llama.queue_tasks.get_new_id();
3225-
llama.queue_results.add_waiting_task_id(task_id);
3226-
llama.request_completion(task_id, data, false, false, -1);
3222+
const int task_id = llama.queue_tasks.get_new_id();
3223+
llama.queue_results.add_waiting_task_id(task_id);
3224+
llama.request_completion(task_id, data, false, false, -1);
32273225

3228-
if (!json_value(data, "stream", false)) {
3229-
std::string completion_text;
3230-
task_result result = llama.queue_results.recv(task_id);
3226+
if (!json_value(data, "stream", false)) {
3227+
std::string completion_text;
3228+
task_result result = llama.queue_results.recv(task_id);
32313229

3232-
if (!result.error && result.stop) {
3233-
json oaicompat_result = format_final_response_oaicompat(data, result);
3230+
if (!result.error && result.stop) {
3231+
json oaicompat_result = format_final_response_oaicompat(data, result);
32343232

3235-
res.set_content(oaicompat_result.dump(-1, ' ', false,
3236-
json::error_handler_t::replace),
3237-
"application/json; charset=utf-8");
3238-
} else {
3239-
res.status = 500;
3240-
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
3241-
}
3242-
llama.queue_results.remove_waiting_task_id(task_id);
3243-
} else {
3244-
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
3245-
while (true) {
3246-
task_result llama_result = llama.queue_results.recv(task_id);
3247-
if (!llama_result.error) {
3248-
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
3233+
res.set_content(oaicompat_result.dump(-1, ' ', false,
3234+
json::error_handler_t::replace),
3235+
"application/json; charset=utf-8");
3236+
} else {
3237+
res.status = 500;
3238+
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
3239+
}
3240+
llama.queue_results.remove_waiting_task_id(task_id);
3241+
} else {
3242+
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
3243+
while (true) {
3244+
task_result llama_result = llama.queue_results.recv(task_id);
3245+
if (!llama_result.error) {
3246+
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
32493247

3250-
for (auto it = result_array.begin(); it != result_array.end(); ++it)
3251-
{
3252-
if (!it->empty()) {
3253-
const std::string str =
3254-
"data: " +
3255-
it->dump(-1, ' ', false, json::error_handler_t::replace) +
3256-
"\n\n";
3257-
LOG_VERBOSE("data stream", {{"to_send", str}});
3258-
if (!sink.write(str.c_str(), str.size())) {
3259-
llama.queue_results.remove_waiting_task_id(task_id);
3260-
return false;
3261-
}
3262-
}
3263-
}
3264-
if (llama_result.stop) {
3265-
break;
3266-
}
3267-
} else {
3248+
for (auto it = result_array.begin(); it != result_array.end(); ++it)
3249+
{
3250+
if (!it->empty()) {
32683251
const std::string str =
3269-
"error: " +
3270-
llama_result.result_json.dump(-1, ' ', false,
3271-
json::error_handler_t::replace) +
3252+
"data: " +
3253+
it->dump(-1, ' ', false, json::error_handler_t::replace) +
32723254
"\n\n";
32733255
LOG_VERBOSE("data stream", {{"to_send", str}});
32743256
if (!sink.write(str.c_str(), str.size())) {
32753257
llama.queue_results.remove_waiting_task_id(task_id);
32763258
return false;
32773259
}
3278-
break;
32793260
}
32803261
}
3281-
sink.done();
3282-
llama.queue_results.remove_waiting_task_id(task_id);
3283-
return true;
3284-
};
3262+
if (llama_result.stop) {
3263+
break;
3264+
}
3265+
} else {
3266+
const std::string str =
3267+
"error: " +
3268+
llama_result.result_json.dump(-1, ' ', false,
3269+
json::error_handler_t::replace) +
3270+
"\n\n";
3271+
LOG_VERBOSE("data stream", {{"to_send", str}});
3272+
if (!sink.write(str.c_str(), str.size())) {
3273+
llama.queue_results.remove_waiting_task_id(task_id);
3274+
return false;
3275+
}
3276+
break;
3277+
}
3278+
}
3279+
sink.done();
3280+
llama.queue_results.remove_waiting_task_id(task_id);
3281+
return true;
3282+
};
32853283

3286-
auto on_complete = [task_id, &llama](bool) {
3287-
// cancel request
3288-
llama.request_cancel(task_id);
3289-
llama.queue_results.remove_waiting_task_id(task_id);
3290-
};
3284+
auto on_complete = [task_id, &llama](bool) {
3285+
// cancel request
3286+
llama.request_cancel(task_id);
3287+
llama.queue_results.remove_waiting_task_id(task_id);
3288+
};
32913289

3292-
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3293-
}
3294-
});
3290+
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3291+
}
3292+
};
3293+
3294+
svr.Post("/chat/completions", chat_completions);
3295+
svr.Post("/v1/chat/completions", chat_completions);
32953296

32963297
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
32973298
{

examples/server/tests/features/parallel.feature

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,28 @@ Feature: Parallel
5454
| disabled | 128 |
5555
| enabled | 64 |
5656

57+
Scenario Outline: Multi users OAI completions compatibility no v1
58+
Given a system prompt You are a writer.
59+
And a model tinyllama-2
60+
Given a prompt:
61+
"""
62+
Write a very long book.
63+
"""
64+
And a prompt:
65+
"""
66+
Write another a poem.
67+
"""
68+
And <n_predict> max tokens to predict
69+
And streaming is <streaming>
70+
Given concurrent OAI completions requests no v1
71+
Then the server is busy
72+
Then the server is idle
73+
Then all prompts are predicted with <n_predict> tokens
74+
Examples:
75+
| streaming | n_predict |
76+
| disabled | 128 |
77+
| enabled | 64 |
78+
5779
Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
5880
Given a prompt:
5981
"""

examples/server/tests/features/steps/steps.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ async def step_oai_chat_completions(context, api_error):
231231
completion = await oai_chat_completions(context.prompts.pop(),
232232
context.system_prompt,
233233
context.base_url,
234+
'/v1/chat',
234235
False,
235236
model=context.model if hasattr(context, 'model') else None,
236237

@@ -288,6 +289,28 @@ async def step_oai_chat_completions(context):
288289
# user_prompt is inserted automatically
289290
context.system_prompt,
290291
context.base_url,
292+
'/v1/chat/completions',
293+
True, # async_client
294+
model=context.model
295+
if hasattr(context, 'model') else None,
296+
n_predict=context.n_predict
297+
if hasattr(context, 'n_predict') else None,
298+
enable_streaming=context.enable_streaming
299+
if hasattr(context, 'enable_streaming') else None,
300+
server_seed=context.server_seed
301+
if hasattr(context, 'server_seed') else None,
302+
user_api_key=context.user_api_key
303+
if hasattr(context, 'user_api_key') else None)
304+
305+
306+
@step(u'concurrent OAI completions requests no v1')
307+
@async_run_until_complete
308+
async def step_oai_chat_completions(context):
309+
await concurrent_requests(context, oai_chat_completions,
310+
# user_prompt is inserted automatically
311+
context.system_prompt,
312+
context.base_url,
313+
'/chat/completions',
291314
True, # async_client
292315
model=context.model
293316
if hasattr(context, 'model') else None,
@@ -497,6 +520,7 @@ async def request_completion(prompt,
497520
async def oai_chat_completions(user_prompt,
498521
system_prompt,
499522
base_url,
523+
base_path,
500524
async_client,
501525
debug=False,
502526
model=None,
@@ -537,7 +561,7 @@ async def oai_chat_completions(user_prompt,
537561
origin = 'llama.cpp'
538562
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
539563
async with aiohttp.ClientSession() as session:
540-
async with session.post(f'{base_url}/v1/chat/completions',
564+
async with session.post(f'{base_url}{base_path}',
541565
json=payload,
542566
headers=headers) as response:
543567
if enable_streaming:
@@ -579,7 +603,7 @@ async def oai_chat_completions(user_prompt,
579603
else:
580604
try:
581605
openai.api_key = user_api_key
582-
openai.api_base = f'{base_url}/v1/chat'
606+
openai.api_base = f'{base_url}{base_path}'
583607
chat_completion = openai.Completion.create(
584608
messages=payload['messages'],
585609
model=model,

0 commit comments

Comments
 (0)