diff --git a/src/funcx_common/messagepack/message_types/ep_status_report.py b/src/funcx_common/messagepack/message_types/ep_status_report.py index 159af4c..f3b48bf 100644 --- a/src/funcx_common/messagepack/message_types/ep_status_report.py +++ b/src/funcx_common/messagepack/message_types/ep_status_report.py @@ -15,4 +15,4 @@ class EPStatusReport(Message): endpoint_id: uuid.UUID ep_status_report: t.Dict[str, t.Any] - task_statuses: t.Dict[str, TaskTransition] + task_statuses: t.Dict[str, t.List[TaskTransition]] diff --git a/src/funcx_common/messagepack/message_types/manager_status_report.py b/src/funcx_common/messagepack/message_types/manager_status_report.py index 72c87af..c505ed4 100644 --- a/src/funcx_common/messagepack/message_types/manager_status_report.py +++ b/src/funcx_common/messagepack/message_types/manager_status_report.py @@ -11,4 +11,4 @@ class ManagerStatusReport(Message): saying which tasks are now RUNNING. """ - task_statuses: t.Dict[str, TaskTransition] + task_statuses: t.Dict[str, t.List[TaskTransition]] diff --git a/src/funcx_common/redis_task.py b/src/funcx_common/redis_task.py index 693bbd0..87b290a 100644 --- a/src/funcx_common/redis_task.py +++ b/src/funcx_common/redis_task.py @@ -74,7 +74,7 @@ class RedisTask(TaskProtocol, metaclass=HasRedisFieldsMeta): queue_name = t.cast(str, RedisField()) # end required fields - endpoint = t.cast(t.Optional[str], RedisField()) + endpoint_id = t.cast(t.Optional[str], RedisField()) # FIXME: `payload` is a string which is currently being round-tripped through the # JSON_SERDE. However, we cannot remove the use of the serde until we are prepared @@ -114,6 +114,7 @@ def __init__( payload_reference: t.Optional[t.Dict[str, t.Any]] = None, task_group_id: t.Optional[str] = None, queue_name: t.Optional[str] = None, + endpoint_id: t.Optional[str] = None, ): """ If optional values are passed, then they will be written. @@ -127,6 +128,7 @@ def __init__( :param payload: serialized function + input data :param task_group_id: UUID of task group that this task belongs to :param queue_name: name of AMQP queue where results will be sent + :param endpoint_id: UUID of the endpoint the task was sent to """ # non-RedisField attributes of a RedisTask self.hname = f"task_{task_id}" @@ -160,6 +162,8 @@ def __init__( self.task_group_id = task_group_id if queue_name is not None: self.queue_name = queue_name + if endpoint_id is not None: + self.endpoint_id = endpoint_id self.ttl = self.DEFAULT_TTL diff --git a/tests/unit/test_messagepack.py b/tests/unit/test_messagepack.py index d9c4366..8b8238e 100644 --- a/tests/unit/test_messagepack.py +++ b/tests/unit/test_messagepack.py @@ -53,11 +53,13 @@ def crudely_pack_data(data): { "endpoint_id": ID_ZERO, "ep_status_report": { - str(ID_ZERO): TaskTransition( - timestamp=1, - state=TaskState.EXEC_END, - actor=ActorName.INTERCHANGE, - ) + str(ID_ZERO): [ + TaskTransition( + timestamp=1, + state=TaskState.EXEC_END, + actor=ActorName.INTERCHANGE, + ) + ] }, "task_statuses": {}, }, @@ -69,11 +71,13 @@ def crudely_pack_data(data): "endpoint_id": ID_ZERO, "ep_status_report": {}, "task_statuses": { - str(ID_ZERO): TaskTransition( - timestamp=1, - state=TaskState.EXEC_END, - actor=ActorName.INTERCHANGE, - ) + str(ID_ZERO): [ + TaskTransition( + timestamp=1, + state=TaskState.EXEC_END, + actor=ActorName.INTERCHANGE, + ) + ] }, }, None, @@ -83,11 +87,13 @@ def crudely_pack_data(data): ManagerStatusReport, { "task_statuses": { - "foo": TaskTransition( - timestamp=1, - state=TaskState.EXEC_END, - actor=ActorName.INTERCHANGE, - ) + "foo": [ + TaskTransition( + timestamp=1, + state=TaskState.EXEC_END, + actor=ActorName.INTERCHANGE, + ) + ] } }, None,