diff --git a/examples/chat_with_function.rs b/examples/chat_with_function.rs index e15b26c..7fe3b4c 100644 --- a/examples/chat_with_function.rs +++ b/examples/chat_with_function.rs @@ -19,12 +19,12 @@ fn main() { }]; let examples = vec![vec![ Example { - role: "user".to_string(), + role: Role::User, content: Some("What's the weather in Shanghai?".to_string()), ..Default::default() }, Example { - role: "assistant".to_string(), + role: Role::Assistant, content: None, name: None, function_call: Some(erniebot_rs::chat::FunctionCall { diff --git a/src/chat/function.rs b/src/chat/function.rs index cfb2730..e5dbc63 100644 --- a/src/chat/function.rs +++ b/src/chat/function.rs @@ -2,40 +2,62 @@ use schemars::schema::RootSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; +use super::Role; + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] +/// Definition of the function structure that some models(like Erniebot) can select and call. pub struct Function { + /// The name of the function. pub name: String, + /// The description of the function. pub description: String, + /// The format of parameters of the function, following the JSON schema format. pub parameters: RootSchema, + /// The format of the response of the function, following the JSON schema format. #[serde(skip_serializing_if = "Option::is_none")] pub response: Option, + /// The examples of the function. each instance of the outer vector represents a round of conversation, and each instance of the inner vector represents a message in the round of conversation. More details can be found in the example of chat_with_function.rs. #[serde(skip_serializing_if = "Option::is_none")] pub examples: Option>>, } +/// Example of a message involved in a function calling process #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] pub struct Example { - pub role: String, + /// Same as the role in Message, can be "user", "assistant", or "function". + pub role: Role, + /// Dialog content instructions: + + /// (1) If the current message contains a function_call and the role is "assistant", the message can be empty. However, in other scenarios, it cannot be empty. + + /// (2) The content corresponding to the last message cannot be a blank character, including spaces, "\n", "\r", r"\f", etc. pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] + /// The "author" of the message. the This member is required when the role value is "function", and in this case is should be the name in the function_call in the response content pub name: Option, + /// this is function calling result of last round of function call, serving as chat history. #[serde(skip_serializing_if = "Option::is_none")] pub function_call: Option, } +/// This is function calling result of last round of function call #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] pub struct FunctionCall { + /// name of a function pub name: String, + /// arguments of a function call that LLM model outputs, following the JSON format. pub arguments: String, + /// The thinking process of the model #[serde(skip_serializing_if = "Option::is_none")] pub thoughts: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] +/// In the context of function calls, prompt the large model to select a specific function (not mandatory). Note: The specified function name must exist within the list of functions. pub struct ToolChoice { pub r#type: String, //only one valid value: "function" pub function: Value, diff --git a/src/chat/message.rs b/src/chat/message.rs index dc95243..a1b7179 100644 --- a/src/chat/message.rs +++ b/src/chat/message.rs @@ -54,7 +54,7 @@ pub struct Message { pub role: Role, /// The content of the message. The value is a string. pub content: String, - /// The name of the function. The value is a string. This member is required when the role value is "function", and in this case is should be the name in the function_call in the response content + /// The "author" of the message. the This member is required when the role value is "function", and in this case is should be the name in the function_call in the response content #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, /// this is function calling result of last round of function call, serving as chat history. diff --git a/src/chat/mod.rs b/src/chat/mod.rs index 12cf08b..4fb55bb 100644 --- a/src/chat/mod.rs +++ b/src/chat/mod.rs @@ -10,3 +10,4 @@ pub use function::{Example, Function, FunctionCall, ToolChoice}; pub use message::{Message, Role}; pub use model::ChatModel; pub use option::{ChatOpt, ResponseFormat}; +pub use response::{Response, Responses, StreamResponse}; diff --git a/src/chat/response.rs b/src/chat/response.rs index be3cdf3..94eeca1 100644 --- a/src/chat/response.rs +++ b/src/chat/response.rs @@ -5,7 +5,9 @@ use serde_json::value; use tokio::sync::mpsc::{self, UnboundedReceiver}; use tokio_stream::Stream; -/// Response is using for non-stream response +/// Response is a struct that represents the response of erniebot API. +/// +/// It is a wrapper of serde_json::Value. in non-stream case, the API will return a single response, and in stream case, the API will return multiple responses.(see in `Responses` struct) #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Response { raw_response: value::Value, @@ -28,6 +30,7 @@ impl Response { self.raw_response.get_mut(key) } + /// get the result of chat response pub fn get_chat_result(&self) -> Result { match self.raw_response.get("result") { Some(result) => match result.as_str() { @@ -42,24 +45,28 @@ impl Response { } } + /// get tokens used by prompt pub fn get_prompt_tokens(&self) -> Option { let usage = self.get("usage")?.as_object()?; let prompt_tokens = usage.get("prompt_tokens")?.as_u64()?; Some(prompt_tokens) } + /// get tokens used by completion pub fn get_completion_tokens(&self) -> Option { let usage = self.get("usage")?.as_object()?; let completion_tokens = usage.get("completion_tokens")?.as_u64()?; Some(completion_tokens) } + /// get total tokens used pub fn get_total_tokens(&self) -> Option { let usage = self.get("usage")?.as_object()?; let total_tokens = usage.get("total_tokens")?.as_u64()?; Some(total_tokens) } + /// get function call pub fn get_function_call(&self) -> Option { let value = self.get("function_call")?; let function_call = serde_json::from_value(value.clone()).ok()?; @@ -67,13 +74,14 @@ impl Response { } } -/// Responses is using for sync stream response +/// Responses is using for sync stream response. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Responses { responses: Vec, } impl Responses { + /// get Responses from reqwest::blocking::client. In the response body, it contains multiple responses split by blank line. This method will parse the response body and return a Responses struct. pub fn from_text(text: String) -> Result { let parts = text.split("\n\n").collect::>(); let mut result = Vec::new(); @@ -99,6 +107,8 @@ impl Responses { } Ok(Responses { responses: result }) } + + /// get chat result as a vector of string pub fn get_results(&self) -> Result, ErnieError> { let mut result = Vec::new(); for response in &self.responses { @@ -107,6 +117,7 @@ impl Responses { Ok(result) } + /// get whole chat result as a single string pub fn get_whole_result(&self) -> Result { let mut result = String::new(); for response in &self.responses { @@ -116,6 +127,7 @@ impl Responses { } } +/// StreamResponse is a struct that represents the response of erniebot API in async stream case. pub struct StreamResponse { receiver: UnboundedReceiver, } diff --git a/src/embedding/endpoint.rs b/src/embedding/endpoint.rs index 867a403..6dd34ed 100644 --- a/src/embedding/endpoint.rs +++ b/src/embedding/endpoint.rs @@ -9,12 +9,15 @@ use url::Url; static EMBEDDING_BASE_URL: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/"; +/** ChatEndpoint is a struct that represents the chat endpoint of erniebot API +*/ pub struct EmbeddingEndpoint { url: Url, access_token: String, } impl EmbeddingEndpoint { + // create a new embedding instance using pre-defined model pub fn new(model: EmbeddingModel) -> Result { Ok(EmbeddingEndpoint { url: build_url(EMBEDDING_BASE_URL, model.to_string().as_str())?, diff --git a/src/embedding/response.rs b/src/embedding/response.rs index 72c65e7..58df236 100644 --- a/src/embedding/response.rs +++ b/src/embedding/response.rs @@ -25,6 +25,7 @@ impl EmbeddingResponse { self.raw_response.get_mut(key) } + /// get the result of embedding response pub fn get_embedding_results(&self) -> Result>, ErnieError> { match self.raw_response.get("data") { Some(data) => { @@ -54,13 +55,13 @@ impl EmbeddingResponse { )), } } - + /// get tokens used by prompt pub fn get_prompt_tokens(&self) -> Option { let usage = self.get("usage")?.as_object()?; let prompt_tokens = usage.get("prompt_tokens")?.as_u64()?; Some(prompt_tokens) } - + /// get tokens used by completion pub fn get_total_tokens(&self) -> Option { let usage = self.get("usage")?.as_object()?; let total_tokens = usage.get("total_tokens")?.as_u64()?; diff --git a/src/text2image/endpoint.rs b/src/text2image/endpoint.rs index bd1f47e..08a4f25 100644 --- a/src/text2image/endpoint.rs +++ b/src/text2image/endpoint.rs @@ -10,13 +10,14 @@ use crate::utils::{build_url, get_access_token}; static TEXT2IMAGE_BASE_URL: &str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/text2image/"; - +/// Text2ImageEndpoint is a struct that represents the text2image endpoint of erniebot API pub struct Text2ImageEndpoint { url: Url, access_token: String, } impl Text2ImageEndpoint { + /// create a new text2image instance using pre-defined model pub fn new(model: Text2ImageModel) -> Result { Ok(Text2ImageEndpoint { url: build_url(TEXT2IMAGE_BASE_URL, model.to_string().as_str())?, @@ -24,6 +25,7 @@ impl Text2ImageEndpoint { }) } + /// create a new text2image instance using custom endpoint pub fn new_with_custom_endpoint(endpoint: &str) -> Result { Ok(Text2ImageEndpoint { url: build_url(TEXT2IMAGE_BASE_URL, endpoint)?, @@ -41,6 +43,7 @@ impl Text2ImageEndpoint { body } + /// sync invoke pub fn invoke( &self, prompt: String, diff --git a/src/text2image/response.rs b/src/text2image/response.rs index 2f4b9e5..7c7128c 100644 --- a/src/text2image/response.rs +++ b/src/text2image/response.rs @@ -2,7 +2,7 @@ use crate::errors::ErnieError; use serde::{Deserialize, Serialize}; use serde_json::value; -/// Response is using for non-stream response +/// Response of text2image endpoint #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Text2ImageResponse { raw_response: value::Value, @@ -25,7 +25,7 @@ impl Text2ImageResponse { self.raw_response.get_mut(key) } - //return list of image b64 strings + /// return list of image b64 strings pub fn get_image_results(&self) -> Result, ErnieError> { match self.raw_response.get("data") { Some(data) => { @@ -56,12 +56,14 @@ impl Text2ImageResponse { } } + /// get tokens used by prompt pub fn get_prompt_tokens(&self) -> Option { let usage = self.get("usage")?.as_object()?; let prompt_tokens = usage.get("prompt_tokens")?.as_u64()?; Some(prompt_tokens) } + /// get total tokens used pub fn get_total_tokens(&self) -> Option { let usage = self.get("usage")?.as_object()?; let total_tokens = usage.get("total_tokens")?.as_u64()?;