Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add .save agent-config repl command #870

Merged
merged 1 commit into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ editor: null # Specifies the command used to edit input buff
wrap: no # Controls text wrapping (no, auto, <max-width>)
wrap_code: false # Enables or disables wrapping of code blocks

# ---- function-calling ----
# Visit https://github.com/sigoden/llm-functions for setup instructions
function_calling: true # Enables or disables function calling (Globally).
mapping_tools: # Alias for a tool or toolset
fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
use_tools: null # Which tools to use by default. (e.g. 'fs,web_search')

# ---- prelude ----
prelude: null # Set a default role or session to start with (e.g. role:<name>, session:<name>)
repl_prelude: null # Overrides the `prelude` setting specifically for conversations started in REPL
Expand All @@ -26,13 +33,6 @@ summarize_prompt: 'Summarize the discussion briefly in 200 words or less to use
# Text prompt used for including the summary of the entire session
summary_prompt: 'This is a summary of the chat history as a recap: '

# ---- function-calling ----
# Visit https://github.com/sigoden/llm-functions for setup instructions
function_calling: true # Enables or disables function calling (Globally).
mapping_tools: # Alias for a tool or toolset
fs: 'fs_cat,fs_ls,fs_mkdir,fs_rm,fs_write'
use_tools: null # Which tools to use by default. (e.g. 'fs,web_search')

# ---- RAG ----
# See [RAG-Guide](https://github.com/sigoden/aichat/wiki/RAG-Guide) for more details.
rag_embedding_model: null # Specifies the embedding model to use
Expand Down
35 changes: 26 additions & 9 deletions src/config/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl Agent {
let agent_config = if config_path.exists() {
AgentConfig::load(&config_path)?
} else {
AgentConfig::default()
AgentConfig::new(&config.read())
};
let mut definition = AgentDefinition::load(&definition_file_path)?;
init_variables(&variables_path, &mut definition.variables)
Expand Down Expand Up @@ -91,6 +91,18 @@ impl Agent {
})
}

pub fn save_config(&self) -> Result<()> {
let config_path = Config::agent_config_file(&self.name)?;
ensure_parent_exists(&config_path)?;
let content = serde_yaml::to_string(&self.config)?;
fs::write(&config_path, content).with_context(|| {
format!("Failed to save agent config to '{}'", config_path.display())
})?;

println!("✨ Saved agent config to '{}'", config_path.display());
Ok(())
}

pub fn export(&self) -> Result<String> {
let mut agent = self.clone();
agent.definition.instructions = self.interpolated_instructions();
Expand Down Expand Up @@ -143,6 +155,10 @@ impl Agent {
self.config.agent_prelude.as_deref()
}

pub fn set_agent_prelude(&mut self, value: Option<String>) {
self.config.agent_prelude = value;
}

pub fn variables(&self) -> &[AgentVariable] {
&self.definition.variables
}
Expand Down Expand Up @@ -208,22 +224,23 @@ impl RoleLike for Agent {

#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct AgentConfig {
#[serde(
rename(serialize = "model", deserialize = "model"),
skip_serializing_if = "Option::is_none"
)]
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub agent_prelude: Option<String>,
}

