diff --git a/Cargo.lock b/Cargo.lock index 393e10f7..365f8486 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1243,10 +1243,10 @@ dependencies = [ [[package]] name = "pgmq" -version = "0.13.1" +version = "0.14.0" dependencies = [ "chrono", - "pgmq 0.15.2", + "pgmq 0.16.0", "pgrx", "pgrx-tests", "rand", @@ -1260,7 +1260,7 @@ dependencies = [ [[package]] name = "pgmq" -version = "0.15.2" +version = "0.16.0" dependencies = [ "chrono", "log", diff --git a/Cargo.toml b/Cargo.toml index bcc9b76f..25e992e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgmq" -version = "0.13.1" +version = "0.14.0" edition = "2021" authors = ["Tembo.io"] description = "Postgres extension for PGMQ" diff --git a/core/Cargo.toml b/core/Cargo.toml index 5da242f7..a5ec068f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgmq" -version = "0.15.2" +version = "0.16.0" edition = "2021" authors = ["Tembo.io"] description = "A distributed message queue for Rust applications, on Postgres." diff --git a/core/sqlx-data.json b/core/sqlx-data.json index 14fabdbb..37dcb775 100644 --- a/core/sqlx-data.json +++ b/core/sqlx-data.json @@ -14,12 +14,12 @@ "type_info": "Int4" }, { - "name": "vt", + "name": "enqueued_at", "ordinal": 2, "type_info": "Timestamptz" }, { - "name": "enqueued_at", + "name": "vt", "ordinal": 3, "type_info": "Timestamptz" }, @@ -220,12 +220,12 @@ "type_info": "Int4" }, { - "name": "vt", + "name": "enqueued_at", "ordinal": 2, "type_info": "Timestamptz" }, { - "name": "enqueued_at", + "name": "vt", "ordinal": 3, "type_info": "Timestamptz" }, @@ -273,6 +273,54 @@ }, "query": "SELECT pgmq_send as msg_id from pgmq_send($1::text, $2::jsonb);" }, + "e4c38347b44aed05aa890d3351a362d3b6f81387e98fc564ec922cefa1e96f71": { + "describe": { + "columns": [ + { + "name": "msg_id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "read_ct", + "ordinal": 1, + "type_info": "Int4" + }, + { + "name": "enqueued_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "vt", + "ordinal": 3, + "type_info": "Timestamptz" + }, + { + "name": "message", + "ordinal": 4, + "type_info": "Jsonb" + } + ], + "nullable": [ + null, + null, + null, + null, + null + ], + "parameters": { + "Left": [ + "Text", + "Int4", + "Int4", + "Int4", + "Int4" + ] + } + }, + "query": "SELECT * from pgmq_read_with_poll($1::text, $2, $3, $4, $5)" + }, "ed8b7aacd0d94fe647899b6d2fe61a29372cd7d6dbc28bf59ac6bb3118e3fe6c": { "describe": { "columns": [ diff --git a/core/src/lib.rs b/core/src/lib.rs index b00273a4..a8523c68 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -163,8 +163,12 @@ pub mod util; pub use pg_ext::PGMQueueExt; use util::fetch_one_message; +use std::time::Duration; + const VT_DEFAULT: i32 = 30; const READ_LIMIT_DEFAULT: i32 = 1; +const POLL_TIMEOUT_DEFAULT: Duration = Duration::from_secs(5); +const POLL_INTERVAL_DEFAULT: Duration = Duration::from_millis(250); /// Message struct received from the queue /// @@ -627,6 +631,45 @@ impl PGMQueue { Ok(messages) } + /// Similar to [`read_batch`], but allows waiting until a message is available + /// + /// You can specify a maximum duration for polling (defaults to 5 seconds), + /// and an interval between calls (defaults to 250ms). A lower interval + /// implies higher maximum latency, but less load on the database. + /// + /// Refer to the [`read_batch`] function for more details. + /// + pub async fn read_batch_with_poll Deserialize<'de>>( + &self, + queue_name: &str, + vt: Option, + max_batch_size: i32, + poll_timeout: Option, + poll_interval: Option, + ) -> Result>>, errors::PgmqError> { + let vt_ = vt.unwrap_or(VT_DEFAULT); + let poll_timeout_ = poll_timeout.unwrap_or(POLL_TIMEOUT_DEFAULT); + let poll_interval_ = poll_interval.unwrap_or(POLL_INTERVAL_DEFAULT); + let start_time = std::time::Instant::now(); + loop { + let query = &query::read(queue_name, vt_, max_batch_size)?; + let messages = fetch_messages::(query, &self.connection).await?; + match messages { + Some(m) => { + break Ok(Some(m)); + } + None => { + if start_time.elapsed() < poll_timeout_ { + tokio::time::sleep(poll_interval_).await; + continue; + } else { + break Ok(None); + } + } + } + } + } + /// Delete a message from the queue. /// This is a permanent delete and cannot be undone. If you want to retain a log of the message, /// use the [archive](#method.archive) method. @@ -886,28 +929,31 @@ async fn fetch_messages Deserialize<'de>>( connection: &Pool, ) -> Result>>, errors::PgmqError> { let mut messages: Vec> = Vec::new(); - let rows: Result, Error> = sqlx::query(query).fetch_all(connection).await; - if let Err(sqlx::error::Error::RowNotFound) = rows { - return Ok(None); - } else if let Err(e) = rows { - return Err(e)?; - } else if let Ok(rows) = rows { - // happy path - successfully read messages - for row in rows.iter() { - let raw_msg = row.get("message"); - let parsed_msg = serde_json::from_value::(raw_msg); - if let Err(e) = parsed_msg { - return Err(errors::PgmqError::JsonParsingError(e)); - } else if let Ok(parsed_msg) = parsed_msg { - messages.push(Message { - msg_id: row.get("msg_id"), - vt: row.get("vt"), - read_ct: row.get("read_ct"), - enqueued_at: row.get("enqueued_at"), - message: parsed_msg, - }) + let result: Result, Error> = sqlx::query(query).fetch_all(connection).await; + match result { + Ok(rows) => { + if rows.is_empty() { + Ok(None) + } else { + // happy path - successfully read messages + for row in rows.iter() { + let raw_msg = row.get("message"); + let parsed_msg = serde_json::from_value::(raw_msg); + if let Err(e) = parsed_msg { + return Err(errors::PgmqError::JsonParsingError(e)); + } else if let Ok(parsed_msg) = parsed_msg { + messages.push(Message { + msg_id: row.get("msg_id"), + vt: row.get("vt"), + read_ct: row.get("read_ct"), + enqueued_at: row.get("enqueued_at"), + message: parsed_msg, + }) + } + } + Ok(Some(messages)) } } + Err(e) => Err(e)?, } - Ok(Some(messages)) } diff --git a/core/src/pg_ext.rs b/core/src/pg_ext.rs index 18e8a7d9..c656bc5a 100644 --- a/core/src/pg_ext.rs +++ b/core/src/pg_ext.rs @@ -7,6 +7,9 @@ use serde::{Deserialize, Serialize}; use sqlx::types::chrono::Utc; use sqlx::{Executor, Pool, Postgres}; +const DEFAULT_POLL_TIMEOUT_S: i32 = 5; +const DEFAULT_POLL_INTERVAL_MS: i32 = 250; + /// Main controller for interacting with a managed by the PGMQ Postgres extension. #[derive(Clone, Debug)] pub struct PGMQueueExt { @@ -194,6 +197,57 @@ impl PGMQueueExt { } } + pub async fn read_batch_with_poll Deserialize<'de>>( + &self, + queue_name: &str, + vt: i32, + max_batch_size: i32, + poll_timeout: Option, + poll_interval: Option, + ) -> Result>>, PgmqError> { + check_input(queue_name)?; + let poll_timeout_s = poll_timeout.map_or(DEFAULT_POLL_TIMEOUT_S, |t| t.as_secs() as i32); + let poll_interval_ms = + poll_interval.map_or(DEFAULT_POLL_INTERVAL_MS, |i| i.as_millis() as i32); + let result = sqlx::query!( + "SELECT * from pgmq_read_with_poll($1::text, $2, $3, $4, $5)", + queue_name, + vt, + max_batch_size, + poll_timeout_s, + poll_interval_ms + ) + .fetch_all(&self.connection) + .await; + + match result { + Err(sqlx::error::Error::RowNotFound) => Ok(None), + Err(e) => Err(e)?, + Ok(rows) => { + // happy path - successfully read messages + let mut messages: Vec> = Vec::new(); + for row in rows.iter() { + let raw_msg = row.message.clone().expect("no message"); + let parsed_msg = serde_json::from_value::(raw_msg); + if let Err(e) = parsed_msg { + return Err(PgmqError::JsonParsingError(e)); + } else if let Ok(parsed_msg) = parsed_msg { + messages.push(Message { + msg_id: row.msg_id.expect("msg_id missing from queue table"), + vt: row.vt.expect("vt missing from queue table"), + read_ct: row.read_ct.expect("read_ct missing from queue table"), + enqueued_at: row + .enqueued_at + .expect("enqueued_at missing from queue table"), + message: parsed_msg, + }) + } + } + Ok(Some(messages)) + } + } + } + /// Move a message to the archive table. pub async fn archive(&self, queue_name: &str, msg_id: i64) -> Result { check_input(queue_name)?; diff --git a/core/src/query.rs b/core/src/query.rs index b35fa058..66681cfe 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -1,6 +1,7 @@ //! Query constructors use crate::{errors::PgmqError, util::CheckedName}; + use sqlx::types::chrono::Utc; pub const TABLE_PREFIX: &str = r#"pgmq"#; pub const PGMQ_SCHEMA: &str = "public"; @@ -234,14 +235,14 @@ pub fn read(name: &str, vt: i32, limit: i32) -> Result { ( SELECT msg_id FROM {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name} - WHERE vt <= now() + WHERE vt <= clock_timestamp() ORDER BY msg_id ASC LIMIT {limit} FOR UPDATE SKIP LOCKED ) UPDATE {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name} SET - vt = now() + interval '{vt} seconds', + vt = clock_timestamp() + interval '{vt} seconds', read_ct = read_ct + 1 WHERE msg_id in (select msg_id from cte) RETURNING *; @@ -354,20 +355,20 @@ pub fn unassign_archive(name: CheckedName<'_>) -> Result { pub fn assign(table_name: &str) -> String { format!( " - DO $$ + DO $$ BEGIN -- Check if the table is not yet associated with the extension IF NOT EXISTS ( - SELECT 1 - FROM pg_depend + SELECT 1 + FROM pg_depend WHERE refobjid = (SELECT oid FROM pg_extension WHERE extname = 'pgmq') AND objid = (SELECT oid FROM pg_class WHERE relname = '{TABLE_PREFIX}_{table_name}') ) THEN - + EXECUTE 'ALTER EXTENSION pgmq ADD TABLE {PGMQ_SCHEMA}.{TABLE_PREFIX}_{table_name}'; - + END IF; - + END $$; " ) diff --git a/core/tests/integration_test.rs b/core/tests/integration_test.rs index cd448b3e..45c23ffd 100644 --- a/core/tests/integration_test.rs +++ b/core/tests/integration_test.rs @@ -231,6 +231,72 @@ async fn test_send_delay() { assert!(one_messages.is_some()); } +#[tokio::test] +async fn test_read_batch_with_poll() { + let test_queue = "test_read_batch_with_poll".to_owned(); + + let queue = init_queue(&test_queue).await; + + // PUBLISH THREE MESSAGES + let msg = serde_json::json!({ + "foo": "bar1" + }); + let msg_id1 = queue.send(&test_queue, &msg).await.unwrap(); + assert_eq!(msg_id1, 1); + let msg_id2 = queue.send(&test_queue, &msg).await.unwrap(); + assert_eq!(msg_id2, 2); + let msg_id3 = queue.send(&test_queue, &msg).await.unwrap(); + assert_eq!(msg_id3, 3); + + // Reading from queue with a 5 seconds VT + let read_message_1 = queue + .read_batch_with_poll::( + &test_queue, + Some(5), + 5, + Some(std::time::Duration::from_secs(6)), + None, + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(read_message_1.len(), 3); + + let starting_time = std::time::Instant::now(); + + // Since VT is 5 seconds, if we poll the queue, it takes around 5 seconds + // to return the result, and returns all 3 messages + let read_message_2 = queue + .read_batch_with_poll::( + &test_queue, + Some(5), + 5, + Some(std::time::Duration::from_secs(6)), + None, + ) + .await + .unwrap() + .unwrap(); + + assert_eq!(read_message_2.len(), 3); + assert!(starting_time.elapsed() > std::time::Duration::from_secs(3)); + + // If we don't poll for long enough, we get none + let read_message_3 = queue + .read_batch_with_poll::( + &test_queue, + Some(3), + 5, + Some(std::time::Duration::from_secs(1)), + None, + ) + .await + .unwrap(); + + assert!(read_message_3.is_none()); +} + #[tokio::test] async fn test_read_batch() { let test_queue = "test_read_batch".to_owned(); @@ -722,7 +788,7 @@ async fn test_extension_api() { assert!(msg_id >= 1); let read_message = queue - .read::(&test_queue, 100) + .read::(&test_queue, 5) .await .expect("error reading message"); assert!(read_message.is_some()); @@ -737,6 +803,22 @@ async fn test_extension_api() { .expect("error reading message"); assert!(read_message.is_none()); + // read with poll, blocks until message visible + let read_messages = queue + .read_batch_with_poll::( + &test_queue, + 5, + 1, + Some(std::time::Duration::from_secs(6)), + None, + ) + .await + .expect("error reading message") + .expect("no message"); + + assert_eq!(read_messages.len(), 1); + assert_eq!(read_messages[0].msg_id, msg_id); + // change the VT to now let _vt_set = queue .set_vt::(&test_queue, msg_id, 0) diff --git a/src/lib.rs b/src/lib.rs index 52c56ff3..ef805d3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,8 @@ use pgmq_crate::query::{ }; use thiserror::Error; +use std::time::Duration; + #[derive(Error, Debug)] pub enum PgmqExtError { #[error("")] @@ -126,6 +128,39 @@ fn pgmq_read( Ok(TableIterator::new(results)) } +#[pg_extern] +fn pgmq_read_with_poll( + queue_name: &str, + vt: i32, + limit: i32, + poll_timeout_s: default!(i32, 5), + poll_interval_ms: default!(i32, 250), +) -> Result< + TableIterator< + 'static, + ( + name!(msg_id, i64), + name!(read_ct, i32), + name!(enqueued_at, TimestampWithTimeZone), + name!(vt, TimestampWithTimeZone), + name!(message, pgrx::JsonB), + ), + >, + spi::Error, +> { + let start_time = std::time::Instant::now(); + let poll_timeout_ms = (poll_timeout_s * 1000) as u128; + loop { + let results = readit(queue_name, vt, limit)?; + if results.is_empty() && start_time.elapsed().as_millis() < poll_timeout_ms { + std::thread::sleep(Duration::from_millis(poll_interval_ms.try_into().unwrap())); + continue; + } else { + break Ok(TableIterator::new(results)); + } + } +} + fn readit( queue_name: &str, vt: i32, diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index ac62b733..a60f4b48 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -63,25 +63,33 @@ async fn test_lifecycle() { .expect("expected message"); assert_eq!(message.msg_id, 1); - // set VT to tomorrow - let query = &format!("SELECT * from pgmq_set_vt('{test_default_queue}', {msg_id}, 84600);"); + // set VT to in 10 seconds + let query = &format!("SELECT * from pgmq_set_vt('{test_default_queue}', {msg_id}, 5);"); let message = fetch_one_message::(query, &conn) .await .expect("failed reading message") .expect("expected message"); assert_eq!(message.msg_id, 1); let now = chrono::offset::Utc::now(); - // closish to 24 hours from now - assert!(message.vt > now + chrono::Duration::seconds(84000)); + // closish to 10 seconds from now + assert!(message.vt > now + chrono::Duration::seconds(4)); - // read again, assert no messages because we just set VT to tomorrow + // read again, assert no messages because we just set VT to the future let query = &format!("SELECT * from pgmq_read('{test_default_queue}', 2, 1);"); let message = fetch_one_message::(query, &conn) .await .expect("failed reading message"); assert!(message.is_none()); - // set VT to now + // read again, now using poll to block until message is ready + let query = &format!("SELECT * from pgmq_read_with_poll('{test_default_queue}', 10, 1, 10);"); + let message = fetch_one_message::(query, &conn) + .await + .expect("failed reading message") + .expect("expected message"); + assert_eq!(message.msg_id, 1); + + // after reading it, set VT to now let query = &format!("SELECT * from pgmq_set_vt('{test_default_queue}', {msg_id}, 0);"); let message = fetch_one_message::(query, &conn) .await