Skip to content

Commit

Permalink
Merge pull request #6 from chenwanqq/to_refer
Browse files Browse the repository at this point in the history
To refer
  • Loading branch information
chenwanqq committed Apr 16, 2024
2 parents 621d8f0 + 39fa630 commit 1779ab1
Show file tree
Hide file tree
Showing 13 changed files with 46 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ license = "MIT"
description = "A unofficial Rust library for the Ernie API"
homepage = "https://github.com/chenwanqq/erniebot-rs"
repository = "https://github.com/chenwanqq/erniebot-rs"
version = "0.2.1"
version = "0.3.1"
edition = "2021"
exclude = [".github/",".vscode/",".gitignore"]

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Add the following to your Cargo.toml file:

```toml
[dependencies]
erniebot-rs = "0.2.1"
erniebot-rs = "0.3.1"
```

## Authentication
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

```toml
[dependencies]
erniebot-rs = "0.2.1"
erniebot-rs = "0.3.1"
```

## 鉴权
Expand Down
10 changes: 5 additions & 5 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn test_invoke() {
ChatOpt::TopP(0.5),
ChatOpt::TopK(50),
];
let response = chat.invoke(messages, options).unwrap();
let response = chat.invoke(&messages, &options).unwrap();
let result = response.get_chat_result().unwrap();
println!("{}", result);
}
Expand All @@ -37,7 +37,7 @@ fn test_stream() {
ChatOpt::TopP(0.5),
ChatOpt::TopK(50),
];
let response = chat.stream(messages, options).unwrap();
let response = chat.stream(&messages, &options).unwrap();
let result_by_chunk = response.get_results().unwrap();
println!("{:?}", result_by_chunk);
let whole_result = response.get_whole_result().unwrap();
Expand All @@ -55,7 +55,7 @@ fn test_ainvoke() {
];
let options = Vec::new();
let rt = Runtime::new().unwrap();
let response = rt.block_on(chat.ainvoke(messages, options)).unwrap();
let response = rt.block_on(chat.ainvoke(&messages, &options)).unwrap();
let result = response.get_chat_result().unwrap();
println!("{}", result);
}
Expand All @@ -72,7 +72,7 @@ fn test_astream() {
let options = Vec::new();
let rt = Runtime::new().unwrap();
rt.block_on(async move {
let mut stream_response = chat.astream(messages, options).await.unwrap();
let mut stream_response = chat.astream(&messages, &options).await.unwrap();
while let Some(response) = stream_response.next().await {
let result = response.get_chat_result().unwrap();
print!("{}", result);
Expand All @@ -93,7 +93,7 @@ fn test_custom_endpoint() {
},
];
let options = Vec::new();
let response = chat.invoke(messages, options).unwrap();
let response = chat.invoke(&messages, &options).unwrap();
let result = response.get_chat_result().unwrap();
println!("{}", result);
}
Expand Down
2 changes: 1 addition & 1 deletion examples/chat_with_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn main() {
let functions = vec![weather_function];
let options = vec![ChatOpt::Functions(functions)];
let chat = ChatEndpoint::new(ChatModel::ErnieBot).unwrap();
let response = chat.invoke(messages, options).unwrap();
let response = chat.invoke(&messages, &options).unwrap();
let text_response = response.get_chat_result().unwrap();
let function_call = response.get_function_call().unwrap();
println!("text_response: {}", text_response);
Expand Down
4 changes: 2 additions & 2 deletions examples/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn test_embedding() {
"你叫什么名字".to_string(),
"你是谁".to_string(),
];
let embedding_response = embedding.invoke(input, None).unwrap();
let embedding_response = embedding.invoke(&input, None).unwrap();
let embedding_results = embedding_response.get_embedding_results().unwrap();
println!("{},{}", embedding_results.len(), embedding_results[0].len());
}
Expand All @@ -19,7 +19,7 @@ fn test_async_embedding() {
"你是谁".to_string(),
];
let rt = Runtime::new().unwrap();
let embedding_response = rt.block_on(embedding.ainvoke(input, None)).unwrap();
let embedding_response = rt.block_on(embedding.ainvoke(&input, None)).unwrap();
let embedding_results = embedding_response.get_embedding_results().unwrap();
println!("{},{}", embedding_results.len(), embedding_results[0].len());
}
Expand Down
9 changes: 6 additions & 3 deletions examples/rerank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ fn test_reranker() {
"你叫什么名字".to_string(),
"你是谁".to_string(),
];
let reranker_response = reranker.invoke(query, documents, None, None).unwrap();
let reranker_response = reranker.invoke(&query, &documents, None, None).unwrap();
let reranker_results = reranker_response.get_reranker_response().unwrap();
let reranked_documents = reranker_results.into_iter().map(|x|x.document).collect::<Vec<String>>();
let reranked_documents = reranker_results
.into_iter()
.map(|x| x.document)
.collect::<Vec<String>>();
println!("{},{:?}", reranked_documents.len(), reranked_documents);
}

Expand All @@ -25,7 +28,7 @@ fn test_async_reranker() {
];
let rt = Runtime::new().unwrap();
let reranker_response = rt
.block_on(reranker.ainvoke(query, documents, None, None))
.block_on(reranker.ainvoke(&query, &documents, None, None))
.unwrap();
let reranker_results = reranker_response.get_reranker_response().unwrap();
println!("{},{:?}", reranker_results.len(), reranker_results);
Expand Down
2 changes: 1 addition & 1 deletion examples/text2image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn main() {
Text2ImageOpt::Style(Style::DigitalArt),
Text2ImageOpt::Size(Size::S1024x768),
];
let text2image_response = text2image.invoke(prompt, options).unwrap();
let text2image_response = text2image.invoke(&prompt, &options).unwrap();
let image_results = text2image_response.get_image_results().unwrap();
for (index, image_string) in image_results.into_iter().enumerate() {
let image = base64_to_image(image_string).unwrap();
Expand Down
24 changes: 12 additions & 12 deletions src/chat/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ impl ChatEndpoint {
}

fn generate_body(
messages: Vec<Message>,
options: Vec<ChatOpt>,
messages: &Vec<Message>,
options: &Vec<ChatOpt>,
stream: bool,
) -> Result<serde_json::Value, ErnieError> {
let mut body = serde_json::json!({
Expand All @@ -64,8 +64,8 @@ impl ChatEndpoint {
/// invoke method is used to send a request to erniebot chat endpoint. This is a blocking method that will return a full response from the chat endpoint
pub fn invoke(
&self,
messages: Vec<Message>,
options: Vec<ChatOpt>,
messages: &Vec<Message>,
options: &Vec<ChatOpt>,
) -> Result<Response, ErnieError> {
let body = ChatEndpoint::generate_body(messages, options, false)?;
let client = reqwest::blocking::Client::new();
Expand All @@ -88,8 +88,8 @@ impl ChatEndpoint {
/// stream method is used to send a request to erniebot chat endpoint. This is a blocking method that will return response in multiple chunks from the chat endpoint
pub fn stream(
&self,
messages: Vec<Message>,
options: Vec<ChatOpt>,
messages: &Vec<Message>,
options: &Vec<ChatOpt>,
) -> Result<Responses, ErnieError> {
let body = ChatEndpoint::generate_body(messages, options, true)?;
let client = reqwest::blocking::Client::new();
Expand All @@ -102,15 +102,15 @@ impl ChatEndpoint {
.map_err(|e| ErnieError::StreamError(e.to_string()))?
.text()
.map_err(|e| ErnieError::StreamError(e.to_string()))?;
let response = Responses::from_text(response)?;
let response = Responses::from_text(&response)?;
Ok(response)
}

/// ainvoke method is used to send a request to erniebot chat endpoint. This is an async method that will return a full response from the chat endpoint
pub async fn ainvoke(
&self,
messages: Vec<Message>,
options: Vec<ChatOpt>,
messages: &Vec<Message>,
options: &Vec<ChatOpt>,
) -> Result<Response, ErnieError> {
let body = ChatEndpoint::generate_body(messages, options, false)?;
let client = reqwest::Client::new();
Expand All @@ -136,8 +136,8 @@ impl ChatEndpoint {
/// astream method is used to send a request to erniebot chat endpoint. This is an async method that will return response in multiple chunks from the chat endpoint
pub async fn astream(
&self,
messages: Vec<Message>,
options: Vec<ChatOpt>,
messages: &Vec<Message>,
options: &Vec<ChatOpt>,
) -> Result<StreamResponse, ErnieError> {
let body = ChatEndpoint::generate_body(messages, options, true)?;
let client = reqwest::Client::new();
Expand Down Expand Up @@ -191,7 +191,7 @@ mod tests {
ChatOpt::TopP(0.5),
ChatOpt::TopK(50),
];
let result = ChatEndpoint::generate_body(messages, options, true).unwrap();
let result = ChatEndpoint::generate_body(&messages, &options, true).unwrap();
let s = serde_json::to_string(&result).unwrap();
println!("{}", s);
}
Expand Down
2 changes: 1 addition & 1 deletion src/chat/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub struct Responses {

impl Responses {
/// get Responses from reqwest::blocking::client. In the response body, it contains multiple responses split by blank line. This method will parse the response body and return a Responses struct.
pub fn from_text(text: String) -> Result<Self, ErnieError> {
pub fn from_text(text: &str) -> Result<Self, ErnieError> {
let parts = text.split("\n\n").collect::<Vec<&str>>();
let mut result = Vec::new();
for part in parts {
Expand Down
8 changes: 4 additions & 4 deletions src/embedding/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ impl EmbeddingEndpoint {
/// sync invoke
pub fn invoke(
&self,
input: Vec<String>,
user_id: Option<String>,
input: &Vec<String>,
user_id: Option<&str>,
) -> Result<EmbeddingResponse, ErnieError> {
let mut body = serde_json::json!({
"input": input,
Expand Down Expand Up @@ -56,8 +56,8 @@ impl EmbeddingEndpoint {
///async invoke
pub async fn ainvoke(
&self,
input: Vec<String>,
user_id: Option<String>,
input: &Vec<String>,
user_id: Option<&str>,
) -> Result<EmbeddingResponse, ErnieError> {
let mut body = serde_json::json!({
"input": input,
Expand Down
12 changes: 6 additions & 6 deletions src/reranker/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ impl RerankerEndpoint {
/// sync invoke
pub fn invoke(
&self,
query: String,
documents: Vec<String>,
query: &str,
documents: &Vec<String>,
top_n: Option<u64>,
user_id: Option<String>,
user_id: Option<&str>,
) -> Result<RerankerResponse, ErnieError> {
let mut body = serde_json::json!({
"query": query,
Expand Down Expand Up @@ -62,10 +62,10 @@ impl RerankerEndpoint {
///async invoke
pub async fn ainvoke(
&self,
query: String,
documents: Vec<String>,
query: &str,
documents: &Vec<String>,
top_n: Option<u64>,
user_id: Option<String>,
user_id: Option<&str>,
) -> Result<RerankerResponse, ErnieError> {
let mut body = serde_json::json!({
"query": query,
Expand Down
10 changes: 5 additions & 5 deletions src/text2image/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Text2ImageEndpoint {
})
}

fn generate_body(prompt: String, options: Vec<Text2ImageOpt>) -> serde_json::Value {
fn generate_body(prompt: &str, options: &Vec<Text2ImageOpt>) -> serde_json::Value {
let mut body = serde_json::json!({
"prompt": prompt,
});
Expand All @@ -46,8 +46,8 @@ impl Text2ImageEndpoint {
/// sync invoke
pub fn invoke(
&self,
prompt: String,
options: Vec<Text2ImageOpt>,
prompt: &str,
options: &Vec<Text2ImageOpt>,
) -> Result<Text2ImageResponse, ErnieError> {
let body = Text2ImageEndpoint::generate_body(prompt, options);
let client = reqwest::blocking::Client::new();
Expand All @@ -71,8 +71,8 @@ impl Text2ImageEndpoint {
///async invoke
pub async fn ainvoke(
&self,
prompt: String,
options: Vec<Text2ImageOpt>,
prompt: &str,
options: &Vec<Text2ImageOpt>,
) -> Result<Text2ImageResponse, ErnieError> {
let body = Text2ImageEndpoint::generate_body(prompt, options);
let client = reqwest::Client::new();
Expand Down

0 comments on commit 1779ab1

Please sign in to comment.