Skip to content

Commit

Permalink
Add class ObjectStore
Browse files Browse the repository at this point in the history
  • Loading branch information
Retribution98 committed Dec 13, 2023
1 parent 1c47ae7 commit ebcb62b
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 79 deletions.
14 changes: 8 additions & 6 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
) from None

from unidist.core.backends.mpi.core.serialization import serialize_complex_data
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.controller.garbage_collector import (
Expand All @@ -28,8 +29,6 @@
request_worker_data,
push_data,
RoundRobin,
get_data,
contains_data,
)
import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
Expand Down Expand Up @@ -388,7 +387,6 @@ def put(data):

data_id = local_store.generate_data_id(garbage_collector)
serialized_data = serialize_complex_data(data)
# data is prepared for sending to another process, but is not saved to local storage
if shared_store.is_allocated():
shared_store.put(data_id, serialized_data)
else:
Expand All @@ -413,17 +411,20 @@ def get(data_ids):
object
A Python object.
"""
object_store = ObjectStore.get_instance()
is_list = isinstance(data_ids, list)
if not is_list:
data_ids = [data_ids]
remote_data_ids = [data_id for data_id in data_ids if not contains_data(data_id)]
remote_data_ids = [
data_id for data_id in data_ids if not object_store.contains_data(data_id)
]
# Remote data gets available in the local store inside `request_worker_data`
if remote_data_ids:
request_worker_data(remote_data_ids)

logger.debug("GET {} ids".format(common.unwrapped_data_ids_list(data_ids)))

values = [get_data(data_id) for data_id in data_ids]
values = [object_store.get_data(data_id) for data_id in data_ids]

# Initiate reference count based cleaup
# if all the tasks were completed
Expand Down Expand Up @@ -452,6 +453,7 @@ def wait(data_ids, num_returns=1):
tuple
List of data IDs that are ready and list of the remaining data IDs.
"""
object_store = ObjectStore.get_instance()
if not isinstance(data_ids, list):
data_ids = [data_ids]
# Since the controller should operate MpiDataID(s),
Expand All @@ -463,7 +465,7 @@ def wait(data_ids, num_returns=1):
ready = []
logger.debug("WAIT {} ids".format(common.unwrapped_data_ids_list(data_ids)))
for data_id in not_ready.copy():
if contains_data(data_id):
if object_store.contains_data(data_id):
ready.append(data_id)
not_ready.remove(data_id)
pending_returns -= 1
Expand Down
63 changes: 2 additions & 61 deletions unidist/core/backends/mpi/core/controller/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.serialization import (
deserialize_complex_data,
serialize_complex_data,
)

Expand Down Expand Up @@ -141,10 +140,10 @@ def pull_data(comm, owner_rank=None):
shared_store = SharedObjectStore.get_instance()
data_id = info_package["id"]

if contains_data(data_id):
if local_store.contains(data_id):
return {
"id": data_id,
"data": get_data(data_id),
"data": local_store.get(data_id),
}

data = shared_store.get(data_id, owner_rank, info_package)
Expand Down Expand Up @@ -414,61 +413,3 @@ def push_data(dest_rank, value, is_blocking_op=False):
_push_data_owner(dest_rank, data_id)
else:
raise ValueError("Unknown DataID!")


def contains_data(data_id):
"""
Check if the data associated with `data_id` exists in the current process.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
Returns
-------
bool
Return the status if an object exist in the current process.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()
return (
local_store.contains(data_id)
or local_store.is_already_serialized(data_id)
or shared_store.contains(data_id)
)


def get_data(data_id):
"""
Get data from any location in the current process.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
Returns
-------
object
Return data associated with `data_id`.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()

if local_store.contains(data_id):
return local_store.get(data_id)

if local_store.is_already_serialized(data_id):
serialized_data = local_store.get_serialized_data(data_id)
value = deserialize_complex_data(
serialized_data["s_data"],
serialized_data["raw_buffers"],
serialized_data["buffer_count"],
)
elif shared_store.contains(data_id):
value = shared_store.get(data_id)
else:
raise ValueError("The current data ID is not contained in the procces.")
local_store.put(data_id, value)
return value
3 changes: 2 additions & 1 deletion unidist/core/backends/mpi/core/local_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ def cache_serialized_data(self, data_id, data):
data : object
Serialized data to cache.
"""
# Copying is necessary to avoid corruption of data obtained through out-of-band serialization,
# We make a copy to avoid data corruption obtained through out-of-band serialization,
# and buffers are marked read-only to prevent them from being modified.
# `to_bytes()` call handles both points.
data["raw_buffers"] = [buf.tobytes() for buf in data["raw_buffers"]]
self._serialization_cache[data_id] = data
self.maybe_update_data_id_map(data_id)
Expand Down
84 changes: 84 additions & 0 deletions unidist/core/backends/mpi/core/object_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.serialization import deserialize_complex_data
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore


class ObjectStore:
"""
Class that combines checking and reciving data from all stores in a current process.
Notes
-----
The store checks for both deserialized and serialized data.
"""

__instance = None

@classmethod
def get_instance(cls):
"""
Get instance of ``ObjectStore``.
Returns
-------
ObjectStore
"""
if cls.__instance is None:
cls.__instance = ObjectStore()
return cls.__instance

def contains_data(self, data_id):
"""
Check if the data associated with `data_id` exists in the current process.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
Returns
-------
bool
Return the status if an object exist in the current process.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()
return (
local_store.contains(data_id)
or local_store.is_already_serialized(data_id)
or shared_store.contains(data_id)
)

