diff --git a/actions.yaml b/actions.yaml deleted file mode 100644 index ce7ef05..0000000 --- a/actions.yaml +++ /dev/null @@ -1,47 +0,0 @@ -show-current-config: - description: > - Display the currently used `slurm.conf`. - - Note: This file only exists in `slurmctld` charm and is automatically - distributed to all compute nodes by Slurm. - - Example usage: - $ juju run-action slurmctld/leader --format=json --wait | jq .[].results.slurm.conf | xargs -I % -0 python3 -c 'print(%)' -drain: - description: > - Drain specified nodes. - - Example usage: - $ juju run-action slurmctld/leader drain nodename=node-[1,2] reason="Updating kernel" - params: - nodename: - type: string - description: The nodes to drain, using the Slurm format, e.g. `node-[1,2]`. - reason: - type: string - description: Reason to drain the nodes. - required: - - nodename - - reason -resume: - description: > - Resume specified nodes. - - Note: Newly added nodes will remain in the `down` state until configured, - with the `node-configured` action. - - Example usage: $ juju run-action slurmctld/leader resume nodename=node-[1,2] - params: - nodename: - type: string - description: > - The nodes to resume, using the Slurm format, e.g. `node-[1,2]`. - required: - - nodename - -influxdb-info: - description: > - Get InfluxDB info. - - This action returns the host, port, username, password, database, and - retention policy regarding to InfluxDB. diff --git a/charmcraft.yaml b/charmcraft.yaml index 125880e..9b57f0f 100644 --- a/charmcraft.yaml +++ b/charmcraft.yaml @@ -1,5 +1,52 @@ # Copyright 2020 Omnivector Solutions, LLC # See LICENSE file for licensing details. +name: slurmctld +summary: | + Slurmctld, the central management daemon of Slurm. +description: | + This charm provides slurmctld, munged, and the bindings to other utilities + that make lifecycle operations a breeze. + + slurmctld is the central management daemon of SLURM. It monitors all other + SLURM daemons and resources, accepts work (jobs), and allocates resources + to those jobs. Given the critical functionality of slurmctld, there may be + a backup server to assume these functions in the event that the primary + server fails. + +links: + contact: https://matrix.to/#/#hpc:ubuntu.com + + issues: + - https://github.com/charmed-hpc/slurmctld-operator/issues + + source: + - https://github.com/charmed-hpc/slurmctld-operator + +peers: + slurmctld-peer: + interface: slurmctld-peer +requires: + slurmd: + interface: slurmd + slurmdbd: + interface: slurmdbd + slurmrestd: + interface: slurmrestd + influxdb-api: + interface: influxdb-api + elasticsearch: + interface: elasticsearch + fluentbit: + interface: fluentbit +provides: + prolog-epilog: + interface: prolog-epilog + grafana-source: + interface: grafana-source + scope: global + +assumes: + - juju type: charm bases: @@ -29,3 +76,145 @@ parts: echo $VERSION > $CRAFT_PART_INSTALL/version stage: - version + +config: + options: + custom-slurm-repo: + type: string + default: "" + description: > + Use a custom repository for Slurm installation. + + This can be set to the Organization's local mirror/cache of packages and + supersedes the Omnivector repositories. Alternatively, it can be used to + track a `testing` Slurm version, e.g. by setting to + `ppa:omnivector/osd-testing`. + + Note: The configuration `custom-slurm-repo` must be set *before* + deploying the units. Changing this value after deploying the units will + not reinstall Slurm. + cluster-name: + type: string + default: osd-cluster + description: > + Name to be recorded in database for jobs from this cluster. + + This is important if a single database is used to record information from + multiple Slurm-managed clusters. + default-partition: + type: string + default: "" + description: > + Default Slurm partition. This is only used if defined, and must match an + existing partition. + custom-config: + type: string + default: "" + description: > + User supplied Slurm configuration. + + This value supplements the charm supplied `slurm.conf` that is used for + Slurm Controller and Compute nodes. + + Example usage: + $ juju config slurmcltd custom-config="FirstJobId=1234" + proctrack-type: + type: string + default: proctrack/cgroup + description: > + Identifies the plugin to be used for process tracking on a job step + basis. + cgroup-config: + type: string + default: | + CgroupAutomount=yes + ConstrainCores=yes + description: > + Configuration content for `cgroup.conf`. + + health-check-params: + default: "" + type: string + description: > + Extra parameters for NHC command. + + This option can be used to customize how NHC is called, e.g. to send an + e-mail to an admin when NHC detects an error set this value to + `-M admin@domain.com`. + health-check-interval: + default: 600 + type: int + description: Interval in seconds between executions of the Health Check. + health-check-state: + default: "ANY,CYCLE" + type: string + description: Only run the Health Check on nodes in this state. + + acct-gather-frequency: + type: string + default: "task=30" + description: > + Accounting and profiling sampling intervals for the acct_gather plugins. + + Note: A value of `0` disables the periodic sampling. In this case, the + accounting information is collected when the job terminates. + + Example usage: + $ juju config slurmcltd acct-gather-frequency="task=30,network=30" + acct-gather-custom: + type: string + default: "" + description: > + User supplied `acct_gather.conf` configuration. + + This value supplements the charm supplied `acct_gather.conf` file that is + used for configuring the acct_gather plugins. + +actions: + show-current-config: + description: > + Display the currently used `slurm.conf`. + + Note: This file only exists in `slurmctld` charm and is automatically + distributed to all compute nodes by Slurm. + + Example usage: + $ juju run-action slurmctld/leader --format=json --wait | jq .[].results.slurm.conf | xargs -I % -0 python3 -c 'print(%)' + drain: + description: > + Drain specified nodes. + + Example usage: + $ juju run-action slurmctld/leader drain nodename=node-[1,2] reason="Updating kernel" + params: + nodename: + type: string + description: The nodes to drain, using the Slurm format, e.g. `node-[1,2]`. + reason: + type: string + description: Reason to drain the nodes. + required: + - nodename + - reason + resume: + description: > + Resume specified nodes. + + Note: Newly added nodes will remain in the `down` state until configured, + with the `node-configured` action. + + Example usage: $ juju run-action slurmctld/leader resume nodename=node-[1,2] + params: + nodename: + type: string + description: > + The nodes to resume, using the Slurm format, e.g. `node-[1,2]`. + required: + - nodename + + influxdb-info: + description: > + Get InfluxDB info. + + This action returns the host, port, username, password, database, and + retention policy regarding to InfluxDB. diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 2d4798e..0000000 --- a/config.yaml +++ /dev/null @@ -1,91 +0,0 @@ -options: - custom-slurm-repo: - type: string - default: "" - description: > - Use a custom repository for Slurm installation. - - This can be set to the Organization's local mirror/cache of packages and - supersedes the Omnivector repositories. Alternatively, it can be used to - track a `testing` Slurm version, e.g. by setting to - `ppa:omnivector/osd-testing`. - - Note: The configuration `custom-slurm-repo` must be set *before* - deploying the units. Changing this value after deploying the units will - not reinstall Slurm. - cluster-name: - type: string - default: osd-cluster - description: > - Name to be recorded in database for jobs from this cluster. - - This is important if a single database is used to record information from - multiple Slurm-managed clusters. - default-partition: - type: string - default: "" - description: > - Default Slurm partition. This is only used if defined, and must match an - existing partition. - custom-config: - type: string - default: "" - description: > - User supplied Slurm configuration. - - This value supplements the charm supplied `slurm.conf` that is used for - Slurm Controller and Compute nodes. - - Example usage: - $ juju config slurmcltd custom-config="FirstJobId=1234" - proctrack-type: - type: string - default: proctrack/cgroup - description: > - Identifies the plugin to be used for process tracking on a job step - basis. - cgroup-config: - type: string - default: | - CgroupAutomount=yes - ConstrainCores=yes - description: > - Configuration content for `cgroup.conf`. - - health-check-params: - default: "" - type: string - description: > - Extra parameters for NHC command. - - This option can be used to customize how NHC is called, e.g. to send an - e-mail to an admin when NHC detects an error set this value to - `-M admin@domain.com`. - health-check-interval: - default: 600 - type: int - description: Interval in seconds between executions of the Health Check. - health-check-state: - default: "ANY,CYCLE" - type: string - description: Only run the Health Check on nodes in this state. - - acct-gather-frequency: - type: string - default: "task=30" - description: > - Accounting and profiling sampling intervals for the acct_gather plugins. - - Note: A value of `0` disables the periodic sampling. In this case, the - accounting information is collected when the job terminates. - - Example usage: - $ juju config slurmcltd acct-gather-frequency="task=30,network=30" - acct-gather-custom: - type: string - default: "" - description: > - User supplied `acct_gather.conf` configuration. - - This value supplements the charm supplied `acct_gather.conf` file that is - used for configuring the acct_gather plugins. diff --git a/lib/charms/operator_libs_linux/v0/apt.py b/lib/charms/operator_libs_linux/v0/apt.py new file mode 100644 index 0000000..1400df7 --- /dev/null +++ b/lib/charms/operator_libs_linux/v0/apt.py @@ -0,0 +1,1361 @@ +# Copyright 2021 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstractions for the system's Debian/Ubuntu package information and repositories. + +This module contains abstractions and wrappers around Debian/Ubuntu-style repositories and +packages, in order to easily provide an idiomatic and Pythonic mechanism for adding packages and/or +repositories to systems for use in machine charms. + +A sane default configuration is attainable through nothing more than instantiation of the +appropriate classes. `DebianPackage` objects provide information about the architecture, version, +name, and status of a package. + +`DebianPackage` will try to look up a package either from `dpkg -L` or from `apt-cache` when +provided with a string indicating the package name. If it cannot be located, `PackageNotFoundError` +will be returned, as `apt` and `dpkg` otherwise return `100` for all errors, and a meaningful error +message if the package is not known is desirable. + +To install packages with convenience methods: + +```python +try: + # Run `apt-get update` + apt.update() + apt.add_package("zsh") + apt.add_package(["vim", "htop", "wget"]) +except PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") +except PackageError as e: + logger.error("could not install package. Reason: %s", e.message) +```` + +To find details of a specific package: + +```python +try: + vim = apt.DebianPackage.from_system("vim") + + # To find from the apt cache only + # apt.DebianPackage.from_apt_cache("vim") + + # To find from installed packages only + # apt.DebianPackage.from_installed_package("vim") + + vim.ensure(PackageState.Latest) + logger.info("updated vim to version: %s", vim.fullversion) +except PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") +except PackageError as e: + logger.error("could not install package. Reason: %s", e.message) +``` + + +`RepositoryMapping` will return a dict-like object containing enabled system repositories +and their properties (available groups, baseuri. gpg key). This class can add, disable, or +manipulate repositories. Items can be retrieved as `DebianRepository` objects. + +In order add a new repository with explicit details for fields, a new `DebianRepository` can +be added to `RepositoryMapping` + +`RepositoryMapping` provides an abstraction around the existing repositories on the system, +and can be accessed and iterated over like any `Mapping` object, to retrieve values by key, +iterate, or perform other operations. + +Keys are constructed as `{repo_type}-{}-{release}` in order to uniquely identify a repository. + +Repositories can be added with explicit values through a Python constructor. + +Example: +```python +repositories = apt.RepositoryMapping() + +if "deb-example.com-focal" not in repositories: + repositories.add(DebianRepository(enabled=True, repotype="deb", + uri="https://example.com", release="focal", groups=["universe"])) +``` + +Alternatively, any valid `sources.list` line may be used to construct a new +`DebianRepository`. + +Example: +```python +repositories = apt.RepositoryMapping() + +if "deb-us.archive.ubuntu.com-xenial" not in repositories: + line = "deb http://us.archive.ubuntu.com/ubuntu xenial main restricted" + repo = DebianRepository.from_repo_line(line) + repositories.add(repo) +``` +""" + +import fileinput +import glob +import logging +import os +import re +import subprocess +from collections.abc import Mapping +from enum import Enum +from subprocess import PIPE, CalledProcessError, check_output +from typing import Iterable, List, Optional, Tuple, Union +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +# The unique Charmhub library identifier, never change it +LIBID = "7c3dbc9c2ad44a47bd6fcb25caa270e5" + +# Increment this major API version when introducing breaking changes +LIBAPI = 0 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 13 + + +VALID_SOURCE_TYPES = ("deb", "deb-src") +OPTIONS_MATCHER = re.compile(r"\[.*?\]") + + +class Error(Exception): + """Base class of most errors raised by this library.""" + + def __repr__(self): + """Represent the Error.""" + return "<{}.{} {}>".format(type(self).__module__, type(self).__name__, self.args) + + @property + def name(self): + """Return a string representation of the model plus class.""" + return "<{}.{}>".format(type(self).__module__, type(self).__name__) + + @property + def message(self): + """Return the message passed as an argument.""" + return self.args[0] + + +class PackageError(Error): + """Raised when there's an error installing or removing a package.""" + + +class PackageNotFoundError(Error): + """Raised when a requested package is not known to the system.""" + + +class PackageState(Enum): + """A class to represent possible package states.""" + + Present = "present" + Absent = "absent" + Latest = "latest" + Available = "available" + + +class DebianPackage: + """Represents a traditional Debian package and its utility functions. + + `DebianPackage` wraps information and functionality around a known package, whether installed + or available. The version, epoch, name, and architecture can be easily queried and compared + against other `DebianPackage` objects to determine the latest version or to install a specific + version. + + The representation of this object as a string mimics the output from `dpkg` for familiarity. + + Installation and removal of packages is handled through the `state` property or `ensure` + method, with the following options: + + apt.PackageState.Absent + apt.PackageState.Available + apt.PackageState.Present + apt.PackageState.Latest + + When `DebianPackage` is initialized, the state of a given `DebianPackage` object will be set to + `Available`, `Present`, or `Latest`, with `Absent` implemented as a convenience for removal + (though it operates essentially the same as `Available`). + """ + + def __init__( + self, name: str, version: str, epoch: str, arch: str, state: PackageState + ) -> None: + self._name = name + self._arch = arch + self._state = state + self._version = Version(version, epoch) + + def __eq__(self, other) -> bool: + """Equality for comparison. + + Args: + other: a `DebianPackage` object for comparison + + Returns: + A boolean reflecting equality + """ + return isinstance(other, self.__class__) and ( + self._name, + self._version.number, + ) == (other._name, other._version.number) + + def __hash__(self): + """Return a hash of this package.""" + return hash((self._name, self._version.number)) + + def __repr__(self): + """Represent the package.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Return a human-readable representation of the package.""" + return "<{}: {}-{}.{} -- {}>".format( + self.__class__.__name__, + self._name, + self._version, + self._arch, + str(self._state), + ) + + @staticmethod + def _apt( + command: str, + package_names: Union[str, List], + optargs: Optional[List[str]] = None, + ) -> None: + """Wrap package management commands for Debian/Ubuntu systems. + + Args: + command: the command given to `apt-get` + package_names: a package name or list of package names to operate on + optargs: an (Optional) list of additioanl arguments + + Raises: + PackageError if an error is encountered + """ + optargs = optargs if optargs is not None else [] + if isinstance(package_names, str): + package_names = [package_names] + _cmd = ["apt-get", "-y", *optargs, command, *package_names] + try: + env = os.environ.copy() + env["DEBIAN_FRONTEND"] = "noninteractive" + subprocess.run(_cmd, capture_output=True, check=True, text=True, env=env) + except CalledProcessError as e: + raise PackageError( + "Could not {} package(s) [{}]: {}".format(command, [*package_names], e.stderr) + ) from None + + def _add(self) -> None: + """Add a package to the system.""" + self._apt( + "install", + "{}={}".format(self.name, self.version), + optargs=["--option=Dpkg::Options::=--force-confold"], + ) + + def _remove(self) -> None: + """Remove a package from the system. Implementation-specific.""" + return self._apt("remove", "{}={}".format(self.name, self.version)) + + @property + def name(self) -> str: + """Returns the name of the package.""" + return self._name + + def ensure(self, state: PackageState): + """Ensure that a package is in a given state. + + Args: + state: a `PackageState` to reconcile the package to + + Raises: + PackageError from the underlying call to apt + """ + if self._state is not state: + if state not in (PackageState.Present, PackageState.Latest): + self._remove() + else: + self._add() + self._state = state + + @property + def present(self) -> bool: + """Returns whether or not a package is present.""" + return self._state in (PackageState.Present, PackageState.Latest) + + @property + def latest(self) -> bool: + """Returns whether the package is the most recent version.""" + return self._state is PackageState.Latest + + @property + def state(self) -> PackageState: + """Returns the current package state.""" + return self._state + + @state.setter + def state(self, state: PackageState) -> None: + """Set the package state to a given value. + + Args: + state: a `PackageState` to reconcile the package to + + Raises: + PackageError from the underlying call to apt + """ + if state in (PackageState.Latest, PackageState.Present): + self._add() + else: + self._remove() + self._state = state + + @property + def version(self) -> "Version": + """Returns the version for a package.""" + return self._version + + @property + def epoch(self) -> str: + """Returns the epoch for a package. May be unset.""" + return self._version.epoch + + @property + def arch(self) -> str: + """Returns the architecture for a package.""" + return self._arch + + @property + def fullversion(self) -> str: + """Returns the name+epoch for a package.""" + return "{}.{}".format(self._version, self._arch) + + @staticmethod + def _get_epoch_from_version(version: str) -> Tuple[str, str]: + """Pull the epoch, if any, out of a version string.""" + epoch_matcher = re.compile(r"^((?P\d+):)?(?P.*)") + matches = epoch_matcher.search(version).groupdict() + return matches.get("epoch", ""), matches.get("version") + + @classmethod + def from_system( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Locates a package, either on the system or known to apt, and serializes the information. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. If an + architecture is not specified, this will be used for selection. + + """ + try: + return DebianPackage.from_installed_package(package, version, arch) + except PackageNotFoundError: + logger.debug( + "package '%s' is not currently installed or has the wrong architecture.", package + ) + + # Ok, try `apt-cache ...` + try: + return DebianPackage.from_apt_cache(package, version, arch) + except (PackageNotFoundError, PackageError): + # If we get here, it's not known to the systems. + # This seems unnecessary, but virtually all `apt` commands have a return code of `100`, + # and providing meaningful error messages without this is ugly. + raise PackageNotFoundError( + "Package '{}{}' could not be found on the system or in the apt cache!".format( + package, ".{}".format(arch) if arch else "" + ) + ) from None + + @classmethod + def from_installed_package( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Check whether the package is already installed and return an instance. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. + If an architecture is not specified, this will be used for selection. + """ + system_arch = check_output( + ["dpkg", "--print-architecture"], universal_newlines=True + ).strip() + arch = arch if arch else system_arch + + # Regexps are a really terrible way to do this. Thanks dpkg + output = "" + try: + output = check_output(["dpkg", "-l", package], stderr=PIPE, universal_newlines=True) + except CalledProcessError: + raise PackageNotFoundError("Package is not installed: {}".format(package)) from None + + # Pop off the output from `dpkg -l' because there's no flag to + # omit it` + lines = str(output).splitlines()[5:] + + dpkg_matcher = re.compile( + r""" + ^(?P\w+?)\s+ + (?P.*?)(?P:\w+?)?\s+ + (?P.*?)\s+ + (?P\w+?)\s+ + (?P.*) + """, + re.VERBOSE, + ) + + for line in lines: + try: + matches = dpkg_matcher.search(line).groupdict() + package_status = matches["package_status"] + + if not package_status.endswith("i"): + logger.debug( + "package '%s' in dpkg output but not installed, status: '%s'", + package, + package_status, + ) + break + + epoch, split_version = DebianPackage._get_epoch_from_version(matches["version"]) + pkg = DebianPackage( + matches["package_name"], + split_version, + epoch, + matches["arch"], + PackageState.Present, + ) + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg + except AttributeError: + logger.warning("dpkg matcher could not parse line: %s", line) + + # If we didn't find it, fail through + raise PackageNotFoundError("Package {}.{} is not installed!".format(package, arch)) + + @classmethod + def from_apt_cache( + cls, package: str, version: Optional[str] = "", arch: Optional[str] = "" + ) -> "DebianPackage": + """Check whether the package is already installed and return an instance. + + Args: + package: a string representing the package + version: an optional string if a specific version is requested + arch: an optional architecture, defaulting to `dpkg --print-architecture`. + If an architecture is not specified, this will be used for selection. + """ + system_arch = check_output( + ["dpkg", "--print-architecture"], universal_newlines=True + ).strip() + arch = arch if arch else system_arch + + # Regexps are a really terrible way to do this. Thanks dpkg + keys = ("Package", "Architecture", "Version") + + try: + output = check_output( + ["apt-cache", "show", package], stderr=PIPE, universal_newlines=True + ) + except CalledProcessError as e: + raise PackageError( + "Could not list packages in apt-cache: {}".format(e.stderr) + ) from None + + pkg_groups = output.strip().split("\n\n") + keys = ("Package", "Architecture", "Version") + + for pkg_raw in pkg_groups: + lines = str(pkg_raw).splitlines() + vals = {} + for line in lines: + if line.startswith(keys): + items = line.split(":", 1) + vals[items[0]] = items[1].strip() + else: + continue + + epoch, split_version = DebianPackage._get_epoch_from_version(vals["Version"]) + pkg = DebianPackage( + vals["Package"], + split_version, + epoch, + vals["Architecture"], + PackageState.Available, + ) + + if (pkg.arch == "all" or pkg.arch == arch) and ( + version == "" or str(pkg.version) == version + ): + return pkg + + # If we didn't find it, fail through + raise PackageNotFoundError("Package {}.{} is not in the apt cache!".format(package, arch)) + + +class Version: + """An abstraction around package versions. + + This seems like it should be strictly unnecessary, except that `apt_pkg` is not usable inside a + venv, and wedging version comparisons into `DebianPackage` would overcomplicate it. + + This class implements the algorithm found here: + https://www.debian.org/doc/debian-policy/ch-controlfields.html#version + """ + + def __init__(self, version: str, epoch: str): + self._version = version + self._epoch = epoch or "" + + def __repr__(self): + """Represent the package.""" + return "<{}.{}: {}>".format(self.__module__, self.__class__.__name__, self.__dict__) + + def __str__(self): + """Return human-readable representation of the package.""" + return "{}{}".format("{}:".format(self._epoch) if self._epoch else "", self._version) + + @property + def epoch(self): + """Returns the epoch for a package. May be empty.""" + return self._epoch + + @property + def number(self) -> str: + """Returns the version number for a package.""" + return self._version + + def _get_parts(self, version: str) -> Tuple[str, str]: + """Separate the version into component upstream and Debian pieces.""" + try: + version.rindex("-") + except ValueError: + # No hyphens means no Debian version + return version, "0" + + upstream, debian = version.rsplit("-", 1) + return upstream, debian + + def _listify(self, revision: str) -> List[str]: + """Split a revision string into a listself. + + This list is comprised of alternating between strings and numbers, + padded on either end to always be "str, int, str, int..." and + always be of even length. This allows us to trivially implement the + comparison algorithm described. + """ + result = [] + while revision: + rev_1, remains = self._get_alphas(revision) + rev_2, remains = self._get_digits(remains) + result.extend([rev_1, rev_2]) + revision = remains + return result + + def _get_alphas(self, revision: str) -> Tuple[str, str]: + """Return a tuple of the first non-digit characters of a revision.""" + # get the index of the first digit + for i, char in enumerate(revision): + if char.isdigit(): + if i == 0: + return "", revision + return revision[0:i], revision[i:] + # string is entirely alphas + return revision, "" + + def _get_digits(self, revision: str) -> Tuple[int, str]: + """Return a tuple of the first integer characters of a revision.""" + # If the string is empty, return (0,'') + if not revision: + return 0, "" + # get the index of the first non-digit + for i, char in enumerate(revision): + if not char.isdigit(): + if i == 0: + return 0, revision + return int(revision[0:i]), revision[i:] + # string is entirely digits + return int(revision), "" + + def _dstringcmp(self, a, b): # noqa: C901 + """Debian package version string section lexical sort algorithm. + + The lexical comparison is a comparison of ASCII values modified so + that all the letters sort earlier than all the non-letters and so that + a tilde sorts before anything, even the end of a part. + """ + if a == b: + return 0 + try: + for i, char in enumerate(a): + if char == b[i]: + continue + # "a tilde sorts before anything, even the end of a part" + # (emptyness) + if char == "~": + return -1 + if b[i] == "~": + return 1 + # "all the letters sort earlier than all the non-letters" + if char.isalpha() and not b[i].isalpha(): + return -1 + if not char.isalpha() and b[i].isalpha(): + return 1 + # otherwise lexical sort + if ord(char) > ord(b[i]): + return 1 + if ord(char) < ord(b[i]): + return -1 + except IndexError: + # a is longer than b but otherwise equal, greater unless there are tildes + if char == "~": + return -1 + return 1 + # if we get here, a is shorter than b but otherwise equal, so check for tildes... + if b[len(a)] == "~": + return 1 + return -1 + + def _compare_revision_strings(self, first: str, second: str): # noqa: C901 + """Compare two debian revision strings.""" + if first == second: + return 0 + + # listify pads results so that we will always be comparing ints to ints + # and strings to strings (at least until we fall off the end of a list) + first_list = self._listify(first) + second_list = self._listify(second) + if first_list == second_list: + return 0 + try: + for i, item in enumerate(first_list): + # explicitly raise IndexError if we've fallen off the edge of list2 + if i >= len(second_list): + raise IndexError + # if the items are equal, next + if item == second_list[i]: + continue + # numeric comparison + if isinstance(item, int): + if item > second_list[i]: + return 1 + if item < second_list[i]: + return -1 + else: + # string comparison + return self._dstringcmp(item, second_list[i]) + except IndexError: + # rev1 is longer than rev2 but otherwise equal, hence greater + # ...except for goddamn tildes + if first_list[len(second_list)][0][0] == "~": + return 1 + return 1 + # rev1 is shorter than rev2 but otherwise equal, hence lesser + # ...except for goddamn tildes + if second_list[len(first_list)][0][0] == "~": + return -1 + return -1 + + def _compare_version(self, other) -> int: + if (self.number, self.epoch) == (other.number, other.epoch): + return 0 + + if self.epoch < other.epoch: + return -1 + if self.epoch > other.epoch: + return 1 + + # If none of these are true, follow the algorithm + upstream_version, debian_version = self._get_parts(self.number) + other_upstream_version, other_debian_version = self._get_parts(other.number) + + upstream_cmp = self._compare_revision_strings(upstream_version, other_upstream_version) + if upstream_cmp != 0: + return upstream_cmp + + debian_cmp = self._compare_revision_strings(debian_version, other_debian_version) + if debian_cmp != 0: + return debian_cmp + + return 0 + + def __lt__(self, other) -> bool: + """Less than magic method impl.""" + return self._compare_version(other) < 0 + + def __eq__(self, other) -> bool: + """Equality magic method impl.""" + return self._compare_version(other) == 0 + + def __gt__(self, other) -> bool: + """Greater than magic method impl.""" + return self._compare_version(other) > 0 + + def __le__(self, other) -> bool: + """Less than or equal to magic method impl.""" + return self.__eq__(other) or self.__lt__(other) + + def __ge__(self, other) -> bool: + """Greater than or equal to magic method impl.""" + return self.__gt__(other) or self.__eq__(other) + + def __ne__(self, other) -> bool: + """Not equal to magic method impl.""" + return not self.__eq__(other) + + +def add_package( + package_names: Union[str, List[str]], + version: Optional[str] = "", + arch: Optional[str] = "", + update_cache: Optional[bool] = False, +) -> Union[DebianPackage, List[DebianPackage]]: + """Add a package or list of packages to the system. + + Args: + package_names: single package name, or list of package names + name: the name(s) of the package(s) + version: an (Optional) version as a string. Defaults to the latest known + arch: an optional architecture for the package + update_cache: whether or not to run `apt-get update` prior to operating + + Raises: + TypeError if no package name is given, or explicit version is set for multiple packages + PackageNotFoundError if the package is not in the cache. + PackageError if packages fail to install + """ + cache_refreshed = False + if update_cache: + update() + cache_refreshed = True + + packages = {"success": [], "retry": [], "failed": []} + + package_names = [package_names] if isinstance(package_names, str) else package_names + if not package_names: + raise TypeError("Expected at least one package name to add, received zero!") + + if len(package_names) != 1 and version: + raise TypeError( + "Explicit version should not be set if more than one package is being added!" + ) + + for p in package_names: + pkg, success = _add(p, version, arch) + if success: + packages["success"].append(pkg) + else: + logger.warning("failed to locate and install/update '%s'", pkg) + packages["retry"].append(p) + + if packages["retry"] and not cache_refreshed: + logger.info("updating the apt-cache and retrying installation of failed packages.") + update() + + for p in packages["retry"]: + pkg, success = _add(p, version, arch) + if success: + packages["success"].append(pkg) + else: + packages["failed"].append(p) + + if packages["failed"]: + raise PackageError("Failed to install packages: {}".format(", ".join(packages["failed"]))) + + return packages["success"] if len(packages["success"]) > 1 else packages["success"][0] + + +def _add( + name: str, + version: Optional[str] = "", + arch: Optional[str] = "", +) -> Tuple[Union[DebianPackage, str], bool]: + """Add a package to the system. + + Args: + name: the name(s) of the package(s) + version: an (Optional) version as a string. Defaults to the latest known + arch: an optional architecture for the package + + Returns: a tuple of `DebianPackage` if found, or a :str: if it is not, and + a boolean indicating success + """ + try: + pkg = DebianPackage.from_system(name, version, arch) + pkg.ensure(state=PackageState.Present) + return pkg, True + except PackageNotFoundError: + return name, False + + +def remove_package( + package_names: Union[str, List[str]] +) -> Union[DebianPackage, List[DebianPackage]]: + """Remove package(s) from the system. + + Args: + package_names: the name of a package + + Raises: + PackageNotFoundError if the package is not found. + """ + packages = [] + + package_names = [package_names] if isinstance(package_names, str) else package_names + if not package_names: + raise TypeError("Expected at least one package name to add, received zero!") + + for p in package_names: + try: + pkg = DebianPackage.from_installed_package(p) + pkg.ensure(state=PackageState.Absent) + packages.append(pkg) + except PackageNotFoundError: + logger.info("package '%s' was requested for removal, but it was not installed.", p) + + # the list of packages will be empty when no package is removed + logger.debug("packages: '%s'", packages) + return packages[0] if len(packages) == 1 else packages + + +def update() -> None: + """Update the apt cache via `apt-get update`.""" + subprocess.run(["apt-get", "update"], capture_output=True, check=True) + + +def import_key(key: str) -> str: + """Import an ASCII Armor key. + + A Radix64 format keyid is also supported for backwards + compatibility. In this case Ubuntu keyserver will be + queried for a key via HTTPS by its keyid. This method + is less preferable because https proxy servers may + require traffic decryption which is equivalent to a + man-in-the-middle attack (a proxy server impersonates + keyserver TLS certificates and has to be explicitly + trusted by the system). + + Args: + key: A GPG key in ASCII armor format, including BEGIN + and END markers or a keyid. + + Returns: + The GPG key filename written. + + Raises: + GPGKeyError if the key could not be imported + """ + key = key.strip() + if "-" in key or "\n" in key: + # Send everything not obviously a keyid to GPG to import, as + # we trust its validation better than our own. eg. handling + # comments before the key. + logger.debug("PGP key found (looks like ASCII Armor format)") + if ( + "-----BEGIN PGP PUBLIC KEY BLOCK-----" in key + and "-----END PGP PUBLIC KEY BLOCK-----" in key + ): + logger.debug("Writing provided PGP key in the binary format") + key_bytes = key.encode("utf-8") + key_name = DebianRepository._get_keyid_by_gpg_key(key_bytes) + key_gpg = DebianRepository._dearmor_gpg_key(key_bytes) + gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key_name) + DebianRepository._write_apt_gpg_keyfile( + key_name=gpg_key_filename, key_material=key_gpg + ) + return gpg_key_filename + else: + raise GPGKeyError("ASCII armor markers missing from GPG key") + else: + logger.warning( + "PGP key found (looks like Radix64 format). " + "SECURELY importing PGP key from keyserver; " + "full key not provided." + ) + # as of bionic add-apt-repository uses curl with an HTTPS keyserver URL + # to retrieve GPG keys. `apt-key adv` command is deprecated as is + # apt-key in general as noted in its manpage. See lp:1433761 for more + # history. Instead, /etc/apt/trusted.gpg.d is used directly to drop + # gpg + key_asc = DebianRepository._get_key_by_keyid(key) + # write the key in GPG format so that apt-key list shows it + key_gpg = DebianRepository._dearmor_gpg_key(key_asc.encode("utf-8")) + gpg_key_filename = "/etc/apt/trusted.gpg.d/{}.gpg".format(key) + DebianRepository._write_apt_gpg_keyfile(key_name=gpg_key_filename, key_material=key_gpg) + return gpg_key_filename + + +class InvalidSourceError(Error): + """Exceptions for invalid source entries.""" + + +class GPGKeyError(Error): + """Exceptions for GPG keys.""" + + +class DebianRepository: + """An abstraction to represent a repository.""" + + def __init__( + self, + enabled: bool, + repotype: str, + uri: str, + release: str, + groups: List[str], + filename: Optional[str] = "", + gpg_key_filename: Optional[str] = "", + options: Optional[dict] = None, + ): + self._enabled = enabled + self._repotype = repotype + self._uri = uri + self._release = release + self._groups = groups + self._filename = filename + self._gpg_key_filename = gpg_key_filename + self._options = options + + @property + def enabled(self): + """Return whether or not the repository is enabled.""" + return self._enabled + + @property + def repotype(self): + """Return whether it is binary or source.""" + return self._repotype + + @property + def uri(self): + """Return the URI.""" + return self._uri + + @property + def release(self): + """Return which Debian/Ubuntu releases it is valid for.""" + return self._release + + @property + def groups(self): + """Return the enabled package groups.""" + return self._groups + + @property + def filename(self): + """Returns the filename for a repository.""" + return self._filename + + @filename.setter + def filename(self, fname: str) -> None: + """Set the filename used when a repo is written back to disk. + + Args: + fname: a filename to write the repository information to. + """ + if not fname.endswith(".list"): + raise InvalidSourceError("apt source filenames should end in .list!") + + self._filename = fname + + @property + def gpg_key(self): + """Returns the path to the GPG key for this repository.""" + return self._gpg_key_filename + + @property + def options(self): + """Returns any additional repo options which are set.""" + return self._options + + def make_options_string(self) -> str: + """Generate the complete options string for a a repository. + + Combining `gpg_key`, if set, and the rest of the options to find + a complex repo string. + """ + options = self._options if self._options else {} + if self._gpg_key_filename: + options["signed-by"] = self._gpg_key_filename + + return ( + "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in options.items()])) + if options + else "" + ) + + @staticmethod + def prefix_from_uri(uri: str) -> str: + """Get a repo list prefix from the uri, depending on whether a path is set.""" + uridetails = urlparse(uri) + path = ( + uridetails.path.lstrip("/").replace("/", "-") if uridetails.path else uridetails.netloc + ) + return "/etc/apt/sources.list.d/{}".format(path) + + @staticmethod + def from_repo_line(repo_line: str, write_file: Optional[bool] = True) -> "DebianRepository": + """Instantiate a new `DebianRepository` a `sources.list` entry line. + + Args: + repo_line: a string representing a repository entry + write_file: boolean to enable writing the new repo to disk + """ + repo = RepositoryMapping._parse(repo_line, "UserInput") + fname = "{}-{}.list".format( + DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") + ) + repo.filename = fname + + options = repo.options if repo.options else {} + if repo.gpg_key: + options["signed-by"] = repo.gpg_key + + # For Python 3.5 it's required to use sorted in the options dict in order to not have + # different results in the order of the options between executions. + options_str = ( + "[{}] ".format(" ".join(["{}={}".format(k, v) for k, v in sorted(options.items())])) + if options + else "" + ) + + if write_file: + with open(fname, "wb") as f: + f.write( + ( + "{}".format("#" if not repo.enabled else "") + + "{} {}{} ".format(repo.repotype, options_str, repo.uri) + + "{} {}\n".format(repo.release, " ".join(repo.groups)) + ).encode("utf-8") + ) + + return repo + + def disable(self) -> None: + """Remove this repository from consideration. + + Disable it instead of removing from the repository file. + """ + searcher = "{} {}{} {}".format( + self.repotype, self.make_options_string(), self.uri, self.release + ) + for line in fileinput.input(self._filename, inplace=True): + if re.match(r"^{}\s".format(re.escape(searcher)), line): + print("# {}".format(line), end="") + else: + print(line, end="") + + def import_key(self, key: str) -> None: + """Import an ASCII Armor key. + + A Radix64 format keyid is also supported for backwards + compatibility. In this case Ubuntu keyserver will be + queried for a key via HTTPS by its keyid. This method + is less preferable because https proxy servers may + require traffic decryption which is equivalent to a + man-in-the-middle attack (a proxy server impersonates + keyserver TLS certificates and has to be explicitly + trusted by the system). + + Args: + key: A GPG key in ASCII armor format, + including BEGIN and END markers or a keyid. + + Raises: + GPGKeyError if the key could not be imported + """ + self._gpg_key_filename = import_key(key) + + @staticmethod + def _get_keyid_by_gpg_key(key_material: bytes) -> str: + """Get a GPG key fingerprint by GPG key material. + + Gets a GPG key fingerprint (40-digit, 160-bit) by the ASCII armor-encoded + or binary GPG key material. Can be used, for example, to generate file + names for keys passed via charm options. + """ + # Use the same gpg command for both Xenial and Bionic + cmd = ["gpg", "--with-colons", "--with-fingerprint"] + ps = subprocess.run( + cmd, + stdout=PIPE, + stderr=PIPE, + input=key_material, + ) + out, err = ps.stdout.decode(), ps.stderr.decode() + if "gpg: no valid OpenPGP data found." in err: + raise GPGKeyError("Invalid GPG key material provided") + # from gnupg2 docs: fpr :: Fingerprint (fingerprint is in field 10) + return re.search(r"^fpr:{9}([0-9A-F]{40}):$", out, re.MULTILINE).group(1) + + @staticmethod + def _get_key_by_keyid(keyid: str) -> str: + """Get a key via HTTPS from the Ubuntu keyserver. + + Different key ID formats are supported by SKS keyservers (the longer ones + are more secure, see "dead beef attack" and https://evil32.com/). Since + HTTPS is used, if SSLBump-like HTTPS proxies are in place, they will + impersonate keyserver.ubuntu.com and generate a certificate with + keyserver.ubuntu.com in the CN field or in SubjAltName fields of a + certificate. If such proxy behavior is expected it is necessary to add the + CA certificate chain containing the intermediate CA of the SSLBump proxy to + every machine that this code runs on via ca-certs cloud-init directive (via + cloudinit-userdata model-config) or via other means (such as through a + custom charm option). Also note that DNS resolution for the hostname in a + URL is done at a proxy server - not at the client side. + 8-digit (32 bit) key ID + https://keyserver.ubuntu.com/pks/lookup?search=0x4652B4E6 + 16-digit (64 bit) key ID + https://keyserver.ubuntu.com/pks/lookup?search=0x6E85A86E4652B4E6 + 40-digit key ID: + https://keyserver.ubuntu.com/pks/lookup?search=0x35F77D63B5CEC106C577ED856E85A86E4652B4E6 + + Args: + keyid: An 8, 16 or 40 hex digit keyid to find a key for + + Returns: + A string contining key material for the specified GPG key id + + + Raises: + subprocess.CalledProcessError + """ + # options=mr - machine-readable output (disables html wrappers) + keyserver_url = ( + "https://keyserver.ubuntu.com" "/pks/lookup?op=get&options=mr&exact=on&search=0x{}" + ) + curl_cmd = ["curl", keyserver_url.format(keyid)] + # use proxy server settings in order to retrieve the key + return check_output(curl_cmd).decode() + + @staticmethod + def _dearmor_gpg_key(key_asc: bytes) -> bytes: + """Convert a GPG key in the ASCII armor format to the binary format. + + Args: + key_asc: A GPG key in ASCII armor format. + + Returns: + A GPG key in binary format as a string + + Raises: + GPGKeyError + """ + ps = subprocess.run(["gpg", "--dearmor"], stdout=PIPE, stderr=PIPE, input=key_asc) + out, err = ps.stdout, ps.stderr.decode() + if "gpg: no valid OpenPGP data found." in err: + raise GPGKeyError( + "Invalid GPG key material. Check your network setup" + " (MTU, routing, DNS) and/or proxy server settings" + " as well as destination keyserver status." + ) + else: + return out + + @staticmethod + def _write_apt_gpg_keyfile(key_name: str, key_material: bytes) -> None: + """Write GPG key material into a file at a provided path. + + Args: + key_name: A key name to use for a key file (could be a fingerprint) + key_material: A GPG key material (binary) + """ + with open(key_name, "wb") as keyf: + keyf.write(key_material) + + +class RepositoryMapping(Mapping): + """An representation of known repositories. + + Instantiation of `RepositoryMapping` will iterate through the + filesystem, parse out repository files in `/etc/apt/...`, and create + `DebianRepository` objects in this list. + + Typical usage: + + repositories = apt.RepositoryMapping() + repositories.add(DebianRepository( + enabled=True, repotype="deb", uri="https://example.com", release="focal", + groups=["universe"] + )) + """ + + def __init__(self): + self._repository_map = {} + # Repositories that we're adding -- used to implement mode param + self.default_file = "/etc/apt/sources.list" + + # read sources.list if it exists + if os.path.isfile(self.default_file): + self.load(self.default_file) + + # read sources.list.d + for file in glob.iglob("/etc/apt/sources.list.d/*.list"): + self.load(file) + + def __contains__(self, key: str) -> bool: + """Magic method for checking presence of repo in mapping.""" + return key in self._repository_map + + def __len__(self) -> int: + """Return number of repositories in map.""" + return len(self._repository_map) + + def __iter__(self) -> Iterable[DebianRepository]: + """Return iterator for RepositoryMapping.""" + return iter(self._repository_map.values()) + + def __getitem__(self, repository_uri: str) -> DebianRepository: + """Return a given `DebianRepository`.""" + return self._repository_map[repository_uri] + + def __setitem__(self, repository_uri: str, repository: DebianRepository) -> None: + """Add a `DebianRepository` to the cache.""" + self._repository_map[repository_uri] = repository + + def load(self, filename: str): + """Load a repository source file into the cache. + + Args: + filename: the path to the repository file + """ + parsed = [] + skipped = [] + with open(filename, "r") as f: + for n, line in enumerate(f): + try: + repo = self._parse(line, filename) + except InvalidSourceError: + skipped.append(n) + else: + repo_identifier = "{}-{}-{}".format(repo.repotype, repo.uri, repo.release) + self._repository_map[repo_identifier] = repo + parsed.append(n) + logger.debug("parsed repo: '%s'", repo_identifier) + + if skipped: + skip_list = ", ".join(str(s) for s in skipped) + logger.debug("skipped the following lines in file '%s': %s", filename, skip_list) + + if parsed: + logger.info("parsed %d apt package repositories", len(parsed)) + else: + raise InvalidSourceError("all repository lines in '{}' were invalid!".format(filename)) + + @staticmethod + def _parse(line: str, filename: str) -> DebianRepository: + """Parse a line in a sources.list file. + + Args: + line: a single line from `load` to parse + filename: the filename being read + + Raises: + InvalidSourceError if the source type is unknown + """ + enabled = True + repotype = uri = release = gpg_key = "" + options = {} + groups = [] + + line = line.strip() + if line.startswith("#"): + enabled = False + line = line[1:] + + # Check for "#" in the line and treat a part after it as a comment then strip it off. + i = line.find("#") + if i > 0: + line = line[:i] + + # Split a source into substrings to initialize a new repo. + source = line.strip() + if source: + # Match any repo options, and get a dict representation. + for v in re.findall(OPTIONS_MATCHER, source): + opts = dict(o.split("=") for o in v.strip("[]").split()) + # Extract the 'signed-by' option for the gpg_key + gpg_key = opts.pop("signed-by", "") + options = opts + + # Remove any options from the source string and split the string into chunks + source = re.sub(OPTIONS_MATCHER, "", source) + chunks = source.split() + + # Check we've got a valid list of chunks + if len(chunks) < 3 or chunks[0] not in VALID_SOURCE_TYPES: + raise InvalidSourceError("An invalid sources line was found in %s!", filename) + + repotype = chunks[0] + uri = chunks[1] + release = chunks[2] + groups = chunks[3:] + + return DebianRepository( + enabled, repotype, uri, release, groups, filename, gpg_key, options + ) + else: + raise InvalidSourceError("An invalid sources line was found in %s!", filename) + + def add(self, repo: DebianRepository, default_filename: Optional[bool] = False) -> None: + """Add a new repository to the system. + + Args: + repo: a `DebianRepository` object + default_filename: an (Optional) filename if the default is not desirable + """ + new_filename = "{}-{}.list".format( + DebianRepository.prefix_from_uri(repo.uri), repo.release.replace("/", "-") + ) + + fname = repo.filename or new_filename + + options = repo.options if repo.options else {} + if repo.gpg_key: + options["signed-by"] = repo.gpg_key + + with open(fname, "wb") as f: + f.write( + ( + "{}".format("#" if not repo.enabled else "") + + "{} {}{} ".format(repo.repotype, repo.make_options_string(), repo.uri) + + "{} {}\n".format(repo.release, " ".join(repo.groups)) + ).encode("utf-8") + ) + + self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo + + def disable(self, repo: DebianRepository) -> None: + """Remove a repository. Disable by default. + + Args: + repo: a `DebianRepository` to disable + """ + searcher = "{} {}{} {}".format( + repo.repotype, repo.make_options_string(), repo.uri, repo.release + ) + + for line in fileinput.input(repo.filename, inplace=True): + if re.match(r"^{}\s".format(re.escape(searcher)), line): + print("# {}".format(line), end="") + else: + print(line, end="") + + self._repository_map["{}-{}-{}".format(repo.repotype, repo.uri, repo.release)] = repo diff --git a/lib/charms/operator_libs_linux/v1/systemd.py b/lib/charms/operator_libs_linux/v1/systemd.py new file mode 100644 index 0000000..cdcbad6 --- /dev/null +++ b/lib/charms/operator_libs_linux/v1/systemd.py @@ -0,0 +1,288 @@ +# Copyright 2021 Canonical Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Abstractions for stopping, starting and managing system services via systemd. + +This library assumes that your charm is running on a platform that uses systemd. E.g., +Centos 7 or later, Ubuntu Xenial (16.04) or later. + +For the most part, we transparently provide an interface to a commonly used selection of +systemd commands, with a few shortcuts baked in. For example, service_pause and +service_resume with run the mask/unmask and enable/disable invocations. + +Example usage: + +```python +from charms.operator_libs_linux.v0.systemd import service_running, service_reload + +# Start a service +if not service_running("mysql"): + success = service_start("mysql") + +# Attempt to reload a service, restarting if necessary +success = service_reload("nginx", restart_on_failure=True) +``` +""" + +__all__ = [ # Don't export `_systemctl`. (It's not the intended way of using this lib.) + "SystemdError", + "daemon_reload", + "service_disable", + "service_enable", + "service_failed", + "service_pause", + "service_reload", + "service_restart", + "service_resume", + "service_running", + "service_start", + "service_stop", +] + +import logging +import subprocess + +logger = logging.getLogger(__name__) + +# The unique Charmhub library identifier, never change it +LIBID = "045b0d179f6b4514a8bb9b48aee9ebaf" + +# Increment this major API version when introducing breaking changes +LIBAPI = 1 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 4 + + +class SystemdError(Exception): + """Custom exception for SystemD related errors.""" + + +def _systemctl(*args: str, check: bool = False) -> int: + """Control a system service using systemctl. + + Args: + *args: Arguments to pass to systemctl. + check: Check the output of the systemctl command. Default: False. + + Returns: + Returncode of systemctl command execution. + + Raises: + SystemdError: Raised if calling systemctl returns a non-zero returncode and check is True. + """ + cmd = ["systemctl", *args] + logger.debug(f"Executing command: {cmd}") + try: + proc = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + encoding="utf-8", + check=check, + ) + logger.debug( + f"Command {cmd} exit code: {proc.returncode}. systemctl output:\n{proc.stdout}" + ) + return proc.returncode + except subprocess.CalledProcessError as e: + raise SystemdError( + f"Command {cmd} failed with returncode {e.returncode}. systemctl output:\n{e.stdout}" + ) + + +def service_running(service_name: str) -> bool: + """Report whether a system service is running. + + Args: + service_name: The name of the service to check. + + Return: + True if service is running/active; False if not. + """ + # If returncode is 0, this means that is service is active. + return _systemctl("--quiet", "is-active", service_name) == 0 + + +def service_failed(service_name: str) -> bool: + """Report whether a system service has failed. + + Args: + service_name: The name of the service to check. + + Returns: + True if service is marked as failed; False if not. + """ + # If returncode is 0, this means that the service has failed. + return _systemctl("--quiet", "is-failed", service_name) == 0 + + +def service_start(*args: str) -> bool: + """Start a system service. + + Args: + *args: Arguments to pass to `systemctl start` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl start ...` returns a non-zero returncode. + """ + return _systemctl("start", *args, check=True) == 0 + + +def service_stop(*args: str) -> bool: + """Stop a system service. + + Args: + *args: Arguments to pass to `systemctl stop` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl stop ...` returns a non-zero returncode. + """ + return _systemctl("stop", *args, check=True) == 0 + + +def service_restart(*args: str) -> bool: + """Restart a system service. + + Args: + *args: Arguments to pass to `systemctl restart` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl restart ...` returns a non-zero returncode. + """ + return _systemctl("restart", *args, check=True) == 0 + + +def service_enable(*args: str) -> bool: + """Enable a system service. + + Args: + *args: Arguments to pass to `systemctl enable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl enable ...` returns a non-zero returncode. + """ + return _systemctl("enable", *args, check=True) == 0 + + +def service_disable(*args: str) -> bool: + """Disable a system service. + + Args: + *args: Arguments to pass to `systemctl disable` (normally the service name). + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl disable ...` returns a non-zero returncode. + """ + return _systemctl("disable", *args, check=True) == 0 + + +def service_reload(service_name: str, restart_on_failure: bool = False) -> bool: + """Reload a system service, optionally falling back to restart if reload fails. + + Args: + service_name: The name of the service to reload. + restart_on_failure: + Boolean indicating whether to fall back to a restart if the reload fails. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl reload|restart ...` returns a non-zero returncode. + """ + try: + return _systemctl("reload", service_name, check=True) == 0 + except SystemdError: + if restart_on_failure: + return service_restart(service_name) + else: + raise + + +def service_pause(service_name: str) -> bool: + """Pause a system service. + + Stops the service and prevents the service from starting again at boot. + + Args: + service_name: The name of the service to pause. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is still running after being paused by systemctl. + """ + _systemctl("disable", "--now", service_name) + _systemctl("mask", service_name) + + if service_running(service_name): + raise SystemdError(f"Attempted to pause {service_name!r}, but it is still running.") + + return True + + +def service_resume(service_name: str) -> bool: + """Resume a system service. + + Re-enable starting the service again at boot. Start the service. + + Args: + service_name: The name of the service to resume. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if service is not running after being resumed by systemctl. + """ + _systemctl("unmask", service_name) + _systemctl("enable", "--now", service_name) + + if not service_running(service_name): + raise SystemdError(f"Attempted to resume {service_name!r}, but it is not running.") + + return True + + +def daemon_reload() -> bool: + """Reload systemd manager configuration. + + Returns: + On success, this function returns True for historical reasons. + + Raises: + SystemdError: Raised if `systemctl daemon-reload` returns a non-zero returncode. + """ + return _systemctl("daemon-reload", check=True) == 0 diff --git a/metadata.yaml b/metadata.yaml deleted file mode 100644 index 713b150..0000000 --- a/metadata.yaml +++ /dev/null @@ -1,44 +0,0 @@ -name: slurmctld -summary: | - Slurmctld, the central management daemon of Slurm. -description: | - This charm provides slurmctld, munged, and the bindings to other utilities - that make lifecycle operations a breeze. - - slurmctld is the central management daemon of SLURM. It monitors all other - SLURM daemons and resources, accepts work (jobs), and allocates resources - to those jobs. Given the critical functionality of slurmctld, there may be - a backup server to assume these functions in the event that the primary - server fails. -source: https://github.com/omnivector-solutions/slurmctld-operator -issues: https://github.com/omnivector-solutions/slurmctld-operator/issues -maintainers: - - OmniVector Solutions - - Jason C. Nucciarone - - David Gomez - -peers: - slurmctld-peer: - interface: slurmctld-peer -requires: - slurmd: - interface: slurmd - slurmdbd: - interface: slurmdbd - slurmrestd: - interface: slurmrestd - influxdb-api: - interface: influxdb-api - elasticsearch: - interface: elasticsearch - fluentbit: - interface: fluentbit -provides: - prolog-epilog: - interface: prolog-epilog - grafana-source: - interface: grafana-source - scope: global - -assumes: - - juju diff --git a/pyproject.toml b/pyproject.toml index ec80137..859ad12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,8 @@ target-version = ["py38"] # Linting tools configuration [tool.ruff] line-length = 99 -select = ["E", "W", "F", "C", "N", "D", "I001"] -extend-ignore = [ +lint.select = ["E", "W", "F", "C", "N", "D", "I001"] +lint.extend-ignore = [ "D203", "D204", "D213", @@ -49,9 +49,9 @@ extend-ignore = [ "D409", "D413", ] -ignore = ["E501", "D107"] +lint.ignore = ["E501", "D107"] extend-exclude = ["__pycache__", "*.egg_info"] -per-file-ignores = {"tests/*" = ["D100","D101","D102","D103","D104"]} +lint.per-file-ignores = {"tests/*" = ["D100","D101","D102","D103","D104"]} -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 10 diff --git a/requirements.txt b/requirements.txt index 064d375..1e1821d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ ops==2.* influxdb==5.3.1 jinja2==3.1.3 -git+https://github.com/omnivector-solutions/slurm-ops-manager.git@0.8.16 \ No newline at end of file +distro +pycryptodome diff --git a/src/charm.py b/src/charm.py index 46a6386..6e2cf01 100755 --- a/src/charm.py +++ b/src/charm.py @@ -24,7 +24,7 @@ from ops.framework import StoredState from ops.main import main from ops.model import ActiveStatus, BlockedStatus, WaitingStatus -from slurm_ops_manager import SlurmManager +from slurmctld_ops import SlurmctldManager logger = logging.getLogger() @@ -48,7 +48,7 @@ def __init__(self, *args): down_nodes=[], ) - self._slurm_manager = SlurmManager(self, "slurmctld") + self._slurm_manager = SlurmctldManager(self, "slurmctld") self._slurmd = Slurmd(self, "slurmd") self._slurmdbd = Slurmdbd(self, "slurmdbd") @@ -219,17 +219,16 @@ def is_slurm_installed(self): def _on_show_current_config(self, event): """Show current slurm.conf.""" - slurm_conf = self._slurm_manager.get_slurm_conf() + slurm_conf = self._slurm_manager.slurm_conf_path.read_text() event.set_results({"slurm.conf": slurm_conf}) def _on_install(self, event): """Perform installation operations for slurmctld.""" - self.unit.set_workload_version(Path("version").read_text().strip()) - self.unit.status = WaitingStatus("Installing slurmctld") - custom_repo = self.config.get("custom-slurm-repo") - successful_installation = self._slurm_manager.install(custom_repo) + successful_installation = self._slurm_manager.install() + + self.unit.set_workload_version(self._slurm_manager.version()) if successful_installation: self._stored.slurm_installed = True @@ -243,7 +242,7 @@ def _on_install(self, event): # peer relation. self._stored.jwt_rsa = self._slurm_manager.generate_jwt_rsa() self._stored.munge_key = self._slurm_manager.get_munge_key() - self._slurm_manager.configure_jwt_rsa(self.get_jwt_rsa()) + self._slurm_manager.write_jwt_rsa(self.get_jwt_rsa()) else: # NOTE: the secondary slurmctld should get the jwt and munge # keys from the peer relation here @@ -420,12 +419,6 @@ def _on_slurmrestd_available(self, event): event.defer() return - if self._stored.slurmrestd_available: - self._slurmrestd.set_slurm_config_on_app_relation_data( - slurm_config, - ) - self._slurmrestd.restart_slurmrestd() - def _on_slurmdbd_available(self, event): self._set_slurmdbd_available(True) self._on_write_slurm_config(event) @@ -451,7 +444,7 @@ def _on_write_slurm_config(self, event): self._slurm_manager.render_slurm_configs(slurm_config) # restart is needed if nodes are added/removed from the cluster - self._slurm_manager.slurm_systemctl("restart") + self._slurm_manager.restart_slurmctld() self._slurm_manager.slurm_cmd("scontrol", "reconfigure") # send the custom NHC parameters to all slurmd diff --git a/src/interface_prolog_epilog.py b/src/interface_prolog_epilog.py index 3d603f4..736de51 100644 --- a/src/interface_prolog_epilog.py +++ b/src/interface_prolog_epilog.py @@ -1,4 +1,5 @@ """Slurm Prolog and Epilog interface.""" + import json import logging diff --git a/src/slurmctld_ops.py b/src/slurmctld_ops.py new file mode 100644 index 0000000..896ccdc --- /dev/null +++ b/src/slurmctld_ops.py @@ -0,0 +1,708 @@ +# Copyright 2024 Omnivector, LLC. +# See LICENSE file for licensing details. +"""This module provides the SlurmManager.""" + +import logging +import os +import shlex +import shutil +import socket +import subprocess +from base64 import b64decode, b64encode +from pathlib import Path + +import charms.operator_libs_linux.v0.apt as apt +import charms.operator_libs_linux.v1.systemd as systemd +import distro +from Crypto.PublicKey import RSA +from jinja2 import Environment, FileSystemLoader +from ops.framework import ( + Object, + StoredState, +) + +logger = logging.getLogger() + + +TEMPLATE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) / "templates" + + +SLURM_PPA_KEY: str = """ +-----BEGIN PGP PUBLIC KEY BLOCK----- +Comment: Hostname: +Version: Hockeypuck 2.1.1-10-gec3b0e7 + +xsFNBGTuZb8BEACtJ1CnZe6/hv84DceHv+a54y3Pqq0gqED0xhTKnbj/E2ByJpmT +NlDNkpeITwPAAN1e3824Me76Qn31RkogTMoPJ2o2XfG253RXd67MPxYhfKTJcnM3 +CEkmeI4u2Lynh3O6RQ08nAFS2AGTeFVFH2GPNWrfOsGZW03Jas85TZ0k7LXVHiBs +W6qonbsFJhshvwC3SryG4XYT+z/+35x5fus4rPtMrrEOD65hij7EtQNaE8owuAju +Kcd0m2b+crMXNcllWFWmYMV0VjksQvYD7jwGrWeKs+EeHgU8ZuqaIP4pYHvoQjag +umqnH9Qsaq5NAXiuAIAGDIIV4RdAfQIR4opGaVgIFJdvoSwYe3oh2JlrLPBlyxyY +dayDifd3X8jxq6/oAuyH1h5K/QLs46jLSR8fUbG98SCHlRmvozTuWGk+e07ALtGe +sGv78ToHKwoM2buXaTTHMwYwu7Rx8LZ4bZPHdersN1VW/m9yn1n5hMzwbFKy2s6/ +D4Q2ZBsqlN+5aW2q0IUmO+m0GhcdaDv8U7RVto1cWWPr50HhiCi7Yvei1qZiD9jq +57oYZVqTUNCTPxi6NeTOdEc+YqNynWNArx4PHh38LT0bqKtlZCGHNfoAJLPVYhbB +b2AHj9edYtHU9AAFSIy+HstET6P0UDxy02IeyE2yxoUBqdlXyv6FL44E+wARAQAB +zRxMYXVuY2hwYWQgUFBBIGZvciBVYnVudHUgSFBDwsGOBBMBCgA4FiEErocSHcPk +oLD4H/Aj9tDF1ca+s3sFAmTuZb8CGwMFCwkIBwIGFQoJCAsCBBYCAwECHgECF4AA +CgkQ9tDF1ca+s3sz3w//RNawsgydrutcbKf0yphDhzWS53wgfrs2KF1KgB0u/H+u +6Kn2C6jrVM0vuY4NKpbEPCduOj21pTCepL6PoCLv++tICOLVok5wY7Zn3WQFq0js +Iy1wO5t3kA1cTD/05v/qQVBGZ2j4DsJo33iMcQS5AjHvSr0nu7XSvDDEE3cQE55D +87vL7lgGjuTOikPh5FpCoS1gpemBfwm2Lbm4P8vGOA4/witRjGgfC1fv1idUnZLM +TbGrDlhVie8pX2kgB6yTYbJ3P3kpC1ZPpXSRWO/cQ8xoYpLBTXOOtqwZZUnxyzHh +gM+hv42vPTOnCo+apD97/VArsp59pDqEVoAtMTk72fdBqR+BB77g2hBkKESgQIEq +EiE1/TOISioMkE0AuUdaJ2ebyQXugSHHuBaqbEC47v8t5DVN5Qr9OriuzCuSDNFn +6SBHpahN9ZNi9w0A/Yh1+lFfpkVw2t04Q2LNuupqOpW+h3/62AeUqjUIAIrmfeML +IDRE2VdquYdIXKuhNvfpJYGdyvx/wAbiAeBWg0uPSepwTfTG59VPQmj0FtalkMnN +ya2212K5q68O5eXOfCnGeMvqIXxqzpdukxSZnLkgk40uFJnJVESd/CxHquqHPUDE +fy6i2AnB3kUI27D4HY2YSlXLSRbjiSxTfVwNCzDsIh7Czefsm6ITK2+cVWs0hNQ= +=cs1s +-----END PGP PUBLIC KEY BLOCK----- +""" + + +class Slurmctld: + """Facilitate slurmctld package lifecycle ops.""" + + _package_name: str = "slurmctld" + _keyring_path: Path = Path("/usr/share/keyrings/slurm-wlm.asc") + + def _repo(self) -> None: + """Return the slurmctld repo.""" + ppa_url: str = "https://ppa.launchpadcontent.net/ubuntu-hpc/slurm-wlm-23.02/ubuntu" + sources_list: str = ( + f"deb [signed-by={self._keyring_path}] {ppa_url} {distro.codename()} main" + ) + return apt.DebianRepository.from_repo_line(sources_list) + + def install(self) -> None: + """Install the slurmctld package using lib apt.""" + # Install the key. + if self._keyring_path.exists(): + self._keyring_path.unlink() + self._keyring_path.write_text(SLURM_PPA_KEY) + + # Add the repo. + repositories = apt.RepositoryMapping() + repositories.add(self._repo()) + + # Install the slurmctld, slurm-client packages. + try: + # Run `apt-get update` + apt.update() + apt.add_package(["mailutils", "logrotate"]) + apt.add_package([self._package_name, "slurm-client"]) + except apt.PackageNotFoundError: + logger.error(f"{self._package_name} not found in package cache or on system") + except apt.PackageError as e: + logger.error(f"Could not install {self._package_name}. Reason: %s", e.message) + + def uninstall(self) -> None: + """Uninstall the slurmctld package using libapt.""" + # Uninstall the slurmctld package. + if apt.remove_package(self._package_name): + logger.info(f"{self._package_name} removed from system.") + else: + logger.error(f"{self._package_name} not found on system") + + # Disable the slurmctld repo. + repositories = apt.RepositoryMapping() + repositories.disable(self._repo()) + + # Remove the key. + if self._keyring_path.exists(): + self._keyring_path.unlink() + + def upgrade_to_latest(self) -> None: + """Upgrade slurmctld to latest.""" + try: + slurmctld = apt.DebianPackage.from_system(self._package_name) + slurmctld.ensure(apt.PackageState.Latest) + logger.info("updated vim to version: %s", slurmctld.version.number) + except apt.PackageNotFoundError: + logger.error("a specified package not found in package cache or on system") + except apt.PackageError as e: + logger.error("could not install package. Reason: %s", e.message) + + def version(self) -> str: + """Return the slurmctld version.""" + try: + slurmctld = apt.DebianPackage.from_installed_package(self._package_name) + except apt.PackageNotFoundError: + logger.error(f"{self._package_name} not found on system") + return slurmctld.version.number + + +class SlurmctldManager(Object): + """SlurmctldManager.""" + + _stored = StoredState() + + def __init__(self, charm, component): + """Set the initial attribute values.""" + super().__init__(charm, component) + + self._charm = charm + + self._stored.set_default(slurm_installed=False) + self._stored.set_default(slurm_version_set=False) + + """Set the initial values for attributes in the base class.""" + self._slurm_conf_template_name = "slurm.conf.tmpl" + self._slurm_conf_path = self._slurm_conf_dir / "slurm.conf" + + self._slurmd_log_file = self._slurm_log_dir / "slurmd.log" + self._slurmctld_log_file = self._slurm_log_dir / "slurmctld.log" + + self._slurmd_pid_file = self._slurm_pid_dir / "slurmd.pid" + self._slurmctld_pid_file = self._slurm_pid_dir / "slurmctld.pid" + + # NOTE: Come back to mitigate this configless cruft + self._slurmctld_parameters = ["enable_configless"] + + self._slurm_conf_template_location = TEMPLATE_DIR / self._slurm_conf_template_name + + @property + def hostname(self) -> str: + """Return the hostname.""" + return socket.gethostname().split(".")[0] + + @property + def port(self) -> str: + """Return the port.""" + return "6817" + + @property + def slurm_conf_path(self) -> Path: + """Return the slurm conf path.""" + return self._slurm_conf_path + + def slurm_is_active(self) -> bool: + """Return True if the slurm component is running.""" + try: + cmd = f"systemctl is-active {self._slurm_systemd_service}" + r = subprocess.check_output(shlex.split(cmd)) + r = r.decode().strip().lower() + logger.debug(f"### systemctl is-active {self._slurm_systemd_service}: {r}") + return "active" == r + except subprocess.CalledProcessError as e: + logger.error(f"#### Error checking if slurm is active: {e}") + return False + return False + + @property + def _slurm_bin_dir(self) -> Path: + """Return the directory where the slurm bins live.""" + return Path("/usr/bin") + + @property + def _slurm_conf_dir(self) -> Path: + """Return the directory for Slurm configuration files.""" + return Path("/etc/slurm") + + @property + def _slurm_spool_dir(self) -> Path: + """Return the directory for slurmd's state information.""" + return Path("/var/spool/slurmd") + + @property + def _slurm_state_dir(self) -> Path: + """Return the directory for slurmctld's state information.""" + return Path("/var/spool/slurmctld") + + @property + def _slurm_log_dir(self) -> Path: + """Return the directory for Slurm logs.""" + return Path("/var/log/slurm") + + @property + def _slurm_pid_dir(self) -> Path: + """Return the directory for Slurm PID file.""" + return Path("/var/run/") + + @property + def _jwt_rsa_key_file(self) -> Path: + """Return the jwt rsa key file path.""" + return self._slurm_state_dir / "jwt_hs256.key" + + @property + def _munge_key_path(self) -> Path: + """Return the full path to the munge key.""" + return Path("/etc/munge/munge.key") + + @property + def _munge_socket(self) -> Path: + """Return the munge socket.""" + return Path("/var/run/munge/munge.socket.2") + + @property + def _munged_systemd_service(self) -> str: + """Return the name of the Munge Systemd unit file.""" + return "munge.service" + + @property + def _munge_user(self) -> str: + """Return the user for munge daemon.""" + return "munge" + + @property + def _munge_group(self) -> str: + """Return the group for munge daemon.""" + return "munge" + + @property + def _slurm_plugstack_dir(self) -> Path: + """Return the directory to the SPANK plugins.""" + return Path("/etc/slurm/plugstack.conf.d") + + @property + def _slurm_plugstack_conf(self) -> Path: + """Return the full path to the root plugstack configuration file.""" + return self._slurm_conf_dir / "plugstack.conf" + + @property + def _slurm_systemd_service(self) -> str: + """Return the Slurm systemd unit file.""" + return "slurmctld.service" + + @property + def _slurm_user(self) -> str: + """Return the slurm user.""" + return "slurm" + + @property + def _slurm_user_id(self) -> str: + """Return the slurm user ID.""" + return "64030" + + @property + def _slurm_group(self) -> str: + """Return the slurm group.""" + return "slurm" + + @property + def _slurm_group_id(self) -> str: + """Return the slurm group ID.""" + return "64030" + + @property + def _slurmd_user(self) -> str: + """Return the slurmd user.""" + return "root" + + @property + def _slurmd_group(self) -> str: + """Return the slurmd group.""" + return "root" + + def create_systemd_override_for_nofile(self): + """Create the override.conf file for slurm systemd service.""" + systemd_override_dir = Path(f"/etc/systemd/system/{self._slurm_systemd_service}.d") + if not systemd_override_dir.exists(): + systemd_override_dir.mkdir(exist_ok=True) + + systemd_override_conf = systemd_override_dir / "override.conf" + systemd_override_conf_tmpl = TEMPLATE_DIR / "override.conf" + + shutil.copyfile(systemd_override_conf_tmpl, systemd_override_conf) + + def slurm_config_nhc_values(self, interval=600, state="ANY,CYCLE"): + """NHC parameters for slurm.conf.""" + return { + "nhc_bin": "/usr/sbin/omni-nhc-wrapper", + "health_check_interval": interval, + "health_check_node_state": state, + } + + def write_acct_gather_conf(self, context: dict) -> None: + """Render the acct_gather.conf.""" + template_name = "acct_gather.conf.tmpl" + source = TEMPLATE_DIR / template_name + target = self._slurm_conf_dir / "acct_gather.conf" + + if not isinstance(context, dict): + raise TypeError("Incorrect type for config.") + + if not source.exists(): + raise FileNotFoundError("The acct_gather template cannot be found.") + + rendered_template = Environment(loader=FileSystemLoader(TEMPLATE_DIR)).get_template( + template_name + ) + + if target.exists(): + target.unlink() + + target.write_text(rendered_template.render(context)) + + def remove_acct_gather_conf(self) -> None: + """Remove acct_gather.conf.""" + target = self._slurm_conf_dir / "acct_gather.conf" + if target.exists(): + target.unlink() + + def write_slurm_config(self, context) -> None: + """Render the context to a template, adding in common configs.""" + common_config = { + "munge_socket": str(self._munge_socket), + "mail_prog": str(self._mail_prog), + "slurm_state_dir": str(self._slurm_state_dir), + "slurm_spool_dir": str(self._slurm_spool_dir), + "slurm_plugin_dir": str(self._slurm_plugin_dir), + "slurmd_log_file": str(self._slurmd_log_file), + "slurmctld_log_file": str(self._slurmctld_log_file), + "slurmd_pid_file": str(self._slurmd_pid_file), + "slurmctld_pid_file": str(self._slurmctld_pid_file), + "jwt_rsa_key_file": str(self._jwt_rsa_key_file), + "slurmctld_parameters": ",".join(self._slurmctld_parameters), + "slurm_plugstack_conf": str(self._slurm_plugstack_conf), + "slurm_user": str(self._slurm_user), + "slurmd_user": str(self._slurmd_user), + } + + template_name = self._slurm_conf_template_name + source = self._slurm_conf_template_location + target = self._slurm_conf_path + + if not isinstance(context, dict): + raise TypeError("Incorrect type for config.") + + if not source.exists(): + raise FileNotFoundError("The slurm config template cannot be found.") + + # Preprocess merging slurmctld_parameters if they exist in the context + context_slurmctld_parameters = context.get("slurmctld_parameters") + if context_slurmctld_parameters: + slurmctld_parameters = list( + set( + common_config["slurmctld_parameters"].split(",") + + context_slurmctld_parameters.split(",") + ) + ) + + common_config["slurmctld_parameters"] = ",".join(slurmctld_parameters) + context.pop("slurmctld_parameters") + + rendered_template = Environment(loader=FileSystemLoader(TEMPLATE_DIR)).get_template( + template_name + ) + + if target.exists(): + target.unlink() + + target.write_text(rendered_template.render({**context, **common_config})) + + user_group = f"{self._slurm_user}:{self._slurm_group}" + subprocess.call(["chown", user_group, target]) + + def write_munge_key(self, munge_key): + """Base64 decode and write the munge key.""" + key = b64decode(munge_key.encode()) + self._munge_key_path.write_bytes(key) + + def write_jwt_rsa(self, jwt_rsa): + """Write the jwt_rsa key.""" + # Remove jwt_rsa if exists. + if self._jwt_rsa_key_file.exists(): + self._jwt_rsa_key_file.write_bytes(os.urandom(2048)) + self._jwt_rsa_key_file.unlink() + + # Write the jwt_rsa key to the file and chmod 0600, + # chown to slurm_user. + self._jwt_rsa_key_file.write_text(jwt_rsa) + self._jwt_rsa_key_file.chmod(0o600) + subprocess.call( + [ + "chown", + self._slurm_user, + str(self._jwt_rsa_key_file), + ] + ) + + def write_cgroup_conf(self, content): + """Write the cgroup.conf file.""" + cgroup_conf_path = self._slurm_conf_dir / "cgroup.conf" + cgroup_conf_path.write_text(content) + + def get_munge_key(self) -> str: + """Read the bytes, encode to base64, decode to a string, return.""" + munge_key = self._munge_key_path.read_bytes() + return b64encode(munge_key).decode() + + def start_munged(self): + """Start munge.service.""" + logger.debug("## Starting munge") + + munge = self._munged_systemd_service + try: + subprocess.check_output(["systemctl", "start", munge]) + except subprocess.CalledProcessError as e: + logger.error(f"## Error starting munge: {e}") + return False + + return self._is_active_munged() + + def _is_active_munged(self): + munge = self._munged_systemd_service + try: + status = subprocess.check_output(f"systemctl is-active {munge}", shell=True) + status = status.decode().strip() + if "active" in status: + logger.debug("#### Munge daemon active") + return True + else: + logger.error(f"## Munge not running: {status}") + return False + except subprocess.CalledProcessError as e: + logger.error(f"## Error querring munged - {e}") + return False + + def check_munged(self) -> bool: + """Check if munge is working correctly.""" + # check if systemd service unit is active + if not self._is_active_munged(): + return False + + # check if munge is working, i.e., can use the credentials correctly + try: + logger.debug("## Testing if munge is working correctly") + cmd = "munge -n" + munge = subprocess.Popen( + shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + unmunge = subprocess.Popen( + ["unmunge"], stdin=munge.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + munge.stdout.close() + output = unmunge.communicate()[0] + if "Success" in output.decode(): + logger.debug(f"## Munge working as expected: {output}") + return True + logger.error(f"## Munge not working: {output}") + except subprocess.CalledProcessError as e: + logger.error(f"## Error testing munge: {e}") + + return False + + @property + def _slurm_plugin_dir(self) -> Path: + # Debian packages slurm plugins in /usr/lib/x86_64-linux-gnu/slurm-wlm/ + # but we symlink /usr/lib64/slurm to it for compatibility with centos + return Path("/usr/lib64/slurm/") + + @property + def _mail_prog(self) -> Path: + return Path("/usr/bin/mail.mailutils") + + def version(self) -> str: + """Return slurm version.""" + return Slurmctld().version() + + def _install_slurm_from_apt(self) -> bool: + """Install Slurm debs. + + Returns: + bool: True on success and False otherwise. + """ + Slurmctld().install() + + # symlink /usr/lib64/slurm -> /usr/lib/x86_64-linux-gnu/slurm-wlm/ to + # have "standard" location across OSes + lib64_slurm = Path("/usr/lib64/slurm") + if lib64_slurm.exists(): + lib64_slurm.unlink() + lib64_slurm.symlink_to("/usr/lib/x86_64-linux-gnu/slurm-wlm/") + return True + + def upgrade(self) -> bool: + """Run upgrade operations.""" + Slurmctld().upgrade_to_latest() + + # symlink /usr/lib64/slurm -> /usr/lib/x86_64-linux-gnu/slurm-wlm/ to + # have "standard" location across OSes + lib64_slurm = Path("/usr/lib64/slurm") + if lib64_slurm.exists(): + lib64_slurm.unlink() + lib64_slurm.symlink_to("/usr/lib/x86_64-linux-gnu/slurm-wlm/") + return True + + def _setup_plugstack_dir_and_config(self) -> None: + """Create plugstack directory and config.""" + # Create the plugstack config directory. + plugstack_dir = self._slurm_plugstack_dir + + if plugstack_dir.exists(): + shutil.rmtree(plugstack_dir) + + plugstack_dir.mkdir() + subprocess.call(["chown", "-R", f"{self._slurm_user}:{self._slurm_group}", plugstack_dir]) + + # Write the plugstack config. + plugstack_conf = self._slurm_plugstack_conf + + if plugstack_conf.exists(): + plugstack_conf.unlink() + + plugstack_conf.write_text(f"include {plugstack_dir}/*.conf") + + def _setup_paths(self): + """Create needed paths with correct permissions.""" + user = f"{self._slurm_user}:{self._slurm_group}" + + all_paths = [ + self._slurm_conf_dir, + self._slurm_log_dir, + self._slurm_state_dir, + self._slurm_spool_dir, + ] + for syspath in all_paths: + if not syspath.exists(): + syspath.mkdir() + subprocess.call(["chown", "-R", user, syspath]) + + def restart_munged(self) -> bool: + """Restart the munged process. + + Return True on success, and False otherwise. + """ + try: + logger.debug("## Restarting munge") + systemd.service_restart("munge") + except Exception("Error restarting munge") as e: + logger.error(e.message) + return False + return self.check_munged() + + def restart_slurmctld(self) -> bool: + """Restart the slurmctld process. + + Return True on success, and False otherwise. + """ + try: + logger.debug("## Restarting slurmctld") + systemd.service_restart("slurmctld") + except Exception("Error restarting slurmctld") as e: + logger.error(e.message) + return False + return True + + def slurm_cmd(self, command, arg_string): + """Run a slurm command.""" + try: + return subprocess.call([f"{command}"] + arg_string.split()) + except subprocess.CalledProcessError as e: + logger.error(f"Error running {command} - {e}") + return -1 + + def generate_jwt_rsa(self) -> str: + """Generate the rsa key to encode the jwt with.""" + return RSA.generate(2048).export_key("PEM").decode() + + @property + def slurm_installed(self) -> bool: + """Return the bool from the stored state.""" + return self._stored.slurm_installed + + @property + def slurm_component(self) -> str: + """Return the slurm component.""" + return "slurmctld" + + @property + def fluentbit_config_slurm(self) -> list: + """Return Fluentbit configuration parameters to forward Slurm logs.""" + log_file = self._slurmctld_log_file + + cfg = [ + { + "input": [ + ("name", "tail"), + ("path", log_file.as_posix()), + ("path_key", "filename"), + ("tag", "slurmctld"), + ("parser", "slurm"), + ] + }, + { + "parser": [ + ("name", "slurm"), + ("format", "regex"), + ("regex", r"^\[(?