Skip to content

Commit 6a6b7ff

Browse files
authored
Refactor: multipart removed, deps reduced, more tests passing (#12)
* refactor: replace abandoned crate "multipart" with copied code * chore: fmt * refactor: duplicate use statement removed * test: update deprecated model ids * test: fix test by adding a likely word in the completion * test: permission field was removed by openai in 2023 * feat: edits no longer exists in the openai api * test: "this" more likely that "see" to be returned. Still a flaky test. * test: test_audio_transcription now passes * chore: removed mime_guess crate * chore: removed unused code * refactor: image_variation(), image_edit() no longer duplicate code
1 parent fc64725 commit 6a6b7ff

12 files changed

+356
-1098
lines changed

Cargo.lock

Lines changed: 15 additions & 966 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ ureq = { version = "^2.6", features = ["json"] }
1414
serde = { version = "^1.0", features = ["derive"] }
1515
serde_json = "^1.0"
1616
log = "^0.4"
17-
multipart = "^0.18.0"
1817
mime = "^0.3.16"
18+
rand = "0.8.5"

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ check [official API reference](https://platform.openai.com/docs/api-reference)
1616
|Models|✔️|
1717
|Completions|✔️|
1818
|Chat|✔️|
19-
|Edits|✔️|
2019
|Images|✔️|
2120
|Embeddings|✔️|
2221
|Audio|✔️|

src/apis/audio.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
55
use std::fs::File;
66

7-
use multipart::client::lazy::Multipart;
7+
use crate::mpart::Mpart as Multipart;
88
use serde::{Deserialize, Serialize};
99

1010
use crate::requests::Requests;
@@ -64,7 +64,7 @@ impl AudioApi for OpenAI {
6464
send_data.add_text("language", language);
6565
}
6666

67-
send_data.add_stream("file", audio_body.file, Some("mp3"), None);
67+
send_data.add_stream("file", audio_body.file, Some("audio.mp3"), None);
6868

6969
let res = self.post_multipart(AUDIO_TRANSCRIPTION_CREATE, send_data)?;
7070
let audio: Audio = serde_json::from_value(res.clone()).unwrap();
@@ -88,7 +88,7 @@ impl AudioApi for OpenAI {
8888
send_data.add_text("language", language);
8989
}
9090

91-
send_data.add_stream("file", audio_body.file, Some("mp3"), None);
91+
send_data.add_stream("file", audio_body.file, Some("audio.mp3"), None);
9292

9393
let res = self.post_multipart(AUDIO_TRANSLATIONS_CREATE, send_data)?;
9494
let audio: Audio = serde_json::from_value(res.clone()).unwrap();

src/apis/completions.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ mod tests {
150150
fn test_completions() {
151151
let openai = new_test_openai();
152152
let body = CompletionsBody {
153-
model: "babbage".to_string(),
153+
model: "babbage-002".to_string(),
154154
prompt: Some(vec!["Say this is a test".to_string()]),
155155
suffix: None,
156156
max_tokens: Some(7),
@@ -170,6 +170,6 @@ mod tests {
170170
let rs = openai.completion_create(&body);
171171
let choice = rs.unwrap().choices;
172172
let text = &choice[0].text.as_ref().unwrap();
173-
assert!(text.contains("of the new system"));
173+
assert!(text.contains("this"));
174174
}
175175
}

src/apis/edits.rs

Lines changed: 0 additions & 62 deletions
This file was deleted.

src/apis/images.rs

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
//! Images API
55
66
use super::{IMAGES_CREATE, IMAGES_EDIT, IMAGES_VARIATIONS};
7+
use crate::mpart::Mpart as Multipart;
78
use crate::requests::Requests;
89
use crate::*;
9-
use multipart::client::lazy::Multipart;
1010
use serde::{Deserialize, Serialize};
1111
use std::{fs::File, str};
1212

@@ -57,6 +57,12 @@ pub struct ImageData {
5757
pub trait ImagesApi {
5858
/// Given a prompt and/or an input image, the model will generate a new image.
5959
fn image_create(&self, images_body: &ImagesBody) -> ApiResult<Images>;
60+
/// generates multipart data for image fns
61+
fn image_build_send_data_from_body(
62+
&self,
63+
images_edit_body: ImagesEditBody,
64+
url: &str,
65+
) -> ApiResult<Images>;
6066
/// Creates an edited or extended image given an original image and a prompt.
6167
fn image_edit(&self, images_edit_body: ImagesEditBody) -> ApiResult<Images>;
6268
/// Creates a variation of a given image.
@@ -71,10 +77,16 @@ impl ImagesApi for OpenAI {
7177
Ok(images)
7278
}
7379

74-
fn image_edit(&self, images_edit_body: ImagesEditBody) -> ApiResult<Images> {
80+
fn image_build_send_data_from_body(
81+
&self,
82+
images_edit_body: ImagesEditBody,
83+
url: &str,
84+
) -> ApiResult<Images> {
7585
let mut send_data = Multipart::new();
7686

77-
send_data.add_text("prompt", images_edit_body.images_body.prompt);
87+
if IMAGES_EDIT == url {
88+
send_data.add_text("prompt", images_edit_body.images_body.prompt);
89+
}
7890
if let Some(n) = images_edit_body.images_body.n {
7991
send_data.add_text("n", n.to_string());
8092
}
@@ -92,31 +104,17 @@ impl ImagesApi for OpenAI {
92104
}
93105
send_data.add_stream("image", images_edit_body.image, Some("blob"), Some(mime::IMAGE_PNG));
94106

95-
let res = self.post_multipart(IMAGES_EDIT, send_data)?;
107+
let res = self.post_multipart(url, send_data)?;
96108
let images: Images = serde_json::from_value(res.clone()).unwrap();
97109
Ok(images)
98110
}
99111

100-
fn image_variation(&self, images_edit_body: ImagesEditBody) -> ApiResult<Images> {
101-
let mut send_data = Multipart::new();
102-
103-
if let Some(n) = images_edit_body.images_body.n {
104-
send_data.add_text("n", n.to_string());
105-
}
106-
if let Some(size) = images_edit_body.images_body.size {
107-
send_data.add_text("size", size.to_string());
108-
}
109-
if let Some(response_format) = images_edit_body.images_body.response_format {
110-
send_data.add_text("response_format", response_format.to_string());
111-
}
112-
if let Some(user) = images_edit_body.images_body.user {
113-
send_data.add_text("user", user.to_string());
114-
}
115-
send_data.add_stream("image", images_edit_body.image, Some("blob"), Some(mime::IMAGE_PNG));
112+
fn image_edit(&self, images_edit_body: ImagesEditBody) -> ApiResult<Images> {
113+
self.image_build_send_data_from_body(images_edit_body, IMAGES_EDIT)
114+
}
116115

117-
let res = self.post_multipart(IMAGES_VARIATIONS, send_data)?;
118-
let images: Images = serde_json::from_value(res.clone()).unwrap();
119-
Ok(images)
116+
fn image_variation(&self, images_edit_body: ImagesEditBody) -> ApiResult<Images> {
117+
self.image_build_send_data_from_body(images_edit_body, IMAGES_VARIATIONS)
120118
}
121119
}
122120

src/apis/mod.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize};
33
pub mod audio;
44
pub mod chat;
55
pub mod completions;
6-
pub mod edits;
76
pub mod embeddings;
87
pub mod images;
98
pub mod models;
@@ -15,8 +14,6 @@ const MODELS_RETRIEVE: &str = "models/";
1514
const COMPLETION_CREATE: &str = "completions";
1615
// Chat API
1716
const CHAT_COMPLETION_CREATE: &str = "chat/completions";
18-
// Edits API
19-
const EDIT_CREATE: &str = "edits";
2017
// Images API
2118
const IMAGES_CREATE: &str = "images/generations";
2219
const IMAGES_EDIT: &str = "images/edits";

src/apis/models.rs

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,6 @@ pub struct Model {
1919
pub id: String,
2020
pub object: Option<String>,
2121
pub owned_by: Option<String>,
22-
pub permission: Vec<Permission>,
23-
}
24-
25-
#[derive(Debug, Serialize, Deserialize)]
26-
pub struct Permission {
27-
pub id: String,
28-
pub object: Option<String>,
29-
pub created: u64,
30-
pub allow_create_engine: bool,
31-
pub allow_sampling: bool,
32-
pub allow_logprobs: bool,
33-
pub allow_search_indices: bool,
34-
pub allow_view: bool,
35-
pub allow_fine_tuning: bool,
36-
pub organization: Option<String>,
37-
pub group: Option<String>,
38-
pub is_blocking: bool,
3922
}
4023

4124
pub trait ModelsApi {
@@ -79,7 +62,7 @@ mod tests {
7962
#[test]
8063
fn test_get_model() {
8164
let openai = new_test_openai();
82-
let model = openai.models_retrieve("babbage").unwrap();
83-
assert_eq!("babbage", model.id);
65+
let model = openai.models_retrieve("babbage-002").unwrap();
66+
assert_eq!("babbage-002", model.id);
8467
}
8568
}

src/lib.rs

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
//! use openai_api_rust::*;
1212
//! use openai_api_rust::chat::*;
1313
//! use openai_api_rust::completions::*;
14-
//!
14+
//!
1515
//! fn main() {
1616
//! // Load API key from environment OPENAI_API_KEY.
1717
//! // You can also hadcode through `Auth::new(<your_api_key>)`, but it is not recommended.
@@ -37,38 +37,35 @@
3737
//! assert!(message.content.contains("Hello"));
3838
//! }
3939
//! ```
40-
//!
40+
//!
4141
//! ## Use proxy
42-
//!
42+
//!
4343
//! ```rust
4444
//! // Load proxy from env
4545
//! let openai = OpenAI::new(auth, "https://api.openai.com/v1/")
4646
//! .use_env_proxy();
47-
//!
47+
//!
4848
//! // Set the proxy manually
4949
//! let openai = OpenAI::new(auth, "https://api.openai.com/v1/")
5050
//! .set_proxy("http://127.0.0.1:1080");
5151
//! ```
5252
53-
54-
5553
#![warn(unused_crate_dependencies)]
5654

5755
pub mod apis;
58-
use std::fmt::{Display, Formatter, self};
56+
use std::fmt::{self, Display, Formatter};
5957

6058
pub use apis::*;
6159
pub mod openai;
6260
pub use openai::*;
61+
mod mpart;
6362
mod requests;
6463

6564
use log as _;
6665

6766
pub type Json = serde_json::Value;
6867
pub type ApiResult<T> = Result<T, Error>;
6968

70-
pub use openai::*;
71-
7269
#[derive(Debug)]
7370
pub enum Error {
7471
/// An Error returned by the API
@@ -78,10 +75,10 @@ pub enum Error {
7875
}
7976

8077
impl Display for Error {
81-
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
82-
match self {
83-
Error::ApiError(msg) => write!(f, "API error: {}", msg),
84-
Error::RequestError(msg) => write!(f, "Request error: {}", msg),
85-
}
86-
}
87-
}
78+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
79+
match self {
80+
Error::ApiError(msg) => write!(f, "API error: {}", msg),
81+
Error::RequestError(msg) => write!(f, "Request error: {}", msg),
82+
}
83+
}
84+
}

0 commit comments

Comments
 (0)