def get_data(self, data_id):
"""
Get data from any location in the current process.
Parameters
----------
data_id : unidist.core.backends.mpi.core.common.MpiDataID
An ID to data.
Returns
-------
object
Return data associated with `data_id`.
"""
local_store = LocalObjectStore.get_instance()
shared_store = SharedObjectStore.get_instance()

if local_store.contains(data_id):
return local_store.get(data_id)

if local_store.is_already_serialized(data_id):
serialized_data = local_store.get_serialized_data(data_id)
value = deserialize_complex_data(
serialized_data["s_data"],
serialized_data["raw_buffers"],
serialized_data["buffer_count"],
)
elif shared_store.contains(data_id):
value = shared_store.get(data_id)
else:
raise ValueError("The current data ID is not contained in the procces.")
local_store.put(data_id, value)
return value
6 changes: 4 additions & 2 deletions unidist/core/backends/mpi/core/shared_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,11 @@ def get(self, data_id, owner_rank=None, shared_info=None):
An ID to data.
owner_rank : int, default: None
The rank that sent the data.
This value is used to synchronize data in shared memory between different hosts if the value is defined.
This value is used to synchronize data in shared memory between different hosts
if the value is not ``None``.
shared_info : dict, default: None
The necessary information to properly deserialize data from shared memory. If `shared_info` is None, the data already exists in shared memory.
The necessary information to properly deserialize data from shared memory.
If `shared_info` is ``None``, the data already exists in shared memory in the current process.
"""
if shared_info is None:
shared_info = self.get_shared_info(data_id)
Expand Down
6 changes: 4 additions & 2 deletions unidist/core/backends/mpi/core/worker/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.worker.request_store import RequestStore
from unidist.core.backends.mpi.core.worker.task_store import TaskStore
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.controller.common import pull_data, get_data
from unidist.core.backends.mpi.core.controller.common import pull_data
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore

# TODO: Find a way to move this after all imports
Expand Down Expand Up @@ -85,6 +86,7 @@ async def worker_loop():
``unidist.core.backends.mpi.core.common.Operations`` defines a set of supported operations.
"""
task_store = TaskStore.get_instance()
object_store = ObjectStore.get_instance()
local_store = LocalObjectStore.get_instance()
request_store = RequestStore.get_instance()
async_operations = AsyncOperations.get_instance()
Expand Down Expand Up @@ -185,7 +187,7 @@ async def worker_loop():
if not ready_to_shutdown_posted:
# Prepare the data
# Actor method here is a data id so we have to retrieve it from the storage
method_name = get_data(request["task"])
method_name = object_store.get_data(request["task"])
handler = request["handler"]
actor_method = getattr(actor_map[handler], method_name)
request["task"] = actor_method
Expand Down
9 changes: 6 additions & 3 deletions unidist/core/backends/mpi/core/worker/request_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.controller.common import push_data, contains_data
from unidist.core.backends.mpi.core.controller.common import push_data
from unidist.core.backends.mpi.core.object_store import ObjectStore


mpi_state = communication.MPIState.get_instance()
Expand Down Expand Up @@ -212,7 +213,8 @@ def process_wait_request(self, data_id):
-----
Only ROOT rank is supported for now, therefore no rank argument needed.
"""
if contains_data(data_id):
object_store = ObjectStore.get_instance()
if object_store.contains_data(data_id):
# Executor wait just for signal
# We use a blocking send here because the receiver is waiting for the result.
communication.mpi_send_object(
Expand Down Expand Up @@ -246,7 +248,8 @@ def process_get_request(self, source_rank, data_id, is_blocking_op=False):
-----
Request is asynchronous, no wait for the data sending.
"""
if contains_data(data_id):
object_store = ObjectStore.get_instance()
if object_store.contains_data(data_id):
push_data(
source_rank,
data_id,
Expand Down
10 changes: 6 additions & 4 deletions unidist/core/backends/mpi/core/worker/task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from unidist.core.backends.common.data_id import is_data_id
import unidist.core.backends.mpi.core.common as common
import unidist.core.backends.mpi.core.communication as communication
from unidist.core.backends.mpi.core.controller.common import get_data, contains_data
from unidist.core.backends.mpi.core.async_operations import AsyncOperations
from unidist.core.backends.mpi.core.object_store import ObjectStore
from unidist.core.backends.mpi.core.local_object_store import LocalObjectStore
from unidist.core.backends.mpi.core.shared_object_store import SharedObjectStore
from unidist.core.backends.mpi.core.serialization import serialize_complex_data
Expand Down Expand Up @@ -188,8 +188,9 @@ def unwrap_local_data_id(self, arg):
"""
if is_data_id(arg):
local_store = LocalObjectStore.get_instance()
if contains_data(arg):
value = get_data(arg)
object_store = ObjectStore.get_instance()
if object_store.contains_data(arg):
value = object_store.get_data(arg)
# Data is already local or was pushed from master
return value, False
elif local_store.contains_data_owner(arg):
Expand Down Expand Up @@ -418,12 +419,13 @@ def process_task_request(self, request):
dict or None
Same request if the task couldn`t be executed, otherwise ``None``.
"""
object_store = ObjectStore.get_instance()
# Parse request
task = request["task"]
# Remote function here is a data id so we have to retrieve it from the storage,
# whereas actor method is already materialized in the worker loop.
if is_data_id(task):
task = get_data(task)
task = object_store.get_data(task)
args = request["args"]
kwargs = request["kwargs"]
output_ids = request["output"]
Expand Down

0 comments on commit ebcb62b

Please sign in to comment.