impl AgentConfig {
pub fn new(config: &Config) -> Self {
Self {
use_tools: config.use_tools.clone(),
agent_prelude: config.agent_prelude.clone(),
..Default::default()
}
}

pub fn load(path: &Path) -> Result<Self> {
let contents = read_to_string(path)
.with_context(|| format!("Failed to read agent config file at '{}'", path.display()))?;
Expand Down
127 changes: 73 additions & 54 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ pub struct Config {
pub wrap: Option<String>,
pub wrap_code: bool,

pub function_calling: bool,
pub mapping_tools: IndexMap<String, String>,
pub use_tools: Option<String>,

pub prelude: Option<String>,
pub repl_prelude: Option<String>,
pub agent_prelude: Option<String>,
Expand All @@ -108,10 +112,6 @@ pub struct Config {
pub summarize_prompt: Option<String>,
pub summary_prompt: Option<String>,

pub function_calling: bool,
pub mapping_tools: IndexMap<String, String>,
pub use_tools: Option<String>,

pub rag_embedding_model: Option<String>,
pub rag_reranker_model: Option<String>,
pub rag_top_k: usize,
Expand Down Expand Up @@ -166,6 +166,10 @@ impl Default for Config {
wrap: None,
wrap_code: false,

function_calling: true,
mapping_tools: Default::default(),
use_tools: None,

prelude: None,
repl_prelude: None,
agent_prelude: None,
Expand All @@ -175,10 +179,6 @@ impl Default for Config {
summarize_prompt: None,
summary_prompt: None,

function_calling: true,
mapping_tools: Default::default(),
use_tools: None,

rag_embedding_model: None,
rag_reranker_model: None,
rag_top_k: 4,
Expand Down Expand Up @@ -402,7 +402,7 @@ impl Config {
self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into())
}

pub fn log(is_serve: bool) -> Result<(LevelFilter, Option<PathBuf>)> {
pub fn log_config(is_serve: bool) -> Result<(LevelFilter, Option<PathBuf>)> {
let log_level = env::var(get_env_name("log_level"))
.ok()
.and_then(|v| v.parse().ok())
Expand Down Expand Up @@ -513,10 +513,14 @@ impl Config {
.wrap
.clone()
.map_or_else(|| String::from("no"), |v| v.to_string());
let (rag_reranker_model, rag_top_k) = match self.rag.as_ref() {
let (rag_reranker_model, rag_top_k) = match &self.rag {
Some(rag) => rag.get_config(),
None => (self.rag_reranker_model.clone(), self.rag_top_k),
};
let agent_prelude = match &self.agent {
Some(agent) => agent.agent_prelude(),
None => self.agent_prelude.as_deref(),
};
let role = self.extract_role();
let mut items = vec![
("model", role.model().id()),
Expand All @@ -535,10 +539,11 @@ impl Config {
("keybindings", self.keybindings.clone()),
("wrap", wrap),
("wrap_code", self.wrap_code.to_string()),
("save_session", format_option_value(&self.save_session)),
("compress_threshold", self.compress_threshold.to_string()),
("function_calling", self.function_calling.to_string()),
("use_tools", format_option_value(&role.use_tools())),
("agent_prelude", format_option_value(&agent_prelude)),
("save_session", format_option_value(&self.save_session)),
("compress_threshold", self.compress_threshold.to_string()),
(
"rag_reranker_model",
format_option_value(&rag_reranker_model),
Expand All @@ -554,7 +559,7 @@ impl Config {
("functions_dir", display_path(&Self::functions_dir()?)),
("messages_file", display_path(&self.messages_file()?)),
];
if let Ok((_, Some(log_path))) = Self::log(self.working_mode.is_serve()) {
if let Ok((_, Some(log_path))) = Self::log_config(self.working_mode.is_serve()) {
items.push(("log_path", display_path(&log_path)));
}
let output = items
Expand Down Expand Up @@ -597,14 +602,6 @@ impl Config {
let value = value.parse().with_context(|| "Invalid value")?;
config.write().save = value;
}
"rag_reranker_model" => {
let value = parse_value(value)?;
Self::set_rag_reranker_model(config, value)?;
}
"rag_top_k" => {
let value = value.parse().with_context(|| "Invalid value")?;
Self::set_rag_top_k(config, value)?;
}
"function_calling" => {
let value = value.parse().with_context(|| "Invalid value")?;
if value && config.write().functions.is_empty() {
Expand All @@ -616,6 +613,10 @@ impl Config {
let value = parse_value(value)?;
config.write().set_use_tools(value);
}
"agent_prelude" => {
let value = parse_value(value)?;
config.write().set_agent_prelude(value);
}
"save_session" => {
let value = parse_value(value)?;
config.write().set_save_session(value);
Expand All @@ -624,6 +625,14 @@ impl Config {
let value = parse_value(value)?;
config.write().set_compress_threshold(value);
}
"rag_reranker_model" => {
let value = parse_value(value)?;
Self::set_rag_reranker_model(config, value)?;
}
"rag_top_k" => {
let value = value.parse().with_context(|| "Invalid value")?;
Self::set_rag_top_k(config, value)?;
}
"highlight" => {
let value = value.parse().with_context(|| "Invalid value")?;
config.write().highlight = value;
Expand All @@ -638,7 +647,7 @@ impl Config {
"roles" => (Self::roles_dir()?, Some(".md")),
"sessions" => (config.read().sessions_dir()?, Some(".yaml")),
"rags" => (Self::rags_dir()?, Some(".yaml")),
"agents-config" => (Self::agents_config_dir()?, None),
"agents" => (Self::agents_config_dir()?, None),
_ => bail!("Unknown kind '{kind}'"),
};
let names = match read_dir(&dir) {
Expand Down Expand Up @@ -722,6 +731,13 @@ impl Config {
}
}

pub fn set_agent_prelude(&mut self, value: Option<String>) {
match self.agent.as_mut() {
Some(agent) => agent.set_agent_prelude(value),
None => self.agent_prelude = value,
}
}

pub fn set_save_session(&mut self, value: Option<bool>) {
if let Some(session) = self.session.as_mut() {
session.set_save_session(value);
Expand Down Expand Up @@ -1269,13 +1285,9 @@ impl Config {
bail!("Already in a agent, please run '.exit agent' first to exit the current agent.");
}
let agent = Agent::init(config, name, abort_signal).await?;
let session = session.map(|v| v.to_string()).or_else(|| {
agent
.agent_prelude()
.map(|v| v.to_string())
.or_else(|| config.read().agent_prelude.clone())
.and_then(|v| if v.is_empty() { None } else { Some(v) })
});
let session = session
.map(|v| v.to_string())
.or_else(|| agent.agent_prelude().map(|v| v.to_string()));
config.write().rag = agent.rag();
config.write().agent = Some(agent);
if let Some(session) = session {
Expand Down Expand Up @@ -1314,6 +1326,14 @@ impl Config {
Ok(())
}

pub fn save_agent_config(&mut self) -> Result<()> {
let agent = match &self.agent {
Some(v) => v,
None => bail!("No agent"),
};
agent.save_config()
}

pub fn exit_agent(&mut self) -> Result<()> {
self.exit_session()?;
if self.agent.take().is_some() {
Expand Down Expand Up @@ -1472,10 +1492,11 @@ impl Config {
"dry_run",
"stream",
"save",
"save_session",
"compress_threshold",
"function_calling",
"use_tools",
"agent_prelude",
"save_session",
"compress_threshold",
"rag_reranker_model",
"rag_top_k",
"highlight",
Expand All @@ -1486,9 +1507,7 @@ impl Config {
.map(|v| (format!("{v} "), None))
.collect()
}
".delete" => {
map_completion_values(vec!["roles", "sessions", "rags", "agents-config"])
}
".delete" => map_completion_values(vec!["roles", "sessions", "rags", "agents"]),
_ => vec![],
};
filter = args[0]
Expand All @@ -1501,14 +1520,6 @@ impl Config {
"dry_run" => complete_bool(self.dry_run),
"stream" => complete_bool(self.stream),
"save" => complete_bool(self.save),
"save_session" => {
let save_session = if let Some(session) = &self.session {
session.save_session()
} else {
self.save_session
};
complete_option_bool(save_session)
}
"function_calling" => complete_bool(self.function_calling),
"use_tools" => {
let mut prefix = String::new();
Expand All @@ -1529,6 +1540,14 @@ impl Config {
.map(|v| format!("{prefix}{v}"))
.collect()
}
"save_session" => {
let save_session = if let Some(session) = &self.session {
session.save_session()
} else {
self.save_session
};
complete_option_bool(save_session)
}
"rag_reranker_model" => list_reranker_models(self).iter().map(|v| v.id()).collect(),
"highlight" => complete_bool(self.highlight),
_ => vec![],
Expand Down Expand Up @@ -1840,6 +1859,18 @@ impl Config {
self.wrap_code = v;
}

if let Some(Some(v)) = read_env_bool("function_calling") {
self.function_calling = v;
}
if let Ok(v) = env::var(get_env_name("mapping_tools")) {
if let Ok(v) = serde_json::from_str(&v) {
self.mapping_tools = v;
}
}
if let Some(v) = read_env_value::<String>("use_tools") {
self.use_tools = v;
}

if let Some(v) = read_env_value::<String>("prelude") {
self.prelude = v;
}
Expand All @@ -1863,18 +1894,6 @@ impl Config {
self.summary_prompt = v;
}

if let Some(Some(v)) = read_env_bool("function_calling") {
self.function_calling = v;
}
if let Ok(v) = env::var(get_env_name("mapping_tools")) {
if let Ok(v) = serde_json::from_str(&v) {
self.mapping_tools = v;
}
}
if let Some(v) = read_env_value::<String>("use_tools") {
self.use_tools = v;
}

if let Some(v) = read_env_value::<String>("rag_embedding_model") {
self.rag_embedding_model = v;
}
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ async fn create_input(
}

fn setup_logger(is_serve: bool) -> Result<()> {
let (log_level, log_path) = Config::log(is_serve)?;
let (log_level, log_path) = Config::log_config(is_serve)?;
if log_level == LevelFilter::Off {
return Ok(());
}
Expand Down
Loading