Skip to content

Commit

Permalink
FIX-#407: Make data put into MPI object store immutable (#409)
Browse files Browse the repository at this point in the history
Co-authored-by: Iaroslav Igoshev <[email protected]>
Signed-off-by: Kirill Suvorov <[email protected]>
  • Loading branch information
Retribution98 and YarShev authored Dec 13, 2023
1 parent 491272e commit c36fbe1
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 75 deletions.
14 changes: 6 additions & 8 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
) from None

from unidist.core.backends.mpi.core.serialization import serialize_complex_data
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.controller.garbage_collector import (
Expand Down Expand Up @@ -386,7 +387,6 @@ def put(data):

data_id = local_store.generate_data_id(garbage_collector)
serialized_data = serialize_complex_data(data)
local_store.put(data_id, data)
if shared_store.is_allocated():
shared_store.put(data_id, serialized_data)
else:
Expand All @@ -411,21 +411,20 @@ def get(data_ids):
object
A Python object.
"""
local_store = LocalObjectStore.get_instance()

object_store = ObjectStore.get_instance()
is_list = isinstance(data_ids, list)
if not is_list:
data_ids = [data_ids]
remote_data_ids = [
data_id for data_id in data_ids if not local_store.contains(data_id)
data_id for data_id in data_ids if not object_store.contains(data_id)
]
# Remote data gets available in the local store inside `request_worker_data`
if remote_data_ids:
request_worker_data(remote_data_ids)

logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids)))

values = [local_store.get(data_id) for data_id in data_ids]
values = [object_store.get(data_id) for data_id in data_ids]

# Initiate reference count based cleaup
# if all the tasks were completed
Expand Down Expand Up @@ -454,6 +453,7 @@ def wait(data_ids, num_returns=1):
tuple
List of data IDs that are ready and list of the remaining data IDs.
"""
object_store = ObjectStore.get_instance()
if not isinstance(data_ids, list):
data_ids = [data_ids]
# Since the controller should operate MpiDataID(s),
Expand All @@ -463,11 +463,9 @@ def wait(data_ids, num_returns=1):
not_ready = data_ids
pending_returns = num_returns
ready = []
local_store = LocalObjectStore.get_instance()

logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids)))
for data_id in not_ready.copy():
if local_store.contains(data_id):
if object_store.contains(data_id):
ready.append(data_id)
not_ready.remove(data_id)
pending_returns -= 1
Expand Down
29 changes: 14 additions & 15 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.serialization import serialize_complex_data
from unidist.core.backends.mpi.core.serialization import (
serialize_complex_data,
)


logger = common.get_logger("common", "common.log")
Expand Down Expand Up @@ -394,22 +396,19 @@ def push_data(dest_rank, value, is_blocking_op=False):
data_id = value
if shared_store.contains(data_id):
_push_shared_data(dest_rank, data_id, is_blocking_op)
elif local_store.is_already_serialized(data_id):
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
elif local_store.contains(data_id):
if local_store.is_already_serialized(data_id):
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
data = local_store.get(data_id)
serialized_data = serialize_complex_data(data)
if shared_store.is_allocated() and shared_store.should_be_shared(
serialized_data
):
shared_store.put(data_id, serialized_data)
_push_shared_data(dest_rank, data_id, is_blocking_op)
else:
data = local_store.get(data_id)
serialized_data = serialize_complex_data(data)
if shared_store.is_allocated() and shared_store.should_be_shared(
serialized_data
):
shared_store.put(data_id, serialized_data)
_push_shared_data(dest_rank, data_id, is_blocking_op)
else:
local_store.cache_serialized_data(data_id, serialized_data)
_push_local_data(
dest_rank, data_id, is_blocking_op, is_serialized=True
)
local_store.cache_serialized_data(data_id, serialized_data)
_push_local_data(dest_rank, data_id, is_blocking_op, is_serialized=True)
elif local_store.contains_data_owner(data_id):
_push_data_owner(dest_rank, data_id)
else:
Expand Down
4 changes: 4 additions & 0 deletions unidist/core/backends/mpi/core/local_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ def cache_serialized_data(self, data_id, data):
data : object
Serialized data to cache.
"""
# We make a copy to avoid data corruption obtained through out-of-band serialization,
# and buffers are marked read-only to prevent them from being modified.
# `to_bytes()` call handles both points.
data["raw_buffers"] = [buf.tobytes() for buf in data["raw_buffers"]]
self._serialization_cache[data_id] = data
self.maybe_update_data_id_map(data_id)

Expand Down
90 changes: 90 additions & 0 deletions unidist/core/backends/mpi/core/object_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (C) 2021-2023 Modin authors
#
# SPDX-License-Identifier: Apache-2.0

"""`ObjectStore` functionality."""

from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.serialization import deserialize_complex_data
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore


class ObjectStore:
"""
Class that combines checking and retrieving data from the shared and local stores in a current process.
Notes
-----
The store checks for both deserialized and serialized data.
"""

__instance = None

@classmethod
def get_instance(cls):
"""
Get instance of ``ObjectStore``.
Returns
-------
ObjectStore
"""
if cls.__instance is None:
cls.__instance = ObjectStore()
return cls.__instance

def contains(self, data_id):
"""
Check if the data associated with `data_id` exists in the current process.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
Returns
-------
bool
Return the status if an object exist in the current process.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()
return (
local_store.contains(data_id)
or local_store.is_already_serialized(data_id)
or shared_store.contains(data_id)
)

