Skip to content

Commit 3ea04cb

Browse files
glecaroscmurtz-msft
authored andcommitted
Support for Azure Dall-e (openai#439)
* This PR updates openai#337 with updates for it to work with the latest API preview --------- Co-authored-by: Christian Mürtz <t-cmurtz@microsoft.com>
1 parent 778ef67 commit 3ea04cb

File tree

3 files changed

+119
-11
lines changed

3 files changed

+119
-11
lines changed

openai/api_requestor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import time
34
import platform
45
import sys
56
import threading
@@ -10,6 +11,7 @@
1011
from typing import (
1112
AsyncGenerator,
1213
AsyncIterator,
14+
Callable,
1315
Dict,
1416
Iterator,
1517
Optional,
@@ -151,6 +153,70 @@ def format_app_info(cls, info):
151153
str += " (%s)" % (info["url"],)
152154
return str
153155

156+
def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
157+
if not predicate(response):
158+
return
159+
error_data = response.data['error']
160+
message = error_data.get('message', 'Operation failed')
161+
code = error_data.get('code')
162+
raise error.OpenAIError(message=message, code=code)
163+
164+
def _poll(
165+
self,
166+
method,
167+
url,
168+
until,
169+
failed,
170+
params = None,
171+
headers = None,
172+
interval = None,
173+
delay = None
174+
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
175+
if delay:
176+
time.sleep(delay)
177+
178+
response, b, api_key = self.request(method, url, params, headers)
179+
self._check_polling_response(response, failed)
180+
start_time = time.time()
181+
while not until(response):
182+
if time.time() - start_time > TIMEOUT_SECS:
183+
raise error.Timeout("Operation polling timed out.")
184+
185+
time.sleep(interval or response.retry_after or 10)
186+
response, b, api_key = self.request(method, url, params, headers)
187+
self._check_polling_response(response, failed)
188+
189+
response.data = response.data['result']
190+
return response, b, api_key
191+
192+
async def _apoll(
193+
self,
194+
method,
195+
url,
196+
until,
197+
failed,
198+
params = None,
199+
headers = None,
200+
interval = None,
201+
delay = None
202+
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
203+
if delay:
204+
await asyncio.sleep(delay)
205+
206+
response, b, api_key = await self.arequest(method, url, params, headers)
207+
self._check_polling_response(response, failed)
208+
start_time = time.time()
209+
while not until(response):
210+
if time.time() - start_time > TIMEOUT_SECS:
211+
raise error.Timeout("Operation polling timed out.")
212+
213+
await asyncio.sleep(interval or response.retry_after or 10)
214+
response, b, api_key = await self.arequest(method, url, params, headers)
215+
self._check_polling_response(response, failed)
216+
217+
response.data = response.data['result']
218+
return response, b, api_key
219+
154220
@overload
155221
def request(
156222
self,

openai/api_resources/image.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
from typing import Any, List
33

44
import openai
5-
from openai import api_requestor, util
5+
from openai import api_requestor, error, util
66
from openai.api_resources.abstract import APIResource
77

88

99
class Image(APIResource):
1010
OBJECT_NAME = "images"
1111

1212
@classmethod
13-
def _get_url(cls, action):
14-
return cls.class_url() + f"/{action}"
13+
def _get_url(cls, action, azure_action, api_type, api_version):
14+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD) and azure_action is not None:
15+
return f"/{cls.azure_api_prefix}{cls.class_url()}/{action}:{azure_action}?api-version={api_version}"
16+
else:
17+
return f"{cls.class_url()}/{action}"
1518

1619
@classmethod
1720
def create(
@@ -31,12 +34,20 @@ def create(
3134
organization=organization,
3235
)
3336

34-
_, api_version = cls._get_api_type_and_version(api_type, api_version)
37+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
3538

3639
response, _, api_key = requestor.request(
37-
"post", cls._get_url("generations"), params
40+
"post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
3841
)
3942

43+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
44+
requestor.api_base = "" # operation_location is a full url
45+
response, _, api_key = requestor._poll(
46+
"get", response.operation_location,
47+
until=lambda response: response.data['status'] in [ 'succeeded' ],
48+
failed=lambda response: response.data['status'] in [ 'failed' ]
49+
)
50+
4051
return util.convert_to_openai_object(
4152
response, api_key, api_version, organization
4253
)
@@ -60,12 +71,20 @@ async def acreate(
6071
organization=organization,
6172
)
6273

63-
_, api_version = cls._get_api_type_and_version(api_type, api_version)
74+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
6475

6576
response, _, api_key = await requestor.arequest(
66-
"post", cls._get_url("generations"), params
77+
"post", cls._get_url("generations", azure_action="submit", api_type=api_type, api_version=api_version), params
6778
)
6879

80+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
81+
requestor.api_base = "" # operation_location is a full url
82+
response, _, api_key = await requestor._apoll(
83+
"get", response.operation_location,
84+
until=lambda response: response.data['status'] in [ 'succeeded' ],
85+
failed=lambda response: response.data['status'] in [ 'failed' ]
86+
)
87+
6988
return util.convert_to_openai_object(
7089
response, api_key, api_version, organization
7190
)
@@ -88,9 +107,9 @@ def _prepare_create_variation(
88107
api_version=api_version,
89108
organization=organization,
90109
)
91-
_, api_version = cls._get_api_type_and_version(api_type, api_version)
110+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
92111

93-
url = cls._get_url("variations")
112+
url = cls._get_url("variations", azure_action=None, api_type=api_type, api_version=api_version)
94113

95114
files: List[Any] = []
96115
for key, value in params.items():
@@ -109,6 +128,9 @@ def create_variation(
109128
organization=None,
110129
**params,
111130
):
131+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
132+
raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
133+
112134
requestor, url, files = cls._prepare_create_variation(
113135
image,
114136
api_key,
@@ -136,6 +158,9 @@ async def acreate_variation(
136158
organization=None,
137159
**params,
138160
):
161+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
162+
raise error.InvalidAPIType("Variations are not supported by the Azure OpenAI API yet.")
163+
139164
requestor, url, files = cls._prepare_create_variation(
140165
image,
141166
api_key,
@@ -171,9 +196,9 @@ def _prepare_create_edit(
171196
api_version=api_version,
172197
organization=organization,
173198
)
174-
_, api_version = cls._get_api_type_and_version(api_type, api_version)
199+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
175200

176-
url = cls._get_url("edits")
201+
url = cls._get_url("edits", azure_action=None, api_type=api_type, api_version=api_version)
177202

178203
files: List[Any] = []
179204
for key, value in params.items():
@@ -195,6 +220,9 @@ def create_edit(
195220
organization=None,
196221
**params,
197222
):
223+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
224+
raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
225+
198226
requestor, url, files = cls._prepare_create_edit(
199227
image,
200228
mask,
@@ -224,6 +252,9 @@ async def acreate_edit(
224252
organization=None,
225253
**params,
226254
):
255+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
256+
raise error.InvalidAPIType("Edits are not supported by the Azure OpenAI API yet.")
257+
227258
requestor, url, files = cls._prepare_create_edit(
228259
image,
229260
mask,

openai/openai_response.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ def __init__(self, data, headers):
1010
def request_id(self) -> Optional[str]:
1111
return self._headers.get("request-id")
1212

13+
@property
14+
def retry_after(self) -> Optional[int]:
15+
try:
16+
return int(self._headers.get("retry-after"))
17+
except TypeError:
18+
return None
19+
20+
@property
21+
def operation_location(self) -> Optional[str]:
22+
return self._headers.get("operation-location")
23+
1324
@property
1425
def organization(self) -> Optional[str]:
1526
return self._headers.get("OpenAI-Organization")

0 commit comments

Comments
 (0)