Skip to content

Commit

Permalink
Merge pull request #402 from egraphs-good/yihozhang-fix-repl
Browse files Browse the repository at this point in the history
In REPL, evaluates only when parens are closed
  • Loading branch information
yihozhang authored Jul 31, 2024
2 parents d4accf6 + 4689fac commit e0fd116
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,8 @@ impl EGraph {
self.desugar.parse_program(filename, input)
}

/// Parse and run a program, returning a list of messages.
/// If filename is None, a default name will be provided
pub fn parse_and_run_program(
&mut self,
filename: Option<String>,
Expand Down
106 changes: 96 additions & 10 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,53 @@ struct Args {
max_calls_per_function: usize,
}

// test if the current command should be evaluated
fn should_eval(curr_cmd: &str) -> bool {
let mut count = 0;
let mut indices = curr_cmd.chars();
while let Some(ch) = indices.next() {
match ch {
'(' => count += 1,
')' => {
count -= 1;
// if we have a negative count,
// this means excessive closing parenthesis
// which we would like to throw an error eagerly
if count < 0 {
return true;
}
}
';' => {
// `any` moves the iterator forward until it finds a match
if !indices.any(|ch| ch == '\n') {
return false;
}
}
'"' => {
if !indices.any(|ch| ch == '"') {
return false;
}
}
_ => {}
}
}
count <= 0
}

#[allow(clippy::disallowed_macros)]
fn run_command_in_scripting(egraph: &mut EGraph, command: &str) {
match egraph.parse_and_run_program(None, command) {
Ok(msgs) => {
for msg in msgs {
println!("{msg}");
}
}
Err(err) => {
log::error!("{err}");
}
}
}

#[allow(clippy::disallowed_macros)]
fn main() {
env_logger::Builder::new()
Expand Down Expand Up @@ -79,18 +126,19 @@ fn main() {
log::info!("Welcome to Egglog!");
let mut egraph = mk_egraph();

let mut cmd_buffer = String::new();

for line in BufReader::new(stdin).lines() {
match line {
Ok(line_str) => match egraph.parse_and_run_program(None, &line_str) {
Ok(msgs) => {
for msg in msgs {
println!("{msg}");
}
Ok(line_str) => {
cmd_buffer.push_str(&line_str);
cmd_buffer.push('\n');
// handles multi-line commands
if should_eval(&cmd_buffer) {
run_command_in_scripting(&mut egraph, &cmd_buffer);
cmd_buffer = String::new();
}
Err(err) => {
log::error!("{err}");
}
},
}
Err(err) => {
log::error!("{err}");
std::process::exit(1)
Expand All @@ -102,7 +150,11 @@ fn main() {
}
}

std::process::exit(1)
if !cmd_buffer.is_empty() {
run_command_in_scripting(&mut egraph, &cmd_buffer)
}

std::process::exit(0)
}

for (idx, input) in args.inputs.iter().enumerate() {
Expand Down Expand Up @@ -197,3 +249,37 @@ fn main() {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_should_eval() {
#[rustfmt::skip]
let test_cases = vec![
vec![
"(extract",
"\"1",
")",
"(",
")))",
"\"",
";; )",
")"
],
vec![
"(extract 1) (extract",
"2) (",
"extract 3) (extract 4) ;;;; ("
]];
for test in test_cases {
let mut cmd_buffer = String::new();
for (i, line) in test.iter().enumerate() {
cmd_buffer.push_str(line);
cmd_buffer.push('\n');
assert_eq!(should_eval(&cmd_buffer), i == test.len() - 1);
}
}
}
}

0 comments on commit e0fd116

Please sign in to comment.