Skip to content

Commit

Permalink
Update dependencies and fix compatibility issues
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwanqq committed Apr 21, 2024
1 parent 23bcdda commit 723189d
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 63 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ license = "MIT"
description = "A unofficial Rust library for the Ernie API"
homepage = "https://github.com/chenwanqq/erniebot-rs"
repository = "https://github.com/chenwanqq/erniebot-rs"
version = "0.3.2"
version = "0.4.1"
edition = "2021"
exclude = [".github/",".vscode/",".gitignore"]

Expand All @@ -17,7 +17,8 @@ strum_macros = "0.26.1"
serde = {version = "1.0.197", features = ["derive"]}
serde_json = "1.0.113"
url = "2.5.0"
reqwest = {version = "0.12.3", features = ["json","blocking"]}
reqwest = {version = "0.12.3", features = ["json"]}
ureq = { version = "2.9.6", features = ["json", "charset"] }
thiserror = "1.0.57"
json_value_merge = "2.0"
reqwest-eventsource = "0.6.0"
Expand Down
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
Unofficial Baidu Ernie(Wenxin Yiyan, Qianfan) Rust SDK, currently supporting three modules: chat, text embedding (embedding), and text-to-image generation (text2image).

**update in 2024/04/09**: Add support for the bce-reranker-base-v1 rerank model
**update in 2024/04/21** For sync mode, use ureq instead of reqwest_blocking, hence it can improve the compatibility with tokio.

## Installation

Add the following to your Cargo.toml file:

```toml
[dependencies]
erniebot-rs = "0.3.2"
erniebot-rs = "0.4.1"
```

