Skip to content

Commit 365d9a4

Browse files
authored
Merge pull request #481 from c0sogi/main
Added `RouteErrorHandler` for server
2 parents a9cb645 + 1551ba1 commit 365d9a4

File tree

2 files changed

+256
-58
lines changed

2 files changed

+256
-58
lines changed

llama_cpp/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def _create_completion(
845845

846846
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
847847
raise ValueError(
848-
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
848+
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
849849
)
850850

851851
if max_tokens <= 0:

llama_cpp/server/app.py

Lines changed: 255 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import json
22
import multiprocessing
3+
from re import compile, Match, Pattern
34
from threading import Lock
45
from functools import partial
5-
from typing import Iterator, List, Optional, Union, Dict
6+
from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict
67
from typing_extensions import TypedDict, Literal
78

89
import llama_cpp
910

1011
import anyio
1112
from anyio.streams.memory import MemoryObjectSendStream
1213
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
13-
from fastapi import Depends, FastAPI, APIRouter, Request
14+
from fastapi import Depends, FastAPI, APIRouter, Request, Response
1415
from fastapi.middleware.cors import CORSMiddleware
16+
from fastapi.responses import JSONResponse
17+
from fastapi.routing import APIRoute
1518
from pydantic import BaseModel, Field
1619
from pydantic_settings import BaseSettings
1720
from sse_starlette.sse import EventSourceResponse
@@ -94,7 +97,190 @@ class Settings(BaseSettings):
9497
)
9598

9699

