Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update functions call #2

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ tokio = { version = "1.36.0", features = ["full"] }
tokio-stream = "0.1.14"
base64 = "0.21.7"
image = "0.24.9"
schemars = "0.8"
10 changes: 5 additions & 5 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn test_invoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(),
name: None,
..Default::default()
},
];
let options = vec![
Expand All @@ -29,7 +29,7 @@ fn test_stream() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using stream method".to_string(),
name: None,
..Default::default()
},
];
let options = vec![
Expand All @@ -50,7 +50,7 @@ fn test_ainvoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using ainvoke method".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -66,7 +66,7 @@ fn test_astream() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -89,7 +89,7 @@ fn test_custom_endpoint() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand Down
57 changes: 57 additions & 0 deletions examples/chat_with_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use erniebot_rs::chat::{ChatEndpoint, ChatModel, ChatOpt, Example, Function, Message, Role};
use schemars::{schema_for, JsonSchema};
use serde::{Deserialize, Serialize};

#[derive(JsonSchema, Debug, Clone, Serialize, Deserialize, PartialEq)]
struct WeatherParameters {
place: String,
}

fn get_weather_mock(params: WeatherParameters) -> String {
format!("The weather in {} is sunny", params.place)
}

fn main() {
let messages = vec![Message {
role: Role::User,
content: "What's the weather in Beijing?".to_string(),
..Default::default()
}];
let examples = vec![vec![
Example {
role: "user".to_string(),
content: Some("What's the weather in Shanghai?".to_string()),
..Default::default()
},
Example {
role: "assistant".to_string(),
content: None,
name: None,
function_call: Some(erniebot_rs::chat::FunctionCall {
name: "weather".to_string(),
arguments: r#"{"place":"Shanghai"}"#.to_string(),
thoughts: Some(
"I'm calling the weather function to get weather of Shanghai".to_string(),
),
}),
},
]];
let weather_function = Function {
name: "weather".to_string(),
description: "Get the weather of a place".to_string(),
parameters: schema_for!(WeatherParameters),
examples: Some(examples),
..Default::default()
};
let functions = vec![weather_function];
let options = vec![ChatOpt::Functions(functions)];
let chat = ChatEndpoint::new(ChatModel::ErnieBot).unwrap();
let response = chat.invoke(messages, options).unwrap();
let text_response = response.get_chat_result().unwrap();
let function_call = response.get_function_call().unwrap();
println!("text_response: {}", text_response);
println!("function_call: {:?}", function_call);
let weather_params: WeatherParameters = serde_json::from_str(&function_call.arguments).unwrap();
let weather_result = get_weather_mock(weather_params);
println!("weather_result: {}", weather_result);
}
8 changes: 5 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn test_invoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(),
name: None,
..Default::default()
},
];
let options = vec![
Expand All @@ -58,7 +58,7 @@ fn test_custom_endpoint() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -79,7 +79,7 @@ fn test_astream() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -97,6 +97,8 @@ fn test_astream() {
}
```

For some models, such as ErnieBot, they support the option of passing in functions for invocation. You can refer to examples/chat_with_function.rs for an example.

Please note that due to varying parameter requirements for each specific model, this SDK does not perform local parameter validation but instead passes the parameters to the server for validation. Therefore, if the parameters do not meet the requirements, the server will return an error message.

## Embedding
Expand Down
8 changes: 5 additions & 3 deletions readme_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn test_invoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using invoke method".to_string(),
name: None,
..Default::default()
},
];
let options = vec![
Expand All @@ -58,7 +58,7 @@ fn test_invoke() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using a custom endpoint".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -79,7 +79,7 @@ fn test_astream() {
Message {
role: Role::User,
content: "hello, I'm a developer. I'm developing a rust SDK for qianfan LLM. If you get this message, that means I successfully send you this message using async stream method. Now reply to me a message as long as possible so that I can test if this function doing well".to_string(),
name: None,
..Default::default()
},
];
let options = Vec::new();
Expand All @@ -97,6 +97,8 @@ fn test_astream() {
}
```

对于一些模型,如ErnieBot,支持传入functions进行调用的选择,可以参考examples/chat_with_function.rs

注意,由于各个具体模型对参数的要求不同,所以本SDK并未在本地进行参数校验,而是将参数传递给服务端进行校验。因此,如果参数不符合要求,服务端会返回错误信息。

## embedding
Expand Down
2 changes: 1 addition & 1 deletion src/chat/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ mod tests {
let messages = vec![Message {
role: Role::User,
content: "hello, I'm a user".to_string(),
name: None,
..Default::default()
}];
let options = vec![
ChatOpt::Temperature(0.5),
Expand Down
96 changes: 96 additions & 0 deletions src/chat/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use schemars::schema::RootSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: RootSchema,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<RootSchema>,
#[serde(skip_serializing_if = "Option::is_none")]
pub examples: Option<Vec<Vec<Example>>>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct Example {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub thoughts: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub struct ToolChoice {
pub r#type: String, //only one valid value: "function"
pub function: Value,
}

impl ToolChoice {
pub fn new(function: Function) -> Self {
Self {
r#type: "function".to_string(),
function: serde_json::json!(
{
"name": function.name,
}
),
}
}
pub fn from_function_name(name: String) -> Self {
Self {
r#type: "function".to_string(),
function: serde_json::json!(
{
"name": name,
}
),
}
}
}

#[cfg(test)]
mod tests {
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::{Deserialize, Serialize};
#[test]
fn test_schema() {
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
struct TestStruct {
pub date: String,
pub place: String,
}
let schema = schema_for!(TestStruct);
println!("{:?}", serde_json::to_string(&schema).unwrap());
let default_schema = RootSchema::default();
println!("{:?}", serde_json::to_string(&default_schema).unwrap());
}

#[test]
fn test_tool_choice() {
use super::Function;
let function = Function {
name: "test".to_string(),
description: "test".to_string(),
..Default::default()
};
let tool_choice = super::ToolChoice::new(function);
println!("{:?}", serde_json::to_string(&tool_choice).unwrap());
}
}
22 changes: 13 additions & 9 deletions src/chat/message.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::FunctionCall;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub enum Role {
#[default]
User,
Assistant,
Function,
Expand All @@ -15,7 +17,7 @@ The "messages" member must not be empty. One member represents a single round of
let message1 = Message {
role: Role::User,
content: "hello, I'm a user".to_string(),
name: None,
..Default::default()
};
let messages = vec![message1];
```
Expand All @@ -24,17 +26,17 @@ The "messages" member must not be empty. One member represents a single round of
let message1 = Message {
role: Role::User,
content: "hello, I'm a user".to_string(),
name: None,
..Default::default()
};
let message2 = Message {
role: Role::Assistant,
content: "hello, I'm a AI LLM model".to_string(),
name: None,
..Default::default()
};
let message3 = Message {
role: Role::User,
content: "hello, I want you to help me".to_string(),
name: None,
..Default::default()
};
let messages = vec![message1, message2, message3];
```
Expand All @@ -44,12 +46,14 @@ The number of members must be odd. The role values of the messages in the member

In the example, the role values of the messages are "user", "assistant", "user", "assistant", and "user" respectively. The role values of the messages at odd positions are "user", which means the role values of the 1st, 3rd, and 5th messages are "user". The role values at even positions are "assistant", which means the role values of the 2nd and 4th messages are "assistant".
*/
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}

#[cfg(test)]
Expand All @@ -61,17 +65,17 @@ mod tests {
let message1 = Message {
role: Role::User,
content: "hello, I'm a user".to_string(),
name: None,
..Default::default()
};
let message2 = Message {
role: Role::Assistant,
content: "hello, I'm a AI LLM model".to_string(),
name: None,
..Default::default()
};
let message3 = Message {
role: Role::User,
content: "hello, I want you to help me".to_string(),
name: None,
..Default::default()
};
let messages = vec![message1, message2, message3];
let messages_str = to_string(&messages).unwrap();
Expand Down
2 changes: 2 additions & 0 deletions src/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod endpoint;
mod function;
mod message;
mod model;
mod option;
mod response;

pub use endpoint::ChatEndpoint;
pub use function::{Example, Function, FunctionCall, ToolChoice};
pub use message::{Message, Role};
pub use model::ChatModel;
pub use option::{ChatOpt, ResponseFormat};
4 changes: 4 additions & 0 deletions src/chat/option.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};

use super::{Function, ToolChoice};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all(serialize = "snake_case", deserialize = "snake_case"))]
pub enum ResponseFormat {
Expand All @@ -21,4 +23,6 @@ pub enum ChatOpt {
MaxOutputTokens(u32),
ResponseFormat(ResponseFormat),
UserId(String),
Functions(Vec<Function>),
ToolChoice(ToolChoice),
}
Loading
Loading