Skip to content

Commit 1551ba1

Browse files
committed
Added RouteErrorHandler for server
1 parent 6d8892f commit 1551ba1

File tree

2 files changed

+256
-58
lines changed

2 files changed

+256
-58
lines changed

llama_cpp/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def _create_completion(
845845

846846
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
847847
raise ValueError(
848-
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
848+
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
849849
)
850850

851851
if max_tokens <= 0:

llama_cpp/server/app.py

Lines changed: 255 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import json
22
import multiprocessing
3+
from re import compile, Match, Pattern
34
from threading import Lock
45
from functools import partial
5-
from typing import Iterator, List, Optional, Union, Dict
6+
from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict
67
from typing_extensions import TypedDict, Literal
78

89
import llama_cpp
910

1011
import anyio
1112
from anyio.streams.memory import MemoryObjectSendStream
1213
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
13-
from fastapi import Depends, FastAPI, APIRouter, Request
14+
from fastapi import Depends, FastAPI, APIRouter, Request, Response
1415
from fastapi.middleware.cors import CORSMiddleware
16+
from fastapi.responses import JSONResponse
17+
from fastapi.routing import APIRoute
1518
from pydantic import BaseModel, Field
1619
from pydantic_settings import BaseSettings
1720
from sse_starlette.sse import EventSourceResponse
@@ -92,7 +95,190 @@ class Settings(BaseSettings):
9295
)
9396

9497

