-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from chenwanqq/tools
update functions call
- Loading branch information
Showing
11 changed files
with
196 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
use erniebot_rs::chat::{ChatEndpoint, ChatModel, ChatOpt, Example, Function, Message, Role}; | ||
use schemars::{schema_for, JsonSchema}; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(JsonSchema, Debug, Clone, Serialize, Deserialize, PartialEq)] | ||
struct WeatherParameters { | ||
place: String, | ||
} | ||
|
||
fn get_weather_mock(params: WeatherParameters) -> String { | ||
format!("The weather in {} is sunny", params.place) | ||
} | ||
|
||
fn main() { | ||
let messages = vec![Message { | ||
role: Role::User, | ||
content: "What's the weather in Beijing?".to_string(), | ||
..Default::default() | ||
}]; | ||
let examples = vec![vec![ | ||
Example { | ||
role: "user".to_string(), | ||
content: Some("What's the weather in Shanghai?".to_string()), | ||
..Default::default() | ||
}, | ||
Example { | ||
role: "assistant".to_string(), | ||
content: None, | ||
name: None, | ||
function_call: Some(erniebot_rs::chat::FunctionCall { | ||
name: "weather".to_string(), | ||
arguments: r#"{"place":"Shanghai"}"#.to_string(), | ||
thoughts: Some( | ||
"I'm calling the weather function to get weather of Shanghai".to_string(), | ||
), | ||
}), | ||
}, | ||
]]; | ||
let weather_function = Function { | ||
name: "weather".to_string(), | ||
description: "Get the weather of a place".to_string(), | ||
parameters: schema_for!(WeatherParameters), | ||
examples: Some(examples), | ||
..Default::default() | ||
}; | ||
let functions = vec![weather_function]; | ||
let options = vec![ChatOpt::Functions(functions)]; | ||
let chat = ChatEndpoint::new(ChatModel::ErnieBot).unwrap(); | ||
let response = chat.invoke(messages, options).unwrap(); | ||
let text_response = response.get_chat_result().unwrap(); | ||
let function_call = response.get_function_call().unwrap(); | ||
println!("text_response: {}", text_response); | ||
println!("function_call: {:?}", function_call); | ||
let weather_params: WeatherParameters = serde_json::from_str(&function_call.arguments).unwrap(); | ||
let weather_result = get_weather_mock(weather_params); | ||
println!("weather_result: {}", weather_result); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
use schemars::schema::RootSchema; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::Value; | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] | ||
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] | ||
pub struct Function { | ||
pub name: String, | ||
pub description: String, | ||
pub parameters: RootSchema, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub response: Option<RootSchema>, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub examples: Option<Vec<Vec<Example>>>, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] | ||
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] | ||
pub struct Example { | ||
pub role: String, | ||
pub content: Option<String>, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub name: Option<String>, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub function_call: Option<FunctionCall>, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] | ||
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] | ||
pub struct FunctionCall { | ||
pub name: String, | ||
pub arguments: String, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub thoughts: Option<String>, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] | ||
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] | ||
pub struct ToolChoice { | ||
pub r#type: String, //only one valid value: "function" | ||
pub function: Value, | ||
} | ||
|
||
impl ToolChoice { | ||
pub fn new(function: Function) -> Self { | ||
Self { | ||
r#type: "function".to_string(), | ||
function: serde_json::json!( | ||
{ | ||
"name": function.name, | ||
} | ||
), | ||
} | ||
} | ||
pub fn from_function_name(name: String) -> Self { | ||
Self { | ||
r#type: "function".to_string(), | ||
function: serde_json::json!( | ||
{ | ||
"name": name, | ||
} | ||
), | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use schemars::{schema::RootSchema, schema_for, JsonSchema}; | ||
use serde::{Deserialize, Serialize}; | ||
#[test] | ||
fn test_schema() { | ||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] | ||
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))] | ||
struct TestStruct { | ||
pub date: String, | ||
pub place: String, | ||
} | ||
let schema = schema_for!(TestStruct); | ||
println!("{:?}", serde_json::to_string(&schema).unwrap()); | ||
let default_schema = RootSchema::default(); | ||
println!("{:?}", serde_json::to_string(&default_schema).unwrap()); | ||
} | ||
|
||
#[test] | ||
fn test_tool_choice() { | ||
use super::Function; | ||
let function = Function { | ||
name: "test".to_string(), | ||
description: "test".to_string(), | ||
..Default::default() | ||
}; | ||
let tool_choice = super::ToolChoice::new(function); | ||
println!("{:?}", serde_json::to_string(&tool_choice).unwrap()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
mod endpoint; | ||
mod function; | ||
mod message; | ||
mod model; | ||
mod option; | ||
mod response; | ||
|
||
pub use endpoint::ChatEndpoint; | ||
pub use function::{Example, Function, FunctionCall, ToolChoice}; | ||
pub use message::{Message, Role}; | ||
pub use model::ChatModel; | ||
pub use option::{ChatOpt, ResponseFormat}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.