diff --git a/openai/__init__.py b/openai/__init__.py index f80085eada..29d886f8cb 100644 --- a/openai/__init__.py +++ b/openai/__init__.py @@ -36,9 +36,9 @@ organization = os.environ.get("OPENAI_ORGANIZATION") api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") api_type = os.environ.get("OPENAI_API_TYPE", "open_ai") -api_version = ( +api_version = os.environ.get("OPENAI_API_VERSION", ( "2023-03-15-preview" if api_type in ("azure", "azure_ad", "azuread") else None -) +)) verify_ssl_certs = True # No effect. Certificates are always verified. proxy = None app_info = None diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 8a0e6cabc2..7614e9a941 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -1,5 +1,6 @@ import asyncio import json +import time import platform import sys import threading @@ -144,6 +145,44 @@ def format_app_info(cls, info): if info["url"]: str += " (%s)" % (info["url"],) return str + + def poll( + self, + method, + url, + until, + params = None, + headers = None, + interval = None, + delay = None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + time.sleep(delay) + + response, b, api_key = self.request(method, url, params, headers) + while not until(response): + time.sleep(interval or response.retry_after or 1) + response, b, api_key = self.request(method, url) + return response, b, api_key + + async def apoll( + self, + method, + url, + until, + params = None, + headers = None, + interval = None, + delay = None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + await asyncio.sleep(delay) + + response, b, api_key = await self.arequest(method, url, params, headers) + while not until(response): + await asyncio.sleep(interval or response.retry_after or 1) + response, b, api_key = await self.arequest(method, url) + return response, b, api_key @overload def request( diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 39a5b6f616..8c5b808b65 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -10,8 +10,11 @@ class Image(APIResource): OBJECT_NAME = "images" @classmethod - def _get_url(cls, action): - return cls.class_url() + f"/{action}" + def _get_url(cls, openai_action, azure_action, api_type, api_version): + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + return f"/{cls.azure_api_prefix}{cls.class_url()}/{azure_action}?api-version={api_version}" + else: + return f"{cls.class_url()}/{openai_action}" @classmethod def create( @@ -31,12 +34,20 @@ def create( organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) response, _, api_key = requestor.request( - "post", cls._get_url("generations"), params + "post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params ) + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + requestor.api_base = "" # operation_location is a full url + response, _, api_key = requestor.poll( + "get", response.operation_location, + until=lambda response: response.data["status"] not in ["NotStarted", "Running"], + delay=response.retry_after + ) + return util.convert_to_openai_object( response, api_key, api_version, organization ) @@ -60,12 +71,20 @@ async def acreate( organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) response, _, api_key = await requestor.arequest( - "post", cls._get_url("generations"), params + "post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params ) + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + requestor.api_base = "" # operation_location is a full url + response, _, api_key = await requestor.apoll( + "get", response.operation_location, + until=lambda response: response.data["status"] not in ["NotStarted", "Running"], + delay=response.retry_after + ) + return util.convert_to_openai_object( response, api_key, api_version, organization ) @@ -88,9 +107,9 @@ def _prepare_create_variation( api_version=api_version, organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) - url = cls._get_url("variations") + url = cls._get_url("variations", None, api_type=api_type, api_version=api_version) files: List[Any] = [] for key, value in params.items(): @@ -171,9 +190,9 @@ def _prepare_create_edit( api_version=api_version, organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) - url = cls._get_url("edits") + url = cls._get_url("edits", None, api_type=api_type, api_version=api_version) files: List[Any] = [] for key, value in params.items(): diff --git a/openai/openai_response.py b/openai/openai_response.py index 9954247319..828f98cf34 100644 --- a/openai/openai_response.py +++ b/openai/openai_response.py @@ -9,6 +9,17 @@ def __init__(self, data, headers): @property def request_id(self) -> Optional[str]: return self._headers.get("request-id") + + @property + def retry_after(self) -> Optional[int]: + try: + return int(self._headers.get("retry-after")) + except ValueError: + return None + + @property + def operation_location(self) -> Optional[str]: + return self._headers.get("operation-location") @property def organization(self) -> Optional[str]: