diff --git a/cylc/flow/network/__init__.py b/cylc/flow/network/__init__.py index 916b129e244..96da617db40 100644 --- a/cylc/flow/network/__init__.py +++ b/cylc/flow/network/__init__.py @@ -13,279 +13,18 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""Package for network interfaces to Cylc scheduler objects.""" -import asyncio -import getpass -import json -from typing import Optional, Tuple +"""Cylc networking code. -import zmq -import zmq.asyncio -import zmq.auth +Contains: +* Server code (hosted by the scheduler process). +* Client implementations (used to communicate with the scheduler). +* Workflow scanning logic. +* Schema and interface definitions. +""" -from cylc.flow import LOG -from cylc.flow.exceptions import ( - ClientError, - CylcError, - CylcVersionError, - ServiceFileError, - WorkflowStopped -) -from cylc.flow.hostuserutil import get_fqdn_by_host -from cylc.flow.workflow_files import ( - ContactFileFields, - KeyType, - KeyOwner, - KeyInfo, - load_contact_file, - get_workflow_srv_dir -) - -API = 5 # cylc API version -MSG_TIMEOUT = "TIMEOUT" - - -def encode_(message): - """Convert the structure holding a message field from JSON to a string.""" - try: - return json.dumps(message) - except TypeError as exc: - return json.dumps({'errors': [{'message': str(exc)}]}) - - -def decode_(message): - """Convert an encoded message string to JSON with an added 'user' field.""" - msg = json.loads(message) - msg['user'] = getpass.getuser() # assume this is the user - return msg - - -def get_location(workflow: str) -> Tuple[str, int, int]: - """Extract host and port from a workflow's contact file. - - NB: if it fails to load the workflow contact file, it will exit. - - Args: - workflow: workflow ID - Returns: - Tuple (host name, port number, publish port number) - Raises: - WorkflowStopped: if the workflow is not running. - CylcVersionError: if target is a Cylc 7 (or earlier) workflow. - """ - try: - contact = load_contact_file(workflow) - except (IOError, ValueError, ServiceFileError): - # Contact file does not exist or corrupted, workflow should be dead - raise WorkflowStopped(workflow) - - host = contact[ContactFileFields.HOST] - host = get_fqdn_by_host(host) - port = int(contact[ContactFileFields.PORT]) - if ContactFileFields.PUBLISH_PORT in contact: - pub_port = int(contact[ContactFileFields.PUBLISH_PORT]) - else: - version = contact.get('CYLC_VERSION', None) - raise CylcVersionError(version=version) - return host, port, pub_port - - -class ZMQSocketBase: - """Initiate the ZMQ socket bind for specified pattern. - - NOTE: Security to be provided via zmq.auth (see PR #3359). - - Args: - pattern (enum): ZeroMQ message pattern (zmq.PATTERN). - - context (object, optional): instantiated ZeroMQ context, defaults - to zmq.asyncio.Context(). - - This class is designed to be inherited by REP Server (REQ/REP) - and by PUB Publisher (PUB/SUB), as the start-up logic is similar. - - - To tailor this class overwrite it's method on inheritance. - - """ - - def __init__( - self, - pattern, - workflow: str, - bind: bool = False, - context: Optional[zmq.Context] = None, - ): - self.bind = bind - if context is None: - self.context: zmq.Context = zmq.asyncio.Context() - else: - self.context = context - self.pattern = pattern - self.workflow = workflow - self.host: Optional[str] = None - self.port: Optional[int] = None - self.socket: Optional[zmq.Socket] = None - self.loop: Optional[asyncio.AbstractEventLoop] = None - self.stopping = False - - def start(self, *args, **kwargs): - """Create the async loop, and bind socket.""" - # set asyncio loop - try: - self.loop = asyncio.get_running_loop() - except RuntimeError: - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - if self.bind: - self._socket_bind(*args, **kwargs) - else: - self._socket_connect(*args, **kwargs) - - # initiate bespoke items - self._bespoke_start() - - # Keeping srv_prv_key_loc as optional arg so as to not break interface - def _socket_bind(self, min_port, max_port, srv_prv_key_loc=None): - """Bind socket. - - Will use a port range provided to select random ports. - - """ - if srv_prv_key_loc is None: - # Create new KeyInfo object for the server private key - workflow_srv_dir = get_workflow_srv_dir(self.workflow) - srv_prv_key_info = KeyInfo( - KeyType.PRIVATE, - KeyOwner.SERVER, - workflow_srv_dir=workflow_srv_dir) - else: - srv_prv_key_info = KeyInfo( - KeyType.PRIVATE, - KeyOwner.SERVER, - full_key_path=srv_prv_key_loc) - - # create socket - self.socket = self.context.socket(self.pattern) - self._socket_options() - - try: - server_public_key, server_private_key = zmq.auth.load_certificate( - srv_prv_key_info.full_key_path) - except ValueError: - raise ServiceFileError( - f"Failed to find server's public " - f"key in " - f"{srv_prv_key_info.full_key_path}." - ) - except OSError: - raise ServiceFileError( - f"IO error opening server's private " - f"key from " - f"{srv_prv_key_info.full_key_path}." - ) - if server_private_key is None: # this can't be caught by exception - raise ServiceFileError( - f"Failed to find server's private " - f"key in " - f"{srv_prv_key_info.full_key_path}." - ) - self.socket.curve_publickey = server_public_key - self.socket.curve_secretkey = server_private_key - self.socket.curve_server = True - - try: - if min_port == max_port: - self.port = min_port - self.socket.bind(f'tcp://*:{min_port}') - else: - self.port = self.socket.bind_to_random_port( - 'tcp://*', min_port, max_port) - except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc: - raise CylcError(f'could not start Cylc ZMQ server: {exc}') - - # Keeping srv_public_key_loc as optional arg so as to not break interface - def _socket_connect(self, host, port, srv_public_key_loc=None): - """Connect socket to stub.""" - workflow_srv_dir = get_workflow_srv_dir(self.workflow) - if srv_public_key_loc is None: - # Create new KeyInfo object for the server public key - srv_pub_key_info = KeyInfo( - KeyType.PUBLIC, - KeyOwner.SERVER, - workflow_srv_dir=workflow_srv_dir) - - else: - srv_pub_key_info = KeyInfo( - KeyType.PUBLIC, - KeyOwner.SERVER, - full_key_path=srv_public_key_loc) - - self.host = host - self.port = port - self.socket = self.context.socket(self.pattern) - self._socket_options() - - client_priv_key_info = KeyInfo( - KeyType.PRIVATE, - KeyOwner.CLIENT, - workflow_srv_dir=workflow_srv_dir) - error_msg = "Failed to find user's private key, so cannot connect." - try: - client_public_key, client_priv_key = zmq.auth.load_certificate( - client_priv_key_info.full_key_path) - except (OSError, ValueError): - raise ClientError(error_msg) - if client_priv_key is None: # this can't be caught by exception - raise ClientError(error_msg) - self.socket.curve_publickey = client_public_key - self.socket.curve_secretkey = client_priv_key - - # A client can only connect to the server if it knows its public key, - # so we grab this from the location it was created on the filesystem: - try: - # 'load_certificate' will try to load both public & private keys - # from a provided file but will return None, not throw an error, - # for the latter item if not there (as for all public key files) - # so it is OK to use; there is no method to load only the - # public key. - server_public_key = zmq.auth.load_certificate( - srv_pub_key_info.full_key_path)[0] - self.socket.curve_serverkey = server_public_key - except (OSError, ValueError): # ValueError raised w/ no public key - raise ClientError( - "Failed to load the workflow's public key, so cannot connect.") - - self.socket.connect(f'tcp://{host}:{port}') - - def _socket_options(self): - """Set socket options. - - i.e. self.socket.sndhwm - """ - self.socket.sndhwm = 10000 - - def _bespoke_start(self): - """Initiate bespoke items at start.""" - self.stopping = False - - def stop(self, stop_loop=True): - """Stop the server. - - Args: - stop_loop (Boolean): Stop running IOLoop. - - """ - self._bespoke_stop() - if stop_loop and self.loop and self.loop.is_running(): - self.loop.stop() - if self.socket and not self.socket.closed: - self.socket.close() - LOG.debug('...stopped') - - def _bespoke_stop(self): - """Bespoke stop items.""" - LOG.debug('stopping zmq socket...') - self.stopping = True +# Cylc API version. +# This is the Cylc protocol version number that determines whether a client can +# communicate with a server. This should be changed when breaking changes are +# made for which backwards compatibility can not be provided. +API = 5 diff --git a/cylc/flow/network/base.py b/cylc/flow/network/base.py new file mode 100644 index 00000000000..1842407b448 --- /dev/null +++ b/cylc/flow/network/base.py @@ -0,0 +1,237 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Base ZMQ socket implementation for network server/client implementations.""" + +import asyncio +from typing import Optional + +import zmq +import zmq.asyncio +import zmq.auth + +from cylc.flow import LOG +from cylc.flow.exceptions import ( + ClientError, + CylcError, + ServiceFileError, +) +from cylc.flow.workflow_files import ( + KeyType, + KeyOwner, + KeyInfo, + get_workflow_srv_dir, +) + + +class ZMQSocketBase: + """Initiate the ZMQ socket bind for specified pattern. + + NOTE: Security to be provided via zmq.auth (see PR #3359). + + Args: + pattern (enum): ZeroMQ message pattern (zmq.PATTERN). + + context (object, optional): instantiated ZeroMQ context, defaults + to zmq.asyncio.Context(). + + This class is designed to be inherited by REP Server (REQ/REP) + and by PUB Publisher (PUB/SUB), as the start-up logic is similar. + + + To tailor this class overwrite it's method on inheritance. + + """ + + def __init__( + self, + pattern, + workflow: str, + bind: bool = False, + context: Optional[zmq.Context] = None, + ): + self.bind = bind + if context is None: + self.context: zmq.Context = zmq.asyncio.Context() + else: + self.context = context + self.pattern = pattern + self.workflow = workflow + self.host: Optional[str] = None + self.port: Optional[int] = None + self.socket: Optional[zmq.Socket] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + self.stopping = False + + def start(self, *args, **kwargs): + """Create the async loop, and bind socket.""" + # set asyncio loop + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + if self.bind: + self._socket_bind(*args, **kwargs) + else: + self._socket_connect(*args, **kwargs) + + # initiate bespoke items + self._bespoke_start() + + # Keeping srv_prv_key_loc as optional arg so as to not break interface + def _socket_bind(self, min_port, max_port, srv_prv_key_loc=None): + """Bind socket. + + Will use a port range provided to select random ports. + + """ + if srv_prv_key_loc is None: + # Create new KeyInfo object for the server private key + workflow_srv_dir = get_workflow_srv_dir(self.workflow) + srv_prv_key_info = KeyInfo( + KeyType.PRIVATE, + KeyOwner.SERVER, + workflow_srv_dir=workflow_srv_dir) + else: + srv_prv_key_info = KeyInfo( + KeyType.PRIVATE, + KeyOwner.SERVER, + full_key_path=srv_prv_key_loc) + + # create socket + self.socket = self.context.socket(self.pattern) + self._socket_options() + + try: + server_public_key, server_private_key = zmq.auth.load_certificate( + srv_prv_key_info.full_key_path) + except ValueError: + raise ServiceFileError( + f"Failed to find server's public " + f"key in " + f"{srv_prv_key_info.full_key_path}." + ) + except OSError: + raise ServiceFileError( + f"IO error opening server's private " + f"key from " + f"{srv_prv_key_info.full_key_path}." + ) + if server_private_key is None: # this can't be caught by exception + raise ServiceFileError( + f"Failed to find server's private " + f"key in " + f"{srv_prv_key_info.full_key_path}." + ) + self.socket.curve_publickey = server_public_key + self.socket.curve_secretkey = server_private_key + self.socket.curve_server = True + + try: + if min_port == max_port: + self.port = min_port + self.socket.bind(f'tcp://*:{min_port}') + else: + self.port = self.socket.bind_to_random_port( + 'tcp://*', min_port, max_port) + except (zmq.error.ZMQError, zmq.error.ZMQBindError) as exc: + raise CylcError(f'could not start Cylc ZMQ server: {exc}') + + # Keeping srv_public_key_loc as optional arg so as to not break interface + def _socket_connect(self, host, port, srv_public_key_loc=None): + """Connect socket to stub.""" + workflow_srv_dir = get_workflow_srv_dir(self.workflow) + if srv_public_key_loc is None: + # Create new KeyInfo object for the server public key + srv_pub_key_info = KeyInfo( + KeyType.PUBLIC, + KeyOwner.SERVER, + workflow_srv_dir=workflow_srv_dir) + + else: + srv_pub_key_info = KeyInfo( + KeyType.PUBLIC, + KeyOwner.SERVER, + full_key_path=srv_public_key_loc) + + self.host = host + self.port = port + self.socket = self.context.socket(self.pattern) + self._socket_options() + + client_priv_key_info = KeyInfo( + KeyType.PRIVATE, + KeyOwner.CLIENT, + workflow_srv_dir=workflow_srv_dir) + error_msg = "Failed to find user's private key, so cannot connect." + try: + client_public_key, client_priv_key = zmq.auth.load_certificate( + client_priv_key_info.full_key_path) + except (OSError, ValueError): + raise ClientError(error_msg) + if client_priv_key is None: # this can't be caught by exception + raise ClientError(error_msg) + self.socket.curve_publickey = client_public_key + self.socket.curve_secretkey = client_priv_key + + # A client can only connect to the server if it knows its public key, + # so we grab this from the location it was created on the filesystem: + try: + # 'load_certificate' will try to load both public & private keys + # from a provided file but will return None, not throw an error, + # for the latter item if not there (as for all public key files) + # so it is OK to use; there is no method to load only the + # public key. + server_public_key = zmq.auth.load_certificate( + srv_pub_key_info.full_key_path)[0] + self.socket.curve_serverkey = server_public_key + except (OSError, ValueError): # ValueError raised w/ no public key + raise ClientError( + "Failed to load the workflow's public key, so cannot connect.") + + self.socket.connect(f'tcp://{host}:{port}') + + def _socket_options(self): + """Set socket options. + + i.e. self.socket.sndhwm + """ + self.socket.sndhwm = 10000 + + def _bespoke_start(self): + """Initiate bespoke items at start.""" + self.stopping = False + + def stop(self, stop_loop=True): + """Stop the server. + + Args: + stop_loop (Boolean): Stop running IOLoop. + + """ + self._bespoke_stop() + if stop_loop and self.loop and self.loop.is_running(): + self.loop.stop() + if self.socket and not self.socket.closed: + self.socket.close() + LOG.debug('...stopped') + + def _bespoke_stop(self): + """Bespoke stop items.""" + LOG.debug('stopping zmq socket...') + self.stopping = True diff --git a/cylc/flow/network/client.py b/cylc/flow/network/client.py index e7e26954d56..6f8206ee786 100644 --- a/cylc/flow/network/client.py +++ b/cylc/flow/network/client.py @@ -35,14 +35,14 @@ WorkflowStopped, ) from cylc.flow.hostuserutil import get_fqdn_by_host -from cylc.flow.network import ( +from cylc.flow.network.base import ZMQSocketBase +from cylc.flow.network.client_factory import CommsMeth +from cylc.flow.network.server import PB_METHOD_MAP +from cylc.flow.network.util import ( encode_, decode_, get_location, - ZMQSocketBase ) -from cylc.flow.network.client_factory import CommsMeth -from cylc.flow.network.server import PB_METHOD_MAP from cylc.flow.workflow_files import ( detect_old_contact_file, ) diff --git a/cylc/flow/network/publisher.py b/cylc/flow/network/publisher.py index 70d40d3cdb9..78574f9e8c5 100644 --- a/cylc/flow/network/publisher.py +++ b/cylc/flow/network/publisher.py @@ -21,7 +21,7 @@ import zmq from cylc.flow import LOG -from cylc.flow.network import ZMQSocketBase +from cylc.flow.network.base import ZMQSocketBase def serialize_data( diff --git a/cylc/flow/network/replier.py b/cylc/flow/network/replier.py index 09bfb55f662..a40756c05b4 100644 --- a/cylc/flow/network/replier.py +++ b/cylc/flow/network/replier.py @@ -21,7 +21,8 @@ import zmq from cylc.flow import LOG -from cylc.flow.network import encode_, decode_, ZMQSocketBase +from cylc.flow.network.base import ZMQSocketBase +from cylc.flow.network.util import encode_, decode_ if TYPE_CHECKING: from cylc.flow.network.server import WorkflowRuntimeServer diff --git a/cylc/flow/network/subscriber.py b/cylc/flow/network/subscriber.py index 66bd16f81f8..28d5b5d1bb2 100644 --- a/cylc/flow/network/subscriber.py +++ b/cylc/flow/network/subscriber.py @@ -22,8 +22,9 @@ import zmq -from cylc.flow.network import ZMQSocketBase, get_location from cylc.flow.data_store_mgr import DELTAS_MAP +from cylc.flow.network.base import ZMQSocketBase +from cylc.flow.network.util import get_location if TYPE_CHECKING: import zmq.asyncio diff --git a/cylc/flow/network/util.py b/cylc/flow/network/util.py new file mode 100644 index 00000000000..6d1a006060d --- /dev/null +++ b/cylc/flow/network/util.py @@ -0,0 +1,77 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Common networking utilities.""" + +import getpass +import json +from typing import Tuple + +from cylc.flow.exceptions import ( + CylcVersionError, + ServiceFileError, + WorkflowStopped +) +from cylc.flow.hostuserutil import get_fqdn_by_host +from cylc.flow.workflow_files import ( + ContactFileFields, + load_contact_file, +) + + +def encode_(message): + """Convert the structure holding a message field from JSON to a string.""" + try: + return json.dumps(message) + except TypeError as exc: + return json.dumps({'errors': [{'message': str(exc)}]}) + + +def decode_(message): + """Convert an encoded message string to JSON with an added 'user' field.""" + msg = json.loads(message) + msg['user'] = getpass.getuser() # assume this is the user + return msg + + +def get_location(workflow: str) -> Tuple[str, int, int]: + """Extract host and port from a workflow's contact file. + + NB: if it fails to load the workflow contact file, it will exit. + + Args: + workflow: workflow ID + Returns: + Tuple (host name, port number, publish port number) + Raises: + WorkflowStopped: if the workflow is not running. + CylcVersionError: if target is a Cylc 7 (or earlier) workflow. + """ + try: + contact = load_contact_file(workflow) + except (IOError, ValueError, ServiceFileError): + # Contact file does not exist or corrupted, workflow should be dead + raise WorkflowStopped(workflow) + + host = contact[ContactFileFields.HOST] + host = get_fqdn_by_host(host) + port = int(contact[ContactFileFields.PORT]) + if ContactFileFields.PUBLISH_PORT in contact: + pub_port = int(contact[ContactFileFields.PUBLISH_PORT]) + else: + version = contact.get('CYLC_VERSION', None) + raise CylcVersionError(version=version) + return host, port, pub_port diff --git a/cylc/flow/scripts/subscribe.py b/cylc/flow/scripts/subscribe.py index 5d174718c23..bd42dd6c90d 100755 --- a/cylc/flow/scripts/subscribe.py +++ b/cylc/flow/scripts/subscribe.py @@ -34,7 +34,7 @@ WORKFLOW_ID_ARG_DOC, CylcOptionParser as COP, ) -from cylc.flow.network import get_location +from cylc.flow.network.util import get_location from cylc.flow.network.subscriber import WorkflowSubscriber, process_delta_msg from cylc.flow.terminal import cli_function from cylc.flow.data_store_mgr import DELTAS_MAP diff --git a/tests/integration/test_replier.py b/tests/integration/test_replier.py index ce0b53fdaa8..7e219e8dd44 100644 --- a/tests/integration/test_replier.py +++ b/tests/integration/test_replier.py @@ -14,11 +14,12 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from async_timeout import timeout -from cylc.flow.network import decode_ -from cylc.flow.network.client import WorkflowRuntimeClient import asyncio +from cylc.flow.network.client import WorkflowRuntimeClient +from cylc.flow.network.util import decode_ + +from async_timeout import timeout import pytest diff --git a/tests/integration/test_zmq.py b/tests/integration/test_zmq.py index 24c8db6d9b0..41ef6be2767 100644 --- a/tests/integration/test_zmq.py +++ b/tests/integration/test_zmq.py @@ -18,7 +18,7 @@ import zmq from cylc.flow.exceptions import CylcError -from cylc.flow.network import ZMQSocketBase +from cylc.flow.network.base import ZMQSocketBase from .key_setup import setup_keys diff --git a/tests/unit/network/test__init__.py b/tests/unit/network/test_util.py similarity index 84% rename from tests/unit/network/test__init__.py rename to tests/unit/network/test_util.py index 71c32cf9bd1..aaed85dabce 100644 --- a/tests/unit/network/test__init__.py +++ b/tests/unit/network/test_util.py @@ -19,8 +19,8 @@ import cylc from cylc.flow.exceptions import CylcVersionError -from cylc.flow.network import get_location -from cylc.flow.workflow_files import load_contact_file, ContactFileFields +from cylc.flow.network.util import get_location +from cylc.flow.workflow_files import ContactFileFields BASE_CONTACT_DATA = { @@ -33,7 +33,7 @@ def mpatch_get_fqdn_by_host(monkeypatch): """Monkeypatch function used the same by all tests.""" monkeypatch.setattr( - cylc.flow.network, 'get_fqdn_by_host', lambda _ : 'myhost.x.y.z' + cylc.flow.network.util, 'get_fqdn_by_host', lambda _ : 'myhost.x.y.z' ) @@ -42,7 +42,7 @@ def test_get_location_ok(monkeypatch, mpatch_get_fqdn_by_host): contact_data = BASE_CONTACT_DATA.copy() contact_data[ContactFileFields.PUBLISH_PORT] = '8042' monkeypatch.setattr( - cylc.flow.network, 'load_contact_file', lambda _ : contact_data + cylc.flow.network.util, 'load_contact_file', lambda _ : contact_data ) assert get_location('_') == ( 'myhost.x.y.z', 42, 8042 @@ -55,7 +55,7 @@ def test_get_location_old_contact_file(monkeypatch, mpatch_get_fqdn_by_host): contact_data['CYLC_SUITE_PUBLISH_PORT'] = '8042' contact_data['CYLC_VERSION'] = '5.1.2' monkeypatch.setattr( - cylc.flow.network, 'load_contact_file', lambda _ : contact_data + cylc.flow.network.util, 'load_contact_file', lambda _ : contact_data ) with pytest.raises(CylcVersionError, match=r'.*5.1.2.*') as exc: get_location('_')