diff --git a/nslsii/md_dict.py b/nslsii/md_dict.py index e370698d..786d95d2 100644 --- a/nslsii/md_dict.py +++ b/nslsii/md_dict.py @@ -1,6 +1,6 @@ import logging +import re from collections import ChainMap, UserDict -from pprint import pformat from uuid import uuid4 import msgpack @@ -14,6 +14,24 @@ class RunEngineRedisDict(UserDict): + """ + A class for storing RunEngine metadata added to RE.md on a Redis server. + + This class has two strong ideas about the metadata it manages: + 1. Some key-values are considered "global" or "facility-wide". These + are in use at all NSLS-II bluesky beamlines and include + proposal_id, data_session, cycle, SAF, and scan_id. The "global" + key-values are stored as Redis key-values. Redis does not support + numeric types, so the RunEngineRedisDict also keeps track of the + types of the "global" key-values. The intention is that this + metadata is accessible by any Redis client. + 2. Non-global, or beamline-specific, metadata is stored in Redis as + a msgpack-ed blob. This means data type conversion between Redis + and Python is handled by msgpack, including numpy arrays. + The drawback is that this "local" metadata key-values are not + directly readable or writeable by Redis clients. + """ + PACKED_RUNENGINE_METADATA_KEY = "runengine-metadata-blob" def __init__( @@ -23,6 +41,7 @@ def __init__( db=0, re_md_channel_name="runengine-metadata", global_keys=None, + global_values_types=None, ): # send no initial data to UserDict.__init__ # since we will replace UserDict.data entirely @@ -34,12 +53,13 @@ def __init__( self._re_md_channel_name = re_md_channel_name self._uuid = str(uuid4()) - redis_dict_log.info(f"connecting to Redis at %s:%s", self._host, self._port) + redis_dict_log.info("connecting to Redis at %s:%s", self._host, self._port) # global metadata will be stored as Redis key-value pairs - # tell the global Redis client to do bytes-to-str conversion + # tell the global metadata Redis client to do bytes-to-str conversion self._redis_global_client = redis.Redis( host=host, port=port, db=db, decode_responses=True ) + # ping() will raise redis.exceptions.ConnectionError on failure self._redis_global_client.ping() # local metadata will be msgpack-ed, so decoding @@ -48,7 +68,6 @@ def __init__( self._redis_local_client = redis.Redis( host=host, port=port, db=db, decode_responses=False ) - # ping() will raise redis.exceptions.ConnectionError on failure self._redis_local_client.ping() if global_keys is None: @@ -61,33 +80,45 @@ def __init__( "SAF", "scan_id", ) + else: + self._global_keys = global_keys + + if global_values_types is None: + # remember numeric types for global metadata + # global metadata keys not specified here will default to str + self._global_values_types = {"scan_id": int} + else: + self._global_values_types = global_values_types # is local metadata already in redis? packed_local_md = self._redis_local_client.get( self.PACKED_RUNENGINE_METADATA_KEY ) if packed_local_md is None: - redis_dict_log.info(f"no local metadata found in Redis") + redis_dict_log.info("no local metadata found in Redis") self._local_md = dict() self._set_local_metadata_on_server() else: - redis_dict_log.info(f"unpacking local metadata from Redis") + redis_dict_log.info("unpacking local metadata from Redis") self._local_md = self._get_local_metadata_from_server() - redis_dict_log.debug(f"unpacked local metadata:\n%s", self._local_md) + redis_dict_log.debug("unpacked local metadata:\n%s", self._local_md) - # what if the global keys do not exist? - # could get all Redis keys and exclude the local md blob key ? self._global_md = dict() for global_key in self._global_keys: global_value = self._redis_global_client.get(global_key) if global_value is None: - redis_dict_log.info(f"no value yet for global key {global_key}") - self._redis_global_client.set(name=global_key, value=global_key) - self._global_md[global_key] = global_value - redis_dict_log.info(f"global metadata:\n{pformat(self._global_md)}") + # if a global key does not exist on the Redis server + # then it will not exist in the RunEngineRedisDict + redis_dict_log.info("no value yet for global key %s", global_key) + else: + if global_key in self._global_values_types: + global_value = self._global_values_types[global_key](global_value) + self._global_md[global_key] = global_value + redis_dict_log.info("global metadata: %s", self._global_md) - # keep in mind _local_md is the first map in this ChainMap - # for when _local_md has to be replaced + # when self._local_md has to be replaced with the metadata + # blob in Redis we must be careful to replace the first + # dict in the ChainMap's list of mappings self.data = ChainMap(self._local_md, self._global_md) # Redis documentation says do not issue commands from @@ -99,9 +130,12 @@ def __init__( ignore_subscribe_messages=True ) - # register _update_on_message to handle Redis messages + # register self._handle_update_message to handle Redis messages + # this is how the RunEngineMetadataDict knows a key-value + # has been modified on the Redis server, and therefore + # self._local_md must be updated from the server self._redis_pubsub.subscribe( - **{self._re_md_channel_name: self._update_on_message} + **{self._re_md_channel_name: self._handle_update_message} ) # start a thread to pass messages to _update_on_message self._update_on_message_thread = self._redis_pubsub.run_in_thread( @@ -109,8 +143,21 @@ def __init__( ) def __setitem__(self, key, value): - if key in self._global_md: + if key in self._global_keys: + # can't rely on self._global_md for this check because + # if global metadata is not in Redis it is not added to self._global_md redis_dict_log.debug("setting global metadata %s:%s", key, value) + # global metadata may be constrained to be of a certain type + # check that value does not violate the type expected for key + expected_value_type = self._global_values_types.get(key, str) + if isinstance(value, expected_value_type): + # everything is good + pass + else: + raise ValueError( + f"expected value for key '{key}' to have type '{expected_value_type}'" + f"but '{value}' has type '{type(value)}'" + ) # update the global key-value pair explicitly in self._global_md # because it can not be updated through the self.data ChainMap # since self._global_md is not the first dictionary in that ChainMap @@ -129,7 +176,7 @@ def __setitem__(self, key, value): # tell subscribers a key-value has changed redis_dict_log.debug("publishing update %s:%s", key, value) - self._publish_metadata_update(key) + self._publish_metadata_update_message(key) def __delitem__(self, key): if key in self._global_keys: @@ -139,9 +186,9 @@ def __delitem__(self, key): self._set_local_metadata_on_server() # tell everyone a (local) key-value has been changed - self._publish_metadata_update(key) + self._publish_metadata_update_message(key) - def _publish_metadata_update(self, key): + def _publish_metadata_update_message(self, key): """ Publish a message that includes the updated key and the identifying UUID for this RunEngineRedisDict. @@ -162,38 +209,54 @@ def _set_local_metadata_on_server(self): self.PACKED_RUNENGINE_METADATA_KEY, self._pack(self._local_md) ) - @staticmethod - def _parse_message_data(message): + _message_data_pattern = re.compile(r"^(?P.+):(?P.+)$") + + @classmethod + def _parse_message_data(klass, message): """ - The message parameter looks like this - b"abd:39f1f7fa-aeef-4d83-a802-c1c7f5ff5cb8" + message["data"] should look like this + b"abc:39f1f7fa-aeef-4d83-a802-c1c7f5ff5cb8" Splitting the message on ":" gives the updated key - and the UUID of the RunEngineRedisDict that made - the update. + ("abc" in this example) and the UUID of the RunEngineRedisDict + that made the update. The UUID is used to determine if + the update message came from this RunEngineRedisDict, in + which case it is not necessary to update the local metadata + from the Redis server. """ - message_key, publisher_uuid = message["data"].rsplit(b":", maxsplit=1) - return message_key.decode(), publisher_uuid.decode() + decoded_message_data = message["data"].decode() + message_data_match = klass._message_data_pattern.match(decoded_message_data) - def _update_on_message(self, message): + if message_data_match is None: + raise ValueError( + f"message[data]=`{decoded_message_data}` could not be parsed" + ) + return message_data_match.group("key"), message_data_match.group("uuid") + + def _handle_update_message(self, message): redis_dict_log.debug("_update_on_message: %s", message) updated_key, publisher_uuid = self._parse_message_data(message) if publisher_uuid == self._uuid: + # this RunEngineRedisDict is the source of this update message, + # so there is no need to go to the Redis server for the new metadata redis_dict_log.debug("update published by me!") - pass elif updated_key in self._global_keys: redis_dict_log.debug("updated key belongs to global metadata") - # we can assume the updated_key is not a new key - # get the key from the Redis database - self._global_md[updated_key] = self._redis_global_client.get( - name=updated_key - ) + # because the updated_key belongs to "global" metadata + # we can assume it is not a new or deleted key, so just + # get the key's value from the Redis database and convert + # its type if necessary (eg, from string to int) + updated_value = self._redis_global_client.get(name=updated_key) + if updated_key in self._global_values_types: + updated_value = self._global_values_types[updated_key](updated_value) + self._global_md[updated_key] = updated_value else: redis_dict_log.debug("updated key belongs to local metadata") # the updated key belongs to local metadata # it may be a newly added or deleted key, so # we have to update the entire local metadata dictionary self._local_md = self._get_local_metadata_from_server() - # update the ChainMap + # update the ChainMap - "local" metadata is always the + # first element in ChainMap.maps self.data.maps[0] = self._local_md @staticmethod diff --git a/nslsii/tests/conftest.py b/nslsii/tests/conftest.py index 49c1e875..82ae4d08 100644 --- a/nslsii/tests/conftest.py +++ b/nslsii/tests/conftest.py @@ -1,18 +1,18 @@ -from contextlib import contextmanager +from contextlib import contextmanager # noqa import redis import pytest from bluesky.tests.conftest import RE # noqa -from bluesky_kafka import BlueskyConsumer -from bluesky_kafka.tests.conftest import ( +from bluesky_kafka import BlueskyConsumer # noqa +from bluesky_kafka.tests.conftest import ( # noqa pytest_addoption, kafka_bootstrap_servers, broker_authorization_config, consume_documents_from_kafka_until_first_stop_document, temporary_topics, -) # noqa +) from ophyd.tests.conftest import hw # noqa from nslsii.md_dict import RunEngineRedisDict @@ -23,13 +23,33 @@ def redis_dict_factory(): """ Return a "fixture as a factory" that will build identical RunEngineRedisDicts. Before the factory is returned, the Redis server will be cleared. + + The factory builds only RunEngineRedisDict instances for a Redis server running + on localhost:6379, db=0. + + If "host", "port", or "db" are specified as kwargs to the factory function + an exception will be raised. """ - redis_client = redis.Redis(host="localhost", port=6379, db=0) - redis_client.flushdb() + redis_server_kwargs = { + "host": "localhost", + "port": 6379, + "db": 0, + } - def _factory(re_md_channel_name): - return RunEngineRedisDict(host="localhost", port=6379, db=0, re_md_channel_name=re_md_channel_name) + redis_client = redis.Redis(**redis_server_kwargs) + redis_client.flushdb() - return _factory + def _factory(**kwargs): + disallowed_kwargs_preset = set(redis_server_kwargs.keys()).intersection( + kwargs.keys() + ) + if len(disallowed_kwargs_preset) > 0: + raise KeyError( + f"{disallowed_kwargs_preset} given, but 'host', 'port', and 'db' may not be specified" + ) + else: + kwargs.update(redis_server_kwargs) + return RunEngineRedisDict(**kwargs) + return _factory diff --git a/nslsii/tests/test_redis_dict.py b/nslsii/tests/test_redis_dict.py index d88189d0..8c083bf6 100644 --- a/nslsii/tests/test_redis_dict.py +++ b/nslsii/tests/test_redis_dict.py @@ -31,7 +31,7 @@ def _get_waiting_messages(redis_subscriber): message = redis_subscriber.get_message() if message is None: # it can happen that there are messages - # even if None is returned + # even if None is returned the first time message = redis_subscriber.get_message() while message is not None: message_list.append(message) @@ -45,7 +45,7 @@ def test_instantiate_with_server(redis_dict_factory): """ Instantiate a RunEngineRedisDict and expect success. """ - redis_dict = redis_dict_factory(re_md_channel_name="test_instantiate_with_server") + redis_dict_factory(re_md_channel_name="test_instantiate_with_server") def test_instantiate_no_server(): @@ -58,6 +58,27 @@ def test_instantiate_no_server(): RunEngineRedisDict(host="localhost", port=9999) +def test__parse_message_data(): + """ + Test a simple message "abc:uuid" and + a potentially problematic message "a:b:c:uuid" + """ + message = {"data": b"abc:uuid"} + key, uuid = RunEngineRedisDict._parse_message_data(message) + assert key == "abc" + assert uuid == "uuid" + + # what if the key contains one or more colons? + message = {"data": b"a:b:c:uuid"} + key, uuid = RunEngineRedisDict._parse_message_data(message) + assert key == "a:b:c" + assert uuid == "uuid" + + message = {"data": b"abcuuid"} + with pytest.raises(ValueError): + RunEngineRedisDict._parse_message_data(message) + + def test_local_int_value(redis_dict_factory): """ Test that an integer is stored and retrieved. @@ -82,15 +103,50 @@ def test_local_ndarray_value(redis_dict_factory): """ Test that a numpy NDArray is stored and retrieved. """ - redis_dict = redis_dict_factory(re_md_channel_name="test_local_float_value") + redis_dict = redis_dict_factory(re_md_channel_name="test_local_ndarray_value") redis_dict["local_array"] = np.ones((10, 10)) assert np.array_equal(redis_dict["local_array"], np.ones((10, 10))) +def test_no_global_metadata(redis_dict_factory): + """ + Construct a RunEngineRedisDict with no "global" metadata. + """ + redis_dict = redis_dict_factory( + re_md_channel_name="test_no_global_metadata", global_keys=[] + ) + + assert len(redis_dict) == 0 + + +def test_global_int_value(redis_dict_factory): + """ + Test that an integer is stored and retrieved. + """ + redis_dict_1 = redis_dict_factory(re_md_channel_name="test_global_int_value") + + # scan_id does not exist yet + with pytest.raises(KeyError): + redis_dict_1["scan_id"] + + redis_dict_1["scan_id"] = 0 + assert redis_dict_1["scan_id"] == 0 + + redis_dict_2 = redis_dict_factory(re_md_channel_name="test_global_int_value") + assert redis_dict_2["scan_id"] == 0 + + # expect an exception because "scan_id" is + # constrained to be an integer + with pytest.raises(ValueError): + redis_dict_1["scan_id"] = "one" + + assert redis_dict_1["scan_id"] == 0 + + def test_del_global_key(redis_dict_factory): """ - Test that attempting to delete a "global" key raised KeyError. + Test that attempting to delete a "global" key raises KeyError. """ redis_dict = redis_dict_factory(re_md_channel_name="test_del_global_key") with pytest.raises(KeyError): @@ -132,18 +188,18 @@ def test_items(redis_dict_factory): """ redis_dict = redis_dict_factory(re_md_channel_name="test_items") - # expect to find the global keys with value None - expected_global_items = {gk: None for gk in redis_dict._global_keys} + # no global metadata exists yet actual_global_items = {gk: gv for gk, gv in redis_dict.items()} - assert actual_global_items == expected_global_items + assert actual_global_items == {} # set a value for each global key global_md_updates = {gk: gk for gk in redis_dict._global_keys} + global_md_updates["scan_id"] = 1 redis_dict.update(global_md_updates) + actual_global_items = {gk: gv for gk, gv in redis_dict.items()} # _local_md should still be empty # since only global metadata was updated - # this is not the default behavior of ChainMap! assert len(redis_dict._local_md) == 0 assert actual_global_items == global_md_updates @@ -197,12 +253,51 @@ def test_two_messages(redis_dict_factory): assert publisher_uuid == redis_dict._uuid -def test_synchronization(redis_dict_factory): +def test_global_metadata_synchronization(redis_dict_factory): + """ + Test "global metadata" synchronization between separate RunEngineRedisDicts. + """ + redis_dict_1 = redis_dict_factory( + re_md_channel_name="test_global_metadata_synchronization" + ) + redis_dict_2 = redis_dict_factory( + re_md_channel_name="test_global_metadata_synchronization" + ) + redis_dict_2_subscriber = _build_redis_subscriber(redis_dict_2) + + # make one change + redis_dict_1["proposal_id"] = "PROPOSAL ID" + redis_dict_1["scan_id"] = 0 + + time.sleep(1) + redis_dict_2_messages = _get_waiting_messages(redis_dict_2_subscriber) + assert len(redis_dict_2_messages) == 2 + + assert redis_dict_2["proposal_id"] == "PROPOSAL ID" + assert redis_dict_2["scan_id"] == 0 + + redis_dict_3 = redis_dict_factory( + re_md_channel_name="test_global_metadata_synchronization" + ) + assert redis_dict_3["proposal_id"] == "PROPOSAL ID" + assert redis_dict_3["scan_id"] == 0 + + redis_dict_3["scan_id"] = 1 + time.sleep(1) + assert redis_dict_1["scan_id"] == 1 + assert redis_dict_2["scan_id"] == 1 + + +def test_local_metadata_synchronization(redis_dict_factory): """ - Test synchronization between separate RunEngineRedisDicts. + Test "local metadata" synchronization between separate RunEngineRedisDicts. """ - redis_dict_1 = redis_dict_factory(re_md_channel_name="test_synchronization") - redis_dict_2 = redis_dict_factory(re_md_channel_name="test_synchronization") + redis_dict_1 = redis_dict_factory( + re_md_channel_name="test_local_metadata_synchronization" + ) + redis_dict_2 = redis_dict_factory( + re_md_channel_name="test_local_metadata_synchronization" + ) redis_dict_2_subscriber = _build_redis_subscriber(redis_dict_2) # make one change @@ -220,7 +315,9 @@ def test_synchronization(redis_dict_factory): assert redis_dict_2["float"] == np.pi assert np.array_equal(redis_dict_2["array"], np.ones((10, 10))) - redis_dict_3 = redis_dict_factory(re_md_channel_name="test_synchronization") + redis_dict_3 = redis_dict_factory( + re_md_channel_name="test_local_metadata_synchronization" + ) assert redis_dict_3["string"] == "string" assert redis_dict_3["int"] == 0 assert redis_dict_3["float"] == np.pi