From 9e0e2c922c6c98012bda2617ab71a408f70a4263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joaqu=C3=ADn=20Rosales?= Date: Wed, 15 May 2024 20:26:19 -0600 Subject: [PATCH] feat: add function to determine if a statement is a query operation --- catalyst-gateway/bin/src/event_db/mod.rs | 34 ++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/catalyst-gateway/bin/src/event_db/mod.rs b/catalyst-gateway/bin/src/event_db/mod.rs index ff10aebcd4..31dc0421d4 100644 --- a/catalyst-gateway/bin/src/event_db/mod.rs +++ b/catalyst-gateway/bin/src/event_db/mod.rs @@ -4,6 +4,7 @@ use std::{str::FromStr, sync::Arc}; use bb8::Pool; use bb8_postgres::PostgresConnectionManager; use dotenvy::dotenv; +use stringzilla::StringZilla; use tokio::sync::RwLock; use tokio_postgres::{types::ToSql, NoTls, Row}; use tracing::{debug, debug_span, Instrument}; @@ -211,3 +212,36 @@ pub(crate) async fn establish_connection(url: Option) -> anyhow::Result< inspection_settings: Arc::new(RwLock::new(DatabaseInspectionSettings::default())), }) } + +/// Determine if the statement is a query statement. +/// +/// If the query statement starts with `SELECT` or contains `RETURNING`, then it is a +/// query. +#[allow(dead_code)] +fn is_query_stmt(stmt: &str) -> bool { + matches!( + (stmt.sz_find("SELECT"), stmt.sz_find("RETURNING"),), + (Some(0), _) | (_, Some(_)), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_query_statement() { + let stmt = "SELECT * FROM dummy"; + assert!(is_query_stmt(stmt)); + let stmt = "UPDATE dummy SET foo = $1 WHERE bar = $2 RETURNING *"; + assert!(is_query_stmt(stmt)); + } + + #[test] + fn test_is_not_query_statement() { + let stmt = "UPDATE dummy SET foo_count = foo_count + 1 WHERE bar = (SELECT bar_id FROM foos WHERE name = 'FooBar')"; + assert!(!is_query_stmt(stmt)); + let stmt = "UPDATE dummy SET foo = $1 WHERE bar = $2"; + assert!(!is_query_stmt(stmt)); + } +}