97-
router = APIRouter()
100+
class ErrorResponse(TypedDict):
101+
"""OpenAI style error response"""
102+
103+
message: str
104+
type: str
105+
param: Optional[str]
106+
code: Optional[str]
107+
108+
109+
class ErrorResponseFormatters:
110+
"""Collection of formatters for error responses.
111+
112+
Args:
113+
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
114+
Request body
115+
match (Match[str]): Match object from regex pattern
116+
117+
Returns:
118+
tuple[int, ErrorResponse]: Status code and error response
119+
"""
120+
121+
@staticmethod
122+
def context_length_exceeded(
123+
request: Union[
124+
"CreateCompletionRequest", "CreateChatCompletionRequest"
125+
],
126+
match: Match[str],
127+
) -> tuple[int, ErrorResponse]:
128+
"""Formatter for context length exceeded error"""
129+
130+
context_window = int(match.group(2))
131+
prompt_tokens = int(match.group(1))
132+
completion_tokens = request.max_tokens
133+
if hasattr(request, "messages"):
134+
# Chat completion
135+
message = (
136+
"This model's maximum context length is {} tokens. "
137+
"However, you requested {} tokens "
138+
"({} in the messages, {} in the completion). "
139+
"Please reduce the length of the messages or completion."
140+
)
141+
else:
142+
# Text completion
143+
message = (
144+
"This model's maximum context length is {} tokens, "
145+
"however you requested {} tokens "
146+
"({} in your prompt; {} for the completion). "
147+
"Please reduce your prompt; or completion length."
148+
)
149+
return 400, ErrorResponse(
150+
message=message.format(
151+
context_window,
152+
completion_tokens + prompt_tokens,
153+
prompt_tokens,
154+
completion_tokens,
155+
),
156+
type="invalid_request_error",
157+
param="messages",
158+
code="context_length_exceeded",
159+
)
160+
161+
@staticmethod
162+
def model_not_found(
163+
request: Union[
164+
"CreateCompletionRequest", "CreateChatCompletionRequest"
165+
],
166+
match: Match[str],
167+
) -> tuple[int, ErrorResponse]:
168+
"""Formatter for model_not_found error"""
169+
170+
model_path = str(match.group(1))
171+
message = f"The model `{model_path}` does not exist"
172+
return 400, ErrorResponse(
173+
message=message,
174+
type="invalid_request_error",
175+
param=None,
176+
code="model_not_found",
177+
)
178+
179+
180+
class RouteErrorHandler(APIRoute):
181+
"""Custom APIRoute that handles application errors and exceptions"""
182+
183+
# key: regex pattern for original error message from llama_cpp
184+
# value: formatter function
185+
pattern_and_formatters: dict[
186+
"Pattern",
187+
Callable[
188+
[
189+
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
190+
Match[str],
191+
],
192+
tuple[int, ErrorResponse],
193+
],
194+
] = {
195+
compile(
196+
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
197+
): ErrorResponseFormatters.context_length_exceeded,
198+
compile(
199+
r"Model path does not exist: (.+)"
200+
): ErrorResponseFormatters.model_not_found,
201+
}
202+
203+
def error_message_wrapper(
204+
self,
205+
error: Exception,
206+
body: Optional[
207+
Union[
208+
"CreateChatCompletionRequest",
209+
"CreateCompletionRequest",
210+
"CreateEmbeddingRequest",
211+
]
212+
] = None,
213+
) -> tuple[int, ErrorResponse]:
214+
"""Wraps error message in OpenAI style error response"""
215+
216+
if body is not None and isinstance(
217+
body,
218+
(
219+
CreateCompletionRequest,
220+
CreateChatCompletionRequest,
221+
),
222+
):
223+
# When text completion or chat completion
224+
for pattern, callback in self.pattern_and_formatters.items():
225+
match = pattern.search(str(error))
226+
if match is not None:
227+
return callback(body, match)
228+
229+
# Wrap other errors as internal server error
230+
return 500, ErrorResponse(
231+
message=str(error),
232+
type="internal_server_error",
233+
param=None,
234+
code=None,
235+
)
236+
237+
def get_route_handler(
238+
self,
239+
) -> Callable[[Request], Coroutine[None, None, Response]]:
240+
"""Defines custom route handler that catches exceptions and formats
241+
in OpenAI style error response"""
242+
243+
original_route_handler = super().get_route_handler()
244+
245+
async def custom_route_handler(request: Request) -> Response:
246+
try:
247+
return await original_route_handler(request)
248+
except Exception as exc:
249+
json_body = await request.json()
250+
try:
251+
if "messages" in json_body:
252+
# Chat completion
253+
body: Optional[
254+
Union[
255+
CreateChatCompletionRequest,
256+
CreateCompletionRequest,
257+
CreateEmbeddingRequest,
258+
]
259+
] = CreateChatCompletionRequest(**json_body)
260+
elif "prompt" in json_body:
261+
# Text completion
262+
body = CreateCompletionRequest(**json_body)
263+
else:
264+
# Embedding
265+
body = CreateEmbeddingRequest(**json_body)
266+
except Exception:
267+
# Invalid request body
268+
body = None
269+
270+
# Get proper error message from the exception
271+
(
272+
status_code,
273+
error_message,
274+
) = self.error_message_wrapper(error=exc, body=body)
275+
return JSONResponse(
276+
{"error": error_message},
277+
status_code=status_code,
278+
)
279+
280+
return custom_route_handler
281+
282+
283+
router = APIRouter(route_class=RouteErrorHandler)
98284

99285
settings: Optional[Settings] = None
100286
llama: Optional[llama_cpp.Llama] = None
@@ -183,10 +369,33 @@ def get_settings():
183369
yield settings
184370

185371

372+
async def get_event_publisher(
373+
request: Request,
374+
inner_send_chan: MemoryObjectSendStream,
375+
iterator: Iterator,
376+
):
377+
async with inner_send_chan:
378+
try:
379+
async for chunk in iterate_in_threadpool(iterator):
380+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
381+
if await request.is_disconnected():
382+
raise anyio.get_cancelled_exc_class()()
383+
if settings.interrupt_requests and llama_outer_lock.locked():
384+
await inner_send_chan.send(dict(data="[DONE]"))
385+
raise anyio.get_cancelled_exc_class()()
386+
await inner_send_chan.send(dict(data="[DONE]"))
387+
except anyio.get_cancelled_exc_class() as e:
388+
print("disconnected")
389+
with anyio.move_on_after(1, shield=True):
390+
print(
391+
f"Disconnected from client (via refresh/close) {request.client}"
392+
)
393+
raise e
394+
186395
model_field = Field(description="The model to use for generating completions.", default=None)
187396

188397
max_tokens_field = Field(
189-
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
398+
default=16, ge=1, description="The maximum number of tokens to generate."
190399
)
191400

