diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 964bbd84e7..e71eaed3a3 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -1,5 +1,6 @@ import asyncio import json +import time import platform import sys import threading @@ -9,6 +10,7 @@ from typing import ( AsyncGenerator, AsyncIterator, + Callable, Dict, Iterator, Optional, @@ -149,6 +151,70 @@ def format_app_info(cls, info): str += " (%s)" % (info["url"],) return str + def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]): + if not predicate(response): + return + error_data = response.data['error'] + message = error_data.get('message', 'Operation failed') + code = error_data.get('code') + raise error.OpenAIError(message=message, code=code) + + def _poll( + self, + method, + url, + until, + failed, + params = None, + headers = None, + interval = None, + delay = None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + time.sleep(delay) + + response, b, api_key = self.request(method, url, params, headers) + self._check_polling_response(response, failed) + start_time = time.time() + while not until(response): + if time.time() - start_time > TIMEOUT_SECS: + raise error.Timeout("Operation polling timed out.") + + time.sleep(interval or response.retry_after or 10) + response, b, api_key = self.request(method, url, params, headers) + self._check_polling_response(response, failed) + + response.data = response.data['result'] + return response, b, api_key + + async def _apoll( + self, + method, + url, + until, + failed, + params = None, + headers = None, + interval = None, + delay = None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + await asyncio.sleep(delay) + + response, b, api_key = await self.arequest(method, url, params, headers) + self._check_polling_response(response, failed) + start_time = time.time() + while not until(response): + if time.time() - start_time > TIMEOUT_SECS: + raise error.Timeout("Operation polling timed out.") + + await asyncio.sleep(interval or response.retry_after or 10) + response, b, api_key = await self.arequest(method, url, params, headers) + self._check_polling_response(response, failed) + + response.data = response.data['result'] + return response, b, api_key + @overload def request( self, diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 39a5b6f616..1522923510 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -2,7 +2,7 @@ from typing import Any, List import openai -from openai import api_requestor, util +from openai import api_requestor, error, util from openai.api_resources.abstract import APIResource @@ -10,8 +10,11 @@ class Image(APIResource): OBJECT_NAME = "images" @classmethod - def _get_url(cls, action): - return cls.class_url() + f"/{action}" + def _get_url(cls, action, azure_action, api_type, api_version): + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD) and azure_action is not None: + return f"/{cls.azure_api_prefix}{cls.class_url()}/{action}:{azure_action}?api-version={api_version}" + else: + return f"{cls.class_url()}/{action}" @classmethod def create( @@ -31,12 +34,20 @@ def create( organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) response, _, api_key = requestor.request( - "post", cls._get_url("generations"), params + "post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params ) + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + requestor.api_base = "" # operation_location is a full url + response, _, api_key = requestor._poll( + "get", response.operation_location, + until=lambda response: response.data['status'] in [ 'succeeded' ], + failed=lambda response: response.data['status'] in [ 'failed' ] + ) + return util.convert_to_openai_object( response, api_key, api_version, organization ) @@ -60,12 +71,20 @@ async def acreate( organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) response, _, api_key = await requestor.arequest( - "post", cls._get_url("generations"), params + "post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params ) + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + requestor.api_base = "" # operation_location is a full url + response, _, api_key = await requestor._apoll( + "get", response.operation_location, + until=lambda response: response.data['status'] in [ 'succeeded' ], + failed=lambda response: response.data['status'] in [ 'failed' ] + ) + return util.convert_to_openai_object( response, api_key, api_version, organization ) @@ -88,9 +107,9 @@ def _prepare_create_variation( api_version=api_version, organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) - url = cls._get_url("variations") + url = cls._get_url("variations", azure_action=None, api_type=api_type, api_version=api_version) files: List[Any] = [] for key, value in params.items(): @@ -109,6 +128,9 @@ def create_variation( organization=None, **params, ): + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.") + requestor, url, files = cls._prepare_create_variation( image, api_key, @@ -136,6 +158,9 @@ async def acreate_variation( organization=None, **params, ): + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.") + requestor, url, files = cls._prepare_create_variation( image, api_key, @@ -171,9 +196,9 @@ def _prepare_create_edit( api_version=api_version, organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) - url = cls._get_url("edits") + url = cls._get_url("edits", azure_action=None, api_type=api_type, api_version=api_version) files: List[Any] = [] for key, value in params.items(): @@ -195,6 +220,9 @@ def create_edit( organization=None, **params, ): + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.") + requestor, url, files = cls._prepare_create_edit( image, mask, @@ -224,6 +252,9 @@ async def acreate_edit( organization=None, **params, ): + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.") + requestor, url, files = cls._prepare_create_edit( image, mask, diff --git a/openai/openai_response.py b/openai/openai_response.py index 9954247319..d2230b1540 100644 --- a/openai/openai_response.py +++ b/openai/openai_response.py @@ -10,6 +10,17 @@ def __init__(self, data, headers): def request_id(self) -> Optional[str]: return self._headers.get("request-id") + @property + def retry_after(self) -> Optional[int]: + try: + return int(self._headers.get("retry-after")) + except TypeError: + return None + + @property + def operation_location(self) -> Optional[str]: + return self._headers.get("operation-location") + @property def organization(self) -> Optional[str]: return self._headers.get("OpenAI-Organization")