diff --git a/crates/trigger/src/runtime_config/sqlite.rs b/crates/trigger/src/runtime_config/sqlite.rs index 6fc7c482e..5163791e9 100644 --- a/crates/trigger/src/runtime_config/sqlite.rs +++ b/crates/trigger/src/runtime_config/sqlite.rs @@ -49,24 +49,27 @@ async fn execute_statements( if statements.is_empty() { return Ok(()); } - let Some(default) = databases.get("default") else { - debug_assert!( - false, - "the 'default' sqlite database should always be available but for some reason was not" - ); - return Ok(()); - }; for m in statements { - if let Some(file) = m.strip_prefix('@') { + if let Some(config) = m.strip_prefix('@') { + let (file, database) = parse_file_and_label(config)?; + let database = databases.get(database).with_context(|| { + format!( + "based on the '@{config}' a registered database named '{database}' was expected but not found. The registered databases are '{:?}'", databases.keys() + ) + })?; let sql = std::fs::read_to_string(file).with_context(|| { format!("could not read file '{file}' containing sql statements") })?; - default + database .execute_batch(&sql) .await .with_context(|| format!("failed to execute sql from file '{file}'"))?; } else { + let Some(default) = databases.get("default") else { + debug_assert!(false, "the 'default' sqlite database should always be available but for some reason was not"); + return Ok(()); + }; default .query(m, Vec::new()) .await @@ -76,6 +79,19 @@ async fn execute_statements( Ok(()) } +/// Parses a @{file:label} sqlite statement +fn parse_file_and_label(config: &str) -> anyhow::Result<(&str, &str)> { + let config = config.trim(); + let (file, label) = match config.split_once(':') { + Some((_, label)) if label.trim().is_empty() => { + anyhow::bail!("database label is empty in the '@{config}' sqlite statement") + } + Some((file, label)) => (file.trim(), label.trim()), + None => (config, "default"), + }; + Ok((file, label)) +} + // Holds deserialized options from a `[sqlite_database.]` runtime config section. #[derive(Clone, Debug, serde::Deserialize)] #[serde(rename_all = "snake_case", tag = "type")] @@ -202,3 +218,23 @@ impl TriggerHooks for SqlitePersistenceMessageHook { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn can_parse_file_and_label() { + let config = "file:label"; + let result = parse_file_and_label(config).unwrap(); + assert_eq!(result, ("file", "label")); + + let config = "file:"; + let result = parse_file_and_label(config); + assert!(result.is_err()); + + let config = "file"; + let result = parse_file_and_label(config).unwrap(); + assert_eq!(result, ("file", "default")); + } +}