95-
router = APIRouter()
98+
class ErrorResponse(TypedDict):
99+
"""OpenAI style error response"""
100+
101+
message: str
102+
type: str
103+
param: Optional[str]
104+
code: Optional[str]
105+
106+
107+
class ErrorResponseFormatters:
108+
"""Collection of formatters for error responses.
109+
110+
Args:
111+
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
112+
Request body
113+
match (Match[str]): Match object from regex pattern
114+
115+
Returns:
116+
tuple[int, ErrorResponse]: Status code and error response
117+
"""
118+
119+
@staticmethod
120+
def context_length_exceeded(
121+
request: Union[
122+
"CreateCompletionRequest", "CreateChatCompletionRequest"
123+
],
124+
match: Match[str],
125+
) -> tuple[int, ErrorResponse]:
126+
"""Formatter for context length exceeded error"""
127+
128+
context_window = int(match.group(2))
129+
prompt_tokens = int(match.group(1))
130+
completion_tokens = request.max_tokens
131+
if hasattr(request, "messages"):
132+
# Chat completion
133+
message = (
134+
"This model's maximum context length is {} tokens. "
135+
"However, you requested {} tokens "
136+
"({} in the messages, {} in the completion). "
137+
"Please reduce the length of the messages or completion."
138+
)
139+
else:
140+
# Text completion
141+
message = (
142+
"This model's maximum context length is {} tokens, "
143+
"however you requested {} tokens "
144+
"({} in your prompt; {} for the completion). "
145+
"Please reduce your prompt; or completion length."
146+
)
147+
return 400, ErrorResponse(
148+
message=message.format(
149+
context_window,
150+
completion_tokens + prompt_tokens,
151+
prompt_tokens,
152+
completion_tokens,
153+
),
154+
type="invalid_request_error",
155+
param="messages",
156+
code="context_length_exceeded",
157+
)
158+
159+
@staticmethod
160+
def model_not_found(
161+
request: Union[
162+
"CreateCompletionRequest", "CreateChatCompletionRequest"
163+
],
164+
match: Match[str],
165+
) -> tuple[int, ErrorResponse]:
166+
"""Formatter for model_not_found error"""
167+
168+
model_path = str(match.group(1))
169+
message = f"The model `{model_path}` does not exist"
170+
return 400, ErrorResponse(
171+
message=message,
172+
type="invalid_request_error",
173+
param=None,
174+
code="model_not_found",
175+
)
176+
177+
178+
class RouteErrorHandler(APIRoute):
179+
"""Custom APIRoute that handles application errors and exceptions"""
180+
181+
# key: regex pattern for original error message from llama_cpp
182+
# value: formatter function
183+
pattern_and_formatters: dict[
184+
"Pattern",
185+
Callable[
186+
[
187+
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
188+
Match[str],
189+
],
190+
tuple[int, ErrorResponse],
191+
],
192+
] = {
193+
compile(
194+
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
195+
): ErrorResponseFormatters.context_length_exceeded,
196+
compile(
197+
r"Model path does not exist: (.+)"
198+
): ErrorResponseFormatters.model_not_found,
199+
}
200+
201+
def error_message_wrapper(
202+
self,
203+
error: Exception,
204+
body: Optional[
205+
Union[
206+
"CreateChatCompletionRequest",
207+
"CreateCompletionRequest",
208+
"CreateEmbeddingRequest",
209+
]
210+
] = None,
211+
) -> tuple[int, ErrorResponse]:
212+
"""Wraps error message in OpenAI style error response"""
213+
214+
if body is not None and isinstance(
215+
body,
216+
(
217+
CreateCompletionRequest,
218+
CreateChatCompletionRequest,
219+
),
220+
):
221+
# When text completion or chat completion
222+
for pattern, callback in self.pattern_and_formatters.items():
223+
match = pattern.search(str(error))
224+
if match is not None:
225+
return callback(body, match)
226+
227+
# Wrap other errors as internal server error
228+
return 500, ErrorResponse(
229+
message=str(error),
230+
type="internal_server_error",
231+
param=None,
232+
code=None,
233+
)
234+
235+
def get_route_handler(
236+
self,
237+
) -> Callable[[Request], Coroutine[None, None, Response]]:
238+
"""Defines custom route handler that catches exceptions and formats
239+
in OpenAI style error response"""
240+
241+
original_route_handler = super().get_route_handler()
242+
243+
async def custom_route_handler(request: Request) -> Response:
244+
try:
245+
return await original_route_handler(request)
246+
except Exception as exc:
247+
json_body = await request.json()
248+
try:
249+
if "messages" in json_body:
250+
# Chat completion
251+
body: Optional[
252+
Union[
253+
CreateChatCompletionRequest,
254+
CreateCompletionRequest,
255+
CreateEmbeddingRequest,
256+
]
257+
] = CreateChatCompletionRequest(**json_body)
258+
elif "prompt" in json_body:
259+
# Text completion
260+
body = CreateCompletionRequest(**json_body)
261+
else:
262+
# Embedding
263+
body = CreateEmbeddingRequest(**json_body)
264+
except Exception:
265+
# Invalid request body
266+
body = None
267+
268+
# Get proper error message from the exception
269+
(
270+
status_code,
271+
error_message,
272+
) = self.error_message_wrapper(error=exc, body=body)
273+
return JSONResponse(
274+
{"error": error_message},
275+
status_code=status_code,
276+
)
277+
278+
return custom_route_handler
279+
280+
281+
router = APIRouter(route_class=RouteErrorHandler)
96282

97283
settings: Optional[Settings] = None
98284
llama: Optional[llama_cpp.Llama] = None
@@ -179,10 +365,33 @@ def get_settings():
179365
yield settings
180366

181367

368+
async def get_event_publisher(
369+
request: Request,
370+
inner_send_chan: MemoryObjectSendStream,
371+
iterator: Iterator,
372+
):
373+
async with inner_send_chan:
374+
try:
375+
async for chunk in iterate_in_threadpool(iterator):
376+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
377+
if await request.is_disconnected():
378+
raise anyio.get_cancelled_exc_class()()
379+
if settings.interrupt_requests and llama_outer_lock.locked():
380+
await inner_send_chan.send(dict(data="[DONE]"))
381+
raise anyio.get_cancelled_exc_class()()
382+
await inner_send_chan.send(dict(data="[DONE]"))
383+
except anyio.get_cancelled_exc_class() as e:
384+
print("disconnected")
385+
with anyio.move_on_after(1, shield=True):
386+
print(
387+
f"Disconnected from client (via refresh/close) {request.client}"
388+
)
389+
raise e
390+
182391
model_field = Field(description="The model to use for generating completions.", default=None)
183392

184393
max_tokens_field = Field(
185-
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
394+
default=16, ge=1, description="The maximum number of tokens to generate."
186395
)
187396

