diff --git a/core/src/query.rs b/core/src/query.rs index 7ea09b00..6aec4ed0 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -1,12 +1,12 @@ //! Query constructors -use crate::errors::PgmqError; +use crate::{errors::PgmqError, util::CheckedName}; use sqlx::types::chrono::Utc; pub const TABLE_PREFIX: &str = r#"pgmq"#; pub const PGMQ_SCHEMA: &str = "public"; pub fn init_queue(name: &str) -> Result, PgmqError> { - check_input(name)?; + let name = CheckedName::new(name)?; Ok(vec![ create_meta(), create_queue(name)?, @@ -20,7 +20,7 @@ pub fn init_queue(name: &str) -> Result, PgmqError> { } pub fn destroy_queue(name: &str) -> Result, PgmqError> { - check_input(name)?; + let name = CheckedName::new(name)?; Ok(vec![ drop_queue(name)?, delete_queue_index(name)?, @@ -29,8 +29,7 @@ pub fn destroy_queue(name: &str) -> Result, PgmqError> { ]) } -pub fn create_queue(name: &str) -> Result { - check_input(name)?; +pub fn create_queue(name: CheckedName<'_>) -> Result { Ok(format!( " CREATE TABLE IF NOT EXISTS {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name} ( @@ -44,8 +43,7 @@ pub fn create_queue(name: &str) -> Result { )) } -pub fn create_archive(name: &str) -> Result { - check_input(name)?; +pub fn create_archive(name: CheckedName<'_>) -> Result { Ok(format!( " CREATE TABLE IF NOT EXISTS {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name}_archive ( @@ -95,20 +93,17 @@ pub fn grant_pgmon_meta() -> String { } // pg_monitor needs to query queue tables -pub fn grant_pgmon_queue(name: &str) -> Result { - check_input(name)?; +pub fn grant_pgmon_queue(name: CheckedName<'_>) -> Result { let table = format!("{PGMQ_SCHEMA}.{TABLE_PREFIX}_{name}"); Ok(grant_stmt(&table)) } -pub fn grant_pgmon_queue_seq(name: &str) -> Result { - check_input(name)?; +pub fn grant_pgmon_queue_seq(name: CheckedName<'_>) -> Result { let table = format!("{PGMQ_SCHEMA}.{TABLE_PREFIX}_{name}_msg_id_seq"); Ok(grant_stmt(&table)) } -pub fn drop_queue(name: &str) -> Result { - check_input(name)?; +pub fn drop_queue(name: CheckedName<'_>) -> Result { Ok(format!( " DROP TABLE IF EXISTS {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name}; @@ -116,8 +111,7 @@ pub fn drop_queue(name: &str) -> Result { )) } -pub fn delete_queue_index(name: &str) -> Result { - check_input(name)?; +pub fn delete_queue_index(name: CheckedName<'_>) -> Result { Ok(format!( " DROP INDEX IF EXISTS {TABLE_PREFIX}_{name}.vt_idx_{name}; @@ -125,8 +119,7 @@ pub fn delete_queue_index(name: &str) -> Result { )) } -pub fn delete_queue_metadata(name: &str) -> Result { - check_input(name)?; +pub fn delete_queue_metadata(name: CheckedName<'_>) -> Result { Ok(format!( " DO $$ @@ -145,8 +138,7 @@ pub fn delete_queue_metadata(name: &str) -> Result { )) } -pub fn drop_queue_archive(name: &str) -> Result { - check_input(name)?; +pub fn drop_queue_archive(name: CheckedName<'_>) -> Result { Ok(format!( " DROP TABLE IF EXISTS {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name}_archive; @@ -154,8 +146,7 @@ pub fn drop_queue_archive(name: &str) -> Result { )) } -pub fn insert_meta(name: &str) -> Result { - check_input(name)?; +pub fn insert_meta(name: CheckedName<'_>) -> Result { Ok(format!( " INSERT INTO {PGMQ_SCHEMA}.{TABLE_PREFIX}_meta (queue_name) @@ -166,8 +157,7 @@ pub fn insert_meta(name: &str) -> Result { )) } -pub fn create_archive_index(name: &str) -> Result { - check_input(name)?; +pub fn create_archive_index(name: CheckedName<'_>) -> Result { Ok(format!( " CREATE INDEX IF NOT EXISTS deleted_at_idx_{name} ON {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name}_archive (deleted_at); @@ -176,8 +166,7 @@ pub fn create_archive_index(name: &str) -> Result { } // indexes are created ascending to support FIFO -pub fn create_index(name: &str) -> Result { - check_input(name)?; +pub fn create_index(name: CheckedName<'_>) -> Result { Ok(format!( " CREATE INDEX IF NOT EXISTS msg_id_vt_idx_{name} ON {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name} (vt ASC, msg_id ASC); @@ -192,7 +181,7 @@ pub fn enqueue( ) -> Result { // construct string of comma separated messages check_input(name)?; - let mut values: String = "".to_owned(); + let mut values = "".to_owned(); for message in messages.iter() { let full_msg = format!( "((now() at time zone 'utc' + interval '{delay} seconds'), '{message}'::json)," @@ -330,7 +319,8 @@ mod tests { #[test] fn test_create() { - let query = create_queue("yolo"); + let queue_name = CheckedName::new("yolo").unwrap(); + let query = create_queue(queue_name); assert!(query.unwrap().contains("pgmq_yolo")); } diff --git a/core/src/util.rs b/core/src/util.rs index f6d40b9c..06a91416 100644 --- a/core/src/util.rs +++ b/core/src/util.rs @@ -1,3 +1,6 @@ +use std::fmt::Display; + +use crate::query::check_input; use crate::{Message, PgmqError}; use log::LevelFilter; use serde::Deserialize; @@ -63,3 +66,28 @@ pub async fn fetch_one_message Deserialize<'de>>( Err(e) => Err(e)?, } } + +/// A string that is known to be formed of only ASCII alphanumeric or an underscore; +#[derive(Clone, Copy)] +pub struct CheckedName<'a>(&'a str); + +impl<'a> CheckedName<'a> { + /// Accepts `input` as a CheckedName if it is a valid queue identifier + pub fn new(input: &'a str) -> Result { + check_input(input)?; + + Ok(Self(input)) + } +} + +impl AsRef for CheckedName<'_> { + fn as_ref(&self) -> &str { + self.0 + } +} + +impl Display for CheckedName<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} diff --git a/src/partition.rs b/src/partition.rs index 011d5b69..55880c9b 100644 --- a/src/partition.rs +++ b/src/partition.rs @@ -4,9 +4,10 @@ use pgrx::prelude::*; use pgmq_crate::{ errors::PgmqError, query::{ - check_input, create_archive, create_index, create_meta, grant_pgmon_meta, - grant_pgmon_queue, grant_pgmon_queue_seq, insert_meta, PGMQ_SCHEMA, TABLE_PREFIX, + create_archive, create_index, create_meta, grant_pgmon_meta, grant_pgmon_queue, + grant_pgmon_queue_seq, insert_meta, PGMQ_SCHEMA, TABLE_PREFIX, }, + util::CheckedName, }; // for now, put pg_partman in the public PGMQ_SCHEMA @@ -17,7 +18,7 @@ pub fn init_partitioned_queue( partition_interval: &str, retention_interval: &str, ) -> Result, PgmqError> { - check_input(name)?; + let name = CheckedName::new(name)?; let partition_col = map_partition_col(partition_interval); Ok(vec![ create_meta(), @@ -44,8 +45,10 @@ fn map_partition_col(partition_interval: &str) -> &'static str { } } -fn create_partitioned_queue(queue: &str, partition_col: &str) -> Result { - check_input(queue)?; +fn create_partitioned_queue( + queue: CheckedName<'_>, + partition_col: &str, +) -> Result { Ok(format!( " CREATE TABLE IF NOT EXISTS {PGMQ_SCHEMA}.{TABLE_PREFIX}_{queue} ( @@ -59,8 +62,10 @@ fn create_partitioned_queue(queue: &str, partition_col: &str) -> Result Result { - check_input(queue)?; +pub fn create_partitioned_index( + queue: CheckedName<'_>, + partiton_col: &str, +) -> Result { Ok(format!( " CREATE INDEX IF NOT EXISTS pgmq_partition_idx_{queue} ON {PGMQ_SCHEMA}.{TABLE_PREFIX}_{queue} ({partiton_col}); @@ -69,7 +74,7 @@ pub fn create_partitioned_index(queue: &str, partiton_col: &str) -> Result, partition_col: &str, partition_interval: &str, ) -> Result { @@ -86,8 +91,7 @@ fn create_partitioned_table( // messages .archived() will be retained forever on the `_archive` table // https://github.com/pgpartman/pg_partman/blob/ca212077f66af19c0ca317c206091cd31d3108b8/doc/pg_partman.md#retention // integer value will set that any partitions with an id value less than the current maximum id value minus the retention value will be dropped -fn set_retention_config(queue: &str, retention: &str) -> Result { - check_input(queue)?; +fn set_retention_config(queue: CheckedName<'_>, retention: &str) -> Result { Ok(format!( " UPDATE {PGMQ_SCHEMA}.part_config