192401
temperature_field = Field(
@@ -374,35 +583,31 @@ async def create_completion(
374583
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
375584
])
376585

377-
if body.stream:
378-
send_chan, recv_chan = anyio.create_memory_object_stream(10)
586+
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
587+
llama_cpp.CompletionChunk
588+
]] = await run_in_threadpool(llama, **kwargs)
379589

380-
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
381-
async with inner_send_chan:
382-
try:
383-
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
384-
async for chunk in iterate_in_threadpool(iterator):
385-
await inner_send_chan.send(dict(data=json.dumps(chunk)))
386-
if await request.is_disconnected():
387-
raise anyio.get_cancelled_exc_class()()
388-
if settings.interrupt_requests and llama_outer_lock.locked():
389-
await inner_send_chan.send(dict(data="[DONE]"))
390-
raise anyio.get_cancelled_exc_class()()
391-
await inner_send_chan.send(dict(data="[DONE]"))
392-
except anyio.get_cancelled_exc_class() as e:
393-
print("disconnected")
394-
with anyio.move_on_after(1, shield=True):
395-
print(
396-
f"Disconnected from client (via refresh/close) {request.client}"
397-
)
398-
raise e
590+
if isinstance(iterator_or_completion, Iterator):
591+
# EAFP: It's easier to ask for forgiveness than permission
592+
first_response = await run_in_threadpool(next, iterator_or_completion)
399593

594+
# If no exception was raised from first_response, we can assume that
595+
# the iterator is valid and we can use it to stream the response.
596+
def iterator() -> Iterator[llama_cpp.CompletionChunk]:
597+
yield first_response
598+
yield from iterator_or_completion
599+
600+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
400601
return EventSourceResponse(
401-
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
402-
) # type: ignore
602+
recv_chan, data_sender_callable=partial( # type: ignore
603+
get_event_publisher,
604+
request=request,
605+
inner_send_chan=send_chan,
606+
iterator=iterator(),
607+
)
608+
)
403609
else:
404-
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
405-
return completion
610+
return iterator_or_completion
406611

407612

408613
class CreateEmbeddingRequest(BaseModel):
@@ -505,38 +710,31 @@ async def create_chat_completion(
505710
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
506711
])
507712

508-
if body.stream:
509-
send_chan, recv_chan = anyio.create_memory_object_stream(10)
713+
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
714+
llama_cpp.ChatCompletionChunk
715+
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
510716

511-
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
512-
async with inner_send_chan:
513-
try:
514-
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
515-
async for chat_chunk in iterate_in_threadpool(iterator):
516-
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
517-
if await request.is_disconnected():
518-
raise anyio.get_cancelled_exc_class()()
519-
if settings.interrupt_requests and llama_outer_lock.locked():
520-
await inner_send_chan.send(dict(data="[DONE]"))
521-
raise anyio.get_cancelled_exc_class()()
522-
await inner_send_chan.send(dict(data="[DONE]"))
523-
except anyio.get_cancelled_exc_class() as e:
524-
print("disconnected")
525-
with anyio.move_on_after(1, shield=True):
526-
print(
527-
f"Disconnected from client (via refresh/close) {request.client}"
528-
)
529-
raise e
717+
if isinstance(iterator_or_completion, Iterator):
718+
# EAFP: It's easier to ask for forgiveness than permission
719+
first_response = await run_in_threadpool(next, iterator_or_completion)
720+
721+
# If no exception was raised from first_response, we can assume that
722+
# the iterator is valid and we can use it to stream the response.
723+
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
724+
yield first_response
725+
yield from iterator_or_completion
530726

727+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
531728
return EventSourceResponse(
532-
recv_chan,
533-
data_sender_callable=partial(event_publisher, send_chan),
534-
) # type: ignore
535-
else:
536-
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
537-
llama.create_chat_completion, **kwargs # type: ignore
729+
recv_chan, data_sender_callable=partial( # type: ignore
730+
get_event_publisher,
731+
request=request,
732+
inner_send_chan=send_chan,
733+
iterator=iterator(),
734+
)
538735
)
539-
return completion
736+
else:
737+
return iterator_or_completion
540738

541739

542740
class ModelData(TypedDict):

0 commit comments

Comments
 (0)