Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Sep 13, 2024
2 parents 902cce7 + 5a26c59 commit 8e03b9b
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 101 deletions.
211 changes: 137 additions & 74 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ use crate::utils::*;

use anyhow::{anyhow, bail, Context, Result};
use indexmap::IndexMap;
use inquire::{validator::Validation, Confirm, Select, Text};
use inquire::{list_option::ListOption, validator::Validation, Confirm, MultiSelect, Select, Text};
use parking_lot::RwLock;
use serde::Deserialize;
use serde_json::json;
use simplelog::LevelFilter;
use std::collections::{HashMap, HashSet};
use std::{
env,
fs::{create_dir_all, read_dir, read_to_string, remove_file, File, OpenOptions},
fs::{
create_dir_all, read_dir, read_to_string, remove_dir_all, remove_file, File, OpenOptions,
},
io::Write,
path::{Path, PathBuf},
process,
Expand Down Expand Up @@ -463,7 +465,12 @@ impl Config {
role.clone()
} else {
let mut role = Role::default();
role.batch_set(&self.model, self.temperature, self.top_p, None);
role.batch_set(
&self.model,
self.temperature,
self.top_p,
self.use_tools.clone(),
);
role
};
if role.temperature().is_none() && self.temperature.is_some() {
Expand All @@ -472,9 +479,6 @@ impl Config {
if role.top_p().is_none() && self.top_p.is_some() {
role.set_top_p(self.top_p);
}
if role.use_tools().is_none() && self.use_tools.is_some() {
role.set_use_tools(self.use_tools.clone())
}
role
}

Expand Down Expand Up @@ -542,12 +546,12 @@ impl Config {
("rag_top_k", rag_top_k.to_string()),
("highlight", self.highlight.to_string()),
("light_theme", self.light_theme.to_string()),
("env_file", display_path(&Self::env_file()?)),
("config_file", display_path(&Self::config_file()?)),
("roles_dir", display_path(&Self::roles_dir()?)),
("env_file", display_path(&Self::env_file()?)),
("functions_dir", display_path(&Self::functions_dir()?)),
("rags_dir", display_path(&Self::rags_dir()?)),
("sessions_dir", display_path(&self.sessions_dir()?)),
("rags_dir", display_path(&Self::rags_dir()?)),
("functions_dir", display_path(&Self::functions_dir()?)),
("messages_file", display_path(&self.messages_file()?)),
];
if let Ok((_, Some(log_path))) = Self::log(self.working_mode.is_serve()) {
Expand Down Expand Up @@ -624,11 +628,79 @@ impl Config {
let value = value.parse().with_context(|| "Invalid value")?;
config.write().highlight = value;
}
_ => bail!("Unknown key `{key}`"),
_ => bail!("Unknown key '{key}'"),
}
Ok(())
}

pub fn delete(config: &GlobalConfig, kind: &str) -> Result<()> {
let (dir, file_ext) = match kind {
"roles" => (Self::roles_dir()?, Some(".md")),
"sessions" => (config.read().sessions_dir()?, Some(".yaml")),
"rags" => (Self::rags_dir()?, Some(".yaml")),
"agents-config" => (Self::agents_config_dir()?, None),
_ => bail!("Unknown kind '{kind}'"),
};
let names = match read_dir(&dir) {
Ok(rd) => {
let mut names = vec![];
for entry in rd.flatten() {
let name = entry.file_name();
match file_ext {
Some(file_ext) => {
if let Some(name) = name.to_string_lossy().strip_suffix(file_ext) {
names.push(name.to_string());
}
}
None => {
if entry.path().is_dir() {
names.push(name.to_string_lossy().to_string());
}
}
}
}
names.sort_unstable();
names
}
Err(_) => vec![],
};

if names.is_empty() {
bail!("No {kind} to delete")
}

let select_names = MultiSelect::new(&format!("Select {kind} to delete:"), names)
.with_validator(|list: &[ListOption<&String>]| {
if list.is_empty() {
Ok(Validation::Invalid(
"At least one item must be selected".into(),
))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;

for name in select_names {
match file_ext {
Some(ext) => {
let path = dir.join(format!("{name}{ext}"));
remove_file(&path).with_context(|| {
format!("Failed to delete {kind} at '{}'", path.display())
})?;
}
None => {
let path = dir.join(name);
remove_dir_all(&path).with_context(|| {
format!("Failed to delete {kind} at '{}'", path.display())
})?;
}
}
}
println!("✨ Successfully deleted {kind}.");
Ok(())
}

pub fn set_temperature(&mut self, value: Option<f64>) {
match self.role_like_mut() {
Some(role_like) => role_like.set_temperature(value),
Expand Down Expand Up @@ -848,13 +920,7 @@ impl Config {
}
let role_path = Self::role_file(&role_name)?;
if let Some(role) = self.role.as_mut() {
let old_name = role.name().to_string();
role.save(&role_name, &role_path, self.working_mode.is_repl())?;
if old_name != role_name {
if let Ok(path) = Self::role_file(&old_name) {
let _ = remove_file(&path);
}
}
}

Ok(())
Expand Down Expand Up @@ -980,13 +1046,7 @@ impl Config {
};
let session_path = self.session_file(&session_name)?;
if let Some(session) = self.session.as_mut() {
let old_name = session.name().to_string();
session.save(&session_name, &session_path, self.working_mode.is_repl())?;
if old_name != session_name {
if let Ok(path) = self.session_file(&old_name) {
let _ = remove_file(&path);
}
}
}
Ok(())
}
Expand Down Expand Up @@ -1309,20 +1369,21 @@ impl Config {
.iter()
.map(|v| v.name.to_string())
.collect();
for item in use_tools.split(',') {
let item = item.trim();
if item == "all" {
tool_names.extend(declaration_names);
break;
} else if let Some(values) = self.mapping_tools.get(item) {
tool_names.extend(
values
.split(',')
.map(|v| v.to_string())
.filter(|v| declaration_names.contains(v)),
)
} else if declaration_names.contains(item) {
tool_names.insert(item.to_string());
if use_tools == "all" {
tool_names.extend(declaration_names);
} else {
for item in use_tools.split(',') {
let item = item.trim();
if let Some(values) = self.mapping_tools.get(item) {
tool_names.extend(
values
.split(',')
.map(|v| v.to_string())
.filter(|v| declaration_names.contains(v)),
)
} else if declaration_names.contains(item) {
tool_names.insert(item.to_string());
}
}
}
functions = self
Expand Down Expand Up @@ -1383,27 +1444,16 @@ impl Config {
let mut filter = "";
if args.len() == 1 {
values = match cmd {
".role" => Self::list_roles(true)
.into_iter()
.map(|v| (v, None))
.collect(),
".role" => map_completion_values(Self::list_roles(true)),
".model" => list_chat_models(self)
.into_iter()
.map(|v| (v.id(), Some(v.description())))
.collect(),
".session" => self
.list_sessions()
.into_iter()
.map(|v| (v, None))
.collect(),
".rag" => Self::list_rags().into_iter().map(|v| (v, None)).collect(),
".agent" => list_agents().into_iter().map(|v| (v, None)).collect(),
".session" => map_completion_values(self.list_sessions()),
".rag" => map_completion_values(Self::list_rags()),
".agent" => map_completion_values(list_agents()),
".starter" => match &self.agent {
Some(agent) => agent
.conversation_staters()
.iter()
.map(|v| (v.clone(), None))
.collect(),
Some(agent) => map_completion_values(agent.conversation_staters().to_vec()),
None => vec![],
},
".variable" => match &self.agent {
Expand All @@ -1414,24 +1464,31 @@ impl Config {
.collect(),
None => vec![],
},
".set" => vec![
"max_output_tokens",
"temperature",
"top_p",
"dry_run",
"stream",
"save",
"save_session",
"compress_threshold",
"function_calling",
"use_tools",
"rag_reranker_model",
"rag_top_k",
"highlight",
]
.into_iter()
.map(|v| (format!("{v} "), None))
.collect(),
".set" => {
let mut values = vec![
"max_output_tokens",
"temperature",
"top_p",
"dry_run",
"stream",
"save",
"save_session",
"compress_threshold",
"function_calling",
"use_tools",
"rag_reranker_model",
"rag_top_k",
"highlight",
];
values.sort_unstable();
values
.into_iter()
.map(|v| (format!("{v} "), None))
.collect()
}
".delete" => {
map_completion_values(vec!["roles", "sessions", "rags", "agents-config"])
}
_ => vec![],
};
filter = args[0]
Expand All @@ -1455,18 +1512,20 @@ impl Config {
"function_calling" => complete_bool(self.function_calling),
"use_tools" => {
let mut prefix = String::new();
let mut ignores = HashSet::new();
if let Some((v, _)) = args[1].rsplit_once(',') {
ignores = v.split(',').collect();
prefix = format!("{v},");
}
let mut values = vec![];
if prefix.is_empty() {
values.push("all".to_string());
}
values.extend(self.mapping_tools.keys().map(|v| v.to_string()));
values.extend(self.functions.declarations().iter().map(|v| v.name.clone()));
values.extend(self.mapping_tools.keys().map(|v| v.to_string()));
values
.into_iter()
.filter(|v| !prefix.contains(&format!("{v},")))
.filter(|v| !ignores.contains(v.as_str()))
.map(|v| format!("{prefix}{v}"))
.collect()
}
Expand Down Expand Up @@ -1505,7 +1564,7 @@ impl Config {
let theme_path = Self::local_path(&theme_filename)?;
if theme_path.exists() {
let theme = ThemeSet::get_theme(&theme_path)
.with_context(|| format!("Invalid theme at {}", theme_path.display()))?;
.with_context(|| format!("Invalid theme at '{}'", theme_path.display()))?;
Some(theme)
} else {
let theme = if self.light_theme {
Expand Down Expand Up @@ -1703,7 +1762,7 @@ impl Config {

fn load_from_file(config_path: &Path) -> Result<Self> {
let content = read_to_string(config_path)
.with_context(|| format!("Failed to load config at {}", config_path.display()))?;
.with_context(|| format!("Failed to load config at '{}'", config_path.display()))?;
let config: Self = serde_yaml::from_str(&content).map_err(|err| {
let err_msg = err.to_string();
let err_msg = if err_msg.starts_with(&format!("{}: ", CLIENTS_FIELD)) {
Expand Down Expand Up @@ -2067,6 +2126,10 @@ fn complete_option_bool(value: Option<bool>) -> Vec<String> {
}
}

fn map_completion_values<T: ToString>(value: Vec<T>) -> Vec<(String, Option<String>)> {
value.into_iter().map(|v| (v.to_string(), None)).collect()
}

fn update_rag<F>(config: &GlobalConfig, f: F) -> Result<()>
where
F: FnOnce(&mut Rag) -> Result<()>,
Expand Down
3 changes: 3 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ async fn shell_execute(config: &GlobalConfig, shell: &Shell, mut input: Input) -
config
.write()
.after_chat_completion(&input, &eval_str, &[])?;
if eval_str.is_empty() {
bail!("No command generated");
}
if config.read().dry_run {
config.read().print_markdown(&eval_str)?;
return Ok(());
Expand Down
2 changes: 1 addition & 1 deletion src/repl/completer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Completer for ReplCompleter {
.map(|(v, _)| *v)
.collect::<Vec<&str>>()
.join(" ");
cmd.name.starts_with(&line)
cmd.name.starts_with(&line) && cmd.name != ".set"
})
.collect();

Expand Down
13 changes: 11 additions & 2 deletions src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lazy_static::lazy_static! {
const MENU_NAME: &str = "completion_menu";

lazy_static::lazy_static! {
static ref REPL_COMMANDS: [ReplCommand; 30] = [
static ref REPL_COMMANDS: [ReplCommand; 31] = [
ReplCommand::new(".help", "Show this help message", AssertState::pass()),
ReplCommand::new(".info", "View system info", AssertState::pass()),
ReplCommand::new(".model", "Change the current LLM", AssertState::pass()),
Expand Down Expand Up @@ -147,7 +147,8 @@ lazy_static::lazy_static! {
"Regenerate the last response",
AssertState::pass()
),
ReplCommand::new(".set", "Adjust settings", AssertState::pass()),
ReplCommand::new(".set", "Adjust runtime configuration", AssertState::pass()),
ReplCommand::new(".delete", "Delete roles/sessions/RAGs/agents-config", AssertState::pass()),
ReplCommand::new(".copy", "Copy the last response", AssertState::pass()),
ReplCommand::new(".exit", "Exit the REPL", AssertState::pass()),
];
Expand Down Expand Up @@ -392,6 +393,14 @@ impl Repl {
println!("Usage: .set <key> <value>...")
}
},
".delete" => match args {
Some(args) => {
Config::delete(&self.config, args)?;
}
_ => {
println!("Usage: .delete [roles|sessions|rags|agents-config]")
}
},
".copy" => {
let config = self.config.read();
self.copy(config.last_reply())
Expand Down
Loading

0 comments on commit 8e03b9b

Please sign in to comment.