|
1 | 1 | import json
|
2 | 2 | import multiprocessing
|
| 3 | +from re import compile, Match, Pattern |
3 | 4 | from threading import Lock
|
4 | 5 | from functools import partial
|
5 |
| -from typing import Iterator, List, Optional, Union, Dict |
| 6 | +from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict |
6 | 7 | from typing_extensions import TypedDict, Literal
|
7 | 8 |
|
8 | 9 | import llama_cpp
|
9 | 10 |
|
10 | 11 | import anyio
|
11 | 12 | from anyio.streams.memory import MemoryObjectSendStream
|
12 | 13 | from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
13 |
| -from fastapi import Depends, FastAPI, APIRouter, Request |
| 14 | +from fastapi import Depends, FastAPI, APIRouter, Request, Response |
14 | 15 | from fastapi.middleware.cors import CORSMiddleware
|
| 16 | +from fastapi.responses import JSONResponse |
| 17 | +from fastapi.routing import APIRoute |
15 | 18 | from pydantic import BaseModel, Field
|
16 | 19 | from pydantic_settings import BaseSettings
|
17 | 20 | from sse_starlette.sse import EventSourceResponse
|
@@ -94,7 +97,190 @@ class Settings(BaseSettings):
|
94 | 97 | )
|
95 | 98 |
|
96 | 99 |
|
97 |
| -router = APIRouter() |
| 100 | +class ErrorResponse(TypedDict): |
| 101 | + """OpenAI style error response""" |
| 102 | + |
| 103 | + message: str |
| 104 | + type: str |
| 105 | + param: Optional[str] |
| 106 | + code: Optional[str] |
| 107 | + |
| 108 | + |
| 109 | +class ErrorResponseFormatters: |
| 110 | + """Collection of formatters for error responses. |
| 111 | +
|
| 112 | + Args: |
| 113 | + request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): |
| 114 | + Request body |
| 115 | + match (Match[str]): Match object from regex pattern |
| 116 | +
|
| 117 | + Returns: |
| 118 | + tuple[int, ErrorResponse]: Status code and error response |
| 119 | + """ |
| 120 | + |
| 121 | + @staticmethod |
| 122 | + def context_length_exceeded( |
| 123 | + request: Union[ |
| 124 | + "CreateCompletionRequest", "CreateChatCompletionRequest" |
| 125 | + ], |
| 126 | + match: Match[str], |
| 127 | + ) -> tuple[int, ErrorResponse]: |
| 128 | + """Formatter for context length exceeded error""" |
| 129 | + |
| 130 | + context_window = int(match.group(2)) |
| 131 | + prompt_tokens = int(match.group(1)) |
| 132 | + completion_tokens = request.max_tokens |
| 133 | + if hasattr(request, "messages"): |
| 134 | + # Chat completion |
| 135 | + message = ( |
| 136 | + "This model's maximum context length is {} tokens. " |
| 137 | + "However, you requested {} tokens " |
| 138 | + "({} in the messages, {} in the completion). " |
| 139 | + "Please reduce the length of the messages or completion." |
| 140 | + ) |
| 141 | + else: |
| 142 | + # Text completion |
| 143 | + message = ( |
| 144 | + "This model's maximum context length is {} tokens, " |
| 145 | + "however you requested {} tokens " |
| 146 | + "({} in your prompt; {} for the completion). " |
| 147 | + "Please reduce your prompt; or completion length." |
| 148 | + ) |
| 149 | + return 400, ErrorResponse( |
| 150 | + message=message.format( |
| 151 | + context_window, |
| 152 | + completion_tokens + prompt_tokens, |
| 153 | + prompt_tokens, |
| 154 | + completion_tokens, |
| 155 | + ), |
| 156 | + type="invalid_request_error", |
| 157 | + param="messages", |
| 158 | + code="context_length_exceeded", |
| 159 | + ) |
| 160 | + |
| 161 | + @staticmethod |
| 162 | + def model_not_found( |
| 163 | + request: Union[ |
| 164 | + "CreateCompletionRequest", "CreateChatCompletionRequest" |
| 165 | + ], |
| 166 | + match: Match[str], |
| 167 | + ) -> tuple[int, ErrorResponse]: |
| 168 | + """Formatter for model_not_found error""" |
| 169 | + |
| 170 | + model_path = str(match.group(1)) |
| 171 | + message = f"The model `{model_path}` does not exist" |
| 172 | + return 400, ErrorResponse( |
| 173 | + message=message, |
| 174 | + type="invalid_request_error", |
| 175 | + param=None, |
| 176 | + code="model_not_found", |
| 177 | + ) |
| 178 | + |
| 179 | + |
| 180 | +class RouteErrorHandler(APIRoute): |
| 181 | + """Custom APIRoute that handles application errors and exceptions""" |
| 182 | + |
| 183 | + # key: regex pattern for original error message from llama_cpp |
| 184 | + # value: formatter function |
| 185 | + pattern_and_formatters: dict[ |
| 186 | + "Pattern", |
| 187 | + Callable[ |
| 188 | + [ |
| 189 | + Union["CreateCompletionRequest", "CreateChatCompletionRequest"], |
| 190 | + Match[str], |
| 191 | + ], |
| 192 | + tuple[int, ErrorResponse], |
| 193 | + ], |
| 194 | + ] = { |
| 195 | + compile( |
| 196 | + r"Requested tokens \((\d+)\) exceed context window of (\d+)" |
| 197 | + ): ErrorResponseFormatters.context_length_exceeded, |
| 198 | + compile( |
| 199 | + r"Model path does not exist: (.+)" |
| 200 | + ): ErrorResponseFormatters.model_not_found, |
| 201 | + } |
| 202 | + |
| 203 | + def error_message_wrapper( |
| 204 | + self, |
| 205 | + error: Exception, |
| 206 | + body: Optional[ |
| 207 | + Union[ |
| 208 | + "CreateChatCompletionRequest", |
| 209 | + "CreateCompletionRequest", |
| 210 | + "CreateEmbeddingRequest", |
| 211 | + ] |
| 212 | + ] = None, |
| 213 | + ) -> tuple[int, ErrorResponse]: |
| 214 | + """Wraps error message in OpenAI style error response""" |
| 215 | + |
| 216 | + if body is not None and isinstance( |
| 217 | + body, |
| 218 | + ( |
| 219 | + CreateCompletionRequest, |
| 220 | + CreateChatCompletionRequest, |
| 221 | + ), |
| 222 | + ): |
| 223 | + # When text completion or chat completion |
| 224 | + for pattern, callback in self.pattern_and_formatters.items(): |
| 225 | + match = pattern.search(str(error)) |
| 226 | + if match is not None: |
| 227 | + return callback(body, match) |
| 228 | + |
| 229 | + # Wrap other errors as internal server error |
| 230 | + return 500, ErrorResponse( |
| 231 | + message=str(error), |
| 232 | + type="internal_server_error", |
| 233 | + param=None, |
| 234 | + code=None, |
| 235 | + ) |
| 236 | + |
| 237 | + def get_route_handler( |
| 238 | + self, |
| 239 | + ) -> Callable[[Request], Coroutine[None, None, Response]]: |
| 240 | + """Defines custom route handler that catches exceptions and formats |
| 241 | + in OpenAI style error response""" |
| 242 | + |
| 243 | + original_route_handler = super().get_route_handler() |
| 244 | + |
| 245 | + async def custom_route_handler(request: Request) -> Response: |
| 246 | + try: |
| 247 | + return await original_route_handler(request) |
| 248 | + except Exception as exc: |
| 249 | + json_body = await request.json() |
| 250 | + try: |
| 251 | + if "messages" in json_body: |
| 252 | + # Chat completion |
| 253 | + body: Optional[ |
| 254 | + Union[ |
| 255 | + CreateChatCompletionRequest, |
| 256 | + CreateCompletionRequest, |
| 257 | + CreateEmbeddingRequest, |
| 258 | + ] |
| 259 | + ] = CreateChatCompletionRequest(**json_body) |
| 260 | + elif "prompt" in json_body: |
| 261 | + # Text completion |
| 262 | + body = CreateCompletionRequest(**json_body) |
| 263 | + else: |
| 264 | + # Embedding |
| 265 | + body = CreateEmbeddingRequest(**json_body) |
| 266 | + except Exception: |
| 267 | + # Invalid request body |
| 268 | + body = None |
| 269 | + |
| 270 | + # Get proper error message from the exception |
| 271 | + ( |
| 272 | + status_code, |
| 273 | + error_message, |
| 274 | + ) = self.error_message_wrapper(error=exc, body=body) |
| 275 | + return JSONResponse( |
| 276 | + {"error": error_message}, |
| 277 | + status_code=status_code, |
| 278 | + ) |
| 279 | + |
| 280 | + return custom_route_handler |
| 281 | + |
| 282 | + |
| 283 | +router = APIRouter(route_class=RouteErrorHandler) |
98 | 284 |
|
99 | 285 | settings: Optional[Settings] = None
|
100 | 286 | llama: Optional[llama_cpp.Llama] = None
|
@@ -183,10 +369,33 @@ def get_settings():
|
183 | 369 | yield settings
|
184 | 370 |
|
185 | 371 |
|
| 372 | +async def get_event_publisher( |
| 373 | + request: Request, |
| 374 | + inner_send_chan: MemoryObjectSendStream, |
| 375 | + iterator: Iterator, |
| 376 | +): |
| 377 | + async with inner_send_chan: |
| 378 | + try: |
| 379 | + async for chunk in iterate_in_threadpool(iterator): |
| 380 | + await inner_send_chan.send(dict(data=json.dumps(chunk))) |
| 381 | + if await request.is_disconnected(): |
| 382 | + raise anyio.get_cancelled_exc_class()() |
| 383 | + if settings.interrupt_requests and llama_outer_lock.locked(): |
| 384 | + await inner_send_chan.send(dict(data="[DONE]")) |
| 385 | + raise anyio.get_cancelled_exc_class()() |
| 386 | + await inner_send_chan.send(dict(data="[DONE]")) |
| 387 | + except anyio.get_cancelled_exc_class() as e: |
| 388 | + print("disconnected") |
| 389 | + with anyio.move_on_after(1, shield=True): |
| 390 | + print( |
| 391 | + f"Disconnected from client (via refresh/close) {request.client}" |
| 392 | + ) |
| 393 | + raise e |
| 394 | + |
186 | 395 | model_field = Field(description="The model to use for generating completions.", default=None)
|
187 | 396 |
|
188 | 397 | max_tokens_field = Field(
|
189 |
| - default=16, ge=1, le=2048, description="The maximum number of tokens to generate." |
| 398 | + default=16, ge=1, description="The maximum number of tokens to generate." |
190 | 399 | )
|
191 | 400 |
|
192 | 401 | temperature_field = Field(
|
@@ -374,35 +583,31 @@ async def create_completion(
|
374 | 583 | make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
375 | 584 | ])
|
376 | 585 |
|
377 |
| - if body.stream: |
378 |
| - send_chan, recv_chan = anyio.create_memory_object_stream(10) |
| 586 | + iterator_or_completion: Union[llama_cpp.Completion, Iterator[ |
| 587 | + llama_cpp.CompletionChunk |
| 588 | + ]] = await run_in_threadpool(llama, **kwargs) |
379 | 589 |
|
380 |
| - async def event_publisher(inner_send_chan: MemoryObjectSendStream): |
381 |
| - async with inner_send_chan: |
382 |
| - try: |
383 |
| - iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore |
384 |
| - async for chunk in iterate_in_threadpool(iterator): |
385 |
| - await inner_send_chan.send(dict(data=json.dumps(chunk))) |
386 |
| - if await request.is_disconnected(): |
387 |
| - raise anyio.get_cancelled_exc_class()() |
388 |
| - if settings.interrupt_requests and llama_outer_lock.locked(): |
389 |
| - await inner_send_chan.send(dict(data="[DONE]")) |
390 |
| - raise anyio.get_cancelled_exc_class()() |
391 |
| - await inner_send_chan.send(dict(data="[DONE]")) |
392 |
| - except anyio.get_cancelled_exc_class() as e: |
393 |
| - print("disconnected") |
394 |
| - with anyio.move_on_after(1, shield=True): |
395 |
| - print( |
396 |
| - f"Disconnected from client (via refresh/close) {request.client}" |
397 |
| - ) |
398 |
| - raise e |
| 590 | + if isinstance(iterator_or_completion, Iterator): |
| 591 | + # EAFP: It's easier to ask for forgiveness than permission |
| 592 | + first_response = await run_in_threadpool(next, iterator_or_completion) |
399 | 593 |
|
| 594 | + # If no exception was raised from first_response, we can assume that |
| 595 | + # the iterator is valid and we can use it to stream the response. |
| 596 | + def iterator() -> Iterator[llama_cpp.CompletionChunk]: |
| 597 | + yield first_response |
| 598 | + yield from iterator_or_completion |
| 599 | + |
| 600 | + send_chan, recv_chan = anyio.create_memory_object_stream(10) |
400 | 601 | return EventSourceResponse(
|
401 |
| - recv_chan, data_sender_callable=partial(event_publisher, send_chan) |
402 |
| - ) # type: ignore |
| 602 | + recv_chan, data_sender_callable=partial( # type: ignore |
| 603 | + get_event_publisher, |
| 604 | + request=request, |
| 605 | + inner_send_chan=send_chan, |
| 606 | + iterator=iterator(), |
| 607 | + ) |
| 608 | + ) |
403 | 609 | else:
|
404 |
| - completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore |
405 |
| - return completion |
| 610 | + return iterator_or_completion |
406 | 611 |
|
407 | 612 |
|
408 | 613 | class CreateEmbeddingRequest(BaseModel):
|
@@ -505,38 +710,31 @@ async def create_chat_completion(
|
505 | 710 | make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
506 | 711 | ])
|
507 | 712 |
|
508 |
| - if body.stream: |
509 |
| - send_chan, recv_chan = anyio.create_memory_object_stream(10) |
| 713 | + iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[ |
| 714 | + llama_cpp.ChatCompletionChunk |
| 715 | + ]] = await run_in_threadpool(llama.create_chat_completion, **kwargs) |
510 | 716 |
|
511 |
| - async def event_publisher(inner_send_chan: MemoryObjectSendStream): |
512 |
| - async with inner_send_chan: |
513 |
| - try: |
514 |
| - iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore |
515 |
| - async for chat_chunk in iterate_in_threadpool(iterator): |
516 |
| - await inner_send_chan.send(dict(data=json.dumps(chat_chunk))) |
517 |
| - if await request.is_disconnected(): |
518 |
| - raise anyio.get_cancelled_exc_class()() |
519 |
| - if settings.interrupt_requests and llama_outer_lock.locked(): |
520 |
| - await inner_send_chan.send(dict(data="[DONE]")) |
521 |
| - raise anyio.get_cancelled_exc_class()() |
522 |
| - await inner_send_chan.send(dict(data="[DONE]")) |
523 |
| - except anyio.get_cancelled_exc_class() as e: |
524 |
| - print("disconnected") |
525 |
| - with anyio.move_on_after(1, shield=True): |
526 |
| - print( |
527 |
| - f"Disconnected from client (via refresh/close) {request.client}" |
528 |
| - ) |
529 |
| - raise e |
| 717 | + if isinstance(iterator_or_completion, Iterator): |
| 718 | + # EAFP: It's easier to ask for forgiveness than permission |
| 719 | + first_response = await run_in_threadpool(next, iterator_or_completion) |
| 720 | + |
| 721 | + # If no exception was raised from first_response, we can assume that |
| 722 | + # the iterator is valid and we can use it to stream the response. |
| 723 | + def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: |
| 724 | + yield first_response |
| 725 | + yield from iterator_or_completion |
530 | 726 |
|
| 727 | + send_chan, recv_chan = anyio.create_memory_object_stream(10) |
531 | 728 | return EventSourceResponse(
|
532 |
| - recv_chan, |
533 |
| - data_sender_callable=partial(event_publisher, send_chan), |
534 |
| - ) # type: ignore |
535 |
| - else: |
536 |
| - completion: llama_cpp.ChatCompletion = await run_in_threadpool( |
537 |
| - llama.create_chat_completion, **kwargs # type: ignore |
| 729 | + recv_chan, data_sender_callable=partial( # type: ignore |
| 730 | + get_event_publisher, |
| 731 | + request=request, |
| 732 | + inner_send_chan=send_chan, |
| 733 | + iterator=iterator(), |
| 734 | + ) |
538 | 735 | )
|
539 |
| - return completion |
| 736 | + else: |
| 737 | + return iterator_or_completion |
540 | 738 |
|
541 | 739 |
|
542 | 740 | class ModelData(TypedDict):
|
|
0 commit comments