diff --git a/Cargo.toml b/Cargo.toml index 84bdc52..1e29a63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,4 @@ tokio = { version = "1.36.0", features = ["full"] } tokio-stream = "0.1.14" base64 = "0.21.7" image = "0.24.9" +schemars = "0.8" diff --git a/examples/chat.rs b/examples/chat.rs index 8838204..6573f2a 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -10,7 +10,7 @@ fn test_invoke() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(), - name: None, + ..Default::default() }, ]; let options = vec![ @@ -29,7 +29,7 @@ fn test_stream() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using stream method".to_string(), - name: None, + ..Default::default() }, ]; let options = vec![ @@ -50,7 +50,7 @@ fn test_ainvoke() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using ainvoke method".to_string(), - name: None, + ..Default::default() }, ]; let options = Vec::new(); @@ -66,7 +66,7 @@ fn test_astream() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(), - name: None, + ..Default::default() }, ]; let options = Vec::new(); @@ -89,7 +89,7 @@ fn test_custom_endpoint() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(), - name: None, + ..Default::default() }, ]; let options = Vec::new(); diff --git a/examples/chat_with_function.rs b/examples/chat_with_function.rs new file mode 100644 index 0000000..e15b26c --- /dev/null +++ b/examples/chat_with_function.rs @@ -0,0 +1,57 @@ +use erniebot_rs::chat::{ChatEndpoint, ChatModel, ChatOpt, Example, Function, Message, Role}; +use schemars::{schema_for, JsonSchema}; +use serde::{Deserialize, Serialize}; + +#[derive(JsonSchema, Debug, Clone, Serialize, Deserialize, PartialEq)] +struct WeatherParameters { + place: String, +} + +fn get_weather_mock(params: WeatherParameters) -> String { + format!("The weather in {} is sunny", params.place) +} + +fn main() { + let messages = vec![Message { + role: Role::User, + content: "What's the weather in Beijing?".to_string(), + ..Default::default() + }]; + let examples = vec![vec![ + Example { + role: "user".to_string(), + content: Some("What's the weather in Shanghai?".to_string()), + ..Default::default() + }, + Example { + role: "assistant".to_string(), + content: None, + name: None, + function_call: Some(erniebot_rs::chat::FunctionCall { + name: "weather".to_string(), + arguments: r#"{"place":"Shanghai"}"#.to_string(), + thoughts: Some( + "I'm calling the weather function to get weather of Shanghai".to_string(), + ), + }), + }, + ]]; + let weather_function = Function { + name: "weather".to_string(), + description: "Get the weather of a place".to_string(), + parameters: schema_for!(WeatherParameters), + examples: Some(examples), + ..Default::default() + }; + let functions = vec![weather_function]; + let options = vec![ChatOpt::Functions(functions)]; + let chat = ChatEndpoint::new(ChatModel::ErnieBot).unwrap(); + let response = chat.invoke(messages, options).unwrap(); + let text_response = response.get_chat_result().unwrap(); + let function_call = response.get_function_call().unwrap(); + println!("text_response: {}", text_response); + println!("function_call: {:?}", function_call); + let weather_params: WeatherParameters = serde_json::from_str(&function_call.arguments).unwrap(); + let weather_result = get_weather_mock(weather_params); + println!("weather_result: {}", weather_result); +} diff --git a/readme.md b/readme.md index 80d0bb0..b99ae7a 100644 --- a/readme.md +++ b/readme.md @@ -35,7 +35,7 @@ fn test_invoke() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(), - name: None, + ..Default::default() }, ]; let options = vec![ @@ -58,7 +58,7 @@ fn test_custom_endpoint() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(), - name: None, + ..Default::default() }, ]; let options = Vec::new(); @@ -79,7 +79,7 @@ fn test_astream() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(), - name: None, + ..Default::default() }, ]; let options = Vec::new(); @@ -97,6 +97,8 @@ fn test_astream() { } ``` +For some models, such as ErnieBot, they support the option of passing in functions for invocation. You can refer to examples/chat_with_function.rs for an example. + Please note that due to varying parameter requirements for each specific model, this SDK does not perform local parameter validation but instead passes the parameters to the server for validation. Therefore, if the parameters do not meet the requirements, the server will return an error message. ## Embedding diff --git a/readme_zh.md b/readme_zh.md index c85f379..0616482 100644 --- a/readme_zh.md +++ b/readme_zh.md @@ -35,7 +35,7 @@ fn test_invoke() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(), - name: None, + ..Default::default() }, ]; let options = vec![ @@ -58,7 +58,7 @@ fn test_invoke() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(), - name: None, + ..Default::default() }, ]; let options = Vec::new(); @@ -79,7 +79,7 @@ fn test_astream() { Message { role: Role::User, content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(), - name: None, + ..Default::default() }, ]; let options = Vec::new(); @@ -97,6 +97,8 @@ fn test_astream() { } ``` +对于一些模型,如ErnieBot,支持传入functions进行调用的选择,可以参考examples/chat_with_function.rs + 注意,由于各个具体模型对参数的要求不同,所以本SDK并未在本地进行参数校验,而是将参数传递给服务端进行校验。因此,如果参数不符合要求,服务端会返回错误信息。 ## embedding diff --git a/src/chat/endpoint.rs b/src/chat/endpoint.rs index da51f2b..06e2bbb 100644 --- a/src/chat/endpoint.rs +++ b/src/chat/endpoint.rs @@ -184,7 +184,7 @@ mod tests { let messages = vec![Message { role: Role::User, content: "hello, I'm a user".to_string(), - name: None, + ..Default::default() }]; let options = vec![ ChatOpt::Temperature(0.5), diff --git a/src/chat/function.rs b/src/chat/function.rs new file mode 100644 index 0000000..cfb2730 --- /dev/null +++ b/src/chat/function.rs @@ -0,0 +1,96 @@ +use schemars::schema::RootSchema; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] +pub struct Function { + pub name: String, + pub description: String, + pub parameters: RootSchema, + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub examples: Option>>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] +pub struct Example { + pub role: String, + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] +pub struct FunctionCall { + pub name: String, + pub arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub thoughts: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] +pub struct ToolChoice { + pub r#type: String, //only one valid value: "function" + pub function: Value, +} + +impl ToolChoice { + pub fn new(function: Function) -> Self { + Self { + r#type: "function".to_string(), + function: serde_json::json!( + { + "name": function.name, + } + ), + } + } + pub fn from_function_name(name: String) -> Self { + Self { + r#type: "function".to_string(), + function: serde_json::json!( + { + "name": name, + } + ), + } + } +} + +#[cfg(test)] +mod tests { + use schemars::{schema::RootSchema, schema_for, JsonSchema}; + use serde::{Deserialize, Serialize}; + #[test] + fn test_schema() { + #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] + #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] + struct TestStruct { + pub date: String, + pub place: String, + } + let schema = schema_for!(TestStruct); + println!("{:?}", serde_json::to_string(&schema).unwrap()); + let default_schema = RootSchema::default(); + println!("{:?}", serde_json::to_string(&default_schema).unwrap()); + } + + #[test] + fn test_tool_choice() { + use super::Function; + let function = Function { + name: "test".to_string(), + description: "test".to_string(), + ..Default::default() + }; + let tool_choice = super::ToolChoice::new(function); + println!("{:?}", serde_json::to_string(&tool_choice).unwrap()); + } +} diff --git a/src/chat/message.rs b/src/chat/message.rs index 057b8e8..8826fc2 100644 --- a/src/chat/message.rs +++ b/src/chat/message.rs @@ -1,8 +1,10 @@ +use super::FunctionCall; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] pub enum Role { + #[default] User, Assistant, Function, @@ -15,7 +17,7 @@ The "messages" member must not be empty. One member represents a single round of let message1 = Message { role: Role::User, content: "hello, I'm a user".to_string(), - name: None, + ..Default::default() }; let messages = vec![message1]; ``` @@ -24,17 +26,17 @@ The "messages" member must not be empty. One member represents a single round of let message1 = Message { role: Role::User, content: "hello, I'm a user".to_string(), - name: None, + ..Default::default() }; let message2 = Message { role: Role::Assistant, content: "hello, I'm a AI LLM model".to_string(), - name: None, + ..Default::default() }; let message3 = Message { role: Role::User, content: "hello, I want you to help me".to_string(), - name: None, + ..Default::default() }; let messages = vec![message1, message2, message3]; ``` @@ -44,12 +46,14 @@ The number of members must be odd. The role values of the messages in the member In the example, the role values of the messages are "user", "assistant", "user", "assistant", and "user" respectively. The role values of the messages at odd positions are "user", which means the role values of the 1st, 3rd, and 5th messages are "user". The role values at even positions are "assistant", which means the role values of the 2nd and 4th messages are "assistant". */ -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] pub struct Message { pub role: Role, pub content: String, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, } #[cfg(test)] @@ -61,17 +65,17 @@ mod tests { let message1 = Message { role: Role::User, content: "hello, I'm a user".to_string(), - name: None, + ..Default::default() }; let message2 = Message { role: Role::Assistant, content: "hello, I'm a AI LLM model".to_string(), - name: None, + ..Default::default() }; let message3 = Message { role: Role::User, content: "hello, I want you to help me".to_string(), - name: None, + ..Default::default() }; let messages = vec![message1, message2, message3]; let messages_str = to_string(&messages).unwrap(); diff --git a/src/chat/mod.rs b/src/chat/mod.rs index db4187d..12cf08b 100644 --- a/src/chat/mod.rs +++ b/src/chat/mod.rs @@ -1,10 +1,12 @@ mod endpoint; +mod function; mod message; mod model; mod option; mod response; pub use endpoint::ChatEndpoint; +pub use function::{Example, Function, FunctionCall, ToolChoice}; pub use message::{Message, Role}; pub use model::ChatModel; pub use option::{ChatOpt, ResponseFormat}; diff --git a/src/chat/option.rs b/src/chat/option.rs index f7f6eaa..728b249 100644 --- a/src/chat/option.rs +++ b/src/chat/option.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use super::{Function, ToolChoice}; + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] pub enum ResponseFormat { @@ -21,4 +23,6 @@ pub enum ChatOpt { MaxOutputTokens(u32), ResponseFormat(ResponseFormat), UserId(String), + Functions(Vec), + ToolChoice(ToolChoice), } diff --git a/src/chat/response.rs b/src/chat/response.rs index 31e00da..be3cdf3 100644 --- a/src/chat/response.rs +++ b/src/chat/response.rs @@ -1,3 +1,4 @@ +use super::FunctionCall; use crate::errors::ErnieError; use serde::{Deserialize, Serialize}; use serde_json::value; @@ -58,6 +59,12 @@ impl Response { let total_tokens = usage.get("total_tokens")?.as_u64()?; Some(total_tokens) } + + pub fn get_function_call(&self) -> Option { + let value = self.get("function_call")?; + let function_call = serde_json::from_value(value.clone()).ok()?; + Some(function_call) + } } /// Responses is using for sync stream response