Skip to content

Commit

Permalink
feat: replace roles.yaml with roles/<name>.md
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Aug 30, 2024
1 parent 298a452 commit 10c3db0
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 172 deletions.
196 changes: 143 additions & 53 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod session;

pub use self::agent::{list_agents, Agent};
pub use self::input::Input;
pub use self::role::{Role, RoleLike, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE};
pub use self::role::{Role, RoleLike, BUILTIN_ROLES, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE};
use self::session::Session;

use crate::client::{
Expand All @@ -19,7 +19,7 @@ use crate::utils::*;

use anyhow::{anyhow, bail, Context, Result};
use indexmap::IndexMap;
use inquire::{Confirm, Select};
use inquire::{validator::Validation, Confirm, Select, Text};
use parking_lot::RwLock;
use serde::Deserialize;
use serde_json::json;
Expand All @@ -40,7 +40,7 @@ const DARK_THEME: &[u8] = include_bytes!("../../assets/monokai-extended.theme.bi
const LIGHT_THEME: &[u8] = include_bytes!("../../assets/monokai-extended-light.theme.bin");

const CONFIG_FILE_NAME: &str = "config.yaml";
const ROLES_FILE_NAME: &str = "roles.yaml";
const ROLES_DIR_NAME: &str = "roles";
const ENV_FILE_NAME: &str = ".env";
const MESSAGES_FILE_NAME: &str = "messages.md";
const SESSIONS_DIR_NAME: &str = "sessions";
Expand Down Expand Up @@ -128,8 +128,6 @@ pub struct Config {

pub clients: Vec<ClientConfig>,

#[serde(skip)]
pub roles: Vec<Role>,
#[serde(skip)]
pub role: Option<Role>,
#[serde(skip)]
Expand Down Expand Up @@ -195,7 +193,6 @@ impl Default for Config {

clients: vec![],

roles: vec![],
role: None,
session: None,
rag: None,
Expand Down Expand Up @@ -236,7 +233,6 @@ impl Config {
}

config.load_functions()?;
config.load_roles()?;

config.setup_model()?;
config.setup_document_loaders();
Expand Down Expand Up @@ -269,13 +265,17 @@ impl Config {
}
}

pub fn roles_file() -> Result<PathBuf> {
match env::var(get_env_name("roles_file")) {
pub fn roles_dir() -> Result<PathBuf> {
match env::var(get_env_name("roles_dir")) {
Ok(value) => Ok(PathBuf::from(value)),
Err(_) => Self::local_path(ROLES_FILE_NAME),
Err(_) => Self::local_path(ROLES_DIR_NAME),
}
}

pub fn role_file(name: &str) -> Result<PathBuf> {
Ok(Self::roles_dir()?.join(format!("{name}.md")))
}

pub fn env_file() -> Result<PathBuf> {
match env::var(get_env_name("env_file")) {
Ok(value) => Ok(PathBuf::from(value)),
Expand Down Expand Up @@ -487,7 +487,7 @@ impl Config {
} else if let Some(session) = &self.session {
session.export()
} else if let Some(role) = &self.role {
role.export()
Ok(role.export())
} else if let Some(rag) = &self.rag {
rag.export()
} else {
Expand Down Expand Up @@ -531,7 +531,7 @@ impl Config {
("highlight", self.highlight.to_string()),
("light_theme", self.light_theme.to_string()),
("config_file", display_path(&Self::config_file()?)),
("roles_file", display_path(&Self::roles_file()?)),
("roles_dir", display_path(&Self::roles_dir()?)),
("env_file", display_path(&Self::env_file()?)),
("functions_dir", display_path(&Self::functions_dir()?)),
("rags_dir", display_path(&Self::rags_dir()?)),
Expand All @@ -543,9 +543,9 @@ impl Config {
}
let output = items
.iter()
.map(|(name, value)| format!("{name:<24}{value}"))
.map(|(name, value)| format!("{name:<24}{value}\n"))
.collect::<Vec<String>>()
.join("\n");
.join("");
Ok(output)
}

Expand Down Expand Up @@ -716,7 +716,7 @@ impl Config {

pub fn role_info(&self) -> Result<String> {
if let Some(role) = &self.role {
role.export()
Ok(role.export())
} else {
bail!("No role")
}
Expand All @@ -733,17 +733,17 @@ impl Config {
}

pub fn retrieve_role(&self, name: &str) -> Result<Role> {
let mut role = self
.roles
.iter()
.find(|v| v.match_name(name))
.map(|v| {
let mut role = v.clone();
role.complete_prompt_args(name);
role
})
.ok_or_else(|| anyhow!("Unknown role `{name}`"))?;

let mut role = if Self::list_roles(false).contains(&name.to_string()) {
let path = Self::role_file(name)?;
let content = read_to_string(path)?;
Role::new(name, &content)
} else {
BUILTIN_ROLES
.iter()
.find(|v| v.name() == name)
.cloned()
.ok_or_else(|| anyhow!("Unknown role `{name}`"))?
};
match role.model_id() {
Some(model_id) => {
if self.model.id() != model_id {
Expand All @@ -758,6 +758,110 @@ impl Config {
Ok(role)
}

pub fn new_role(&mut self, name: &str) -> Result<()> {
let ans = Confirm::new("Create a new role?")
.with_default(true)
.prompt()?;
if ans {
self.upsert_role(name)?;
}
Ok(())
}

pub fn edit_role(&mut self) -> Result<()> {
if let Some(name) = self.role.as_ref().map(|v| v.name().to_string()) {
self.upsert_role(&name)
} else {
bail!("No role")
}
}

pub fn upsert_role(&mut self, name: &str) -> Result<()> {
let role_path = Self::role_file(name)?;
ensure_parent_exists(&role_path)?;
let editor = self.editor()?;
edit_file(&editor, &role_path)?;
self.use_role(name)?;
Ok(())
}

pub fn save_role(&mut self, name: Option<&str>) -> Result<()> {
let mut role_name = match &self.role {
Some(role) => match name {
Some(v) => v.to_string(),
None => role.name().to_string(),
},
None => bail!("No role"),
};
if role_name == TEMP_ROLE_NAME {
role_name = Text::new("Role name:")
.with_validator(|input: &str| {
if input.trim().is_empty() {
Ok(Validation::Invalid("This field is required".into()))
} else {
Ok(Validation::Valid)
}
})
.prompt()?;
}
let role_path = Self::role_file(&role_name)?;
if let Some(role) = self.role.as_mut() {
let old_name = role.name().to_string();
role.save(&role_name, &role_path, self.working_mode.is_repl())?;
if old_name != role_name {
if let Ok(path) = Self::role_file(&old_name) {
let _ = remove_file(&path);
}
}
}

Ok(())
}

pub fn all_roles() -> Vec<Role> {
let mut roles: HashMap<String, Role> = BUILTIN_ROLES
.iter()
.map(|v| (v.name().to_string(), v.clone()))
.collect();
let names = Self::list_roles(false);
for name in names {
if let Ok(path) = Self::role_file(&name) {
if let Ok(content) = read_to_string(&path) {
let role = Role::new(&name, &content);
roles.insert(name, role);
}
}
}
let mut roles: Vec<_> = roles.into_values().collect();
roles.sort_unstable_by(|a, b| a.name().cmp(b.name()));
roles
}

pub fn list_roles(with_builtin: bool) -> Vec<String> {
let mut names = HashSet::new();
if let Some(rd) = Self::roles_dir().ok().and_then(|dir| read_dir(dir).ok()) {
for entry in rd.flatten() {
if let Some(name) = entry
.file_name()
.to_str()
.and_then(|v| v.strip_suffix(".md"))
{
names.insert(name.to_string());
}
}
}
if with_builtin {
names.extend(BUILTIN_ROLES.iter().map(|v| v.name().to_string()));
}
let mut names: Vec<_> = names.into_iter().collect();
names.sort_unstable();
names
}

pub fn has_role(name: &str) -> bool {
Self::list_roles(true).iter().any(|v| v == name)
}

pub fn use_session(&mut self, session_name: Option<&str>) -> Result<()> {
if self.session.is_some() {
bail!(
Expand Down Expand Up @@ -824,16 +928,22 @@ impl Config {
}

pub fn save_session(&mut self, name: Option<&str>) -> Result<()> {
let name = match &self.session {
let session_name = match &self.session {
Some(session) => match name {
Some(v) => v.to_string(),
None => session.name().to_string(),
},
None => bail!("No session"),
};
let session_path = self.session_file(&name)?;
let session_path = self.session_file(&session_name)?;
if let Some(session) = self.session.as_mut() {
session.save(&session_path, self.working_mode.is_repl())?;
let old_name = session.name().to_string();
session.save(&session_name, &session_path, self.working_mode.is_repl())?;
if old_name != session_name {
if let Ok(path) = self.session_file(&old_name) {
let _ = remove_file(&path);
}
}
}
Ok(())
}
Expand All @@ -843,9 +953,9 @@ impl Config {
Some(session) => session.name().to_string(),
None => bail!("No session"),
};
let editor = self.editor()?;
let session_path = self.session_file(&name)?;
self.save_session(Some(&name))?;
let editor = self.editor()?;
edit_file(&editor, &session_path).with_context(|| {
format!(
"Failed to edit '{}' with '{editor}'",
Expand Down Expand Up @@ -1201,10 +1311,9 @@ impl Config {
let mut filter = "";
if args.len() == 1 {
values = match cmd {
".role" => self
.roles
.iter()
.map(|v| (v.name().to_string(), None))
".role" => Self::list_roles(true)
.into_iter()
.map(|v| (v, None))
.collect(),
".model" => list_chat_models(self)
.into_iter()
Expand Down Expand Up @@ -1702,25 +1811,6 @@ impl Config {
Ok(())
}

fn load_roles(&mut self) -> Result<()> {
let path = Self::roles_file()?;
self.roles = if !path.exists() {
vec![]
} else {
let content = read_to_string(&path)
.with_context(|| format!("Failed to load roles at {}", path.display()))?;
serde_yaml::from_str(&content).with_context(|| "Invalid roles config")?
};
let exist_roles: HashSet<_> = self.roles.iter().map(|v| v.name().to_string()).collect();
let builtin_roles = Role::builtin();
for role in builtin_roles {
if !exist_roles.contains(role.name()) {
self.roles.push(role);
}
}
Ok(())
}

fn setup_model(&mut self) -> Result<()> {
let mut model_id = self.model_id.clone();
if model_id.is_empty() {
Expand Down
Loading

0 comments on commit 10c3db0

Please sign in to comment.