Skip to content

Commit

Permalink
update a lot of comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwanqq committed Feb 29, 2024
1 parent a72bbb6 commit af550a8
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 11 deletions.
4 changes: 2 additions & 2 deletions examples/chat_with_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ fn main() {
}];
let examples = vec![vec![
Example {
role: "user".to_string(),
role: Role::User,
content: Some("What's the weather in Shanghai?".to_string()),
..Default::default()
},
Example {
role: "assistant".to_string(),
role: Role::Assistant,
content: None,
name: None,
function_call: Some(erniebot_rs::chat::FunctionCall {
Expand Down
25 changes: 24 additions & 1 deletion src/chat/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,63 @@ use schemars::schema::RootSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::Role;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
/// Definition of the function structure that some models(like Erniebot) can select and call.
pub struct Function {
/// The name of the function.
pub name: String,
/// The description of the function.
pub description: String,
/// The format of parameters of the function, following the JSON schema format.
pub parameters: RootSchema,
/// The format of the response of the function, following the JSON schema format.
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<RootSchema>,
/// The examples of the function. each instance of the outer vector represents a round of conversation, and each instance of the inner vector represents a message in the round of conversation. More details can be found in the example of chat_with_function.rs.
#[serde(skip_serializing_if = "Option::is_none")]
pub examples: Option<Vec<Vec<Example>>>,
}

/// Example of a message involved in a function calling process
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct Example {
pub role: String,
/// Same as the role in Message, can be "user", "assistant", or "function".
pub role: Role,
/**Dialog content instructions:
(1) When the current message contains a function_call and the role is "assistant", the message can be empty. However, in other scenarios, it cannot be empty.
(2) The content corresponding to the last message cannot be a blank character, including spaces, "\n", "\r", "\f", etc.
*/
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
/// The "author" of the message. the This member is required when the role value is "function", and in this case is should be the name in the function_call in the response content
pub name: Option<String>,
/// this is function calling result of last round of function call, serving as chat history.
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}

/// This is function calling result of last round of function call
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct FunctionCall {
/// name of a function
pub name: String,
/// arguments of a function call that LLM model outputs, following the JSON format.
pub arguments: String,
/// The thinking process of the model
#[serde(skip_serializing_if = "Option::is_none")]
pub thoughts: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
/// In the context of function calls, prompt the large model to select a specific function (not mandatory). Note: The specified function name must exist within the list of functions.
pub struct ToolChoice {
pub r#type: String, //only one valid value: "function"
pub function: Value,
Expand Down
2 changes: 1 addition & 1 deletion src/chat/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub struct Message {
pub role: Role,
/// The content of the message. The value is a string.
pub content: String,
/// The name of the function. The value is a string. This member is required when the role value is "function", and in this case is should be the name in the function_call in the response content
/// The "author" of the message. the This member is required when the role value is "function", and in this case is should be the name in the function_call in the response content
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
/// this is function calling result of last round of function call, serving as chat history.
Expand Down
1 change: 1 addition & 0 deletions src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ pub use function::{Example, Function, FunctionCall, ToolChoice};
pub use message::{Message, Role};
pub use model::ChatModel;
pub use option::{ChatOpt, ResponseFormat};
pub use response::{Response, Responses, StreamResponse};
16 changes: 14 additions & 2 deletions src/chat/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use serde_json::value;
use tokio::sync::mpsc::{self, UnboundedReceiver};
use tokio_stream::Stream;

/// Response is using for non-stream response
/// Response is a struct that represents the response of erniebot API.
///
/// It is a wrapper of serde_json::Value. in non-stream case, the API will return a single response, and in stream case, the API will return multiple responses.(see in `Responses` struct)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Response {
raw_response: value::Value,
Expand All @@ -28,6 +30,7 @@ impl Response {
self.raw_response.get_mut(key)
}

/// get the result of chat response
pub fn get_chat_result(&self) -> Result<String, ErnieError> {
match self.raw_response.get("result") {
Some(result) => match result.as_str() {
Expand All @@ -42,38 +45,43 @@ impl Response {
}
}

/// get tokens used by prompt
pub fn get_prompt_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let prompt_tokens = usage.get("prompt_tokens")?.as_u64()?;
Some(prompt_tokens)
}

/// get tokens used by completion
pub fn get_completion_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let completion_tokens = usage.get("completion_tokens")?.as_u64()?;
Some(completion_tokens)
}

/// get total tokens used
pub fn get_total_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let total_tokens = usage.get("total_tokens")?.as_u64()?;
Some(total_tokens)
}

/// get function call
pub fn get_function_call(&self) -> Option<FunctionCall> {
let value = self.get("function_call")?;
let function_call = serde_json::from_value(value.clone()).ok()?;
Some(function_call)
}
}

/// Responses is using for sync stream response
/// Responses is using for sync stream response.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Responses {
responses: Vec<Response>,
}

impl Responses {
/// get Responses from reqwest::blocking::client. In the response body, it contains multiple responses split by blank line. This method will parse the response body and return a Responses struct.
pub fn from_text(text: String) -> Result<Self, ErnieError> {
let parts = text.split("\n\n").collect::<Vec<&str>>();
let mut result = Vec::new();
Expand All @@ -99,6 +107,8 @@ impl Responses {
}
Ok(Responses { responses: result })
}

/// get chat result as a vector of string
pub fn get_results(&self) -> Result<Vec<String>, ErnieError> {
let mut result = Vec::new();
for response in &self.responses {
Expand All @@ -107,6 +117,7 @@ impl Responses {
Ok(result)
}

/// get whole chat result as a single string
pub fn get_whole_result(&self) -> Result<String, ErnieError> {
let mut result = String::new();
for response in &self.responses {
Expand All @@ -116,6 +127,7 @@ impl Responses {
}
}

/// StreamResponse is a struct that represents the response of erniebot API in async stream case.
pub struct StreamResponse {
receiver: UnboundedReceiver<Response>,
}
Expand Down
3 changes: 3 additions & 0 deletions src/embedding/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ use url::Url;
static EMBEDDING_BASE_URL: &str =
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/";

/** ChatEndpoint is a struct that represents the chat endpoint of erniebot API
*/
pub struct EmbeddingEndpoint {
url: Url,
access_token: String,
}

impl EmbeddingEndpoint {
// create a new embedding instance using pre-defined model
pub fn new(model: EmbeddingModel) -> Result<Self, ErnieError> {
Ok(EmbeddingEndpoint {
url: build_url(EMBEDDING_BASE_URL, model.to_string().as_str())?,
Expand Down
5 changes: 3 additions & 2 deletions src/embedding/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ impl EmbeddingResponse {
self.raw_response.get_mut(key)
}

/// get the result of embedding response
pub fn get_embedding_results(&self) -> Result<Vec<Vec<f64>>, ErnieError> {
match self.raw_response.get("data") {
Some(data) => {
Expand Down Expand Up @@ -54,13 +55,13 @@ impl EmbeddingResponse {
)),
}
}

/// get tokens used by prompt
pub fn get_prompt_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let prompt_tokens = usage.get("prompt_tokens")?.as_u64()?;
Some(prompt_tokens)
}

/// get tokens used by completion
pub fn get_total_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let total_tokens = usage.get("total_tokens")?.as_u64()?;
Expand Down
5 changes: 4 additions & 1 deletion src/text2image/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@ use crate::utils::{build_url, get_access_token};

static TEXT2IMAGE_BASE_URL: &str =
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/text2image/";

/// Text2ImageEndpoint is a struct that represents the text2image endpoint of erniebot API
pub struct Text2ImageEndpoint {
url: Url,
access_token: String,
}

impl Text2ImageEndpoint {
/// create a new text2image instance using pre-defined model
pub fn new(model: Text2ImageModel) -> Result<Self, ErnieError> {
Ok(Text2ImageEndpoint {
url: build_url(TEXT2IMAGE_BASE_URL, model.to_string().as_str())?,
access_token: get_access_token()?,
})
}

/// create a new text2image instance using custom endpoint
pub fn new_with_custom_endpoint(endpoint: &str) -> Result<Self, ErnieError> {
Ok(Text2ImageEndpoint {
url: build_url(TEXT2IMAGE_BASE_URL, endpoint)?,
Expand All @@ -41,6 +43,7 @@ impl Text2ImageEndpoint {
body
}

/// sync invoke
pub fn invoke(
&self,
prompt: String,
Expand Down
6 changes: 4 additions & 2 deletions src/text2image/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::errors::ErnieError;
use serde::{Deserialize, Serialize};
use serde_json::value;

/// Response is using for non-stream response
/// Response of text2image endpoint
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Text2ImageResponse {
raw_response: value::Value,
Expand All @@ -25,7 +25,7 @@ impl Text2ImageResponse {
self.raw_response.get_mut(key)
}

//return list of image b64 strings
/// return list of image b64 strings
pub fn get_image_results(&self) -> Result<Vec<String>, ErnieError> {
match self.raw_response.get("data") {
Some(data) => {
Expand Down Expand Up @@ -56,12 +56,14 @@ impl Text2ImageResponse {
}
}

/// get tokens used by prompt
pub fn get_prompt_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let prompt_tokens = usage.get("prompt_tokens")?.as_u64()?;
Some(prompt_tokens)
}

/// get total tokens used
pub fn get_total_tokens(&self) -> Option<u64> {
let usage = self.get("usage")?.as_object()?;
let total_tokens = usage.get("total_tokens")?.as_u64()?;
Expand Down

0 comments on commit af550a8

Please sign in to comment.