From 66fabc464da28809648c3187fee3a4e5e1c35c36 Mon Sep 17 00:00:00 2001 From: Ian Cooke Date: Tue, 1 Aug 2023 16:10:45 -0400 Subject: [PATCH] Add initial workflow runner (#4) * add function to run a workflow, and retrieve output (or error) * use subprocess, as os.exec replaces existing also, ditch separate workflow module * use Deployment.run_workflow method from cli * w/o profile, use default cache * only update AWS_PROFILE if set * wire in force_rerun * use "is_flag" click syntax * leave "force_rerun" to the library * slashes are special, only add dash and time in microseconds * revert to nanoseconds for force-rerun timestamp * use backoff lib to add exponential backoff/retry to statedb get * add import * add accessor to exception * use predicate backoff since exception doesn't seem to be raised during failed state db get * expand comment on boto session hacking * Update src/cirrus/plugins/management/deployment.py * bump up the polling timeout * updates per PR feedback - poll_interval is user configurable - run_workflow command uses stdin instead of files - Deployment.exec changes reverted, and Deployment.call method added - stale doc strings updated * remove non-standard 'force_rerun' parameter one should use 'replace' instead of this parameter * add functions to invoke deployment lambdas * lint: black and isort * update CHANGELOG.md * handle 50 lambda limit --------- Co-authored-by: Arthur Elmes --- CHANGELOG.md | 7 ++ requirements.txt | 3 +- .../plugins/management/commands/manage.py | 76 +++++++++++- src/cirrus/plugins/management/deployment.py | 108 +++++++++++++++++- src/cirrus/plugins/management/utils/boto3.py | 12 +- 5 files changed, 198 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 615c408..e910a82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +### Added +- CLI and library functions to run a cirrus workflow, and collect its output. + Also adds a `call` command which is similar to `exec` but uses a + subprocess. And adds `invoke_lambda` which invokes lambdas that are part of + the cirrus deployment. ([#4](https://github.com/cirrus-geo/cirrus-mgmt/pull/4) + + ## [v0.1.0] - Initial release diff --git a/requirements.txt b/requirements.txt index 0a6c964..0ec6eac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ click-option-group>=0.5.5 -cirrus-geo>=0.9.0a0 +cirrus-geo>=0.9.0 +backoff>=2.2.1 diff --git a/src/cirrus/plugins/management/commands/manage.py b/src/cirrus/plugins/management/commands/manage.py index da51d80..dd4b2dc 100644 --- a/src/cirrus/plugins/management/commands/manage.py +++ b/src/cirrus/plugins/management/commands/manage.py @@ -7,7 +7,7 @@ from cirrus.cli.utils import click as utils_click from click_option_group import RequiredMutuallyExclusiveOptionGroup, optgroup -from cirrus.plugins.management.deployment import Deployment +from cirrus.plugins.management.deployment import WORKFLOW_POLL_INTERVAL, Deployment from cirrus.plugins.management.utils.click import ( additional_variables, silence_templating_errors, @@ -111,6 +111,36 @@ def refresh(deployment, stackname=None, profile=None): deployment.refresh(stackname=stackname, profile=profile) +@manage.command("run-workflow") +@click.option( + "-t", + "--timeout", + type=int, + default=3600, + help="Maximum time (seconds) to allow for the workflow to complete", +) +@click.option( + "-p", + "--poll-interval", + type=int, + default=WORKFLOW_POLL_INTERVAL, + help="Maximum time (seconds) to allow for the workflow to complete", +) +@raw_option +@pass_deployment +def run_workflow(deployment, timeout, raw, poll_interval): + """Pass a payload (from stdin) off to a deployment, wait for the workflow to finish, + retrieve and return its output payload""" + payload = json.load(sys.stdin.read()) + + output = deployment.run_workflow( + payload=payload, + timeout=timeout, + poll_interval=poll_interval, + ) + click.echo(json.dump(output, sys.stdout, indent=4 if not raw else None)) + + @manage.command("get-payload") @click.argument( "payload-id", @@ -203,6 +233,18 @@ def process(deployment): click.echo(json.dumps(deployment.process_payload(sys.stdin), indent=4)) +@manage.command() +@click.argument( + "lambda-name", +) +@pass_deployment +def invoke_lambda(deployment, lambda_name): + """Invoke lambda with event (from stdin)""" + click.echo( + json.dumps(deployment.invoke_lambda(sys.stdin.read(), lambda_name), indent=4) + ) + + @manage.command("template-payload") @additional_variables @silence_templating_errors @@ -245,6 +287,38 @@ def _exec(ctx, deployment, command, include_user_vars): deployment.exec(command, include_user_vars=include_user_vars) +@manage.command( + "call", + context_settings={ + "ignore_unknown_options": True, + }, +) +@click.argument( + "command", + nargs=-1, +) +@include_user_vars +@pass_deployment +@click.pass_context +def _call(ctx, deployment, command, include_user_vars): + """Run an executable, in a new process, with the deployment environment vars loaded""" + if not command: + return + deployment.call(command, include_user_vars=include_user_vars) + + +@manage.command() +@pass_deployment +@click.pass_context +def list_lambdas(ctx, deployment): + """List lambda functions""" + click.echo( + json.dumps( + {"Functions": deployment.get_lambda_functions()}, indent=4, default=str + ) + ) + + # check-pipeline # - this is like failmgr check # - not sure how to reconcile with cache above diff --git a/src/cirrus/plugins/management/deployment.py b/src/cirrus/plugins/management/deployment.py index e048fd7..b92793d 100644 --- a/src/cirrus/plugins/management/deployment.py +++ b/src/cirrus/plugins/management/deployment.py @@ -4,6 +4,11 @@ import os from datetime import datetime, timezone from pathlib import Path +from subprocess import check_call +from time import sleep, time + +import backoff +from cirrus.lib2.process_payload import ProcessPayload from . import exceptions from .utils.boto3 import get_mfa_session, validate_session @@ -14,6 +19,8 @@ MAX_SQS_MESSAGE_LENGTH = 2**18 # max length of SQS message CONFIG_VERSION = 0 +WORKFLOW_POLL_INTERVAL = 15 # seconds between state checks + def deployments_dir_from_project(project): _dir = project.dot_dir.joinpath(DEFAULT_DEPLOYMENTS_DIR_NAME) @@ -72,6 +79,7 @@ def __init__(self, path: Path, *args, **kwargs): super().__init__(*args, **kwargs) self._session = None + self._functions = None @classmethod def create(cls, name: str, project, stackname: str = None, profile: str = None): @@ -152,6 +160,24 @@ def get_env_from_lambda(stackname: str, session): return process_conf["Environment"]["Variables"] + def get_lambda_functions(self): + if self._functions is None: + aws_lambda = self.get_session().client("lambda") + + def deployment_functions_filter(response): + return [ + f["FunctionName"].replace(f"{self.stackname}-", "") + for f in response["Functions"] + if f["FunctionName"].startswith(self.stackname) + ] + + resp = aws_lambda.list_functions() + self._functions = deployment_functions_filter(resp) + while "NextMarker" in resp: + resp = aws_lambda.list_functions(Marker=resp["NextMarker"]) + self._functions += deployment_functions_filter(resp) + return self._functions + def get_session(self): if not self._session: self._session = self._get_session(profile=self.profile) @@ -171,7 +197,8 @@ def set_env(self, include_user_vars=False): os.environ.update(self.environment) if include_user_vars: os.environ.update(self.user_vars) - os.environ["AWS_PROFILE"] = self.profile + if self.profile: + os.environ["AWS_PROFILE"] = self.profile def add_user_vars(self, _vars, save=False): self.user_vars.update(_vars) @@ -198,6 +225,16 @@ def exec(self, command, include_user_vars=True, isolated=False): self.set_env(include_user_vars=include_user_vars) os.execlp(command[0], *command) + def call(self, command, include_user_vars=True, isolated=False): + if isolated: + env = self.environment.copy() + if include_user_vars: + env.update(self.user_vars) + check_call(command, env=env) + else: + self.set_env(include_user_vars=include_user_vars) + check_call(command) + def get_payload_state(self, payload_id): from cirrus.lib2.statedb import StateDB @@ -205,7 +242,13 @@ def get_payload_state(self, payload_id): table_name=self.environment["CIRRUS_STATE_DB"], session=self.get_session(), ) - state = statedb.get_dbitem(payload_id) + + @backoff.on_predicate(backoff.expo, lambda x: x is None, max_time=60) + def _get_payload_item_from_statedb(statedb, payload_id): + return statedb.get_dbitem(payload_id) + + state = _get_payload_item_from_statedb(statedb, payload_id) + if not state: raise exceptions.PayloadNotFoundError(payload_id) return state @@ -265,6 +308,67 @@ def get_execution_by_payload_id(self, payload_id): return self.get_execution(exec_arn) + def invoke_lambda(self, event, function_name): + aws_lambda = self.get_session().client("lambda") + if function_name not in self.get_lambda_functions(): + raise ValueError( + f"lambda named '{function_name}' not found in deployment '{self.name}'" + ) + full_name = f"{self.stackname}-{function_name}" + response = aws_lambda.invoke(FunctionName=full_name, Payload=event) + if response["StatusCode"] < 200 or response["StatusCode"] > 299: + raise RuntimeError(response) + + return json.load(response["Payload"]) + + def run_workflow( + self, + payload: dict, + timeout: int = 3600, + poll_interval: int = WORKFLOW_POLL_INTERVAL, + ) -> dict: + """ + + Args: + deployment (Deployment): where the workflow will be run. + + payload (str): payload to pass to the deployment to kick off the workflow. + + timeout (Optional[int]): - upper bound on the number of seconds to poll the + deployment before considering the test failed. + + poll_interval (Optional[int]): - seconds to delay between checks of the + workflow status. + + Returns: + dict containing output payload or error message + + """ + payload = ProcessPayload(payload) + wf_id = payload["id"] + logger.info("Submitting %s to %s", wf_id, self.name) + resp = self.process_payload(json.dumps(payload)) + logger.debug(resp) + + state = "PROCESSING" + end_time = time() + timeout - poll_interval + while state == "PROCESSING" and time() < end_time: + sleep(poll_interval) + resp = self.get_payload_state(wf_id) + state = resp["state_updated"].split("_")[0] + logger.debug({"state": state}) + + execution = self.get_execution_by_payload_id(wf_id) + + if state == "COMPLETED": + output = dict(ProcessPayload.from_event(json.loads(execution["output"]))) + elif state == "PROCESSING": + output = {"last_error": "Unkonwn: cirrus-mgmt polling timeout exceeded"} + else: + output = {"last_error": resp.get("last_error", "last error not recorded")} + + return output + def template_payload( self, payload: str, diff --git a/src/cirrus/plugins/management/utils/boto3.py b/src/cirrus/plugins/management/utils/boto3.py index c30b4e1..023cbdd 100644 --- a/src/cirrus/plugins/management/utils/boto3.py +++ b/src/cirrus/plugins/management/utils/boto3.py @@ -25,14 +25,18 @@ def get_mfa_session(**kwargs): key/value pairs defined in this dictionary will override the corresponding variables defined in ``SESSION_VARIABLES``. """ - # Change the cache path from the default of - # ~/.aws/boto/cache to the one used by awscli - working_dir = os.path.join(os.path.expanduser("~"), ".aws/cli/cache") + profile = kwargs.get("profile", None) # Construct botocore session with cache session = botocore.session.Session(**kwargs) provider = session.get_component("credential_provider").get_provider("assume-role") - provider.cache = credentials.JSONFileCache(working_dir) + if profile: + # If ``profile`` is provided, then we need to + # change the cache path from the default of + # ~/.aws/boto/cache to the one used by awscli. + # Without ``profile``, we defer to normal boto operations. + working_dir = os.path.join(os.path.expanduser("~"), ".aws/cli/cache") + provider.cache = credentials.JSONFileCache(working_dir) return boto3.Session(botocore_session=session)