Skip to content

Commit

Permalink
Merge pull request #2 from chenwanqq/tools
Browse files Browse the repository at this point in the history
update functions call
  • Loading branch information
chenwanqq authored Feb 28, 2024
2 parents 723dd2f + 474aa7c commit 4bbcf72
Show file tree
Hide file tree
Showing 11 changed files with 196 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ tokio = { version = "1.36.0", features = ["full"] }
tokio-stream = "0.1.14"
base64 = "0.21.7"
image = "0.24.9"
schemars = "0.8"
10 changes: 5 additions & 5 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn test_invoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(),
name: None,
..Default::default()
},
];
let options = vec![
Expand All @@ -29,7 +29,7 @@ fn test_stream() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using stream method".to_string(),
name: None,
..Default::default()
},
];
let options = vec![
Expand All @@ -50,7 +50,7 @@ fn test_ainvoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using ainvoke method".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -66,7 +66,7 @@ fn test_astream() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -89,7 +89,7 @@ fn test_custom_endpoint() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand Down
57 changes: 57 additions & 0 deletions examples/chat_with_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use erniebot_rs::chat::{ChatEndpoint, ChatModel, ChatOpt, Example, Function, Message, Role};
use schemars::{schema_for, JsonSchema};
use serde::{Deserialize, Serialize};

#[derive(JsonSchema, Debug, Clone, Serialize, Deserialize, PartialEq)]
struct WeatherParameters {
place: String,
}

fn get_weather_mock(params: WeatherParameters) -> String {
format!("The weather in {} is sunny", params.place)
}

fn main() {
let messages = vec![Message {
role: Role::User,
content: "What's the weather in Beijing?".to_string(),
..Default::default()
}];
let examples = vec![vec![
Example {
role: "user".to_string(),
content: Some("What's the weather in Shanghai?".to_string()),
..Default::default()
},
Example {
role: "assistant".to_string(),
content: None,
name: None,
function_call: Some(erniebot_rs::chat::FunctionCall {
name: "weather".to_string(),
arguments: r#"{"place":"Shanghai"}"#.to_string(),
thoughts: Some(
"I'm calling the weather function to get weather of Shanghai".to_string(),
),
}),
},
]];
let weather_function = Function {
name: "weather".to_string(),
description: "Get the weather of a place".to_string(),
parameters: schema_for!(WeatherParameters),
examples: Some(examples),
..Default::default()
};
let functions = vec![weather_function];
let options = vec![ChatOpt::Functions(functions)];
let chat = ChatEndpoint::new(ChatModel::ErnieBot).unwrap();
let response = chat.invoke(messages, options).unwrap();
let text_response = response.get_chat_result().unwrap();
let function_call = response.get_function_call().unwrap();
println!("text_response: {}", text_response);
println!("function_call: {:?}", function_call);
let weather_params: WeatherParameters = serde_json::from_str(&function_call.arguments).unwrap();
let weather_result = get_weather_mock(weather_params);
println!("weather_result: {}", weather_result);
}
8 changes: 5 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn test_invoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(),
name: None,
..Default::default()
},
];
let options = vec![
Expand All @@ -58,7 +58,7 @@ fn test_custom_endpoint() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -79,7 +79,7 @@ fn test_astream() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -97,6 +97,8 @@ fn test_astream() {
}
```

For some models, such as ErnieBot, they support the option of passing in functions for invocation. You can refer to examples/chat_with_function.rs for an example.

Please note that due to varying parameter requirements for each specific model, this SDK does not perform local parameter validation but instead passes the parameters to the server for validation. Therefore, if the parameters do not meet the requirements, the server will return an error message.

## Embedding
Expand Down
8 changes: 5 additions & 3 deletions readme_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn test_invoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(),
name: None,
..Default::default()
},
];
let options = vec![
Expand All @@ -58,7 +58,7 @@ fn test_invoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -79,7 +79,7 @@ fn test_astream() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -97,6 +97,8 @@ fn test_astream() {
}
```

对于一些模型,如ErnieBot,支持传入functions进行调用的选择,可以参考examples/chat_with_function.rs

注意,由于各个具体模型对参数的要求不同,所以本SDK并未在本地进行参数校验,而是将参数传递给服务端进行校验。因此,如果参数不符合要求,服务端会返回错误信息。

## embedding
Expand Down
2 changes: 1 addition & 1 deletion src/chat/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ mod tests {
let messages = vec![Message {
role: Role::User,
content: "hello, I'm a user".to_string(),
name: None,
..Default::default()
}];
let options = vec![
ChatOpt::Temperature(0.5),
Expand Down
96 changes: 96 additions & 0 deletions src/chat/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use schemars::schema::RootSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: RootSchema,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<RootSchema>,
#[serde(skip_serializing_if = "Option::is_none")]
pub examples: Option<Vec<Vec<Example>>>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct Example {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub thoughts: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct ToolChoice {
pub r#type: String, //only one valid value: "function"
pub function: Value,
}

impl ToolChoice {
pub fn new(function: Function) -> Self {
Self {
r#type: "function".to_string(),
function: serde_json::json!(
{
"name": function.name,
}
),
}
}
pub fn from_function_name(name: String) -> Self {
Self {
r#type: "function".to_string(),
function: serde_json::json!(
{
"name": name,
}
),
}
}
}

#[cfg(test)]
mod tests {
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::{Deserialize, Serialize};
#[test]
fn test_schema() {
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
struct TestStruct {
pub date: String,
pub place: String,
}
let schema = schema_for!(TestStruct);
println!("{:?}", serde_json::to_string(&schema).unwrap());
let default_schema = RootSchema::default();
println!("{:?}", serde_json::to_string(&default_schema).unwrap());
}

#[test]
fn test_tool_choice() {
use super::Function;
let function = Function {
name: "test".to_string(),
description: "test".to_string(),
..Default::default()
};
let tool_choice = super::ToolChoice::new(function);
println!("{:?}", serde_json::to_string(&tool_choice).unwrap());
}
}
22 changes: 13 additions & 9 deletions src/chat/message.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::FunctionCall;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub enum Role {
#[default]
User,
Assistant,
Function,
Expand All @@ -15,7 +17,7 @@ The "messages" member must not be empty. One member represents a single round of
let message1 = Message {
role: Role::User,
content: "hello, I'm a user".to_string(),
name: None,
..Default::default()
};
let messages = vec![message1];
```
Expand All @@ -24,17 +26,17 @@ The "messages" member must not be empty. One member represents a single round of
let message1 = Message {
role: Role::User,
content: "hello, I'm a user".to_string(),
name: None,
..Default::default()
};
let message2 = Message {
role: Role::Assistant,
content: "hello, I'm a AI LLM model".to_string(),
name: None,
..Default::default()
};
let message3 = Message {
role: Role::User,
content: "hello, I want you to help me".to_string(),
name: None,
..Default::default()
};
let messages = vec![message1, message2, message3];
```
Expand All @@ -44,12 +46,14 @@ The number of members must be odd. The role values of the messages in the member
In the example, the role values of the messages are "user", "assistant", "user", "assistant", and "user" respectively. The role values of the messages at odd positions are "user", which means the role values of the 1st, 3rd, and 5th messages are "user". The role values at even positions are "assistant", which means the role values of the 2nd and 4th messages are "assistant".
*/
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}

#[cfg(test)]
Expand All @@ -61,17 +65,17 @@ mod tests {
let message1 = Message {
role: Role::User,
content: "hello, I'm a user".to_string(),
name: None,
..Default::default()
};
let message2 = Message {
role: Role::Assistant,
content: "hello, I'm a AI LLM model".to_string(),
name: None,
..Default::default()
};
let message3 = Message {
role: Role::User,
content: "hello, I want you to help me".to_string(),
name: None,
..Default::default()
};
let messages = vec![message1, message2, message3];
let messages_str = to_string(&messages).unwrap();
Expand Down
2 changes: 2 additions & 0 deletions src/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod endpoint;
mod function;
mod message;
mod model;
mod option;
mod response;

pub use endpoint::ChatEndpoint;
pub use function::{Example, Function, FunctionCall, ToolChoice};
pub use message::{Message, Role};
pub use model::ChatModel;
pub use option::{ChatOpt, ResponseFormat};
4 changes: 4 additions & 0 deletions src/chat/option.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};

use super::{Function, ToolChoice};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub enum ResponseFormat {
Expand All @@ -21,4 +23,6 @@ pub enum ChatOpt {
MaxOutputTokens(u32),
ResponseFormat(ResponseFormat),
UserId(String),
Functions(Vec<Function>),
ToolChoice(ToolChoice),
}
Loading

0 comments on commit 4bbcf72

Please sign in to comment.