2
2
from typing import Any , List
3
3
4
4
import openai
5
- from openai import api_requestor , util
5
+ from openai import api_requestor , error , util
6
6
from openai .api_resources .abstract import APIResource
7
7
8
8
9
9
class Image (APIResource ):
10
10
OBJECT_NAME = "images"
11
11
12
12
@classmethod
13
- def _get_url (cls , action ):
14
- return cls .class_url () + f"/{ action } "
13
+ def _get_url (cls , action , azure_action , api_type , api_version ):
14
+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ) and azure_action is not None :
15
+ return f"/{ cls .azure_api_prefix } { cls .class_url ()} /{ action } :{ azure_action } ?api-version={ api_version } "
16
+ else :
17
+ return f"{ cls .class_url ()} /{ action } "
15
18
16
19
@classmethod
17
20
def create (
@@ -31,12 +34,20 @@ def create(
31
34
organization = organization ,
32
35
)
33
36
34
- _ , api_version = cls ._get_api_type_and_version (api_type , api_version )
37
+ api_type , api_version = cls ._get_api_type_and_version (api_type , api_version )
35
38
36
39
response , _ , api_key = requestor .request (
37
- "post" , cls ._get_url ("generations" ), params
40
+ "post" , cls ._get_url ("generations" , azure_action = "submit" , api_type = api_type , api_version = api_version ), params
38
41
)
39
42
43
+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
44
+ requestor .api_base = "" # operation_location is a full url
45
+ response , _ , api_key = requestor ._poll (
46
+ "get" , response .operation_location ,
47
+ until = lambda response : response .data ['status' ] in [ 'succeeded' ],
48
+ failed = lambda response : response .data ['status' ] in [ 'failed' ]
49
+ )
50
+
40
51
return util .convert_to_openai_object (
41
52
response , api_key , api_version , organization
42
53
)
@@ -60,12 +71,20 @@ async def acreate(
60
71
organization = organization ,
61
72
)
62
73
63
- _ , api_version = cls ._get_api_type_and_version (api_type , api_version )
74
+ api_type , api_version = cls ._get_api_type_and_version (api_type , api_version )
64
75
65
76
response , _ , api_key = await requestor .arequest (
66
- "post" , cls ._get_url ("generations" ), params
77
+ "post" , cls ._get_url ("generations" , azure_action = "submit" , api_type = api_type , api_version = api_version ), params
67
78
)
68
79
80
+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
81
+ requestor .api_base = "" # operation_location is a full url
82
+ response , _ , api_key = await requestor ._apoll (
83
+ "get" , response .operation_location ,
84
+ until = lambda response : response .data ['status' ] in [ 'succeeded' ],
85
+ failed = lambda response : response .data ['status' ] in [ 'failed' ]
86
+ )
87
+
69
88
return util .convert_to_openai_object (
70
89
response , api_key , api_version , organization
71
90
)
@@ -88,9 +107,9 @@ def _prepare_create_variation(
88
107
api_version = api_version ,
89
108
organization = organization ,
90
109
)
91
- _ , api_version = cls ._get_api_type_and_version (api_type , api_version )
110
+ api_type , api_version = cls ._get_api_type_and_version (api_type , api_version )
92
111
93
- url = cls ._get_url ("variations" )
112
+ url = cls ._get_url ("variations" , azure_action = None , api_type = api_type , api_version = api_version )
94
113
95
114
files : List [Any ] = []
96
115
for key , value in params .items ():
@@ -109,6 +128,9 @@ def create_variation(
109
128
organization = None ,
110
129
** params ,
111
130
):
131
+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
132
+ raise error .InvalidAPIType ("Variations are not supported by the Azure OpenAI API yet." )
133
+
112
134
requestor , url , files = cls ._prepare_create_variation (
113
135
image ,
114
136
api_key ,
@@ -136,6 +158,9 @@ async def acreate_variation(
136
158
organization = None ,
137
159
** params ,
138
160
):
161
+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
162
+ raise error .InvalidAPIType ("Variations are not supported by the Azure OpenAI API yet." )
163
+
139
164
requestor , url , files = cls ._prepare_create_variation (
140
165
image ,
141
166
api_key ,
@@ -171,9 +196,9 @@ def _prepare_create_edit(
171
196
api_version = api_version ,
172
197
organization = organization ,
173
198
)
174
- _ , api_version = cls ._get_api_type_and_version (api_type , api_version )
199
+ api_type , api_version = cls ._get_api_type_and_version (api_type , api_version )
175
200
176
- url = cls ._get_url ("edits" )
201
+ url = cls ._get_url ("edits" , azure_action = None , api_type = api_type , api_version = api_version )
177
202
178
203
files : List [Any ] = []
179
204
for key , value in params .items ():
@@ -195,6 +220,9 @@ def create_edit(
195
220
organization = None ,
196
221
** params ,
197
222
):
223
+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
224
+ raise error .InvalidAPIType ("Edits are not supported by the Azure OpenAI API yet." )
225
+
198
226
requestor , url , files = cls ._prepare_create_edit (
199
227
image ,
200
228
mask ,
@@ -224,6 +252,9 @@ async def acreate_edit(
224
252
organization = None ,
225
253
** params ,
226
254
):
255
+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
256
+ raise error .InvalidAPIType ("Edits are not supported by the Azure OpenAI API yet." )
257
+
227
258
requestor , url , files = cls ._prepare_create_edit (
228
259
image ,
229
260
mask ,
0 commit comments