Skip to content

Commit 462b58c

Browse files
committed
feat(serialization): add two-decimal precision serialization for float fields
Added a new `serialize_f32_two_decimals` helper function to serialize Option<f32> fields with exactly two decimal places. Applied this serialization to all float fields in ChatBody (temperature, top_p, n, presence_penalty, frequency_penalty) to ensure consistent numeric formatting in API requests. Also improved error handling in requests.rs by: 1. Adding fallback for JSON parsing errors 2. Enhancing error messages with status codes 3. Making error messages more descriptive Signed-off-by: jinlong <jinlong@tencent.com>
1 parent 2f2c874 commit 462b58c

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

src/apis/chat.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@ use std::collections::HashMap;
77

88
use crate::requests::Requests;
99
use crate::*;
10-
use serde::{Deserialize, Serialize};
10+
use serde::{Deserialize, Serialize, Serializer};
11+
12+
fn serialize_f32_two_decimals<S>(value: &Option<f32>, serializer: S) -> Result<S::Ok, S::Error>
13+
where
14+
S: Serializer,
15+
{
16+
match value {
17+
Some(v) => serializer.serialize_f64((*v as f64 * 100.0).round() / 100.0),
18+
None => serializer.serialize_none(),
19+
}
20+
}
21+
1122

1223
use super::{completions::Completion, CHAT_COMPLETION_CREATE};
1324

@@ -23,14 +34,14 @@ pub struct ChatBody {
2334
/// while lower values like 0.2 will make it more focused and deterministic.
2435
/// We generally recommend altering this or top_p but not both.
2536
/// Defaults to 1
26-
#[serde(skip_serializing_if = "Option::is_none")]
37+
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_f32_two_decimals")]
2738
pub temperature: Option<f32>,
2839
/// An alternative to sampling with temperature, called nucleus sampling,
2940
/// where the model considers the results of the tokens with top_p probability mass.
3041
/// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
3142
/// We generally recommend altering this or temperature but not both.
3243
/// Defaults to 1
33-
#[serde(skip_serializing_if = "Option::is_none")]
44+
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_f32_two_decimals")]
3445
pub top_p: Option<f32>,
3546
/// How many chat completion choices to generate for each input message.
3647
/// Defaults to 1
@@ -55,13 +66,13 @@ pub struct ChatBody {
5566
/// Positive values penalize new tokens based on whether they appear in the text so far,
5667
/// increasing the model's likelihood to talk about new topics.
5768
/// Defaults to 0
58-
#[serde(skip_serializing_if = "Option::is_none")]
69+
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_f32_two_decimals")]
5970
pub presence_penalty: Option<f32>,
6071
/// Number between -2.0 and 2.0.
6172
/// Positive values penalize new tokens based on their existing frequency in the text so far,
6273
/// decreasing the model's likelihood to repeat the same line verbatim.
6374
/// Defaults to 0
64-
#[serde(skip_serializing_if = "Option::is_none")]
75+
#[serde(skip_serializing_if = "Option::is_none", serialize_with = "serialize_f32_two_decimals")]
6576
pub frequency_penalty: Option<f32>,
6677
/// Modify the likelihood of specified tokens appearing in the completion.
6778
/// Accepts a json object that maps tokens (specified by their token ID in the tokenizer)

src/requests.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,12 @@ fn deal_response(response: Result<ureq::Response, ureq::Error>, sub_url: &str) -
7070
},
7171
Err(err) => match err {
7272
ureq::Error::Status(status, response) => {
73-
let error_msg = response.into_json::<Json>().unwrap();
73+
let mut error_msg = response
74+
.into_json::<Json>()
75+
.unwrap_or_else(|x| serde_json::Value::String(x.to_string()));
76+
if let serde_json::Value::String(ref mut s) = error_msg {
77+
*s = format!("status: {}, msg: {}", status, s);
78+
}
7479
error!("<== ❌\n\tError api: {sub_url}, status: {status}, error: {error_msg}");
7580
return Err(Error::ApiError(format!("{error_msg}")));
7681
},

0 commit comments

Comments
 (0)