From e20ff3436e8246e8428d7882b325f8f04179b639 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCrtz?= Date: Fri, 10 Mar 2023 14:13:08 +0100 Subject: [PATCH 1/8] WIP --- openai/__init__.py | 4 +-- openai/api_requestor.py | 31 ++++++++++++++++ openai/api_resources/abstract/api_resource.py | 1 + openai/api_resources/image.py | 36 +++++++++++++------ openai/openai_response.py | 7 ++++ 5 files changed, 67 insertions(+), 12 deletions(-) diff --git a/openai/__init__.py b/openai/__init__.py index 879fe33d04..c801e8bac8 100644 --- a/openai/__init__.py +++ b/openai/__init__.py @@ -34,9 +34,9 @@ organization = os.environ.get("OPENAI_ORGANIZATION") api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") api_type = os.environ.get("OPENAI_API_TYPE", "open_ai") -api_version = ( +api_version = os.environ.get("OPENAI_API_VERSION", ( "2022-12-01" if api_type in ("azure", "azure_ad", "azuread") else None -) +)) verify_ssl_certs = True # No effect. Certificates are always verified. proxy = None app_info = None diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 04b65fdcdb..b218334501 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -1,5 +1,6 @@ import asyncio import json +import time import platform import sys import threading @@ -144,6 +145,36 @@ def format_app_info(cls, info): if info["url"]: str += " (%s)" % (info["url"],) return str + + def poll( + self, + method, + url, + until, + params = None, + headers = None, + interval = None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + response, b, api_key = self.request(method, url, params, headers) + while not until(response): + time.sleep(interval or response.retry_after or 1) + response, b, api_key = self.request(method, url) + return response, b, api_key + + async def apoll( + self, + method, + url, + until, + params = None, + headers = None, + interval = None + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + response, b, api_key = await self.arequest(method, url, params, headers) + while not until(response): + await asyncio.sleep(interval or response.retry_after or 1) + response, b, api_key = await self.arequest(method, url) + return response, b, api_key @overload def request( diff --git a/openai/api_resources/abstract/api_resource.py b/openai/api_resources/abstract/api_resource.py index 53a7dec799..70b46e7f84 100644 --- a/openai/api_resources/abstract/api_resource.py +++ b/openai/api_resources/abstract/api_resource.py @@ -10,6 +10,7 @@ class APIResource(OpenAIObject): api_prefix = "" azure_api_prefix = "openai" + azure_dalle_prefix = "dalle" azure_deployments_prefix = "deployments" @classmethod diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 39a5b6f616..1bde2880c0 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -4,14 +4,22 @@ import openai from openai import api_requestor, util from openai.api_resources.abstract import APIResource +from openai.error import APIError class Image(APIResource): OBJECT_NAME = "images" @classmethod - def _get_url(cls, action): - return cls.class_url() + f"/{action}" + def _get_url(cls, openai_action, azure_action, api_type, api_version): + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + return f"/{cls.azure_dalle_prefix}{cls.class_url()}/{azure_action}?api-version={api_version}" + else: + return cls.class_url() + f"/{openai_action}" + + @classmethod + def _get_azure_operations_url(cls, operation_id, api_version): + return "/%s/operations/%s?api-version=%s" % (cls.azure_dalle_prefix, operation_id, api_version) @classmethod def create( @@ -31,12 +39,16 @@ def create( organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) response, _, api_key = requestor.request( - "post", cls._get_url("generations"), params + "post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params ) + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + url = cls._get_azure_operations_url(response.data['id'], api_version) + response, _, api_key = requestor.poll("get", url, until=lambda response: response.data["status"] not in ["NotStarted", "Running"]) + return util.convert_to_openai_object( response, api_key, api_version, organization ) @@ -60,12 +72,16 @@ async def acreate( organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) response, _, api_key = await requestor.arequest( - "post", cls._get_url("generations"), params + "post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params ) + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + url = cls._get_azure_operations_url(response.data['id'], api_version) + response, _, api_key = requestor.poll("get", url, until=lambda response: response.data["status"] not in ["NotStarted", "Running"]) + return util.convert_to_openai_object( response, api_key, api_version, organization ) @@ -88,9 +104,9 @@ def _prepare_create_variation( api_version=api_version, organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) - url = cls._get_url("variations") + url = cls._get_url("variations", None, api_type=api_type, api_version=api_version) files: List[Any] = [] for key, value in params.items(): @@ -171,9 +187,9 @@ def _prepare_create_edit( api_version=api_version, organization=organization, ) - _, api_version = cls._get_api_type_and_version(api_type, api_version) + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) - url = cls._get_url("edits") + url = cls._get_url("edits", None, api_type=api_type, api_version=api_version) files: List[Any] = [] for key, value in params.items(): diff --git a/openai/openai_response.py b/openai/openai_response.py index 9954247319..75ff3cc07d 100644 --- a/openai/openai_response.py +++ b/openai/openai_response.py @@ -9,6 +9,13 @@ def __init__(self, data, headers): @property def request_id(self) -> Optional[str]: return self._headers.get("request-id") + + @property + def retry_after(self) -> Optional[str]: + try: + return int(self._headers.get("retry-after")) + except ValueError: + return None @property def organization(self) -> Optional[str]: From ab699b89855c49012ada4906abc3baceb6def7d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCrtz?= Date: Tue, 14 Mar 2023 15:36:36 +0100 Subject: [PATCH 2/8] WIP --- openai/api_resources/image.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 1bde2880c0..151b7c04cd 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -4,12 +4,25 @@ import openai from openai import api_requestor, util from openai.api_resources.abstract import APIResource -from openai.error import APIError +from openai.util import ApiType class Image(APIResource): OBJECT_NAME = "images" + _azure_preview_version = "2022-11-23-preview" + + @classmethod + def _get_api_type_and_version( + cls, api_type = None, api_version = None + ): + api_type, base_api_version = super()._get_api_type_and_version() + if api_type in (ApiType.AZURE, ApiType.AZURE_AD): + # This override is only temporary: DallE and GPT endpoint versioning is currently out of sync but will be aligned soon. + return (api_type, api_version or Image._azure_preview_version) + else: + return (api_type, base_api_version) + @classmethod def _get_url(cls, openai_action, azure_action, api_type, api_version): if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): @@ -19,7 +32,7 @@ def _get_url(cls, openai_action, azure_action, api_type, api_version): @classmethod def _get_azure_operations_url(cls, operation_id, api_version): - return "/%s/operations/%s?api-version=%s" % (cls.azure_dalle_prefix, operation_id, api_version) + return f"/{cls.azure_dalle_prefix}/operations/{operation_id}?api-version={api_version}" @classmethod def create( From a076188da9dd8b53bc83527bf40f78abfcf1c0d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCrtz?= Date: Tue, 14 Mar 2023 15:39:47 +0100 Subject: [PATCH 3/8] Fix type --- openai/openai_response.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openai/openai_response.py b/openai/openai_response.py index 75ff3cc07d..0b61a646de 100644 --- a/openai/openai_response.py +++ b/openai/openai_response.py @@ -11,7 +11,7 @@ def request_id(self) -> Optional[str]: return self._headers.get("request-id") @property - def retry_after(self) -> Optional[str]: + def retry_after(self) -> Optional[int]: try: return int(self._headers.get("retry-after")) except ValueError: From 25e118f72cf4399aa45bdab01f34cec698a9b516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCrtz?= Date: Tue, 14 Mar 2023 16:03:56 +0100 Subject: [PATCH 4/8] Fixes --- openai/api_requestor.py | 12 ++++++++++-- openai/api_resources/image.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/openai/api_requestor.py b/openai/api_requestor.py index b218334501..f01856c0f3 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -153,8 +153,12 @@ def poll( until, params = None, headers = None, - interval = None + interval = None, + delay = None ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + time.sleep(delay) + response, b, api_key = self.request(method, url, params, headers) while not until(response): time.sleep(interval or response.retry_after or 1) @@ -168,8 +172,12 @@ async def apoll( until, params = None, headers = None, - interval = None + interval = None, + delay = None ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + if delay: + await asyncio.sleep(delay) + response, b, api_key = await self.arequest(method, url, params, headers) while not until(response): await asyncio.sleep(interval or response.retry_after or 1) diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 151b7c04cd..9548c6d941 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -60,7 +60,11 @@ def create( if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): url = cls._get_azure_operations_url(response.data['id'], api_version) - response, _, api_key = requestor.poll("get", url, until=lambda response: response.data["status"] not in ["NotStarted", "Running"]) + response, _, api_key = requestor.poll( + "get", url, + until=lambda response: response.data["status"] not in ["NotStarted", "Running"], + delay=response.retry_after + ) return util.convert_to_openai_object( response, api_key, api_version, organization @@ -93,7 +97,11 @@ async def acreate( if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): url = cls._get_azure_operations_url(response.data['id'], api_version) - response, _, api_key = requestor.poll("get", url, until=lambda response: response.data["status"] not in ["NotStarted", "Running"]) + response, _, api_key = await requestor.apoll( + "get", url, + until=lambda response: response.data["status"] not in ["NotStarted", "Running"], + delay=response.retry_after + ) return util.convert_to_openai_object( response, api_key, api_version, organization From 9903e4874d27b1c762bf28c654832063b79dee48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCrtz?= Date: Fri, 17 Mar 2023 17:24:08 +0100 Subject: [PATCH 5/8] Revert api_version override --- openai/api_resources/image.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 9548c6d941..4af3bd8f03 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -4,25 +4,11 @@ import openai from openai import api_requestor, util from openai.api_resources.abstract import APIResource -from openai.util import ApiType class Image(APIResource): OBJECT_NAME = "images" - _azure_preview_version = "2022-11-23-preview" - - @classmethod - def _get_api_type_and_version( - cls, api_type = None, api_version = None - ): - api_type, base_api_version = super()._get_api_type_and_version() - if api_type in (ApiType.AZURE, ApiType.AZURE_AD): - # This override is only temporary: DallE and GPT endpoint versioning is currently out of sync but will be aligned soon. - return (api_type, api_version or Image._azure_preview_version) - else: - return (api_type, base_api_version) - @classmethod def _get_url(cls, openai_action, azure_action, api_type, api_version): if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): From 1f66fb6da9a8e6105be2f94bf8943fa46b2f6d23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCrtz?= Date: Thu, 23 Mar 2023 16:01:54 +0100 Subject: [PATCH 6/8] Fix --- openai/api_resources/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 4af3bd8f03..4d3de96b56 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -14,7 +14,7 @@ def _get_url(cls, openai_action, azure_action, api_type, api_version): if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): return f"/{cls.azure_dalle_prefix}{cls.class_url()}/{azure_action}?api-version={api_version}" else: - return cls.class_url() + f"/{openai_action}" + return f"{cls.class_url()}/{openai_action}" @classmethod def _get_azure_operations_url(cls, operation_id, api_version): From 71c3fd81450943798dd39ea0a648eb9684c77102 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCrtz?= Date: Sat, 25 Mar 2023 22:14:39 +0100 Subject: [PATCH 7/8] Remove azure_dalle_prefix --- openai/api_resources/abstract/api_resource.py | 1 - openai/api_resources/image.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/openai/api_resources/abstract/api_resource.py b/openai/api_resources/abstract/api_resource.py index 70b46e7f84..53a7dec799 100644 --- a/openai/api_resources/abstract/api_resource.py +++ b/openai/api_resources/abstract/api_resource.py @@ -10,7 +10,6 @@ class APIResource(OpenAIObject): api_prefix = "" azure_api_prefix = "openai" - azure_dalle_prefix = "dalle" azure_deployments_prefix = "deployments" @classmethod diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 4d3de96b56..946bf335cd 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -12,13 +12,13 @@ class Image(APIResource): @classmethod def _get_url(cls, openai_action, azure_action, api_type, api_version): if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): - return f"/{cls.azure_dalle_prefix}{cls.class_url()}/{azure_action}?api-version={api_version}" + return f"/{cls.azure_api_prefix}{cls.class_url()}/{azure_action}?api-version={api_version}" else: return f"{cls.class_url()}/{openai_action}" @classmethod def _get_azure_operations_url(cls, operation_id, api_version): - return f"/{cls.azure_dalle_prefix}/operations/{operation_id}?api-version={api_version}" + return f"/{cls.azure_api_prefix}/operations/{operation_id}?api-version={api_version}" @classmethod def create( From 985329fba99c64c0362c98db5588812679c45c95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20M=C3=BCrtz?= Date: Sat, 25 Mar 2023 22:34:55 +0100 Subject: [PATCH 8/8] Use operation-location header --- openai/api_resources/image.py | 14 +++++--------- openai/openai_response.py | 4 ++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/openai/api_resources/image.py b/openai/api_resources/image.py index 946bf335cd..8c5b808b65 100644 --- a/openai/api_resources/image.py +++ b/openai/api_resources/image.py @@ -15,10 +15,6 @@ def _get_url(cls, openai_action, azure_action, api_type, api_version): return f"/{cls.azure_api_prefix}{cls.class_url()}/{azure_action}?api-version={api_version}" else: return f"{cls.class_url()}/{openai_action}" - - @classmethod - def _get_azure_operations_url(cls, operation_id, api_version): - return f"/{cls.azure_api_prefix}/operations/{operation_id}?api-version={api_version}" @classmethod def create( @@ -45,9 +41,9 @@ def create( ) if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): - url = cls._get_azure_operations_url(response.data['id'], api_version) + requestor.api_base = "" # operation_location is a full url response, _, api_key = requestor.poll( - "get", url, + "get", response.operation_location, until=lambda response: response.data["status"] not in ["NotStarted", "Running"], delay=response.retry_after ) @@ -81,10 +77,10 @@ async def acreate( "post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params ) - if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): - url = cls._get_azure_operations_url(response.data['id'], api_version) + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + requestor.api_base = "" # operation_location is a full url response, _, api_key = await requestor.apoll( - "get", url, + "get", response.operation_location, until=lambda response: response.data["status"] not in ["NotStarted", "Running"], delay=response.retry_after ) diff --git a/openai/openai_response.py b/openai/openai_response.py index 0b61a646de..828f98cf34 100644 --- a/openai/openai_response.py +++ b/openai/openai_response.py @@ -16,6 +16,10 @@ def retry_after(self) -> Optional[int]: return int(self._headers.get("retry-after")) except ValueError: return None + + @property + def operation_location(self) -> Optional[str]: + return self._headers.get("operation-location") @property def organization(self) -> Optional[str]: