diff --git a/Cargo.toml b/Cargo.toml index b5bfde5..baf6747 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ license = "MIT" description = "A unofficial Rust library for the Ernie API" homepage = "https://github.com/chenwanqq/erniebot-rs" repository = "https://github.com/chenwanqq/erniebot-rs" -version = "0.2.1" +version = "0.3.1" edition = "2021" exclude = [".github/",".vscode/",".gitignore"] diff --git a/README.md b/README.md index 3c5c37c..6f0f256 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Add the following to your Cargo.toml file: ```toml [dependencies] -erniebot-rs = "0.2.1" +erniebot-rs = "0.3.1" ``` ## Authentication diff --git a/README_zh.md b/README_zh.md index 09b535c..1dbb354 100644 --- a/README_zh.md +++ b/README_zh.md @@ -10,7 +10,7 @@ ```toml [dependencies] -erniebot-rs = "0.2.1" +erniebot-rs = "0.3.1" ``` ## 鉴权 diff --git a/examples/chat.rs b/examples/chat.rs index 6573f2a..ee70032 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -18,7 +18,7 @@ fn test_invoke() { ChatOpt::TopP(0.5), ChatOpt::TopK(50), ]; - let response = chat.invoke(messages, options).unwrap(); + let response = chat.invoke(&messages, &options).unwrap(); let result = response.get_chat_result().unwrap(); println!("{}", result); } @@ -37,7 +37,7 @@ fn test_stream() { ChatOpt::TopP(0.5), ChatOpt::TopK(50), ]; - let response = chat.stream(messages, options).unwrap(); + let response = chat.stream(&messages, &options).unwrap(); let result_by_chunk = response.get_results().unwrap(); println!("{:?}", result_by_chunk); let whole_result = response.get_whole_result().unwrap(); @@ -55,7 +55,7 @@ fn test_ainvoke() { ]; let options = Vec::new(); let rt = Runtime::new().unwrap(); - let response = rt.block_on(chat.ainvoke(messages, options)).unwrap(); + let response = rt.block_on(chat.ainvoke(&messages, &options)).unwrap(); let result = response.get_chat_result().unwrap(); println!("{}", result); } @@ -72,7 +72,7 @@ fn test_astream() { let options = Vec::new(); let rt = Runtime::new().unwrap(); rt.block_on(async move { - let mut stream_response = chat.astream(messages, options).await.unwrap(); + let mut stream_response = chat.astream(&messages, &options).await.unwrap(); while let Some(response) = stream_response.next().await { let result = response.get_chat_result().unwrap(); print!("{}", result); @@ -93,7 +93,7 @@ fn test_custom_endpoint() { }, ]; let options = Vec::new(); - let response = chat.invoke(messages, options).unwrap(); + let response = chat.invoke(&messages, &options).unwrap(); let result = response.get_chat_result().unwrap(); println!("{}", result); } diff --git a/examples/chat_with_function.rs b/examples/chat_with_function.rs index 7fe3b4c..658fdd0 100644 --- a/examples/chat_with_function.rs +++ b/examples/chat_with_function.rs @@ -46,7 +46,7 @@ fn main() { let functions = vec![weather_function]; let options = vec![ChatOpt::Functions(functions)]; let chat = ChatEndpoint::new(ChatModel::ErnieBot).unwrap(); - let response = chat.invoke(messages, options).unwrap(); + let response = chat.invoke(&messages, &options).unwrap(); let text_response = response.get_chat_result().unwrap(); let function_call = response.get_function_call().unwrap(); println!("text_response: {}", text_response); diff --git a/examples/embedding.rs b/examples/embedding.rs index ca03097..f10eaa4 100644 --- a/examples/embedding.rs +++ b/examples/embedding.rs @@ -7,7 +7,7 @@ fn test_embedding() { "你叫什么名字".to_string(), "你是谁".to_string(), ]; - let embedding_response = embedding.invoke(input, None).unwrap(); + let embedding_response = embedding.invoke(&input, None).unwrap(); let embedding_results = embedding_response.get_embedding_results().unwrap(); println!("{},{}", embedding_results.len(), embedding_results[0].len()); } @@ -19,7 +19,7 @@ fn test_async_embedding() { "你是谁".to_string(), ]; let rt = Runtime::new().unwrap(); - let embedding_response = rt.block_on(embedding.ainvoke(input, None)).unwrap(); + let embedding_response = rt.block_on(embedding.ainvoke(&input, None)).unwrap(); let embedding_results = embedding_response.get_embedding_results().unwrap(); println!("{},{}", embedding_results.len(), embedding_results[0].len()); } diff --git a/examples/rerank.rs b/examples/rerank.rs index ecdf161..e6dc8ae 100644 --- a/examples/rerank.rs +++ b/examples/rerank.rs @@ -9,9 +9,12 @@ fn test_reranker() { "你叫什么名字".to_string(), "你是谁".to_string(), ]; - let reranker_response = reranker.invoke(query, documents, None, None).unwrap(); + let reranker_response = reranker.invoke(&query, &documents, None, None).unwrap(); let reranker_results = reranker_response.get_reranker_response().unwrap(); - let reranked_documents = reranker_results.into_iter().map(|x|x.document).collect::>(); + let reranked_documents = reranker_results + .into_iter() + .map(|x| x.document) + .collect::>(); println!("{},{:?}", reranked_documents.len(), reranked_documents); } @@ -25,7 +28,7 @@ fn test_async_reranker() { ]; let rt = Runtime::new().unwrap(); let reranker_response = rt - .block_on(reranker.ainvoke(query, documents, None, None)) + .block_on(reranker.ainvoke(&query, &documents, None, None)) .unwrap(); let reranker_results = reranker_response.get_reranker_response().unwrap(); println!("{},{:?}", reranker_results.len(), reranker_results); diff --git a/examples/text2image.rs b/examples/text2image.rs index 83657e5..0047e93 100644 --- a/examples/text2image.rs +++ b/examples/text2image.rs @@ -7,7 +7,7 @@ fn main() { Text2ImageOpt::Style(Style::DigitalArt), Text2ImageOpt::Size(Size::S1024x768), ]; - let text2image_response = text2image.invoke(prompt, options).unwrap(); + let text2image_response = text2image.invoke(&prompt, &options).unwrap(); let image_results = text2image_response.get_image_results().unwrap(); for (index, image_string) in image_results.into_iter().enumerate() { let image = base64_to_image(image_string).unwrap(); diff --git a/src/chat/endpoint.rs b/src/chat/endpoint.rs index 06e2bbb..dc1e5ca 100644 --- a/src/chat/endpoint.rs +++ b/src/chat/endpoint.rs @@ -38,8 +38,8 @@ impl ChatEndpoint { } fn generate_body( - messages: Vec, - options: Vec, + messages: &Vec, + options: &Vec, stream: bool, ) -> Result { let mut body = serde_json::json!({ @@ -64,8 +64,8 @@ impl ChatEndpoint { /// invoke method is used to send a request to erniebot chat endpoint. This is a blocking method that will return a full response from the chat endpoint pub fn invoke( &self, - messages: Vec, - options: Vec, + messages: &Vec, + options: &Vec, ) -> Result { let body = ChatEndpoint::generate_body(messages, options, false)?; let client = reqwest::blocking::Client::new(); @@ -88,8 +88,8 @@ impl ChatEndpoint { /// stream method is used to send a request to erniebot chat endpoint. This is a blocking method that will return response in multiple chunks from the chat endpoint pub fn stream( &self, - messages: Vec, - options: Vec, + messages: &Vec, + options: &Vec, ) -> Result { let body = ChatEndpoint::generate_body(messages, options, true)?; let client = reqwest::blocking::Client::new(); @@ -102,15 +102,15 @@ impl ChatEndpoint { .map_err(|e| ErnieError::StreamError(e.to_string()))? .text() .map_err(|e| ErnieError::StreamError(e.to_string()))?; - let response = Responses::from_text(response)?; + let response = Responses::from_text(&response)?; Ok(response) } /// ainvoke method is used to send a request to erniebot chat endpoint. This is an async method that will return a full response from the chat endpoint pub async fn ainvoke( &self, - messages: Vec, - options: Vec, + messages: &Vec, + options: &Vec, ) -> Result { let body = ChatEndpoint::generate_body(messages, options, false)?; let client = reqwest::Client::new(); @@ -136,8 +136,8 @@ impl ChatEndpoint { /// astream method is used to send a request to erniebot chat endpoint. This is an async method that will return response in multiple chunks from the chat endpoint pub async fn astream( &self, - messages: Vec, - options: Vec, + messages: &Vec, + options: &Vec, ) -> Result { let body = ChatEndpoint::generate_body(messages, options, true)?; let client = reqwest::Client::new(); @@ -191,7 +191,7 @@ mod tests { ChatOpt::TopP(0.5), ChatOpt::TopK(50), ]; - let result = ChatEndpoint::generate_body(messages, options, true).unwrap(); + let result = ChatEndpoint::generate_body(&messages, &options, true).unwrap(); let s = serde_json::to_string(&result).unwrap(); println!("{}", s); } diff --git a/src/chat/response.rs b/src/chat/response.rs index 94eeca1..c11434b 100644 --- a/src/chat/response.rs +++ b/src/chat/response.rs @@ -82,7 +82,7 @@ pub struct Responses { impl Responses { /// get Responses from reqwest::blocking::client. In the response body, it contains multiple responses split by blank line. This method will parse the response body and return a Responses struct. - pub fn from_text(text: String) -> Result { + pub fn from_text(text: &str) -> Result { let parts = text.split("\n\n").collect::>(); let mut result = Vec::new(); for part in parts { diff --git a/src/embedding/endpoint.rs b/src/embedding/endpoint.rs index 6dd34ed..fa79d3e 100644 --- a/src/embedding/endpoint.rs +++ b/src/embedding/endpoint.rs @@ -27,8 +27,8 @@ impl EmbeddingEndpoint { /// sync invoke pub fn invoke( &self, - input: Vec, - user_id: Option, + input: &Vec, + user_id: Option<&str>, ) -> Result { let mut body = serde_json::json!({ "input": input, @@ -56,8 +56,8 @@ impl EmbeddingEndpoint { ///async invoke pub async fn ainvoke( &self, - input: Vec, - user_id: Option, + input: &Vec, + user_id: Option<&str>, ) -> Result { let mut body = serde_json::json!({ "input": input, diff --git a/src/reranker/endpoint.rs b/src/reranker/endpoint.rs index 0d7a853..8467bfe 100644 --- a/src/reranker/endpoint.rs +++ b/src/reranker/endpoint.rs @@ -27,10 +27,10 @@ impl RerankerEndpoint { /// sync invoke pub fn invoke( &self, - query: String, - documents: Vec, + query: &str, + documents: &Vec, top_n: Option, - user_id: Option, + user_id: Option<&str>, ) -> Result { let mut body = serde_json::json!({ "query": query, @@ -62,10 +62,10 @@ impl RerankerEndpoint { ///async invoke pub async fn ainvoke( &self, - query: String, - documents: Vec, + query: &str, + documents: &Vec, top_n: Option, - user_id: Option, + user_id: Option<&str>, ) -> Result { let mut body = serde_json::json!({ "query": query, diff --git a/src/text2image/endpoint.rs b/src/text2image/endpoint.rs index 08a4f25..3a394c6 100644 --- a/src/text2image/endpoint.rs +++ b/src/text2image/endpoint.rs @@ -33,7 +33,7 @@ impl Text2ImageEndpoint { }) } - fn generate_body(prompt: String, options: Vec) -> serde_json::Value { + fn generate_body(prompt: &str, options: &Vec) -> serde_json::Value { let mut body = serde_json::json!({ "prompt": prompt, }); @@ -46,8 +46,8 @@ impl Text2ImageEndpoint { /// sync invoke pub fn invoke( &self, - prompt: String, - options: Vec, + prompt: &str, + options: &Vec, ) -> Result { let body = Text2ImageEndpoint::generate_body(prompt, options); let client = reqwest::blocking::Client::new(); @@ -71,8 +71,8 @@ impl Text2ImageEndpoint { ///async invoke pub async fn ainvoke( &self, - prompt: String, - options: Vec, + prompt: &str, + options: &Vec, ) -> Result { let body = Text2ImageEndpoint::generate_body(prompt, options); let client = reqwest::Client::new();