Skip to content

[Azure] Dall-E #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
organization = os.environ.get("OPENAI_ORGANIZATION")
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
api_type = os.environ.get("OPENAI_API_TYPE", "open_ai")
api_version = (
api_version = os.environ.get("OPENAI_API_VERSION", (
"2023-03-15-preview" if api_type in ("azure", "azure_ad", "azuread") else None
)
))
verify_ssl_certs = True # No effect. Certificates are always verified.
proxy = None
app_info = None
Expand Down
39 changes: 39 additions & 0 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import time
import platform
import sys
import threading
Expand Down Expand Up @@ -144,6 +145,44 @@ def format_app_info(cls, info):
if info["url"]:
str += " (%s)" % (info["url"],)
return str

def poll(
self,
method,
url,
until,
params = None,
headers = None,
interval = None,
delay = None
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
if delay:
time.sleep(delay)

response, b, api_key = self.request(method, url, params, headers)
while not until(response):
time.sleep(interval or response.retry_after or 1)
response, b, api_key = self.request(method, url)
return response, b, api_key

async def apoll(
self,
method,
url,
until,
params = None,
headers = None,
interval = None,
delay = None
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
if delay:
await asyncio.sleep(delay)

response, b, api_key = await self.arequest(method, url, params, headers)
while not until(response):
await asyncio.sleep(interval or response.retry_after or 1)
response, b, api_key = await self.arequest(method, url)
return response, b, api_key

@overload
def request(
Expand Down
39 changes: 29 additions & 10 deletions openai/api_resources/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ class Image(APIResource):
OBJECT_NAME = "images"

@classmethod
def _get_url(cls, action):
return cls.class_url() + f"/{action}"
def _get_url(cls, openai_action, azure_action, api_type, api_version):
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
return f"/{cls.azure_api_prefix}{cls.class_url()}/{azure_action}?api-version={api_version}"
else:
return f"{cls.class_url()}/{openai_action}"

@classmethod
def create(
Expand All @@ -31,12 +34,20 @@ def create(
organization=organization,
)

_, api_version = cls._get_api_type_and_version(api_type, api_version)
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)

response, _, api_key = requestor.request(
"post", cls._get_url("generations"), params
"post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params
)

if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
requestor.api_base = "" # operation_location is a full url
response, _, api_key = requestor.poll(
"get", response.operation_location,
until=lambda response: response.data["status"] not in ["NotStarted", "Running"],
delay=response.retry_after
)

return util.convert_to_openai_object(
response, api_key, api_version, organization
)
Expand All @@ -60,12 +71,20 @@ async def acreate(
organization=organization,
)

_, api_version = cls._get_api_type_and_version(api_type, api_version)
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)

response, _, api_key = await requestor.arequest(
"post", cls._get_url("generations"), params
"post", cls._get_url("generations", "generate", api_type=api_type, api_version=api_version), params
)

if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
requestor.api_base = "" # operation_location is a full url
response, _, api_key = await requestor.apoll(
"get", response.operation_location,
until=lambda response: response.data["status"] not in ["NotStarted", "Running"],
delay=response.retry_after
)

return util.convert_to_openai_object(
response, api_key, api_version, organization
)
Expand All @@ -88,9 +107,9 @@ def _prepare_create_variation(
api_version=api_version,
organization=organization,
)
_, api_version = cls._get_api_type_and_version(api_type, api_version)
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)

url = cls._get_url("variations")
url = cls._get_url("variations", None, api_type=api_type, api_version=api_version)

files: List[Any] = []
for key, value in params.items():
Expand Down Expand Up @@ -171,9 +190,9 @@ def _prepare_create_edit(
api_version=api_version,
organization=organization,
)
_, api_version = cls._get_api_type_and_version(api_type, api_version)
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)

url = cls._get_url("edits")
url = cls._get_url("edits", None, api_type=api_type, api_version=api_version)

files: List[Any] = []
for key, value in params.items():
Expand Down
11 changes: 11 additions & 0 deletions openai/openai_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@ def __init__(self, data, headers):
@property
def request_id(self) -> Optional[str]:
return self._headers.get("request-id")

@property
def retry_after(self) -> Optional[int]:
try:
return int(self._headers.get("retry-after"))
except ValueError:
return None

@property
def operation_location(self) -> Optional[str]:
return self._headers.get("operation-location")

@property
def organization(self) -> Optional[str]:
Expand Down