From df073f545668ab3fb57020366cefe14524de5fd9 Mon Sep 17 00:00:00 2001 From: Ondrej Sedlacek Date: Fri, 23 Aug 2024 13:59:16 +0200 Subject: [PATCH] FIX: Send ControlMessage *only* to worker 0. --- dp3/common/control.py | 3 +++ dp3/common/task.py | 18 ++++++++++++++++++ dp3/scripts/add_hashes.py | 2 +- dp3/task_processing/task_queue.py | 15 ++------------- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/dp3/common/control.py b/dp3/common/control.py index 85447137..c3ccf3f7 100644 --- a/dp3/common/control.py +++ b/dp3/common/control.py @@ -33,6 +33,9 @@ class ControlMessage(Task): def routing_key(self): return "" + def hashed_routing_key(self) -> int: + return 0 + def as_message(self) -> str: return self.model_dump_json() diff --git a/dp3/common/task.py b/dp3/common/task.py index f9a208c3..73d790c4 100644 --- a/dp3/common/task.py +++ b/dp3/common/task.py @@ -1,3 +1,4 @@ +import hashlib from abc import ABC, abstractmethod from collections.abc import Iterator from contextlib import contextmanager @@ -21,6 +22,16 @@ _init_context_var = ContextVar("_init_context_var", default=None) +def HASH(key: str) -> int: + """Hash function used to distribute tasks to worker processes. + Args: + key: to be hashed + Returns: + last 4 bytes of MD5 + """ + return int(hashlib.md5(key.encode("utf8")).hexdigest()[-4:], 16) + + @contextmanager def task_context(model_spec: ModelSpec) -> Iterator[None]: """Context manager for setting the `model_spec` context variable.""" @@ -45,6 +56,13 @@ def routing_key(self) -> str: A string to be used as a routing key between workers. """ + def hashed_routing_key(self) -> int: + """ + Returns: + An integer to be used as a hashed routing key between workers. + """ + return HASH(self.routing_key()) + @abstractmethod def as_message(self) -> str: """ diff --git a/dp3/scripts/add_hashes.py b/dp3/scripts/add_hashes.py index 75e6f1fe..87571bac 100755 --- a/dp3/scripts/add_hashes.py +++ b/dp3/scripts/add_hashes.py @@ -9,8 +9,8 @@ from pymongo import UpdateOne from dp3.common.config import ModelSpec, read_config_dir +from dp3.common.task import HASH from dp3.database.database import EntityDatabase, MongoConfig -from dp3.task_processing.task_queue import HASH # Arguments parser parser = argparse.ArgumentParser( diff --git a/dp3/task_processing/task_queue.py b/dp3/task_processing/task_queue.py index 04db03d6..6ddd1718 100644 --- a/dp3/task_processing/task_queue.py +++ b/dp3/task_processing/task_queue.py @@ -33,7 +33,6 @@ import collections import contextlib -import hashlib import logging import threading import time @@ -53,16 +52,6 @@ DEFAULT_PRIORITY_QUEUE = "{}-worker-{}-pri" -def HASH(key: str) -> int: - """Hash function used to distribute tasks to worker processes. - Args: - key: to be hashed - Returns: - last 4 bytes of MD5 - """ - return int(hashlib.md5(key.encode("utf8")).hexdigest()[-4:], 16) - - # When reading, pre-fetch only a limited amount of messages # (because pre-fetched messages are not counted to queue length limit) PREFETCH_COUNT = 50 @@ -294,8 +283,8 @@ def put_task(self, task: Task, priority: bool = False) -> None: # Prepare routing key body = task.as_message() - key = task.routing_key() - routing_key = HASH(key) % self.workers # index of the worker to send the task to + # index of the worker to send the task to + routing_key = task.hashed_routing_key() % self.workers exchange = self.exchange_pri if priority else self.exchange self._send_message(routing_key, exchange, body)