diff --git a/Cargo.toml b/Cargo.toml index 9a18ed7..ba2cebd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ license = "MIT" description = "A unofficial Rust library for the Ernie API" homepage = "https://github.com/chenwanqq/erniebot-rs" repository = "https://github.com/chenwanqq/erniebot-rs" -version = "0.3.2" +version = "0.4.1" edition = "2021" exclude = [".github/",".vscode/",".gitignore"] @@ -17,7 +17,8 @@ strum_macros = "0.26.1" serde = {version = "1.0.197", features = ["derive"]} serde_json = "1.0.113" url = "2.5.0" -reqwest = {version = "0.12.3", features = ["json","blocking"]} +reqwest = {version = "0.12.3", features = ["json"]} +ureq = { version = "2.9.6", features = ["json", "charset"] } thiserror = "1.0.57" json_value_merge = "2.0" reqwest-eventsource = "0.6.0" diff --git a/README.md b/README.md index 0ec644c..0015fa6 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ Unofficial Baidu Ernie(Wenxin Yiyan, Qianfan) Rust SDK, currently supporting three modules: chat, text embedding (embedding), and text-to-image generation (text2image). **update in 2024/04/09**: Add support for the bce-reranker-base-v1 rerank model +**update in 2024/04/21** For sync mode, use ureq instead of reqwest_blocking, hence it can improve the compatibility with tokio. ## Installation @@ -10,7 +11,7 @@ Add the following to your Cargo.toml file: ```toml [dependencies] -erniebot-rs = "0.3.2" +erniebot-rs = "0.4.1" ``` ## Authentication @@ -45,7 +46,7 @@ fn test_invoke() { ChatOpt::TopP(0.5), ChatOpt::TopK(50), ]; - let response = chat.invoke(messages, options).unwrap(); + let response = chat.invoke(&messages, &options).unwrap(); let result = response.get_chat_result().unwrap(); println!("{}", result); } @@ -64,7 +65,7 @@ fn test_custom_endpoint() { }, ]; let options = Vec::new(); - let response = chat.invoke(messages, options).unwrap(); + let response = chat.invoke(&messages, &options).unwrap(); let result = response.get_chat_result().unwrap(); println!("{}", result); } @@ -87,7 +88,7 @@ fn test_astream() { let options = Vec::new(); let rt = Runtime::new().unwrap(); rt.block_on(async move { - let mut stream_response = chat.astream(messages, options).await.unwrap(); + let mut stream_response = chat.astream(&messages, &options).await.unwrap(); while let Some(response) = stream_response.next().await { let result = response.get_chat_result().unwrap(); print!("{}", result); @@ -123,7 +124,7 @@ fn test_async_embedding() { "你是谁".to_string(), ]; let rt = Runtime::new().unwrap(); - let embedding_response = rt.block_on(embedding.ainvoke(input, None)).unwrap(); + let embedding_response = rt.block_on(embedding.ainvoke(&input, None)).unwrap(); let embedding_results = embedding_response.get_embedding_results().unwrap(); println!("{},{}", embedding_results.len(), embedding_results[0].len()); } @@ -141,7 +142,7 @@ fn main() { Text2ImageOpt::Style(Style::DigitalArt), Text2ImageOpt::Size(Size::S1024x768), ]; - let text2image_response = text2image.invoke(prompt, options).unwrap(); + let text2image_response = text2image.invoke(&prompt, &options).unwrap(); let image_results = text2image_response.get_image_results().unwrap(); for (index, image_string) in image_results.into_iter().enumerate() { let image = base64_to_image(image_string).unwrap(); diff --git a/README_zh.md b/README_zh.md index 70c4e81..f261a3a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -3,6 +3,7 @@ 非官方的百度千帆大模型(文心一言,或者是Ernie,随便啦)SDK, 目前支持对话(chat),文本嵌入(embedding)以及文生图(text2image)三个模块。 **2024/04/09更新**: 添加对bce-reranker-base-v1重排序模型的支持 +**2024/04/21更新** 对于同步模式,使用ureq替代reqwest_blocking,因此可以提高与tokio的兼容性。 ## 安装 @@ -10,7 +11,7 @@ ```toml [dependencies] -erniebot-rs = "0.3.2" +erniebot-rs = "0.4.1" ``` ## 鉴权 diff --git a/src/chat/endpoint.rs b/src/chat/endpoint.rs index 1f19a5a..51f5b28 100644 --- a/src/chat/endpoint.rs +++ b/src/chat/endpoint.rs @@ -69,15 +69,12 @@ impl ChatEndpoint { options: &Vec, ) -> Result { let body = ChatEndpoint::generate_body(messages, options, false)?; - let client = reqwest::blocking::Client::new(); - let response: Value = client - .post(self.url.as_str()) - .header("Content-Type", "application/json") - .query(&[("access_token", self.access_token.as_str())]) - .json(&body) - .send() + let response: Value = ureq::post(self.url.as_str()) + .set("Content-Type", "application/json") + .query("access_token", self.access_token.as_str()) + .send_json(body) .map_err(|e| ErnieError::InvokeError(e.to_string()))? - .json() + .into_json() .map_err(|e| ErnieError::InvokeError(e.to_string()))?; //if error_code key in response, means RemoteAPIError @@ -93,16 +90,13 @@ impl ChatEndpoint { options: &Vec, ) -> Result { let body = ChatEndpoint::generate_body(messages, options, true)?; - let client = reqwest::blocking::Client::new(); - let response = client - .post(self.url.as_str()) - .header("Content-Type", "application/json") - .query(&[("access_token", self.access_token.as_str())]) - .json(&body) - .send() - .map_err(|e| ErnieError::StreamError(e.to_string()))? - .text() - .map_err(|e| ErnieError::StreamError(e.to_string()))?; + let response: String = ureq::post(self.url.as_str()) + .set("Content-Type", "application/json") + .query("access_token", self.access_token.as_str()) + .send_json(body) + .map_err(|e| ErnieError::InvokeError(e.to_string()))? + .into_string() + .map_err(|e| ErnieError::InvokeError(e.to_string()))?; let response = Responses::from_text(&response)?; Ok(response) } diff --git a/src/chat/response.rs b/src/chat/response.rs index c11434b..a6988cf 100644 --- a/src/chat/response.rs +++ b/src/chat/response.rs @@ -81,7 +81,7 @@ pub struct Responses { } impl Responses { - /// get Responses from reqwest::blocking::client. In the response body, it contains multiple responses split by blank line. This method will parse the response body and return a Responses struct. + /// get Responses from blocking response. In the response body, it contains multiple responses split by blank line. This method will parse the response body and return a Responses struct. pub fn from_text(text: &str) -> Result { let parts = text.split("\n\n").collect::>(); let mut result = Vec::new(); diff --git a/src/embedding/endpoint.rs b/src/embedding/endpoint.rs index de808ec..4fd3494 100644 --- a/src/embedding/endpoint.rs +++ b/src/embedding/endpoint.rs @@ -37,14 +37,11 @@ impl EmbeddingEndpoint { if let Some(user_id) = user_id { body.merge(&serde_json::json!({"user_id": user_id})); } - let client = reqwest::blocking::Client::new(); - let response: Value = client - .post(self.url.as_str()) - .query(&[("access_token", self.access_token.as_str())]) - .json(&body) - .send() + let response: Value = ureq::post(self.url.as_str()) + .query("access_token", self.access_token.as_str()) + .send_json(body) .map_err(|e| ErnieError::InvokeError(e.to_string()))? - .json() + .into_json() .map_err(|e| ErnieError::InvokeError(e.to_string()))?; //if error_code key in response, means RemoteAPIError diff --git a/src/reranker/endpoint.rs b/src/reranker/endpoint.rs index 20e6783..3e8905a 100644 --- a/src/reranker/endpoint.rs +++ b/src/reranker/endpoint.rs @@ -43,16 +43,12 @@ impl RerankerEndpoint { if let Some(user_id) = user_id { body.merge(&serde_json::json!({"user_id": user_id})); } - let client = reqwest::blocking::Client::new(); - let response: Value = client - .post(self.url.as_str()) - .query(&[("access_token", self.access_token.as_str())]) - .json(&body) - .send() + let response: Value = ureq::post(self.url.as_str()) + .query("access_token", self.access_token.as_str()) + .send_json(body) .map_err(|e| ErnieError::InvokeError(e.to_string()))? - .json() + .into_json() .map_err(|e| ErnieError::InvokeError(e.to_string()))?; - //if error_code key in response, means RemoteAPIError if response.get("error_code").is_some() { return Err(ErnieError::RemoteAPIError(response.to_string())); diff --git a/src/text2image/endpoint.rs b/src/text2image/endpoint.rs index 6cbf3bb..055acf9 100644 --- a/src/text2image/endpoint.rs +++ b/src/text2image/endpoint.rs @@ -51,16 +51,12 @@ impl Text2ImageEndpoint { options: &Vec, ) -> Result { let body = Text2ImageEndpoint::generate_body(prompt, options); - let client = reqwest::blocking::Client::new(); - let response: Value = client - .post(self.url.as_str()) - .query(&[("access_token", self.access_token.as_str())]) - .json(&body) - .send() + let response: Value = ureq::post(self.url.as_str()) + .query("access_token", self.access_token.as_str()) + .send_json(body) .map_err(|e| ErnieError::InvokeError(e.to_string()))? - .json() + .into_json() .map_err(|e| ErnieError::InvokeError(e.to_string()))?; - //if error_code key in response, means RemoteAPIError if response.get("error_code").is_some() { return Err(ErnieError::RemoteAPIError(response.to_string())); diff --git a/src/utils.rs b/src/utils.rs index 46f41ee..15533ed 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -11,20 +11,14 @@ pub fn get_access_token() -> Result { .map_err(|_| ErnieError::GetAccessTokenError("QIANFAN_AK is not set".to_string()))?; let sk = var("QIANFAN_SK") .map_err(|_| ErnieError::GetAccessTokenError("QIANFAN_SK is not set".to_string()))?; - - let client = reqwest::blocking::Client::new(); - let res: Value = client - .post(url) - .query(&[ - ("grant_type", "client_credentials"), - ("client_id", ak.as_str()), - ("client_secret", sk.as_str()), - ]) - .send() + let res: Value = ureq::post(url) + .query("grant_type", "client_credentials") + .query("client_id", ak.as_str()) + .query("client_secret", sk.as_str()) + .call() .map_err(|e| ErnieError::GetAccessTokenError(e.to_string()))? - .json() + .into_json() .map_err(|e| ErnieError::GetAccessTokenError(e.to_string()))?; - if let Some(error) = res.get("error") { let error_description = res.get("error_description").unwrap(); Err(ErnieError::GetAccessTokenError(format!(