def get(self, data_id):
"""
Get data from any location in the current process.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
Returns
-------
object
Return data associated with `data_id`.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()

if local_store.contains(data_id):
return local_store.get(data_id)

if local_store.is_already_serialized(data_id):
serialized_data = local_store.get_serialized_data(data_id)
value = deserialize_complex_data(
serialized_data["s_data"],
serialized_data["raw_buffers"],
serialized_data["buffer_count"],
)
elif shared_store.contains(data_id):
value = shared_store.get(data_id)
else:
raise ValueError("The current data ID is not contained in the procces.")
local_store.put(data_id, value)
return value
98 changes: 55 additions & 43 deletions unidist/core/backends/mpi/core/shared_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,62 +747,74 @@ def put(self, data_id, serialized_data):
# put shared info
self._put_shared_info(data_id, shared_info)

def get(self, data_id, owner_rank, shared_info):
def get(self, data_id, owner_rank=None, shared_info=None):
"""
Get data from another worker using shared memory.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
owner_rank : int
owner_rank : int, default: None
The rank that sent the data.
shared_info : dict
This value is used to synchronize data in shared memory between different hosts
if the value is not ``None``.
shared_info : dict, default: None
The necessary information to properly deserialize data from shared memory.
If `shared_info` is ``None``, the data already exists in shared memory in the current process.
"""
mpi_state = communication.MPIState.get_instance()
s_data_len = shared_info["s_data_len"]
raw_buffers_len = shared_info["raw_buffers_len"]
service_index = shared_info["service_index"]
buffer_count = shared_info["buffer_count"]

# check data in shared memory
if not self._check_service_info(data_id, service_index):
# reserve shared memory
shared_data_len = s_data_len + sum([buf for buf in raw_buffers_len])
reservation_info = communication.send_reserve_operation(
mpi_state.global_comm, data_id, shared_data_len
)

service_index = reservation_info["service_index"]
# check if worker should sync shared buffer or it is doing by another worker
if reservation_info["is_first_request"]:
# syncronize shared buffer
self._sync_shared_memory_from_another_host(
mpi_state.global_comm,
data_id,
owner_rank,
reservation_info["first_index"],
reservation_info["last_index"],
service_index,
)
# put service info
self._put_service_info(
service_index, data_id, reservation_info["first_index"]
if shared_info is None:
shared_info = self.get_shared_info(data_id)
else:
mpi_state = communication.MPIState.get_instance()
s_data_len = shared_info["s_data_len"]
raw_buffers_len = shared_info["raw_buffers_len"]
service_index = shared_info["service_index"]
buffer_count = shared_info["buffer_count"]

# check data in shared memory
if not self._check_service_info(data_id, service_index):
# reserve shared memory
shared_data_len = s_data_len + sum([buf for buf in raw_buffers_len])
reservation_info = communication.send_reserve_operation(
mpi_state.global_comm, data_id, shared_data_len
)
else:
# wait while another worker syncronize shared buffer
while not self._check_service_info(data_id, service_index):
time.sleep(MpiBackoff.get())

# put shared info with updated data_id and service_index
shared_info = common.MetadataPackage.get_shared_info(
data_id, s_data_len, raw_buffers_len, buffer_count, service_index
)
self._put_shared_info(data_id, shared_info)
service_index = reservation_info["service_index"]
# check if worker should sync shared buffer or it is doing by another worker
if reservation_info["is_first_request"]:
# syncronize shared buffer
if owner_rank is None:
raise ValueError(
"The data is not in the host's shared memory and the data must be synchronized, "
+ "but the owner rank is not defined."
)

self._sync_shared_memory_from_another_host(
mpi_state.global_comm,
data_id,
owner_rank,
reservation_info["first_index"],
reservation_info["last_index"],
service_index,
)
# put service info
self._put_service_info(
service_index, data_id, reservation_info["first_index"]
)
else:
# wait while another worker syncronize shared buffer
while not self._check_service_info(data_id, service_index):
time.sleep(MpiBackoff.get())

# put shared info with updated data_id and service_index
shared_info = common.MetadataPackage.get_shared_info(
data_id, s_data_len, raw_buffers_len, buffer_count, service_index
)
self._put_shared_info(data_id, shared_info)

# increment ref
self._increment_ref_number(data_id, shared_info["service_index"])
# increment ref
self._increment_ref_number(data_id, shared_info["service_index"])

# read from shared buffer and deserialized
return self._read_from_shared_buffer(data_id, shared_info)
Expand Down
4 changes: 3 additions & 1 deletion unidist/core/backends/mpi/core/worker/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.worker.request_store import RequestStore
from unidist.core.backends.mpi.core.worker.task_store import TaskStore
Expand Down Expand Up @@ -85,6 +86,7 @@ async def worker_loop():
``unidist.core.backends.mpi.core.common.Operations`` defines a set of supported operations.
"""
task_store = TaskStore.get_instance()
object_store = ObjectStore.get_instance()
local_store = LocalObjectStore.get_instance()
request_store = RequestStore.get_instance()
async_operations = AsyncOperations.get_instance()
Expand Down Expand Up @@ -185,7 +187,7 @@ async def worker_loop():
if not ready_to_shutdown_posted:
# Prepare the data
# Actor method here is a data id so we have to retrieve it from the storage
method_name = local_store.get(request["task"])
method_name = object_store.get(request["task"])
handler = request["handler"]
actor_method = getattr(actor_map[handler], method_name)
request["task"] = actor_method
Expand Down
9 changes: 5 additions & 4 deletions unidist/core/backends/mpi/core/worker/request_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.controller.common import push_data
from unidist.core.backends.mpi.core.object_store import ObjectStore


mpi_state = communication.MPIState.get_instance()
Expand Down Expand Up @@ -213,7 +213,8 @@ def process_wait_request(self, data_id):
-----
Only ROOT rank is supported for now, therefore no rank argument needed.
"""
if LocalObjectStore.get_instance().contains(data_id):
object_store = ObjectStore.get_instance()
if object_store.contains(data_id):
# Executor wait just for signal
# We use a blocking send here because the receiver is waiting for the result.
communication.mpi_send_object(
Expand Down Expand Up @@ -247,8 +248,8 @@ def process_get_request(self, source_rank, data_id, is_blocking_op=False):
-----
Request is asynchronous, no wait for the data sending.
"""
local_store = LocalObjectStore.get_instance()
if local_store.contains(data_id):
object_store = ObjectStore.get_instance()
if object_store.contains(data_id):
push_data(
source_rank,
data_id,
Expand Down
Loading

0 comments on commit c36fbe1

Please sign in to comment.