188397
temperature_field = Field(
@@ -370,35 +579,31 @@ async def create_completion(
370579
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
371580
])
372581

373-
if body.stream:
374-
send_chan, recv_chan = anyio.create_memory_object_stream(10)
582+
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
583+
llama_cpp.CompletionChunk
584+
]] = await run_in_threadpool(llama, **kwargs)
375585

376-
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
377-
async with inner_send_chan:
378-
try:
379-
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
380-
async for chunk in iterate_in_threadpool(iterator):
381-
await inner_send_chan.send(dict(data=json.dumps(chunk)))
382-
if await request.is_disconnected():
383-
raise anyio.get_cancelled_exc_class()()
384-
if settings.interrupt_requests and llama_outer_lock.locked():
385-
await inner_send_chan.send(dict(data="[DONE]"))
386-
raise anyio.get_cancelled_exc_class()()
387-
await inner_send_chan.send(dict(data="[DONE]"))
388-
except anyio.get_cancelled_exc_class() as e:
389-
print("disconnected")
390-
with anyio.move_on_after(1, shield=True):
391-
print(
392-
f"Disconnected from client (via refresh/close) {request.client}"
393-
)
394-
raise e
586+
if isinstance(iterator_or_completion, Iterator):
587+
# EAFP: It's easier to ask for forgiveness than permission
588+
first_response = await run_in_threadpool(next, iterator_or_completion)
395589

590+
# If no exception was raised from first_response, we can assume that
591+
# the iterator is valid and we can use it to stream the response.
592+
def iterator() -> Iterator[llama_cpp.CompletionChunk]:
593+
yield first_response
594+
yield from iterator_or_completion
595+
596+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
396597
return EventSourceResponse(
397-
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
398-
) # type: ignore
598+
recv_chan, data_sender_callable=partial( # type: ignore
599+
get_event_publisher,
600+
request=request,
601+
inner_send_chan=send_chan,
602+
iterator=iterator(),
603+
)
604+
)
399605
else:
400-
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
401-
return completion
606+
return iterator_or_completion
402607

403608

404609
class CreateEmbeddingRequest(BaseModel):
@@ -501,38 +706,31 @@ async def create_chat_completion(
501706
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
502707
])
503708

504-
if body.stream:
505-
send_chan, recv_chan = anyio.create_memory_object_stream(10)
709+
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
710+
llama_cpp.ChatCompletionChunk
711+
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
506712

507-
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
508-
async with inner_send_chan:
509-
try:
510-
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
511-
async for chat_chunk in iterate_in_threadpool(iterator):
512-
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
513-
if await request.is_disconnected():
514-
raise anyio.get_cancelled_exc_class()()
515-
if settings.interrupt_requests and llama_outer_lock.locked():
516-
await inner_send_chan.send(dict(data="[DONE]"))
517-
raise anyio.get_cancelled_exc_class()()
518-
await inner_send_chan.send(dict(data="[DONE]"))
519-
except anyio.get_cancelled_exc_class() as e:
520-
print("disconnected")
521-
with anyio.move_on_after(1, shield=True):
522-
print(
523-
f"Disconnected from client (via refresh/close) {request.client}"
524-
)
525-
raise e
713+
if isinstance(iterator_or_completion, Iterator):
714+
# EAFP: It's easier to ask for forgiveness than permission
715+
first_response = await run_in_threadpool(next, iterator_or_completion)
716+
717+
# If no exception was raised from first_response, we can assume that
718+
# the iterator is valid and we can use it to stream the response.
719+
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
720+
yield first_response
721+
yield from iterator_or_completion
526722

723+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
527724
return EventSourceResponse(
528-
recv_chan,
529-
data_sender_callable=partial(event_publisher, send_chan),
530-
) # type: ignore
531-
else:
532-
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
533-
llama.create_chat_completion, **kwargs # type: ignore
725+
recv_chan, data_sender_callable=partial( # type: ignore
726+
get_event_publisher,
727+
request=request,
728+
inner_send_chan=send_chan,
729+
iterator=iterator(),
730+
)
534731
)
535-
return completion
732+
else:
733+
return iterator_or_completion
536734

537735

538736
class ModelData(TypedDict):

0 commit comments

Comments
 (0)