From 1e953efe654882f87cc1f47d0ec20374650ab76f Mon Sep 17 00:00:00 2001 From: Ali Tavallaie Date: Sat, 28 Sep 2024 20:31:51 +0330 Subject: [PATCH] pgmq-python: adding support for Transaction (#268) * feat: adding transaction as decorator * feat: test for transactions * add logger * chore: linting * feat: successfull transaction operation * chore: linting and formatting * feat: adding better logger and optional for verbose * feat: update readme for transaction * feat: support for transaction: - adding support for logging in sync and async functions - adding support for transaction in sync and asyns operations - adding uint tests - separated module for transaction decorators to avoid conflicts in unit tests - updating readme * feat:remove perform_transaction * feat: adding example for transaction * feat: update readme for using transactions * chore: linting * chore: remove unused tnx variable * feat: update examples for non-db and non-pgmq * chore: remove extra space in README * feat: complete async example app * chore: fixing some python code intention within README --- tembo-pgmq-python/README.md | 64 ++- .../example/example_app_async.py | 160 ++++++ tembo-pgmq-python/example/example_app_sync.py | 255 +++++++++ .../tembo_pgmq_python/__init__.py | 3 +- .../tembo_pgmq_python/async_queue.py | 500 +++++++++++++----- .../tembo_pgmq_python/decorators.py | 69 +++ tembo-pgmq-python/tembo_pgmq_python/queue.py | 270 ++++++---- .../tests/test_async_integration.py | 56 +- tembo-pgmq-python/tests/test_integration.py | 72 ++- 9 files changed, 1201 insertions(+), 248 deletions(-) create mode 100644 tembo-pgmq-python/example/example_app_async.py create mode 100644 tembo-pgmq-python/example/example_app_sync.py create mode 100644 tembo-pgmq-python/tembo_pgmq_python/decorators.py diff --git a/tembo-pgmq-python/README.md b/tembo-pgmq-python/README.md index 77205759..74ffd902 100644 --- a/tembo-pgmq-python/README.md +++ b/tembo-pgmq-python/README.md @@ -8,16 +8,15 @@ Install with `pip` from pypi.org: pip install tembo-pgmq-python ``` -In order to use async version install with the optional dependecies: +To use the async version, install with the optional dependencies: -``` bash +```bash pip install tembo-pgmq-python[async] ``` - Dependencies: -Postgres running the [Tembo PGMQ extension](https://github.com/tembo-io/tembo/tree/main/pgmq). +- Postgres running the [Tembo PGMQ extension](https://github.com/tembo-io/tembo/tree/main/pgmq). ## Usage @@ -51,7 +50,7 @@ queue = PGMQueue() Initialization for the async version requires an explicit call of the initializer: -``` bash +```python from tembo_pgmq_python.async_queue import PGMQueue async def main(): @@ -81,11 +80,12 @@ queue = PGMQueue( queue.create_queue("my_queue") ``` -### or a partitioned queue +### Or create a partitioned queue ```python queue.create_partitioned_queue("my_partitioned_queue", partition_interval=10000) ``` + ### List all queues ```python @@ -128,14 +128,18 @@ The `read_with_poll` method allows you to repeatedly check for messages in the q In the following example, the method will check for up to 5 messages in the queue `my_queue`, making the messages invisible for 30 seconds (`vt`), and will poll for a maximum of 5 seconds (`max_poll_seconds`) with intervals of 100 milliseconds (`poll_interval_ms`) between checks. ```python -read_messages: list[Message] = queue.read_with_poll("my_queue", vt=30, qty=5, max_poll_seconds=5, poll_interval_ms=100) +read_messages: list[Message] = queue.read_with_poll( + "my_queue", vt=30, qty=5, max_poll_seconds=5, poll_interval_ms=100 +) for message in read_messages: print(message) ``` This method will continue polling until it either finds the specified number of messages (`qty`) or the `max_poll_seconds` duration is reached. The `poll_interval_ms` parameter controls the interval between successive polls, allowing you to avoid hammering the database with continuous queries. -### Archive the message after we're done with it. Archived messages are moved to an archive table +### Archive the message after we're done with it + +Archived messages are moved to an archive table. ```python archived: bool = queue.archive("my_queue", read_message.msg_id) @@ -238,5 +242,49 @@ for metrics in all_metrics: print(f"Scrape time: {metrics.scrape_time}") ``` +### Optional Logging Configuration +You can enable verbose logging and specify a custom log filename. +```python +queue = PGMQueue( + host="0.0.0.0", + port="5432", + username="postgres", + password="postgres", + database="postgres", + verbose=True, + log_filename="my_custom_log.log" +) +``` + +# Using Transactions + +To perform multiple operations within a single transaction, use the `@transaction` decorator from the `tembo_pgmq_python.decorators` module. +This ensures that all operations within the function are executed within the same transaction and are either committed together or rolled back if an error occurs. + +First, import the transaction decorator: + +```python +from tembo_pgmq_python.decorators import transaction +``` + +### Example: Transactional Operation + +```python +@transaction +def transactional_operation(queue: PGMQueue, conn=None): + # Perform multiple queue operations within a transaction + queue.create_queue("transactional_queue", conn=conn) + queue.send("transactional_queue", {"message": "Hello, World!"}, conn=conn) + +``` +To execute the transaction: + +```python +try: + transactional_operation(queue) +except Exception as e: + print(f"Transaction failed: {e}") +``` +In this example, the transactional_operation function is decorated with `@transaction`, ensuring all operations inside it are part of a single transaction. If an error occurs, the entire transaction is rolled back automatically. \ No newline at end of file diff --git a/tembo-pgmq-python/example/example_app_async.py b/tembo-pgmq-python/example/example_app_async.py new file mode 100644 index 00000000..296b23c4 --- /dev/null +++ b/tembo-pgmq-python/example/example_app_async.py @@ -0,0 +1,160 @@ +import asyncio +from tembo_pgmq_python.async_queue import PGMQueue +from tembo_pgmq_python.decorators import async_transaction as transaction + + +async def main(): + # Initialize the queue + queue = PGMQueue( + host="0.0.0.0", + port="5432", + username="postgres", + password="postgres", + database="postgres", + verbose=True, + log_filename="pgmq_async.log", + ) + await queue.init() + + test_queue = "transactional_queue_async" + + # Clean up if the queue already exists + queues = await queue.list_queues() + if test_queue in queues: + await queue.drop_queue(test_queue) + await queue.create_queue(test_queue) + + # Example messages + message1 = {"id": 1, "content": "First message"} + message2 = {"id": 2, "content": "Second message"} + + # Transactional operation: send messages within a transaction + @transaction + async def transactional_operation(queue: PGMQueue, conn=None): + # Perform multiple queue operations within a transaction + await queue.send(test_queue, message1, conn=conn) + await queue.send(test_queue, message2, conn=conn) + # Transaction commits if no exception occurs + + # Execute the transactional function (Success Case) + try: + await transactional_operation(queue) + print("Transaction committed successfully.") + except Exception as e: + print(f"Transaction failed: {e}") + + # Read messages outside of the transaction + read_message1 = await queue.read(test_queue) + read_message2 = await queue.read(test_queue) + print("Messages read after transaction commit:") + if read_message1: + print(f"Message 1: {read_message1.message}") + if read_message2: + print(f"Message 2: {read_message2.message}") + + # Purge the queue for the failure case + await queue.purge(test_queue) + + # Transactional operation: simulate failure + @transaction + async def transactional_operation_failure(queue: PGMQueue, conn=None): + await queue.send(test_queue, message1, conn=conn) + await queue.send(test_queue, message2, conn=conn) + # Simulate an error to trigger rollback + raise Exception("Simulated failure") + + # Execute the transactional function (Failure Case) + try: + await transactional_operation_failure(queue) + except Exception as e: + print(f"Transaction failed: {e}") + + # Attempt to read messages after failed transaction + read_message = await queue.read(test_queue) + if read_message: + print("Message read after failed transaction (should not exist):") + print(read_message.message) + else: + print("No messages found after transaction rollback.") + + # Simulate conditional rollback + await queue.purge(test_queue) # Clear the queue before the next test + + @transaction + async def conditional_failure(queue: PGMQueue, conn=None): + # Send messages + msg_ids = await queue.send_batch(test_queue, [message1, message2], conn=conn) + print(f"Messages sent with IDs: {msg_ids}") + messages_in_queue = await queue.read_batch(test_queue, batch_size=10, conn=conn) + print( + f"Messages currently in queue before conditional failure: {messages_in_queue}" + ) + + # Conditional rollback based on number of messages + if len(messages_in_queue) > 3: + await queue.delete( + test_queue, msg_id=messages_in_queue[0].msg_id, conn=conn + ) + print( + f"Message ID {messages_in_queue[0].msg_id} deleted within transaction." + ) + else: + # Simulate failure if queue size is not greater than 3 + print( + "Transaction failed: Not enough messages in queue to proceed with deletion." + ) + raise Exception("Queue size too small to proceed.") + + print("\n=== Executing Conditional Failure Scenario ===") + try: + await conditional_failure(queue) + except Exception as e: + print(f"Conditional Failure Transaction failed: {e}") + + # Simulate success for conditional scenario + @transaction + async def conditional_success(queue: PGMQueue, conn=None): + # Send additional messages to ensure queue has more than 3 messages + additional_messages = [ + {"id": 3, "content": "Third message"}, + {"id": 4, "content": "Fourth message"}, + ] + msg_ids = await queue.send_batch(test_queue, additional_messages, conn=conn) + print(f"Additional messages sent with IDs: {msg_ids}") + + # Read messages in queue + messages_in_queue = await queue.read_batch(test_queue, batch_size=10, conn=conn) + print( + f"Messages currently in queue before successful conditional deletion: {messages_in_queue}" + ) + + if len(messages_in_queue) > 3: + await queue.delete( + test_queue, msg_id=messages_in_queue[0].msg_id, conn=conn + ) + print( + f"Message ID {messages_in_queue[0].msg_id} deleted within transaction." + ) + + print("\n=== Executing Conditional Success Scenario ===") + try: + await conditional_success(queue) + except Exception as e: + print(f"Conditional Success Transaction failed: {e}") + + # Read messages after the conditional scenarios + read_messages = await queue.read_batch(test_queue, batch_size=10) + if read_messages: + print("Messages read after conditional scenarios:") + for msg in read_messages: + print(f"ID: {msg.msg_id}, Content: {msg.message}") + else: + print("No messages found after transactions.") + + await queue.drop_queue(test_queue) + await queue.pool.close() + + +# Run the main function +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tembo-pgmq-python/example/example_app_sync.py b/tembo-pgmq-python/example/example_app_sync.py new file mode 100644 index 00000000..29ec47fb --- /dev/null +++ b/tembo-pgmq-python/example/example_app_sync.py @@ -0,0 +1,255 @@ +from tembo_pgmq_python.queue import PGMQueue +from tembo_pgmq_python.decorators import transaction + +queue = PGMQueue( + host="localhost", + port="5432", + username="postgres", + password="postgres", + database="postgres", + verbose=True, + log_filename="pgmq_sync.log", +) + +test_queue = "transaction_queue_sync" + +# Clean up if the queue already exists +queues = queue.list_queues() +if test_queue in queues: + queue.drop_queue(test_queue) +queue.create_queue(test_queue) + +# Example messages +messages = [ + {"id": 1, "content": "First message"}, + {"id": 2, "content": "Second message"}, + {"id": 3, "content": "Third message"}, +] + + +# Create table function for non-PGMQ DB operation +def create_mytable(conn): + try: + with conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS mytable ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL + ) + """) + print("Table 'mytable' created or already exists.") + except Exception as e: + print(f"Failed to create table 'mytable': {e}") + raise + + +# Transaction with only PGMQ operations +@transaction +def pgmq_operations(queue, conn=None): + # Send multiple messages + msg_ids = queue.send_batch( + test_queue, + messages=messages, + conn=conn, + ) + print(f"PGMQ: Messages sent with IDs: {msg_ids}") + + # Read messages within the transaction + internal_messages = queue.read_batch( + test_queue, + batch_size=10, + conn=conn, + ) + print(f"PGMQ: Messages read within transaction: {internal_messages}") + + +# Transaction with non-PGMQ DB operation and PGMQ operation - Success case +@transaction +def non_pgmq_db_operations_success(queue, conn=None): + create_mytable(conn) + + # Non-PGMQ database operation (simulating a custom DB operation) + with conn.cursor() as cur: + cur.execute("INSERT INTO mytable (name) VALUES ('Alice')") + print("Non-PGMQ DB: Inserted into 'mytable'.") + + # Send multiple PGMQ messages + msg_ids = queue.send_batch( + test_queue, + messages=messages, + conn=conn, + ) + print(f"PGMQ: Messages sent with IDs: {msg_ids}") + + +# Transaction with non-PGMQ DB operation and PGMQ operation - Failure case +@transaction +def non_pgmq_db_operations_failure(queue, conn=None): + create_mytable(conn) + + # Non-PGMQ database operation (simulating a custom DB operation) + with conn.cursor() as cur: + cur.execute("INSERT INTO mytable (name) VALUES ('Bob')") + print("Non-PGMQ DB: Inserted into 'mytable'.") + + # Simulating a failure after a PGMQ operation + raise Exception( + "Simulated failure after inserting into mytable and sending messages" + ) + + +# Transaction with PGMQ operations and non-database operation (simple print statement) +@transaction +def non_db_operations(queue, conn=None): + # Send multiple messages + msg_ids = queue.send_batch( + test_queue, + messages=messages, + conn=conn, + ) + print(f"PGMQ: Messages sent with IDs: {msg_ids}") + + # Non-database operation: Print statement + print("Non-DB: Simulating a non-database operation (printing).") + + +# Transaction failure: only delete if queue size is larger than threshold +@transaction +def conditional_failure(queue, conn=None): + # Send multiple messages within the transaction + msg_ids = queue.send_batch( + test_queue, + messages=messages, + conn=conn, + ) + print(f"Messages sent with IDs: {msg_ids}") + + # Read messages currently in the queue within the transaction + messages_in_queue = queue.read_batch( + test_queue, + batch_size=10, + conn=conn, + ) + print( + f"Messages currently in queue before conditional failure: {messages_in_queue}" + ) + + # Simulate a condition: only delete if the queue has more than 3 messages + if len(messages_in_queue) > 3: + queue.delete( + test_queue, + msg_id=messages_in_queue[0].msg_id, + conn=conn, + ) + print(f"Message ID {messages_in_queue[0].msg_id} deleted within transaction.") + else: + # Simulate a failure if the queue size is not greater than 3 + print( + "Transaction failed: Not enough messages in queue to proceed with deletion." + ) + raise Exception("Queue size too small to proceed.") + + print("Transaction completed successfully.") + + +# Transaction success for conditional scenario +@transaction +def conditional_success(queue, conn=None): + # Send additional messages to ensure the queue has more than 3 messages + additional_messages = [ + {"id": 4, "content": "Fourth message"}, + {"id": 5, "content": "Fifth message"}, + ] + msg_ids = queue.send_batch( + test_queue, + messages=additional_messages, + conn=conn, + ) + print(f"Messages sent with IDs: {msg_ids}") + + # Read messages currently in the queue within the transaction + messages_in_queue = queue.read_batch( + test_queue, + batch_size=10, + conn=conn, + ) + print( + f"Messages currently in queue before successful conditional deletion: {messages_in_queue}" + ) + + # Proceed with deletion if more than 3 messages are in the queue + if len(messages_in_queue) > 3: + queue.delete( + test_queue, + msg_id=messages_in_queue[0].msg_id, + conn=conn, + ) + print(f"Message ID {messages_in_queue[0].msg_id} deleted within transaction.") + + print("Conditional success transaction completed.") + + +# Read messages after transaction to see if changes were committed +def read_queue_after_transaction(): + external_messages = queue.read_batch(test_queue, batch_size=10) + if external_messages: + print("Messages read after transaction:") + for msg in external_messages: + print(f"ID: {msg.msg_id}, Content: {msg.message}") + else: + print("No messages found after transaction rollback.") + + +# Execute transactions and handle exceptions +print("=== Executing PGMQ Operations ===") +try: + pgmq_operations(queue) +except Exception as e: + print(f"PGMQ Transaction failed: {e}") + +print("\n=== Executing Non-PGMQ DB and PGMQ Operations (Success Case) ===") +try: + non_pgmq_db_operations_success(queue) +except Exception as e: + print(f"Non-PGMQ DB Transaction failed: {e}") + +print("\n=== Executing Non-PGMQ DB and PGMQ Operations (Failure Case) ===") +try: + non_pgmq_db_operations_failure(queue) +except Exception as e: + print(f"Non-PGMQ DB Transaction failed: {e}") + +print("\n=== Executing Non-DB and PGMQ Operations ===") +try: + non_db_operations(queue) +except Exception as e: + print(f"Non-DB Transaction failed: {e}") + +print("\n=== Reading Queue After Transactions ===") +read_queue_after_transaction() + +# Purge the queue for failure case +queue.purge(test_queue) + +print("\n=== Executing Conditional Failure Scenario ===") +try: + conditional_failure(queue) +except Exception as e: + print(f"Conditional Failure Transaction failed: {e}") +read_queue_after_transaction() + +print("\n=== Executing Conditional Success Scenario ===") +try: + conditional_success(queue) + + +except Exception as e: + print(f"Conditional Success Transaction failed: {e}") + read_queue_after_transaction() + +# Read the queue after the conditional failure and success +print("\n=== Reading Queue After Conditional Scenarios ===") +read_queue_after_transaction() + +# Clean up +queue.drop_queue(test_queue) diff --git a/tembo-pgmq-python/tembo_pgmq_python/__init__.py b/tembo-pgmq-python/tembo_pgmq_python/__init__.py index 58d8ab2a..e02685a7 100644 --- a/tembo-pgmq-python/tembo_pgmq_python/__init__.py +++ b/tembo-pgmq-python/tembo_pgmq_python/__init__.py @@ -1,3 +1,4 @@ from tembo_pgmq_python.queue import Message, PGMQueue # type: ignore +from tembo_pgmq_python.decorators import transaction, async_transaction -__all__ = ["Message", "PGMQueue"] +__all__ = ["Message", "PGMQueue", "transaction", "async_transaction"] diff --git a/tembo-pgmq-python/tembo_pgmq_python/async_queue.py b/tembo-pgmq-python/tembo_pgmq_python/async_queue.py index e5926079..f3137820 100644 --- a/tembo-pgmq-python/tembo_pgmq_python/async_queue.py +++ b/tembo-pgmq-python/tembo_pgmq_python/async_queue.py @@ -1,16 +1,21 @@ +# async_queue.py + from dataclasses import dataclass, field from typing import Optional, List import asyncpg import os +import logging +from datetime import datetime from orjson import dumps, loads from tembo_pgmq_python.messages import Message, QueueMetrics +from tembo_pgmq_python.decorators import async_transaction as transaction @dataclass class PGMQueue: - """Base class for interacting with a queue""" + """Asynchronous PGMQueue client for interacting with queues.""" host: str = field(default_factory=lambda: os.getenv("PG_HOST", "localhost")) port: str = field(default_factory=lambda: os.getenv("PG_PORT", "5432")) @@ -20,7 +25,11 @@ class PGMQueue: delay: int = 0 vt: int = 30 pool_size: int = 10 + perform_transaction: bool = False + verbose: bool = False + log_filename: Optional[str] = None pool: asyncpg.pool.Pool = field(init=False) + logger: logging.Logger = field(init=False) def __post_init__(self) -> None: self.host = self.host or "localhost" @@ -32,110 +41,229 @@ def __post_init__(self) -> None: if not all([self.host, self.port, self.database, self.username, self.password]): raise ValueError("Incomplete database connection information provided.") + self._initialize_logging() + self.logger.debug("PGMQueue initialized") + + def _initialize_logging(self) -> None: + if self.verbose: + log_filename = self.log_filename or datetime.now().strftime("pgmq_async_debug_%Y%m%d_%H%M%S.log") + logging.basicConfig( + filename=os.path.join(os.getcwd(), log_filename), + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + else: + logging.basicConfig(level=logging.WARNING) + self.logger = logging.getLogger(__name__) + async def init(self): + self.logger.debug("Creating asyncpg connection pool") self.pool = await asyncpg.create_pool( user=self.username, database=self.database, password=self.password, host=self.host, port=self.port, + min_size=1, + max_size=self.pool_size, ) + self.logger.debug("Initializing pgmq extension") async with self.pool.acquire() as conn: - await conn.fetch("create extension if not exists pgmq cascade;") + await conn.execute("create extension if not exists pgmq cascade;") + @transaction async def create_partitioned_queue( self, queue: str, partition_interval: int = 10000, retention_interval: int = 100000, + conn=None, ) -> None: - """Create a new queue - - Note: Partitions are created pg_partman which must be configured in postgresql.conf - Set `pg_partman_bgw.interval` to set the interval for partition creation and deletion. - A value of 10 will create new/delete partitions every 10 seconds. This value should be tuned - according to the volume of messages being sent to the queue. - - Args: - queue: The name of the queue. - partition_interval: The number of messages per partition. Defaults to 10,000. - retention_interval: The number of messages to retain. Messages exceeding this number will be dropped. - Defaults to 100,000. - """ - - async with self.pool.acquire() as conn: - await conn.execute( - "SELECT pgmq.create($1, $2::text, $3::text);", - queue, - partition_interval, - retention_interval, - ) + """Create a new partitioned queue.""" + self.logger.debug( + f"create_partitioned_queue called with queue='{queue}', " + f"partition_interval={partition_interval}, " + f"retention_interval={retention_interval}, conn={conn}" + ) + if conn is None: + async with self.pool.acquire() as conn: + await self._create_partitioned_queue_internal(queue, partition_interval, retention_interval, conn) + else: + await self._create_partitioned_queue_internal(queue, partition_interval, retention_interval, conn) + + async def _create_partitioned_queue_internal(self, queue, partition_interval, retention_interval, conn): + self.logger.debug(f"Creating partitioned queue '{queue}'") + await conn.execute( + "SELECT pgmq.create($1, $2::text, $3::text);", + queue, + partition_interval, + retention_interval, + ) - async def create_queue(self, queue: str, unlogged: bool = False) -> None: + @transaction + async def create_queue(self, queue: str, unlogged: bool = False, conn=None) -> None: """Create a new queue.""" - async with self.pool.acquire() as conn: - if unlogged: - await conn.execute("SELECT pgmq.create_unlogged($1);", queue) - else: - await conn.execute("SELECT pgmq.create($1);", queue) + self.logger.debug(f"create_queue called with queue='{queue}', unlogged={unlogged}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + await self._create_queue_internal(queue, unlogged, conn) + else: + await self._create_queue_internal(queue, unlogged, conn) + + async def _create_queue_internal(self, queue, unlogged, conn): + self.logger.debug(f"Creating queue '{queue}' with unlogged={unlogged}") + if unlogged: + await conn.execute("SELECT pgmq.create_unlogged($1);", queue) + else: + await conn.execute("SELECT pgmq.create($1);", queue) async def validate_queue_name(self, queue_name: str) -> None: """Validate the length of a queue name.""" + self.logger.debug(f"validate_queue_name called with queue_name='{queue_name}'") async with self.pool.acquire() as conn: await conn.execute("SELECT pgmq.validate_queue_name($1);", queue_name) - async def drop_queue(self, queue: str, partitioned: bool = False) -> bool: + @transaction + async def drop_queue(self, queue: str, partitioned: bool = False, conn=None) -> bool: """Drop a queue.""" - async with self.pool.acquire() as conn: - result = await conn.fetchrow("SELECT pgmq.drop_queue($1, $2);", queue, partitioned) + self.logger.debug(f"drop_queue called with queue='{queue}', partitioned={partitioned}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._drop_queue_internal(queue, partitioned, conn) + else: + return await self._drop_queue_internal(queue, partitioned, conn) + + async def _drop_queue_internal(self, queue, partitioned, conn): + result = await conn.fetchrow("SELECT pgmq.drop_queue($1, $2);", queue, partitioned) + self.logger.debug(f"Queue '{queue}' dropped: {result[0]}") return result[0] - async def list_queues(self) -> List[str]: + @transaction + async def list_queues(self, conn=None) -> List[str]: """List all queues.""" - async with self.pool.acquire() as conn: - rows = await conn.fetch("SELECT queue_name FROM pgmq.list_queues();") - return [row["queue_name"] for row in rows] - - async def send(self, queue: str, message: dict, delay: int = 0) -> int: + self.logger.debug(f"list_queues called with conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._list_queues_internal(conn) + else: + return await self._list_queues_internal(conn) + + async def _list_queues_internal(self, conn): + rows = await conn.fetch("SELECT queue_name FROM pgmq.list_queues();") + queues = [row["queue_name"] for row in rows] + self.logger.debug(f"Queues listed: {queues}") + return queues + + @transaction + async def send(self, queue: str, message: dict, delay: int = 0, conn=None) -> int: """Send a message to a queue.""" - async with self.pool.acquire() as conn: - result = await conn.fetchrow( - "SELECT * FROM pgmq.send($1, $2::jsonb, $3);", queue, dumps(message).decode("utf-8"), delay - ) + self.logger.debug(f"send called with queue='{queue}', message={message}, delay={delay}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._send_internal(queue, message, delay, conn) + else: + return await self._send_internal(queue, message, delay, conn) + + async def _send_internal(self, queue, message, delay, conn): + self.logger.debug(f"Sending message to queue '{queue}' with delay={delay}") + result = await conn.fetchrow( + "SELECT * FROM pgmq.send($1, $2::jsonb, $3);", + queue, + dumps(message).decode("utf-8"), + delay, + ) + self.logger.debug(f"Message sent with msg_id={result[0]}") return result[0] - async def send_batch(self, queue: str, messages: List[dict], delay: int = 0) -> List[int]: + @transaction + async def send_batch(self, queue: str, messages: List[dict], delay: int = 0, conn=None) -> List[int]: """Send a batch of messages to a queue.""" + self.logger.debug(f"send_batch called with queue='{queue}', messages={messages}, delay={delay}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._send_batch_internal(queue, messages, delay, conn) + else: + return await self._send_batch_internal(queue, messages, delay, conn) + + async def _send_batch_internal(self, queue, messages, delay, conn): + self.logger.debug(f"Sending batch of messages to queue '{queue}' with delay={delay}") jsonb_array = [dumps(message).decode("utf-8") for message in messages] + result = await conn.fetch( + "SELECT * FROM pgmq.send_batch($1, $2::jsonb[], $3);", + queue, + jsonb_array, + delay, + ) + msg_ids = [message[0] for message in result] + self.logger.debug(f"Batch messages sent with msg_ids={msg_ids}") + return msg_ids - async with self.pool.acquire() as conn: - result = await conn.fetch( - "SELECT * FROM pgmq.send_batch($1, $2::jsonb[], $3);", - queue, - jsonb_array, - delay, - ) - return [message[0] for message in result] - - async def read(self, queue: str, vt: Optional[int] = None) -> Optional[Message]: + @transaction + async def read(self, queue: str, vt: Optional[int] = None, conn=None) -> Optional[Message]: """Read a message from a queue.""" + self.logger.debug(f"read called with queue='{queue}', vt={vt}, conn={conn}") batch_size = 1 - async with self.pool.acquire() as conn: - rows = await conn.fetch("SELECT * FROM pgmq.read($1, $2, $3);", queue, vt or self.vt, batch_size) + if conn is None: + async with self.pool.acquire() as conn: + return await self._read_internal(queue, vt, batch_size, conn) + else: + return await self._read_internal(queue, vt, batch_size, conn) + + async def _read_internal(self, queue, vt, batch_size, conn): + self.logger.debug(f"Reading message from queue '{queue}' with vt={vt}") + rows = await conn.fetch( + "SELECT * FROM pgmq.read($1, $2, $3);", + queue, + vt or self.vt, + batch_size, + ) messages = [ - Message(msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=loads(row[4])) for row in rows + Message( + msg_id=row[0], + read_ct=row[1], + enqueued_at=row[2], + vt=row[3], + message=loads(row[4]), + ) + for row in rows ] - return messages[0] if len(messages) == 1 else None + self.logger.debug(f"Message read: {messages[0] if messages else None}") + return messages[0] if messages else None - async def read_batch(self, queue: str, vt: Optional[int] = None, batch_size=1) -> Optional[List[Message]]: + @transaction + async def read_batch( + self, queue: str, vt: Optional[int] = None, batch_size=1, conn=None + ) -> Optional[List[Message]]: """Read a batch of messages from a queue.""" - async with self.pool.acquire() as conn: - rows = await conn.fetch("SELECT * FROM pgmq.read($1, $2, $3);", queue, vt or self.vt, batch_size) - - return [ - Message(msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=loads(row[4])) for row in rows + self.logger.debug(f"read_batch called with queue='{queue}', vt={vt}, batch_size={batch_size}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._read_batch_internal(queue, vt, batch_size, conn) + else: + return await self._read_batch_internal(queue, vt, batch_size, conn) + + async def _read_batch_internal(self, queue, vt, batch_size, conn): + self.logger.debug(f"Reading batch of messages from queue '{queue}' with vt={vt}") + rows = await conn.fetch( + "SELECT * FROM pgmq.read($1, $2, $3);", + queue, + vt or self.vt, + batch_size, + ) + messages = [ + Message( + msg_id=row[0], + read_ct=row[1], + enqueued_at=row[2], + vt=row[3], + message=loads(row[4]), + ) + for row in rows ] + self.logger.debug(f"Batch messages read: {messages}") + return messages + @transaction async def read_with_poll( self, queue: str, @@ -143,69 +271,164 @@ async def read_with_poll( qty: int = 1, max_poll_seconds: int = 5, poll_interval_ms: int = 100, + conn=None, ) -> Optional[List[Message]]: """Read messages from a queue with polling.""" - async with self.pool.acquire() as conn: - rows = await conn.fetch( - "SELECT * FROM pgmq.read_with_poll($1, $2, $3, $4, $5);", - queue, - vt or self.vt, - qty, - max_poll_seconds, - poll_interval_ms, + self.logger.debug( + f"read_with_poll called with queue='{queue}', vt={vt}, qty={qty}, " + f"max_poll_seconds={max_poll_seconds}, poll_interval_ms={poll_interval_ms}, conn={conn}" + ) + if conn is None: + async with self.pool.acquire() as conn: + return await self._read_with_poll_internal(queue, vt, qty, max_poll_seconds, poll_interval_ms, conn) + else: + return await self._read_with_poll_internal(queue, vt, qty, max_poll_seconds, poll_interval_ms, conn) + + async def _read_with_poll_internal(self, queue, vt, qty, max_poll_seconds, poll_interval_ms, conn): + self.logger.debug(f"Reading messages with polling from queue '{queue}'") + rows = await conn.fetch( + "SELECT * FROM pgmq.read_with_poll($1, $2, $3, $4, $5);", + queue, + vt or self.vt, + qty, + max_poll_seconds, + poll_interval_ms, + ) + messages = [ + Message( + msg_id=row[0], + read_ct=row[1], + enqueued_at=row[2], + vt=row[3], + message=loads(row[4]), ) - - return [ - Message(msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=loads(row[4])) for row in rows + for row in rows ] + self.logger.debug(f"Messages read with polling: {messages}") + return messages - async def pop(self, queue: str) -> Message: + @transaction + async def pop(self, queue: str, conn=None) -> Message: """Pop a message from a queue.""" - async with self.pool.acquire() as conn: - rows = await conn.fetch("SELECT * FROM pgmq.pop($1);", queue) - + self.logger.debug(f"pop called with queue='{queue}', conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._pop_internal(queue, conn) + else: + return await self._pop_internal(queue, conn) + + async def _pop_internal(self, queue, conn): + self.logger.debug(f"Popping message from queue '{queue}'") + rows = await conn.fetch("SELECT * FROM pgmq.pop($1);", queue) messages = [ - Message(msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=loads(row[4])) for row in rows + Message( + msg_id=row[0], + read_ct=row[1], + enqueued_at=row[2], + vt=row[3], + message=loads(row[4]), + ) + for row in rows ] - return messages[0] + self.logger.debug(f"Message popped: {messages[0] if messages else None}") + return messages[0] if messages else None - async def delete(self, queue: str, msg_id: int) -> bool: + @transaction + async def delete(self, queue: str, msg_id: int, conn=None) -> bool: """Delete a message from a queue.""" - async with self.pool.acquire() as conn: - row = await conn.fetchrow("SELECT pgmq.delete($1::text, $2::int);", queue, msg_id) - + self.logger.debug(f"delete called with queue='{queue}', msg_id={msg_id}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._delete_internal(queue, msg_id, conn) + else: + return await self._delete_internal(queue, msg_id, conn) + + async def _delete_internal(self, queue, msg_id, conn): + self.logger.debug(f"Deleting message with msg_id={msg_id} from queue '{queue}'") + row = await conn.fetchrow("SELECT pgmq.delete($1::text, $2::int);", queue, msg_id) + self.logger.debug(f"Message deleted: {row[0]}") return row[0] - async def delete_batch(self, queue: str, msg_ids: List[int]) -> List[int]: + @transaction + async def delete_batch(self, queue: str, msg_ids: List[int], conn=None) -> List[int]: """Delete multiple messages from a queue.""" - async with self.pool.acquire() as conn: - results = await conn.fetch("SELECT * FROM pgmq.delete($1::text, $2::int[]);", queue, msg_ids) - return [result[0] for result in results] - - async def archive(self, queue: str, msg_id: int) -> bool: + self.logger.debug(f"delete_batch called with queue='{queue}', msg_ids={msg_ids}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._delete_batch_internal(queue, msg_ids, conn) + else: + return await self._delete_batch_internal(queue, msg_ids, conn) + + async def _delete_batch_internal(self, queue, msg_ids, conn): + self.logger.debug(f"Deleting messages with msg_ids={msg_ids} from queue '{queue}'") + results = await conn.fetch("SELECT * FROM pgmq.delete($1::text, $2::int[]);", queue, msg_ids) + deleted_ids = [result[0] for result in results] + self.logger.debug(f"Messages deleted: {deleted_ids}") + return deleted_ids + + @transaction + async def archive(self, queue: str, msg_id: int, conn=None) -> bool: """Archive a message from a queue.""" - async with self.pool.acquire() as conn: - row = await conn.fetchrow("SELECT pgmq.archive($1::text, $2::int);", queue, msg_id) - + self.logger.debug(f"archive called with queue='{queue}', msg_id={msg_id}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._archive_internal(queue, msg_id, conn) + else: + return await self._archive_internal(queue, msg_id, conn) + + async def _archive_internal(self, queue, msg_id, conn): + self.logger.debug(f"Archiving message with msg_id={msg_id} from queue '{queue}'") + row = await conn.fetchrow("SELECT pgmq.archive($1::text, $2::int);", queue, msg_id) + self.logger.debug(f"Message archived: {row[0]}") return row[0] - async def archive_batch(self, queue: str, msg_ids: List[int]) -> List[int]: + @transaction + async def archive_batch(self, queue: str, msg_ids: List[int], conn=None) -> List[int]: """Archive multiple messages from a queue.""" - async with self.pool.acquire() as conn: - results = await conn.fetch("SELECT * FROM pgmq.archive($1::text, $2::int[]);", queue, msg_ids) - return [result[0] for result in results] - - async def purge(self, queue: str) -> int: + self.logger.debug(f"archive_batch called with queue='{queue}', msg_ids={msg_ids}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._archive_batch_internal(queue, msg_ids, conn) + else: + return await self._archive_batch_internal(queue, msg_ids, conn) + + async def _archive_batch_internal(self, queue, msg_ids, conn): + self.logger.debug(f"Archiving messages with msg_ids={msg_ids} from queue '{queue}'") + results = await conn.fetch("SELECT * FROM pgmq.archive($1::text, $2::int[]);", queue, msg_ids) + archived_ids = [result[0] for result in results] + self.logger.debug(f"Messages archived: {archived_ids}") + return archived_ids + + @transaction + async def purge(self, queue: str, conn=None) -> int: """Purge a queue.""" - async with self.pool.acquire() as conn: - row = await conn.fetchrow("SELECT pgmq.purge_queue($1);", queue) - + self.logger.debug(f"purge called with queue='{queue}', conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._purge_internal(queue, conn) + else: + return await self._purge_internal(queue, conn) + + async def _purge_internal(self, queue, conn): + self.logger.debug(f"Purging queue '{queue}'") + row = await conn.fetchrow("SELECT pgmq.purge_queue($1);", queue) + self.logger.debug(f"Messages purged: {row[0]}") return row[0] - async def metrics(self, queue: str) -> QueueMetrics: - async with self.pool.acquire() as conn: - result = await conn.fetchrow("SELECT * FROM pgmq.metrics($1);", queue) - return QueueMetrics( + @transaction + async def metrics(self, queue: str, conn=None) -> QueueMetrics: + """Get metrics for a specific queue.""" + self.logger.debug(f"metrics called with queue='{queue}', conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._metrics_internal(queue, conn) + else: + return await self._metrics_internal(queue, conn) + + async def _metrics_internal(self, queue, conn): + self.logger.debug(f"Fetching metrics for queue '{queue}'") + result = await conn.fetchrow("SELECT * FROM pgmq.metrics($1);", queue) + metrics = QueueMetrics( queue_name=result[0], queue_length=result[1], newest_msg_age_sec=result[2], @@ -213,11 +436,23 @@ async def metrics(self, queue: str) -> QueueMetrics: total_messages=result[4], scrape_time=result[5], ) - - async def metrics_all(self) -> List[QueueMetrics]: - async with self.pool.acquire() as conn: - results = await conn.fetch("SELECT * FROM pgmq.metrics_all();") - return [ + self.logger.debug(f"Metrics fetched: {metrics}") + return metrics + + @transaction + async def metrics_all(self, conn=None) -> List[QueueMetrics]: + """Get metrics for all queues.""" + self.logger.debug(f"metrics_all called with conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._metrics_all_internal(conn) + else: + return await self._metrics_all_internal(conn) + + async def _metrics_all_internal(self, conn): + self.logger.debug("Fetching metrics for all queues") + results = await conn.fetch("SELECT * FROM pgmq.metrics_all();") + metrics_list = [ QueueMetrics( queue_name=row[0], queue_length=row[1], @@ -228,14 +463,43 @@ async def metrics_all(self) -> List[QueueMetrics]: ) for row in results ] + self.logger.debug(f"All metrics fetched: {metrics_list}") + return metrics_list - async def set_vt(self, queue: str, msg_id: int, vt: int) -> Message: + @transaction + async def set_vt(self, queue: str, msg_id: int, vt: int, conn=None) -> Message: """Set the visibility timeout for a specific message.""" - async with self.pool.acquire() as conn: - row = await conn.fetchrow("SELECT * FROM pgmq.set_vt($1, $2, $3);", queue, msg_id, vt) - return Message(msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=row[4]) + self.logger.debug(f"set_vt called with queue='{queue}', msg_id={msg_id}, vt={vt}, conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + return await self._set_vt_internal(queue, msg_id, vt, conn) + else: + return await self._set_vt_internal(queue, msg_id, vt, conn) + + async def _set_vt_internal(self, queue, msg_id, vt, conn): + self.logger.debug(f"Setting VT for msg_id={msg_id} in queue '{queue}' to vt={vt}") + row = await conn.fetchrow("SELECT * FROM pgmq.set_vt($1, $2, $3);", queue, msg_id, vt) + message = Message( + msg_id=row[0], + read_ct=row[1], + enqueued_at=row[2], + vt=row[3], + message=loads(row[4]), + ) + self.logger.debug(f"VT set for message: {message}") + return message - async def detach_archive(self, queue: str) -> None: + @transaction + async def detach_archive(self, queue: str, conn=None) -> None: """Detach an archive from a queue.""" - async with self.pool.acquire() as conn: - await conn.fetch("select pgmq.detach_archive($1);", queue) + self.logger.debug(f"detach_archive called with queue='{queue}', conn={conn}") + if conn is None: + async with self.pool.acquire() as conn: + await self._detach_archive_internal(queue, conn) + else: + await self._detach_archive_internal(queue, conn) + + async def _detach_archive_internal(self, queue, conn): + self.logger.debug(f"Detaching archive from queue '{queue}'") + await conn.execute("SELECT pgmq.detach_archive($1);", queue) + self.logger.debug(f"Archive detached from queue '{queue}'") diff --git a/tembo-pgmq-python/tembo_pgmq_python/decorators.py b/tembo-pgmq-python/tembo_pgmq_python/decorators.py new file mode 100644 index 00000000..537f3843 --- /dev/null +++ b/tembo-pgmq-python/tembo_pgmq_python/decorators.py @@ -0,0 +1,69 @@ +# decorators.py +import functools + + +def transaction(func): + """Decorator to run a function within a database transaction.""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if args and hasattr(args[0], "pool") and hasattr(args[0], "logger"): + self = args[0] + + if "conn" not in kwargs: + with self.pool.connection() as conn: + with conn.transaction(): + self.logger.debug(f"Transaction started with conn: {conn}") + try: + kwargs["conn"] = conn # Inject 'conn' into kwargs + result = func(*args, **kwargs) + self.logger.debug(f"Transaction completed with conn: {conn}") + return result + except Exception as e: + self.logger.error(f"Transaction failed with exception: {e}, rolling back.") + raise + else: + return func(*args, **kwargs) + + else: + queue = kwargs.get("queue") or args[0] + + if "conn" not in kwargs: + with queue.pool.connection() as conn: + with conn.transaction(): + queue.logger.debug(f"Transaction started with conn: {conn}") + try: + kwargs["conn"] = conn # Inject 'conn' into kwargs + result = func(*args, **kwargs) + queue.logger.debug(f"Transaction completed with conn: {conn}") + return result + except Exception as e: + queue.logger.error(f"Transaction failed with exception: {e}, rolling back.") + raise + else: + return func(*args, **kwargs) + + return wrapper + + +def async_transaction(func): + """Asynchronous decorator to run a method within a database transaction.""" + + @functools.wraps(func) + async def wrapper(self, *args, **kwargs): + if "conn" not in kwargs: + async with self.pool.acquire() as conn: + txn = conn.transaction() + await txn.start() + try: + kwargs["conn"] = conn + result = await func(self, *args, **kwargs) + await txn.commit() + return result + except Exception: + await txn.rollback() + raise + else: + return await func(self, *args, **kwargs) + + return wrapper diff --git a/tembo-pgmq-python/tembo_pgmq_python/queue.py b/tembo-pgmq-python/tembo_pgmq_python/queue.py index 8d871ef1..f4a56e83 100644 --- a/tembo-pgmq-python/tembo_pgmq_python/queue.py +++ b/tembo-pgmq-python/tembo_pgmq_python/queue.py @@ -1,9 +1,12 @@ from dataclasses import dataclass, field -from typing import Optional, List +from typing import Optional, List, Union from psycopg.types.json import Jsonb from psycopg_pool import ConnectionPool import os from tembo_pgmq_python.messages import Message, QueueMetrics +from tembo_pgmq_python.decorators import transaction +import logging +import datetime @dataclass @@ -19,18 +22,12 @@ class PGMQueue: vt: int = 30 pool_size: int = 10 kwargs: dict = field(default_factory=dict) + verbose: bool = False + log_filename: Optional[str] = None pool: ConnectionPool = field(init=False) + logger: logging.Logger = field(init=False) def __post_init__(self) -> None: - self.host = self.host or "localhost" - self.port = self.port or "5432" - self.database = self.database or "postgres" - self.username = self.username or "postgres" - self.password = self.password or "postgres" - - if not all([self.host, self.port, self.database, self.username, self.password]): - raise ValueError("Incomplete database connection information provided.") - conninfo = f""" host={self.host} port={self.port} @@ -39,94 +36,116 @@ def __post_init__(self) -> None: password={self.password} """ self.pool = ConnectionPool(conninfo, open=True, **self.kwargs) - - with self.pool.connection() as conn: - conn.execute("create extension if not exists pgmq cascade;") - + self._initialize_logging() + self._initialize_extensions() + + def _initialize_logging(self) -> None: + if self.verbose: + log_filename = self.log_filename or datetime.now().strftime("pgmq_debug_%Y%m%d_%H%M%S.log") + logging.basicConfig( + filename=os.path.join(os.getcwd(), log_filename), + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + else: + logging.basicConfig(level=logging.WARNING) + self.logger = logging.getLogger(__name__) + + def _initialize_extensions(self, conn=None) -> None: + self._execute_query("create extension if not exists pgmq cascade;", conn=conn) + + def _execute_query(self, query: str, params: Optional[Union[List, tuple]] = None, conn=None) -> None: + self.logger.debug(f"Executing query: {query} with params: {params} using conn: {conn}") + if conn: + conn.execute(query, params) + else: + with self.pool.connection() as conn: + conn.execute(query, params) + + def _execute_query_with_result(self, query: str, params: Optional[Union[List, tuple]] = None, conn=None): + self.logger.debug(f"Executing query with result: {query} with params: {params} using conn: {conn}") + if conn: + return conn.execute(query, params).fetchall() + else: + with self.pool.connection() as conn: + return conn.execute(query, params).fetchall() + + @transaction def create_partitioned_queue( self, queue: str, partition_interval: int = 10000, retention_interval: int = 100000, + conn=None, ) -> None: - """Create a new queue - - Note: Partitions are created pg_partman which must be configured in postgresql.conf - Set `pg_partman_bgw.interval` to set the interval for partition creation and deletion. - A value of 10 will create new/delete partitions every 10 seconds. This value should be tuned - according to the volume of messages being sent to the queue. - - Args: - queue: The name of the queue. - partition_interval: The number of messages per partition. Defaults to 10,000. - retention_interval: The number of messages to retain. Messages exceeding this number will be dropped. - Defaults to 100,000. - """ - - with self.pool.connection() as conn: - conn.execute( - "select pgmq.create(%s, %s::text, %s::text);", - [queue, partition_interval, retention_interval], - ) + """Create a new queue""" + query = "select pgmq.create(%s, %s::text, %s::text);" + params = [queue, partition_interval, retention_interval] + self._execute_query(query, params, conn=conn) - def create_queue(self, queue: str, unlogged: bool = False) -> None: + @transaction + def create_queue(self, queue: str, unlogged: bool = False, conn=None) -> None: """Create a new queue.""" - with self.pool.connection() as conn: - if unlogged: - conn.execute("select pgmq.create_unlogged(%s);", [queue]) - else: - conn.execute("select pgmq.create(%s);", [queue]) + self.logger.debug(f"create_queue called with conn: {conn}") + query = "select pgmq.create_unlogged(%s);" if unlogged else "select pgmq.create(%s);" + self._execute_query(query, [queue], conn=conn) - def validate_queue_name(self, queue_name: str) -> None: + def validate_queue_name(self, queue_name: str, conn=None) -> None: """Validate the length of a queue name.""" - with self.pool.connection() as conn: - conn.execute("select pgmq.validate_queue_name(%s);", [queue_name]) + query = "select pgmq.validate_queue_name(%s);" + self._execute_query(query, [queue_name], conn=conn) - def drop_queue(self, queue: str, partitioned: bool = False) -> bool: + @transaction + def drop_queue(self, queue: str, partitioned: bool = False, conn=None) -> bool: """Drop a queue.""" - with self.pool.connection() as conn: - result = conn.execute("select pgmq.drop_queue(%s, %s);", [queue, partitioned]).fetchone() - return result[0] + self.logger.debug(f"drop_queue called with conn: {conn}") + query = "select pgmq.drop_queue(%s, %s);" + result = self._execute_query_with_result(query, [queue, partitioned], conn=conn) + return result[0][0] - def list_queues(self) -> List[str]: + @transaction + def list_queues(self, conn=None) -> List[str]: """List all queues.""" - with self.pool.connection() as conn: - rows = conn.execute("select queue_name from pgmq.list_queues();").fetchall() + self.logger.debug(f"list_queues called with conn: {conn}") + query = "select queue_name from pgmq.list_queues();" + rows = self._execute_query_with_result(query, conn=conn) return [row[0] for row in rows] - def send(self, queue: str, message: dict, delay: int = 0) -> int: + @transaction + def send(self, queue: str, message: dict, delay: int = 0, conn=None) -> int: """Send a message to a queue.""" - with self.pool.connection() as conn: - result = conn.execute("select * from pgmq.send(%s, %s, %s);", [queue, Jsonb(message), delay]).fetchall() + self.logger.debug(f"send called with conn: {conn}") + query = "select * from pgmq.send(%s, %s, %s);" + result = self._execute_query_with_result(query, [queue, Jsonb(message), delay], conn=conn) return result[0][0] - def send_batch(self, queue: str, messages: List[dict], delay: int = 0) -> List[int]: + @transaction + def send_batch(self, queue: str, messages: List[dict], delay: int = 0, conn=None) -> List[int]: """Send a batch of messages to a queue.""" - with self.pool.connection() as conn: - result = conn.execute( - "select * from pgmq.send_batch(%s, %s, %s);", - [queue, [Jsonb(message) for message in messages], delay], - ).fetchall() + self.logger.debug(f"send_batch called with conn: {conn}") + query = "select * from pgmq.send_batch(%s, %s, %s);" + params = [queue, [Jsonb(message) for message in messages], delay] + result = self._execute_query_with_result(query, params, conn=conn) return [message[0] for message in result] - def read(self, queue: str, vt: Optional[int] = None) -> Optional[Message]: + @transaction + def read(self, queue: str, vt: Optional[int] = None, conn=None) -> Optional[Message]: """Read a message from a queue.""" - with self.pool.connection() as conn: - rows = conn.execute("select * from pgmq.read(%s, %s, %s);", [queue, vt or self.vt, 1]).fetchall() - + self.logger.debug(f"read called with conn: {conn}") + query = "select * from pgmq.read(%s, %s, %s);" + rows = self._execute_query_with_result(query, [queue, vt or self.vt, 1], conn=conn) messages = [Message(msg_id=x[0], read_ct=x[1], enqueued_at=x[2], vt=x[3], message=x[4]) for x in rows] - return messages[0] if len(messages) == 1 else None + return messages[0] if messages else None - def read_batch(self, queue: str, vt: Optional[int] = None, batch_size=1) -> Optional[List[Message]]: + @transaction + def read_batch(self, queue: str, vt: Optional[int] = None, batch_size=1, conn=None) -> Optional[List[Message]]: """Read a batch of messages from a queue.""" - with self.pool.connection() as conn: - rows = conn.execute( - "select * from pgmq.read(%s, %s, %s);", - [queue, vt or self.vt, batch_size], - ).fetchall() - + self.logger.debug(f"read_batch called with conn: {conn}") + query = "select * from pgmq.read(%s, %s, %s);" + rows = self._execute_query_with_result(query, [queue, vt or self.vt, batch_size], conn=conn) return [Message(msg_id=x[0], read_ct=x[1], enqueued_at=x[2], vt=x[3], message=x[4]) for x in rows] + @transaction def read_with_poll( self, queue: str, @@ -134,60 +153,70 @@ def read_with_poll( qty: int = 1, max_poll_seconds: int = 5, poll_interval_ms: int = 100, + conn=None, ) -> Optional[List[Message]]: """Read messages from a queue with polling.""" - with self.pool.connection() as conn: - rows = conn.execute( - "select * from pgmq.read_with_poll(%s, %s, %s, %s, %s);", - [queue, vt or self.vt, qty, max_poll_seconds, poll_interval_ms], - ).fetchall() - + self.logger.debug(f"read_with_poll called with conn: {conn}") + query = "select * from pgmq.read_with_poll(%s, %s, %s, %s, %s);" + params = [queue, vt or self.vt, qty, max_poll_seconds, poll_interval_ms] + rows = self._execute_query_with_result(query, params, conn=conn) return [Message(msg_id=x[0], read_ct=x[1], enqueued_at=x[2], vt=x[3], message=x[4]) for x in rows] - def pop(self, queue: str) -> Message: + @transaction + def pop(self, queue: str, conn=None) -> Message: """Pop a message from a queue.""" - with self.pool.connection() as conn: - rows = conn.execute("select * from pgmq.pop(%s);", [queue]).fetchall() - + self.logger.debug(f"pop called with conn: {conn}") + query = "select * from pgmq.pop(%s);" + rows = self._execute_query_with_result(query, [queue], conn=conn) messages = [Message(msg_id=x[0], read_ct=x[1], enqueued_at=x[2], vt=x[3], message=x[4]) for x in rows] return messages[0] - def delete(self, queue: str, msg_id: int) -> bool: + @transaction + def delete(self, queue: str, msg_id: int, conn=None) -> bool: """Delete a message from a queue.""" - with self.pool.connection() as conn: - row = conn.execute("select pgmq.delete(%s, %s);", [queue, msg_id]).fetchall() - - return row[0][0] + self.logger.debug(f"delete called with conn: {conn}") + query = "select pgmq.delete(%s, %s);" + result = self._execute_query_with_result(query, [queue, msg_id], conn=conn) + return result[0][0] - def delete_batch(self, queue: str, msg_ids: List[int]) -> List[int]: + @transaction + def delete_batch(self, queue: str, msg_ids: List[int], conn=None) -> List[int]: """Delete multiple messages from a queue.""" - with self.pool.connection() as conn: - result = conn.execute("select * from pgmq.delete(%s, %s);", [queue, msg_ids]).fetchall() + self.logger.debug(f"delete_batch called with conn: {conn}") + query = "select * from pgmq.delete(%s, %s);" + result = self._execute_query_with_result(query, [queue, msg_ids], conn=conn) return [x[0] for x in result] - def archive(self, queue: str, msg_id: int) -> bool: + @transaction + def archive(self, queue: str, msg_id: int, conn=None) -> bool: """Archive a message from a queue.""" - with self.pool.connection() as conn: - row = conn.execute("select pgmq.archive(%s, %s);", [queue, msg_id]).fetchall() - - return row[0][0] + self.logger.debug(f"archive called with conn: {conn}") + query = "select pgmq.archive(%s, %s);" + result = self._execute_query_with_result(query, [queue, msg_id], conn=conn) + return result[0][0] - def archive_batch(self, queue: str, msg_ids: List[int]) -> List[int]: + @transaction + def archive_batch(self, queue: str, msg_ids: List[int], conn=None) -> List[int]: """Archive multiple messages from a queue.""" - with self.pool.connection() as conn: - result = conn.execute("select * from pgmq.archive(%s, %s);", [queue, msg_ids]).fetchall() + self.logger.debug(f"archive_batch called with conn: {conn}") + query = "select * from pgmq.archive(%s, %s);" + result = self._execute_query_with_result(query, [queue, msg_ids], conn=conn) return [x[0] for x in result] - def purge(self, queue: str) -> int: + @transaction + def purge(self, queue: str, conn=None) -> int: """Purge a queue.""" - with self.pool.connection() as conn: - row = conn.execute("select pgmq.purge_queue(%s);", [queue]).fetchall() - - return row[0][0] + self.logger.debug(f"purge called with conn: {conn}") + query = "select pgmq.purge_queue(%s);" + result = self._execute_query_with_result(query, [queue], conn=conn) + return result[0][0] - def metrics(self, queue: str) -> QueueMetrics: - with self.pool.connection() as conn: - result = conn.execute("SELECT * FROM pgmq.metrics(%s);", [queue]).fetchone() + @transaction + def metrics(self, queue: str, conn=None) -> QueueMetrics: + """Get metrics for a specific queue.""" + self.logger.debug(f"metrics called with conn: {conn}") + query = "SELECT * FROM pgmq.metrics(%s);" + result = self._execute_query_with_result(query, [queue], conn=conn)[0] return QueueMetrics( queue_name=result[0], queue_length=result[1], @@ -197,9 +226,12 @@ def metrics(self, queue: str) -> QueueMetrics: scrape_time=result[5], ) - def metrics_all(self) -> List[QueueMetrics]: - with self.pool.connection() as conn: - results = conn.execute("SELECT * FROM pgmq.metrics_all();").fetchall() + @transaction + def metrics_all(self, conn=None) -> List[QueueMetrics]: + """Get metrics for all queues.""" + self.logger.debug(f"metrics_all called with conn: {conn}") + query = "SELECT * FROM pgmq.metrics_all();" + results = self._execute_query_with_result(query, conn=conn) return [ QueueMetrics( queue_name=row[0], @@ -212,13 +244,23 @@ def metrics_all(self) -> List[QueueMetrics]: for row in results ] - def set_vt(self, queue: str, msg_id: int, vt: int) -> Message: + @transaction + def set_vt(self, queue: str, msg_id: int, vt: int, conn=None) -> Message: """Set the visibility timeout for a specific message.""" - with self.pool.connection() as conn: - row = conn.execute("select * from pgmq.set_vt(%s, %s, %s);", [queue, msg_id, vt]).fetchone() - return Message(msg_id=row[0], read_ct=row[1], enqueued_at=row[2], vt=row[3], message=row[4]) + self.logger.debug(f"set_vt called with conn: {conn}") + query = "select * from pgmq.set_vt(%s, %s, %s);" + result = self._execute_query_with_result(query, [queue, msg_id, vt], conn=conn)[0] + return Message( + msg_id=result[0], + read_ct=result[1], + enqueued_at=result[2], + vt=result[3], + message=result[4], + ) - def detach_archive(self, queue: str) -> None: + @transaction + def detach_archive(self, queue: str, conn=None) -> None: """Detach an archive from a queue.""" - with self.pool.connection() as conn: - conn.execute("select pgmq.detach_archive(%s);", [queue]) + self.logger.debug(f"detach_archive called with conn: {conn}") + query = "select pgmq.detach_archive(%s);" + self._execute_query(query, [queue], conn=conn) diff --git a/tembo-pgmq-python/tests/test_async_integration.py b/tembo-pgmq-python/tests/test_async_integration.py index 057829b7..e4e614ba 100644 --- a/tembo-pgmq-python/tests/test_async_integration.py +++ b/tembo-pgmq-python/tests/test_async_integration.py @@ -2,6 +2,7 @@ import time from tembo_pgmq_python.messages import Message from tembo_pgmq_python.async_queue import PGMQueue +from tembo_pgmq_python.decorators import async_transaction as transaction from datetime import datetime, timezone, timedelta # Function to load environment variables @@ -92,7 +93,9 @@ async def test_read_batch(self): """Test reading a batch of messages from the queue.""" messages = [self.test_message, self.test_message] await self.queue.send_batch(self.test_queue, messages) - read_messages = await self.queue.read_batch(self.test_queue, vt=20, batch_size=2) + read_messages = await self.queue.read_batch( + self.test_queue, vt=20, batch_size=2 + ) self.assertEqual(len(read_messages), 2) for message in read_messages: self.assertEqual(message.message, self.test_message) @@ -145,7 +148,9 @@ async def test_archive_batch(self): messages = [self.test_message, self.test_message] msg_ids = await self.queue.send_batch(self.test_queue, messages) await self.queue.archive_batch(self.test_queue, msg_ids) - read_messages = await self.queue.read_batch(self.test_queue, vt=20, batch_size=2) + read_messages = await self.queue.read_batch( + self.test_queue, vt=20, batch_size=2 + ) self.assertEqual(len(read_messages), 0) async def test_delete_batch(self): @@ -153,7 +158,9 @@ async def test_delete_batch(self): messages = [self.test_message, self.test_message] msg_ids = await self.queue.send_batch(self.test_queue, messages) await self.queue.delete_batch(self.test_queue, msg_ids) - read_messages = await self.queue.read_batch(self.test_queue, vt=20, batch_size=2) + read_messages = await self.queue.read_batch( + self.test_queue, vt=20, batch_size=2 + ) self.assertEqual(len(read_messages), 0) async def test_set_vt(self): @@ -195,9 +202,50 @@ async def test_validate_queue_name(self): await self.queue.validate_queue_name(invalid_queue_name) self.assertIn("queue name is too long", str(context.exception)) + async def test_transaction_create_queue(self): + @transaction + async def transactional_create_queue(queue): + await queue.create_queue("test_queue_txn") + raise Exception("Simulated failure") -class TestPGMQueueWithEnv(unittest.IsolatedAsyncioTestCase): + try: + await transactional_create_queue(self.queue) + except Exception: + pass + queues = await self.queue.list_queues() + self.assertNotIn("test_queue_txn", queues) + + async def test_transaction_rollback(self): + @transaction + async def transactional_operation(queue): + await queue.send( + self.test_queue, + self.test_message, + ) + raise Exception("Intentional failure") + + try: + await transactional_operation(self.queue) + except Exception: + pass + + message = await self.queue.read(self.test_queue) + self.assertIsNone(message, "No message expected in queue after rollback") + + async def test_transaction_send_and_read_message(self): + @transaction + async def transactional_send(queue, conn): + await queue.send(self.test_queue, self.test_message, conn=conn) + + await transactional_send(self.queue) + + message = await self.queue.read(self.test_queue) + self.assertIsNotNone(message, "Expected message in queue") + self.assertEqual(message.message, self.test_message) + + +class TestPGMQueueWithEnv(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): """Set up a connection to the PGMQueue using environment variables and create a test queue.""" diff --git a/tembo-pgmq-python/tests/test_integration.py b/tembo-pgmq-python/tests/test_integration.py index d8991c61..1f1190db 100644 --- a/tembo-pgmq-python/tests/test_integration.py +++ b/tembo-pgmq-python/tests/test_integration.py @@ -1,8 +1,8 @@ import unittest import time -from tembo_pgmq_python import Message, PGMQueue +from tembo_pgmq_python import Message, PGMQueue, transaction + from datetime import datetime, timezone, timedelta -# Function to load environment variables class BaseTestPGMQueue(unittest.TestCase): @@ -15,6 +15,7 @@ def setUpClass(cls): username="postgres", password="postgres", database="postgres", + verbose=False, ) # Test database connection first @@ -173,7 +174,6 @@ def test_detach_archive(self): self.queue.send(self.test_queue, self.test_message) self.queue.archive(self.test_queue, 1) self.queue.detach_archive(self.test_queue) - # This is just a basic call to ensure the method works without exceptions. def test_drop_queue(self): """Test dropping a queue.""" @@ -193,6 +193,72 @@ def test_validate_queue_name(self): self.queue.validate_queue_name(invalid_queue_name) self.assertIn("queue name is too long", str(context.exception)) + def test_transaction_create_queue(self): + """Test creating a queue within a transaction.""" + + @transaction + def transactional_create_queue(queue, conn=None): + queue.create_queue("test_queue_txn", conn=conn) + raise Exception("Intentional failure") + + try: + transactional_create_queue(self.queue) + except Exception: + pass + finally: + queues = self.queue.list_queues() + self.assertNotIn("test_queue_txn", queues) + + def test_transaction_send_and_read_message(self): + """Test sending and reading a message within a transaction.""" + + @transaction + def transactional_send(queue, conn=None): + queue.send(self.test_queue, self.test_message, conn=conn) + raise Exception("Intentional failure") + + try: + transactional_send(self.queue) + except Exception: + pass + finally: + message = self.queue.read(self.test_queue) + self.assertIsNone(message, "No message expected in queue") + + def test_transaction_purge_queue(self): + """Test purging a queue within a transaction.""" + + self.queue.send(self.test_queue, self.test_message) + + @transaction + def transactional_purge(queue, conn=None): + queue.purge(self.test_queue, conn=conn) + raise Exception("Intentional failure") + + try: + transactional_purge(self.queue) + except Exception: + pass + finally: + message = self.queue.read(self.test_queue) + self.assertIsNotNone(message, "Message expected in queue") + + def test_transaction_rollback(self): + """Test rollback of a transaction.""" + + @transaction + def transactional_operation(queue, conn=None): + queue.send(self.test_queue, self.test_message, conn=conn) + raise Exception("Intentional failure to trigger rollback") + + try: + transactional_operation(self.queue) + except Exception: + pass + finally: + message = self.queue.read(self.test_queue) + self.assertIsNone(message, "No message expected in queue after rollback") + class TestPGMQueueWithEnv(BaseTestPGMQueue): @classmethod