Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pgmq_read_with_poll #64

Merged
merged 13 commits into from
Aug 18, 2023
56 changes: 52 additions & 4 deletions core/sqlx-data.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
"type_info": "Int4"
},
{
"name": "vt",
"name": "enqueued_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "enqueued_at",
"name": "vt",
"ordinal": 3,
"type_info": "Timestamptz"
},
Expand Down Expand Up @@ -220,12 +220,12 @@
"type_info": "Int4"
},
{
"name": "vt",
"name": "enqueued_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "enqueued_at",
"name": "vt",
"ordinal": 3,
"type_info": "Timestamptz"
},
Expand Down Expand Up @@ -273,6 +273,54 @@
},
"query": "SELECT pgmq_send as msg_id from pgmq_send($1::text, $2::jsonb);"
},
"e4c38347b44aed05aa890d3351a362d3b6f81387e98fc564ec922cefa1e96f71": {
"describe": {
"columns": [
{
"name": "msg_id",
"ordinal": 0,
"type_info": "Int8"
},
{
"name": "read_ct",
"ordinal": 1,
"type_info": "Int4"
},
{
"name": "enqueued_at",
"ordinal": 2,
"type_info": "Timestamptz"
},
{
"name": "vt",
"ordinal": 3,
"type_info": "Timestamptz"
},
{
"name": "message",
"ordinal": 4,
"type_info": "Jsonb"
}
],
"nullable": [
null,
null,
null,
null,
null
],
"parameters": {
"Left": [
"Text",
"Int4",
"Int4",
"Int4",
"Int4"
]
}
},
"query": "SELECT * from pgmq_read_with_poll($1::text, $2, $3, $4, $5)"
},
"ed8b7aacd0d94fe647899b6d2fe61a29372cd7d6dbc28bf59ac6bb3118e3fe6c": {
"describe": {
"columns": [
Expand Down
49 changes: 49 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,12 @@ pub mod util;
pub use pg_ext::PGMQueueExt;
use util::fetch_one_message;

use std::time::Duration;

const VT_DEFAULT: i32 = 30;
const READ_LIMIT_DEFAULT: i32 = 1;
const POLL_TIMEOUT_DEFAULT: Duration = Duration::from_secs(5);
const POLL_INTERVAL_DEFAULT: Duration = Duration::from_millis(250);

/// Message struct received from the queue
///
Expand Down Expand Up @@ -627,6 +631,51 @@ impl PGMQueue {
Ok(messages)
}

/// Similar to [`read_batch`], but allows waiting until a message is available
///
/// You can specify a maximum duration for polling (defaults to 5 seconds),
/// and an interval between calls (defaults to 250ms). A lower interval
/// implies higher maximum latency, but less load on the database.
///
/// Refer to the [`read_batch`] function for more details.
///
pub async fn read_batch_with_poll<T: for<'de> Deserialize<'de>>(
&self,
queue_name: &str,
vt: Option<i32>,
max_batch_size: i32,
poll_timeout: Option<Duration>,
poll_interval: Option<Duration>,
) -> Result<Option<Vec<Message<T>>>, errors::PgmqError> {
let vt_ = vt.unwrap_or(VT_DEFAULT);
let poll_timeout_ = poll_timeout.unwrap_or(POLL_TIMEOUT_DEFAULT);
let poll_interval_ = poll_interval.unwrap_or(POLL_INTERVAL_DEFAULT);
let start_time = std::time::Instant::now();
loop {
let query = &query::read(queue_name, vt_, max_batch_size)?;
let messages = fetch_messages::<T>(query, &self.connection).await?;
// Why can fetch_messages return both `None` and `Option(())`
v0idpwn marked this conversation as resolved.
Show resolved Hide resolved
match messages {
None => {
if start_time.elapsed() < poll_timeout_ {
tokio::time::sleep(poll_interval_).await;
continue;
} else {
break Ok(None);
}
}
Some(m) => {
if m.is_empty() && start_time.elapsed() < poll_timeout_ {
tokio::time::sleep(poll_interval_).await;
continue;
} else {
break Ok(Some(m));
}
}
}
v0idpwn marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Delete a message from the queue.
/// This is a permanent delete and cannot be undone. If you want to retain a log of the message,
/// use the [archive](#method.archive) method.
Expand Down
54 changes: 54 additions & 0 deletions core/src/pg_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use serde::{Deserialize, Serialize};
use sqlx::types::chrono::Utc;
use sqlx::{Executor, Pool, Postgres};

const DEFAULT_POLL_TIMEOUT_S: i32 = 5;
const DEFAULT_POLL_INTERVAL_MS: i32 = 250;

/// Main controller for interacting with a managed by the PGMQ Postgres extension.
#[derive(Clone, Debug)]
pub struct PGMQueueExt {
Expand Down Expand Up @@ -194,6 +197,57 @@ impl PGMQueueExt {
}
}

pub async fn read_batch_with_poll<T: for<'de> Deserialize<'de>>(
&self,
queue_name: &str,
vt: i32,
max_batch_size: i32,
poll_timeout: Option<std::time::Duration>,
poll_interval: Option<std::time::Duration>,
) -> Result<Option<Vec<Message<T>>>, PgmqError> {
check_input(queue_name)?;
let poll_timeout_s = poll_timeout.map_or(DEFAULT_POLL_TIMEOUT_S, |t| t.as_secs() as i32);
let poll_interval_ms =
poll_interval.map_or(DEFAULT_POLL_INTERVAL_MS, |i| i.as_millis() as i32);
let result = sqlx::query!(
"SELECT * from pgmq_read_with_poll($1::text, $2, $3, $4, $5)",
queue_name,
vt,
max_batch_size,
poll_timeout_s,
poll_interval_ms
)
.fetch_all(&self.connection)
.await;

match result {
Err(sqlx::error::Error::RowNotFound) => Ok(None),
Err(e) => Err(e)?,
Ok(rows) => {
// happy path - successfully read messages
let mut messages: Vec<Message<T>> = Vec::new();
for row in rows.iter() {
let raw_msg = row.message.clone().expect("no message");
let parsed_msg = serde_json::from_value::<T>(raw_msg);
if let Err(e) = parsed_msg {
return Err(PgmqError::JsonParsingError(e));
} else if let Ok(parsed_msg) = parsed_msg {
messages.push(Message {
msg_id: row.msg_id.expect("msg_id missing from queue table"),
vt: row.vt.expect("vt missing from queue table"),
read_ct: row.read_ct.expect("read_ct missing from queue table"),
enqueued_at: row
.enqueued_at
.expect("enqueued_at missing from queue table"),
message: parsed_msg,
})
}
}
Ok(Some(messages))
}
}
}

/// Move a message to the archive table.
pub async fn archive(&self, queue_name: &str, msg_id: i64) -> Result<bool, PgmqError> {
check_input(queue_name)?;
Expand Down
17 changes: 9 additions & 8 deletions core/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Query constructors

use crate::{errors::PgmqError, util::CheckedName};

use sqlx::types::chrono::Utc;
pub const TABLE_PREFIX: &str = r#"pgmq"#;
pub const PGMQ_SCHEMA: &str = "public";
Expand Down Expand Up @@ -234,14 +235,14 @@ pub fn read(name: &str, vt: i32, limit: i32) -> Result<String, PgmqError> {
(
SELECT msg_id
FROM {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name}
WHERE vt <= now()
WHERE vt <= clock_timestamp()
Copy link
Collaborator Author

@v0idpwn v0idpwn Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now() uses transaction start time. This can be pretty useful, but in this case, if we repeatedly query it will always use the same timestamp. I also tried statement_timestamp(), but for some reason while using the SPI I got the same timestamp for different runs 🤔, so I resorted to clock_timestamp()

ORDER BY msg_id ASC
LIMIT {limit}
FOR UPDATE SKIP LOCKED
)
UPDATE {PGMQ_SCHEMA}.{TABLE_PREFIX}_{name}
SET
vt = now() + interval '{vt} seconds',
vt = clock_timestamp() + interval '{vt} seconds',
read_ct = read_ct + 1
WHERE msg_id in (select msg_id from cte)
RETURNING *;
Expand Down Expand Up @@ -354,20 +355,20 @@ pub fn unassign_archive(name: CheckedName<'_>) -> Result<String, PgmqError> {
pub fn assign(table_name: &str) -> String {
format!(
"
DO $$
DO $$
BEGIN
-- Check if the table is not yet associated with the extension
IF NOT EXISTS (
SELECT 1
FROM pg_depend
SELECT 1
FROM pg_depend
WHERE refobjid = (SELECT oid FROM pg_extension WHERE extname = 'pgmq')
AND objid = (SELECT oid FROM pg_class WHERE relname = '{TABLE_PREFIX}_{table_name}')
) THEN

EXECUTE 'ALTER EXTENSION pgmq ADD TABLE {PGMQ_SCHEMA}.{TABLE_PREFIX}_{table_name}';

END IF;

END $$;
"
)
Expand Down
85 changes: 84 additions & 1 deletion core/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,73 @@ async fn test_send_delay() {
assert!(one_messages.is_some());
}

#[tokio::test]
async fn test_read_batch_with_poll() {
let test_queue = "test_read_batch_with_poll".to_owned();

let queue = init_queue(&test_queue).await;

// PUBLISH THREE MESSAGES
let msg = serde_json::json!({
"foo": "bar1"
});
let msg_id1 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id1, 1);
let msg_id2 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id2, 2);
let msg_id3 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id3, 3);

// Reading from queue with a 5 seconds VT
let read_message_1 = queue
.read_batch_with_poll::<Value>(
&test_queue,
Some(5),
5,
Some(std::time::Duration::from_secs(6)),
None,
)
.await
.unwrap()
.unwrap();

assert_eq!(read_message_1.len(), 3);

let starting_time = std::time::Instant::now();

// Since VT is 5 seconds, if we poll the queue, it takes around 5 seconds
// to return the result, and returns all 3 messages
let read_message_2 = queue
.read_batch_with_poll::<Value>(
&test_queue,
Some(5),
5,
Some(std::time::Duration::from_secs(6)),
None,
)
.await
.unwrap()
.unwrap();

assert_eq!(read_message_2.len(), 3);
assert!(starting_time.elapsed() > std::time::Duration::from_secs(3));

// If we don't poll for long enough, we get none
let read_message_3 = queue
.read_batch_with_poll::<Value>(
&test_queue,
Some(3),
5,
Some(std::time::Duration::from_secs(1)),
None,
)
.await
.unwrap()
.unwrap();

assert_eq!(read_message_3.len(), 0);
}

#[tokio::test]
async fn test_read_batch() {
let test_queue = "test_read_batch".to_owned();
Expand Down Expand Up @@ -722,7 +789,7 @@ async fn test_extension_api() {
assert!(msg_id >= 1);

let read_message = queue
.read::<MyMessage>(&test_queue, 100)
.read::<MyMessage>(&test_queue, 5)
.await
.expect("error reading message");
assert!(read_message.is_some());
Expand All @@ -737,6 +804,22 @@ async fn test_extension_api() {
.expect("error reading message");
assert!(read_message.is_none());

// read with poll, blocks until message visible
let read_messages = queue
.read_batch_with_poll::<MyMessage>(
&test_queue,
5,
1,
Some(std::time::Duration::from_secs(6)),
None,
)
.await
.expect("error reading message")
.expect("no message");

assert_eq!(read_messages.len(), 1);
assert_eq!(read_messages[0].msg_id, msg_id);

// change the VT to now
let _vt_set = queue
.set_vt::<MyMessage>(&test_queue, msg_id, 0)
Expand Down
Loading
Loading