|
1 | 1 | import json
|
2 | 2 | import multiprocessing
|
| 3 | +from re import compile, Match, Pattern |
3 | 4 | from threading import Lock
|
4 | 5 | from functools import partial
|
5 |
| -from typing import Iterator, List, Optional, Union, Dict |
| 6 | +from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict |
6 | 7 | from typing_extensions import TypedDict, Literal
|
7 | 8 |
|
8 | 9 | import llama_cpp
|
9 | 10 |
|
10 | 11 | import anyio
|
11 | 12 | from anyio.streams.memory import MemoryObjectSendStream
|
12 | 13 | from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
13 |
| -from fastapi import Depends, FastAPI, APIRouter, Request |
| 14 | +from fastapi import Depends, FastAPI, APIRouter, Request, Response |
14 | 15 | from fastapi.middleware.cors import CORSMiddleware
|
| 16 | +from fastapi.responses import JSONResponse |
| 17 | +from fastapi.routing import APIRoute |
15 | 18 | from pydantic import BaseModel, Field
|
16 | 19 | from pydantic_settings import BaseSettings
|
17 | 20 | from sse_starlette.sse import EventSourceResponse
|
@@ -92,7 +95,190 @@ class Settings(BaseSettings):
|
92 | 95 | )
|
93 | 96 |
|
94 | 97 |
|
95 |
| -router = APIRouter() |
| 98 | +class ErrorResponse(TypedDict): |
| 99 | + """OpenAI style error response""" |
| 100 | + |
| 101 | + message: str |
| 102 | + type: str |
| 103 | + param: Optional[str] |
| 104 | + code: Optional[str] |
| 105 | + |
| 106 | + |
| 107 | +class ErrorResponseFormatters: |
| 108 | + """Collection of formatters for error responses. |
| 109 | +
|
| 110 | + Args: |
| 111 | + request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): |
| 112 | + Request body |
| 113 | + match (Match[str]): Match object from regex pattern |
| 114 | +
|
| 115 | + Returns: |
| 116 | + tuple[int, ErrorResponse]: Status code and error response |
| 117 | + """ |
| 118 | + |
| 119 | + @staticmethod |
| 120 | + def context_length_exceeded( |
| 121 | + request: Union[ |
| 122 | + "CreateCompletionRequest", "CreateChatCompletionRequest" |
| 123 | + ], |
| 124 | + match: Match[str], |
| 125 | + ) -> tuple[int, ErrorResponse]: |
| 126 | + """Formatter for context length exceeded error""" |
| 127 | + |
| 128 | + context_window = int(match.group(2)) |
| 129 | + prompt_tokens = int(match.group(1)) |
| 130 | + completion_tokens = request.max_tokens |
| 131 | + if hasattr(request, "messages"): |
| 132 | + # Chat completion |
| 133 | + message = ( |
| 134 | + "This model's maximum context length is {} tokens. " |
| 135 | + "However, you requested {} tokens " |
| 136 | + "({} in the messages, {} in the completion). " |
| 137 | + "Please reduce the length of the messages or completion." |
| 138 | + ) |
| 139 | + else: |
| 140 | + # Text completion |
| 141 | + message = ( |
| 142 | + "This model's maximum context length is {} tokens, " |
| 143 | + "however you requested {} tokens " |
| 144 | + "({} in your prompt; {} for the completion). " |
| 145 | + "Please reduce your prompt; or completion length." |
| 146 | + ) |
| 147 | + return 400, ErrorResponse( |
| 148 | + message=message.format( |
| 149 | + context_window, |
| 150 | + completion_tokens + prompt_tokens, |
| 151 | + prompt_tokens, |
| 152 | + completion_tokens, |
| 153 | + ), |
| 154 | + type="invalid_request_error", |
| 155 | + param="messages", |
| 156 | + code="context_length_exceeded", |
| 157 | + ) |
| 158 | + |
| 159 | + @staticmethod |
| 160 | + def model_not_found( |
| 161 | + request: Union[ |
| 162 | + "CreateCompletionRequest", "CreateChatCompletionRequest" |
| 163 | + ], |
| 164 | + match: Match[str], |
| 165 | + ) -> tuple[int, ErrorResponse]: |
| 166 | + """Formatter for model_not_found error""" |
| 167 | + |
| 168 | + model_path = str(match.group(1)) |
| 169 | + message = f"The model `{model_path}` does not exist" |
| 170 | + return 400, ErrorResponse( |
| 171 | + message=message, |
| 172 | + type="invalid_request_error", |
| 173 | + param=None, |
| 174 | + code="model_not_found", |
| 175 | + ) |
| 176 | + |
| 177 | + |
| 178 | +class RouteErrorHandler(APIRoute): |
| 179 | + """Custom APIRoute that handles application errors and exceptions""" |
| 180 | + |
| 181 | + # key: regex pattern for original error message from llama_cpp |
| 182 | + # value: formatter function |
| 183 | + pattern_and_formatters: dict[ |
| 184 | + "Pattern", |
| 185 | + Callable[ |
| 186 | + [ |
| 187 | + Union["CreateCompletionRequest", "CreateChatCompletionRequest"], |
| 188 | + Match[str], |
| 189 | + ], |
| 190 | + tuple[int, ErrorResponse], |
| 191 | + ], |
| 192 | + ] = { |
| 193 | + compile( |
| 194 | + r"Requested tokens \((\d+)\) exceed context window of (\d+)" |
| 195 | + ): ErrorResponseFormatters.context_length_exceeded, |
| 196 | + compile( |
| 197 | + r"Model path does not exist: (.+)" |
| 198 | + ): ErrorResponseFormatters.model_not_found, |
| 199 | + } |
| 200 | + |
| 201 | + def error_message_wrapper( |
| 202 | + self, |
| 203 | + error: Exception, |
| 204 | + body: Optional[ |
| 205 | + Union[ |
| 206 | + "CreateChatCompletionRequest", |
| 207 | + "CreateCompletionRequest", |
| 208 | + "CreateEmbeddingRequest", |
| 209 | + ] |
| 210 | + ] = None, |
| 211 | + ) -> tuple[int, ErrorResponse]: |
| 212 | + """Wraps error message in OpenAI style error response""" |
| 213 | + |
| 214 | + if body is not None and isinstance( |
| 215 | + body, |
| 216 | + ( |
| 217 | + CreateCompletionRequest, |
| 218 | + CreateChatCompletionRequest, |
| 219 | + ), |
| 220 | + ): |
| 221 | + # When text completion or chat completion |
| 222 | + for pattern, callback in self.pattern_and_formatters.items(): |
| 223 | + match = pattern.search(str(error)) |
| 224 | + if match is not None: |
| 225 | + return callback(body, match) |
| 226 | + |
| 227 | + # Wrap other errors as internal server error |
| 228 | + return 500, ErrorResponse( |
| 229 | + message=str(error), |
| 230 | + type="internal_server_error", |
| 231 | + param=None, |
| 232 | + code=None, |
| 233 | + ) |
| 234 | + |
| 235 | + def get_route_handler( |
| 236 | + self, |
| 237 | + ) -> Callable[[Request], Coroutine[None, None, Response]]: |
| 238 | + """Defines custom route handler that catches exceptions and formats |
| 239 | + in OpenAI style error response""" |
| 240 | + |
| 241 | + original_route_handler = super().get_route_handler() |
| 242 | + |
| 243 | + async def custom_route_handler(request: Request) -> Response: |
| 244 | + try: |
| 245 | + return await original_route_handler(request) |
| 246 | + except Exception as exc: |
| 247 | + json_body = await request.json() |
| 248 | + try: |
| 249 | + if "messages" in json_body: |
| 250 | + # Chat completion |
| 251 | + body: Optional[ |
| 252 | + Union[ |
| 253 | + CreateChatCompletionRequest, |
| 254 | + CreateCompletionRequest, |
| 255 | + CreateEmbeddingRequest, |
| 256 | + ] |
| 257 | + ] = CreateChatCompletionRequest(**json_body) |
| 258 | + elif "prompt" in json_body: |
| 259 | + # Text completion |
| 260 | + body = CreateCompletionRequest(**json_body) |
| 261 | + else: |
| 262 | + # Embedding |
| 263 | + body = CreateEmbeddingRequest(**json_body) |
| 264 | + except Exception: |
| 265 | + # Invalid request body |
| 266 | + body = None |
| 267 | + |
| 268 | + # Get proper error message from the exception |
| 269 | + ( |
| 270 | + status_code, |
| 271 | + error_message, |
| 272 | + ) = self.error_message_wrapper(error=exc, body=body) |
| 273 | + return JSONResponse( |
| 274 | + {"error": error_message}, |
| 275 | + status_code=status_code, |
| 276 | + ) |
| 277 | + |
| 278 | + return custom_route_handler |
| 279 | + |
| 280 | + |
| 281 | +router = APIRouter(route_class=RouteErrorHandler) |
96 | 282 |
|
97 | 283 | settings: Optional[Settings] = None
|
98 | 284 | llama: Optional[llama_cpp.Llama] = None
|
@@ -179,10 +365,33 @@ def get_settings():
|
179 | 365 | yield settings
|
180 | 366 |
|
181 | 367 |
|
| 368 | +async def get_event_publisher( |
| 369 | + request: Request, |
| 370 | + inner_send_chan: MemoryObjectSendStream, |
| 371 | + iterator: Iterator, |
| 372 | +): |
| 373 | + async with inner_send_chan: |
| 374 | + try: |
| 375 | + async for chunk in iterate_in_threadpool(iterator): |
| 376 | + await inner_send_chan.send(dict(data=json.dumps(chunk))) |
| 377 | + if await request.is_disconnected(): |
| 378 | + raise anyio.get_cancelled_exc_class()() |
| 379 | + if settings.interrupt_requests and llama_outer_lock.locked(): |
| 380 | + await inner_send_chan.send(dict(data="[DONE]")) |
| 381 | + raise anyio.get_cancelled_exc_class()() |
| 382 | + await inner_send_chan.send(dict(data="[DONE]")) |
| 383 | + except anyio.get_cancelled_exc_class() as e: |
| 384 | + print("disconnected") |
| 385 | + with anyio.move_on_after(1, shield=True): |
| 386 | + print( |
| 387 | + f"Disconnected from client (via refresh/close) {request.client}" |
| 388 | + ) |
| 389 | + raise e |
| 390 | + |
182 | 391 | model_field = Field(description="The model to use for generating completions.", default=None)
|
183 | 392 |
|
184 | 393 | max_tokens_field = Field(
|
185 |
| - default=16, ge=1, le=2048, description="The maximum number of tokens to generate." |
| 394 | + default=16, ge=1, description="The maximum number of tokens to generate." |
186 | 395 | )
|
187 | 396 |
|
188 | 397 | temperature_field = Field(
|
@@ -370,35 +579,31 @@ async def create_completion(
|
370 | 579 | make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
371 | 580 | ])
|
372 | 581 |
|
373 |
| - if body.stream: |
374 |
| - send_chan, recv_chan = anyio.create_memory_object_stream(10) |
| 582 | + iterator_or_completion: Union[llama_cpp.Completion, Iterator[ |
| 583 | + llama_cpp.CompletionChunk |
| 584 | + ]] = await run_in_threadpool(llama, **kwargs) |
375 | 585 |
|
376 |
| - async def event_publisher(inner_send_chan: MemoryObjectSendStream): |
377 |
| - async with inner_send_chan: |
378 |
| - try: |
379 |
| - iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore |
380 |
| - async for chunk in iterate_in_threadpool(iterator): |
381 |
| - await inner_send_chan.send(dict(data=json.dumps(chunk))) |
382 |
| - if await request.is_disconnected(): |
383 |
| - raise anyio.get_cancelled_exc_class()() |
384 |
| - if settings.interrupt_requests and llama_outer_lock.locked(): |
385 |
| - await inner_send_chan.send(dict(data="[DONE]")) |
386 |
| - raise anyio.get_cancelled_exc_class()() |
387 |
| - await inner_send_chan.send(dict(data="[DONE]")) |
388 |
| - except anyio.get_cancelled_exc_class() as e: |
389 |
| - print("disconnected") |
390 |
| - with anyio.move_on_after(1, shield=True): |
391 |
| - print( |
392 |
| - f"Disconnected from client (via refresh/close) {request.client}" |
393 |
| - ) |
394 |
| - raise e |
| 586 | + if isinstance(iterator_or_completion, Iterator): |
| 587 | + # EAFP: It's easier to ask for forgiveness than permission |
| 588 | + first_response = await run_in_threadpool(next, iterator_or_completion) |
395 | 589 |
|
| 590 | + # If no exception was raised from first_response, we can assume that |
| 591 | + # the iterator is valid and we can use it to stream the response. |
| 592 | + def iterator() -> Iterator[llama_cpp.CompletionChunk]: |
| 593 | + yield first_response |
| 594 | + yield from iterator_or_completion |
| 595 | + |
| 596 | + send_chan, recv_chan = anyio.create_memory_object_stream(10) |
396 | 597 | return EventSourceResponse(
|
397 |
| - recv_chan, data_sender_callable=partial(event_publisher, send_chan) |
398 |
| - ) # type: ignore |
| 598 | + recv_chan, data_sender_callable=partial( # type: ignore |
| 599 | + get_event_publisher, |
| 600 | + request=request, |
| 601 | + inner_send_chan=send_chan, |
| 602 | + iterator=iterator(), |
| 603 | + ) |
| 604 | + ) |
399 | 605 | else:
|
400 |
| - completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore |
401 |
| - return completion |
| 606 | + return iterator_or_completion |
402 | 607 |
|
403 | 608 |
|
404 | 609 | class CreateEmbeddingRequest(BaseModel):
|
@@ -501,38 +706,31 @@ async def create_chat_completion(
|
501 | 706 | make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
502 | 707 | ])
|
503 | 708 |
|
504 |
| - if body.stream: |
505 |
| - send_chan, recv_chan = anyio.create_memory_object_stream(10) |
| 709 | + iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[ |
| 710 | + llama_cpp.ChatCompletionChunk |
| 711 | + ]] = await run_in_threadpool(llama.create_chat_completion, **kwargs) |
506 | 712 |
|
507 |
| - async def event_publisher(inner_send_chan: MemoryObjectSendStream): |
508 |
| - async with inner_send_chan: |
509 |
| - try: |
510 |
| - iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore |
511 |
| - async for chat_chunk in iterate_in_threadpool(iterator): |
512 |
| - await inner_send_chan.send(dict(data=json.dumps(chat_chunk))) |
513 |
| - if await request.is_disconnected(): |
514 |
| - raise anyio.get_cancelled_exc_class()() |
515 |
| - if settings.interrupt_requests and llama_outer_lock.locked(): |
516 |
| - await inner_send_chan.send(dict(data="[DONE]")) |
517 |
| - raise anyio.get_cancelled_exc_class()() |
518 |
| - await inner_send_chan.send(dict(data="[DONE]")) |
519 |
| - except anyio.get_cancelled_exc_class() as e: |
520 |
| - print("disconnected") |
521 |
| - with anyio.move_on_after(1, shield=True): |
522 |
| - print( |
523 |
| - f"Disconnected from client (via refresh/close) {request.client}" |
524 |
| - ) |
525 |
| - raise e |
| 713 | + if isinstance(iterator_or_completion, Iterator): |
| 714 | + # EAFP: It's easier to ask for forgiveness than permission |
| 715 | + first_response = await run_in_threadpool(next, iterator_or_completion) |
| 716 | + |
| 717 | + # If no exception was raised from first_response, we can assume that |
| 718 | + # the iterator is valid and we can use it to stream the response. |
| 719 | + def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: |
| 720 | + yield first_response |
| 721 | + yield from iterator_or_completion |
526 | 722 |
|
| 723 | + send_chan, recv_chan = anyio.create_memory_object_stream(10) |
527 | 724 | return EventSourceResponse(
|
528 |
| - recv_chan, |
529 |
| - data_sender_callable=partial(event_publisher, send_chan), |
530 |
| - ) # type: ignore |
531 |
| - else: |
532 |
| - completion: llama_cpp.ChatCompletion = await run_in_threadpool( |
533 |
| - llama.create_chat_completion, **kwargs # type: ignore |
| 725 | + recv_chan, data_sender_callable=partial( # type: ignore |
| 726 | + get_event_publisher, |
| 727 | + request=request, |
| 728 | + inner_send_chan=send_chan, |
| 729 | + iterator=iterator(), |
| 730 | + ) |
534 | 731 | )
|
535 |
| - return completion |
| 732 | + else: |
| 733 | + return iterator_or_completion |
536 | 734 |
|
537 | 735 |
|
538 | 736 | class ModelData(TypedDict):
|
|
0 commit comments