## Authentication
Expand Down Expand Up @@ -45,7 +46,7 @@ fn test_invoke() {
ChatOpt::TopP(0.5),
ChatOpt::TopK(50),
];
let response = chat.invoke(messages, options).unwrap();
let response = chat.invoke(&messages, &options).unwrap();
let result = response.get_chat_result().unwrap();
println!("{}", result);
}
Expand All @@ -64,7 +65,7 @@ fn test_custom_endpoint() {
},
];
let options = Vec::new();
let response = chat.invoke(messages, options).unwrap();
let response = chat.invoke(&messages, &options).unwrap();
let result = response.get_chat_result().unwrap();
println!("{}", result);
}
Expand All @@ -87,7 +88,7 @@ fn test_astream() {
let options = Vec::new();
let rt = Runtime::new().unwrap();
rt.block_on(async move {
let mut stream_response = chat.astream(messages, options).await.unwrap();
let mut stream_response = chat.astream(&messages, &options).await.unwrap();
while let Some(response) = stream_response.next().await {
let result = response.get_chat_result().unwrap();
print!("{}", result);
Expand Down Expand Up @@ -123,7 +124,7 @@ fn test_async_embedding() {
"你是谁".to_string(),
];
let rt = Runtime::new().unwrap();
let embedding_response = rt.block_on(embedding.ainvoke(input, None)).unwrap();
let embedding_response = rt.block_on(embedding.ainvoke(&input, None)).unwrap();
let embedding_results = embedding_response.get_embedding_results().unwrap();
println!("{},{}", embedding_results.len(), embedding_results[0].len());
}
Expand All @@ -141,7 +142,7 @@ fn main() {
Text2ImageOpt::Style(Style::DigitalArt),
Text2ImageOpt::Size(Size::S1024x768),
];
let text2image_response = text2image.invoke(prompt, options).unwrap();
let text2image_response = text2image.invoke(&prompt, &options).unwrap();
let image_results = text2image_response.get_image_results().unwrap();
for (index, image_string) in image_results.into_iter().enumerate() {
let image = base64_to_image(image_string).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
非官方的百度千帆大模型(文心一言,或者是Ernie,随便啦)SDK, 目前支持对话(chat),文本嵌入(embedding)以及文生图(text2image)三个模块。

**2024/04/09更新**: 添加对bce-reranker-base-v1重排序模型的支持
**2024/04/21更新** 对于同步模式,使用ureq替代reqwest_blocking,因此可以提高与tokio的兼容性。

## 安装

`Cargo.toml`文件中添加以下内容:

```toml
[dependencies]
erniebot-rs = "0.3.2"
erniebot-rs = "0.4.1"
```

## 鉴权
Expand Down
30 changes: 12 additions & 18 deletions src/chat/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,12 @@ impl ChatEndpoint {
options: &Vec<ChatOpt>,
) -> Result<Response, ErnieError> {
let body = ChatEndpoint::generate_body(messages, options, false)?;
let client = reqwest::blocking::Client::new();
let response: Value = client
.post(self.url.as_str())
.header("Content-Type", "application/json")
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
let response: Value = ureq::post(self.url.as_str())
.set("Content-Type", "application/json")
.query("access_token", self.access_token.as_str())
.send_json(body)
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.json()
.into_json()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;

//if error_code key in response, means RemoteAPIError
Expand All @@ -93,16 +90,13 @@ impl ChatEndpoint {
options: &Vec<ChatOpt>,
) -> Result<Responses, ErnieError> {
let body = ChatEndpoint::generate_body(messages, options, true)?;
let client = reqwest::blocking::Client::new();
let response = client
.post(self.url.as_str())
.header("Content-Type", "application/json")
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
.map_err(|e| ErnieError::StreamError(e.to_string()))?
.text()
.map_err(|e| ErnieError::StreamError(e.to_string()))?;
let response: String = ureq::post(self.url.as_str())
.set("Content-Type", "application/json")
.query("access_token", self.access_token.as_str())
.send_json(body)
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.into_string()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;
let response = Responses::from_text(&response)?;
Ok(response)
}
Expand Down
2 changes: 1 addition & 1 deletion src/chat/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub struct Responses {
}

impl Responses {
/// get Responses from reqwest::blocking::client. In the response body, it contains multiple responses split by blank line. This method will parse the response body and return a Responses struct.
/// get Responses from blocking response. In the response body, it contains multiple responses split by blank line. This method will parse the response body and return a Responses struct.
pub fn from_text(text: &str) -> Result<Self, ErnieError> {
let parts = text.split("\n\n").collect::<Vec<&str>>();
let mut result = Vec::new();
Expand Down
11 changes: 4 additions & 7 deletions src/embedding/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,11 @@ impl EmbeddingEndpoint {
if let Some(user_id) = user_id {
body.merge(&serde_json::json!({"user_id": user_id}));
}
let client = reqwest::blocking::Client::new();
let response: Value = client
.post(self.url.as_str())
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
let response: Value = ureq::post(self.url.as_str())
.query("access_token", self.access_token.as_str())
.send_json(body)
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.json()
.into_json()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;

//if error_code key in response, means RemoteAPIError
Expand Down
12 changes: 4 additions & 8 deletions src/reranker/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,12 @@ impl RerankerEndpoint {
if let Some(user_id) = user_id {
body.merge(&serde_json::json!({"user_id": user_id}));
}
let client = reqwest::blocking::Client::new();
let response: Value = client
.post(self.url.as_str())
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
let response: Value = ureq::post(self.url.as_str())
.query("access_token", self.access_token.as_str())
.send_json(body)
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.json()
.into_json()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;

//if error_code key in response, means RemoteAPIError
if response.get("error_code").is_some() {
return Err(ErnieError::RemoteAPIError(response.to_string()));
Expand Down
12 changes: 4 additions & 8 deletions src/text2image/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,12 @@ impl Text2ImageEndpoint {
options: &Vec<Text2ImageOpt>,
) -> Result<Text2ImageResponse, ErnieError> {
let body = Text2ImageEndpoint::generate_body(prompt, options);
let client = reqwest::blocking::Client::new();
let response: Value = client
.post(self.url.as_str())
.query(&[("access_token", self.access_token.as_str())])
.json(&body)
.send()
let response: Value = ureq::post(self.url.as_str())
.query("access_token", self.access_token.as_str())
.send_json(body)
.map_err(|e| ErnieError::InvokeError(e.to_string()))?
.json()
.into_json()
.map_err(|e| ErnieError::InvokeError(e.to_string()))?;

//if error_code key in response, means RemoteAPIError
if response.get("error_code").is_some() {
return Err(ErnieError::RemoteAPIError(response.to_string()));
Expand Down
18 changes: 6 additions & 12 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,14 @@ pub fn get_access_token() -> Result<String, ErnieError> {
.map_err(|_| ErnieError::GetAccessTokenError("QIANFAN_AK is not set".to_string()))?;
let sk = var("QIANFAN_SK")
.map_err(|_| ErnieError::GetAccessTokenError("QIANFAN_SK is not set".to_string()))?;

let client = reqwest::blocking::Client::new();
let res: Value = client
.post(url)
.query(&[
("grant_type", "client_credentials"),
("client_id", ak.as_str()),
("client_secret", sk.as_str()),
])
.send()
let res: Value = ureq::post(url)
.query("grant_type", "client_credentials")
.query("client_id", ak.as_str())
.query("client_secret", sk.as_str())
.call()
.map_err(|e| ErnieError::GetAccessTokenError(e.to_string()))?
.json()
.into_json()
.map_err(|e| ErnieError::GetAccessTokenError(e.to_string()))?;

if let Some(error) = res.get("error") {
let error_description = res.get("error_description").unwrap();
Err(ErnieError::GetAccessTokenError(format!(
Expand Down

0 comments on commit 723189d

Please sign in to comment.