Skip to content

Commit

Permalink
refactor: role/session/agent should not inherit the global use_tools (
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Sep 13, 2024
1 parent 0fb403a commit 5a26c59
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
51 changes: 28 additions & 23 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,12 @@ impl Config {
role.clone()
} else {
let mut role = Role::default();
role.batch_set(&self.model, self.temperature, self.top_p, None);
role.batch_set(
&self.model,
self.temperature,
self.top_p,
self.use_tools.clone(),
);
role
};
if role.temperature().is_none() && self.temperature.is_some() {
Expand All @@ -474,9 +479,6 @@ impl Config {
if role.top_p().is_none() && self.top_p.is_some() {
role.set_top_p(self.top_p);
}
if role.use_tools().is_none() && self.use_tools.is_some() {
role.set_use_tools(self.use_tools.clone())
}
role
}

Expand Down Expand Up @@ -544,12 +546,12 @@ impl Config {
("rag_top_k", rag_top_k.to_string()),
("highlight", self.highlight.to_string()),
("light_theme", self.light_theme.to_string()),
("env_file", display_path(&Self::env_file()?)),
("config_file", display_path(&Self::config_file()?)),
("roles_dir", display_path(&Self::roles_dir()?)),
("env_file", display_path(&Self::env_file()?)),
("functions_dir", display_path(&Self::functions_dir()?)),
("rags_dir", display_path(&Self::rags_dir()?)),
("sessions_dir", display_path(&self.sessions_dir()?)),
("rags_dir", display_path(&Self::rags_dir()?)),
("functions_dir", display_path(&Self::functions_dir()?)),
("messages_file", display_path(&self.messages_file()?)),
];
if let Ok((_, Some(log_path))) = Self::log(self.working_mode.is_serve()) {
Expand Down Expand Up @@ -1367,20 +1369,21 @@ impl Config {
.iter()
.map(|v| v.name.to_string())
.collect();
for item in use_tools.split(',') {
let item = item.trim();
if item == "all" {
tool_names.extend(declaration_names);
break;
} else if let Some(values) = self.mapping_tools.get(item) {
tool_names.extend(
values
.split(',')
.map(|v| v.to_string())
.filter(|v| declaration_names.contains(v)),
)
} else if declaration_names.contains(item) {
tool_names.insert(item.to_string());
if use_tools == "all" {
tool_names.extend(declaration_names);
} else {
for item in use_tools.split(',') {
let item = item.trim();
if let Some(values) = self.mapping_tools.get(item) {
tool_names.extend(
values
.split(',')
.map(|v| v.to_string())
.filter(|v| declaration_names.contains(v)),
)
} else if declaration_names.contains(item) {
tool_names.insert(item.to_string());
}
}
}
functions = self
Expand Down Expand Up @@ -1509,18 +1512,20 @@ impl Config {
"function_calling" => complete_bool(self.function_calling),
"use_tools" => {
let mut prefix = String::new();
let mut ignores = HashSet::new();
if let Some((v, _)) = args[1].rsplit_once(',') {
ignores = v.split(',').collect();
prefix = format!("{v},");
}
let mut values = vec![];
if prefix.is_empty() {
values.push("all".to_string());
}
values.extend(self.mapping_tools.keys().map(|v| v.to_string()));
values.extend(self.functions.declarations().iter().map(|v| v.name.clone()));
values.extend(self.mapping_tools.keys().map(|v| v.to_string()));
values
.into_iter()
.filter(|v| !prefix.contains(&format!("{v},")))
.filter(|v| !ignores.contains(v.as_str()))
.map(|v| format!("{prefix}{v}"))
.collect()
}
Expand Down
3 changes: 3 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ async fn shell_execute(config: &GlobalConfig, shell: &Shell, mut input: Input) -
config
.write()
.after_chat_completion(&input, &eval_str, &[])?;
if eval_str.is_empty() {
bail!("No command generated");
}
if config.read().dry_run {
config.read().print_markdown(&eval_str)?;
return Ok(());
Expand Down

0 comments on commit 5a26c59

Please sign in to comment.