Skip to content

Commit

Permalink
Add initial workflow runner (#4)
Browse files Browse the repository at this point in the history
* add function to run a workflow, and retrieve output (or error)

* use subprocess, as os.exec replaces existing

also, ditch separate workflow module

* use Deployment.run_workflow method from cli

* w/o profile, use default cache

* only update AWS_PROFILE if set

* wire in force_rerun

* use "is_flag" click syntax

* leave "force_rerun" to the library

* slashes are special, only add dash and time in microseconds

* revert to nanoseconds for force-rerun timestamp

* use backoff lib to add exponential backoff/retry to
statedb get

* add import

* add accessor to exception

* use predicate backoff since exception
doesn't seem to be raised during failed
state db get

* expand comment on boto session hacking

* Update src/cirrus/plugins/management/deployment.py

* bump up the polling timeout

* updates per PR feedback

- poll_interval is user configurable
- run_workflow command uses stdin instead of files
- Deployment.exec changes reverted, and Deployment.call method added
- stale doc strings updated

* remove non-standard 'force_rerun' parameter

one should use 'replace' instead of this parameter

* add functions to invoke deployment lambdas

* lint: black and isort

* update CHANGELOG.md

* handle 50 lambda limit

---------

Co-authored-by: Arthur Elmes <[email protected]>
  • Loading branch information
ircwaves and arthurelmes authored Aug 1, 2023
1 parent 67ba278 commit 66fabc4
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 8 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

### Added
- CLI and library functions to run a cirrus workflow, and collect its output.
Also adds a `call` command which is similar to `exec` but uses a
subprocess. And adds `invoke_lambda` which invokes lambdas that are part of
the cirrus deployment. ([#4](https://github.com/cirrus-geo/cirrus-mgmt/pull/4)


## [v0.1.0] -

Initial release
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
click-option-group>=0.5.5
cirrus-geo>=0.9.0a0
cirrus-geo>=0.9.0
backoff>=2.2.1
76 changes: 75 additions & 1 deletion src/cirrus/plugins/management/commands/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cirrus.cli.utils import click as utils_click
from click_option_group import RequiredMutuallyExclusiveOptionGroup, optgroup

from cirrus.plugins.management.deployment import Deployment
from cirrus.plugins.management.deployment import WORKFLOW_POLL_INTERVAL, Deployment
from cirrus.plugins.management.utils.click import (
additional_variables,
silence_templating_errors,
Expand Down Expand Up @@ -111,6 +111,36 @@ def refresh(deployment, stackname=None, profile=None):
deployment.refresh(stackname=stackname, profile=profile)


@manage.command("run-workflow")
@click.option(
"-t",
"--timeout",
type=int,
default=3600,
help="Maximum time (seconds) to allow for the workflow to complete",
)
@click.option(
"-p",
"--poll-interval",
type=int,
default=WORKFLOW_POLL_INTERVAL,
help="Maximum time (seconds) to allow for the workflow to complete",
)
@raw_option
@pass_deployment
def run_workflow(deployment, timeout, raw, poll_interval):
"""Pass a payload (from stdin) off to a deployment, wait for the workflow to finish,
retrieve and return its output payload"""
payload = json.load(sys.stdin.read())

output = deployment.run_workflow(
payload=payload,
timeout=timeout,
poll_interval=poll_interval,
)
click.echo(json.dump(output, sys.stdout, indent=4 if not raw else None))


@manage.command("get-payload")
@click.argument(
"payload-id",
Expand Down Expand Up @@ -203,6 +233,18 @@ def process(deployment):
click.echo(json.dumps(deployment.process_payload(sys.stdin), indent=4))


@manage.command()
@click.argument(
"lambda-name",
)
@pass_deployment
def invoke_lambda(deployment, lambda_name):
"""Invoke lambda with event (from stdin)"""
click.echo(
json.dumps(deployment.invoke_lambda(sys.stdin.read(), lambda_name), indent=4)
)


@manage.command("template-payload")
@additional_variables
@silence_templating_errors
Expand Down Expand Up @@ -245,6 +287,38 @@ def _exec(ctx, deployment, command, include_user_vars):
deployment.exec(command, include_user_vars=include_user_vars)


@manage.command(
"call",
context_settings={
"ignore_unknown_options": True,
},
)
@click.argument(
"command",
nargs=-1,
)
@include_user_vars
@pass_deployment
@click.pass_context
def _call(ctx, deployment, command, include_user_vars):
"""Run an executable, in a new process, with the deployment environment vars loaded"""
if not command:
return
deployment.call(command, include_user_vars=include_user_vars)


@manage.command()
@pass_deployment
@click.pass_context
def list_lambdas(ctx, deployment):
"""List lambda functions"""
click.echo(
json.dumps(
{"Functions": deployment.get_lambda_functions()}, indent=4, default=str
)
)


# check-pipeline
# - this is like failmgr check
# - not sure how to reconcile with cache above
Expand Down
108 changes: 106 additions & 2 deletions src/cirrus/plugins/management/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import os
from datetime import datetime, timezone
from pathlib import Path
from subprocess import check_call
from time import sleep, time

import backoff
from cirrus.lib2.process_payload import ProcessPayload

from . import exceptions
from .utils.boto3 import get_mfa_session, validate_session
Expand All @@ -14,6 +19,8 @@
MAX_SQS_MESSAGE_LENGTH = 2**18 # max length of SQS message
CONFIG_VERSION = 0

WORKFLOW_POLL_INTERVAL = 15 # seconds between state checks


def deployments_dir_from_project(project):
_dir = project.dot_dir.joinpath(DEFAULT_DEPLOYMENTS_DIR_NAME)
Expand Down Expand Up @@ -72,6 +79,7 @@ def __init__(self, path: Path, *args, **kwargs):
super().__init__(*args, **kwargs)

self._session = None
self._functions = None

@classmethod
def create(cls, name: str, project, stackname: str = None, profile: str = None):
Expand Down Expand Up @@ -152,6 +160,24 @@ def get_env_from_lambda(stackname: str, session):

return process_conf["Environment"]["Variables"]

def get_lambda_functions(self):
if self._functions is None:
aws_lambda = self.get_session().client("lambda")

def deployment_functions_filter(response):
return [
f["FunctionName"].replace(f"{self.stackname}-", "")
for f in response["Functions"]
if f["FunctionName"].startswith(self.stackname)
]

resp = aws_lambda.list_functions()
self._functions = deployment_functions_filter(resp)
while "NextMarker" in resp:
resp = aws_lambda.list_functions(Marker=resp["NextMarker"])
self._functions += deployment_functions_filter(resp)
return self._functions

def get_session(self):
if not self._session:
self._session = self._get_session(profile=self.profile)
Expand All @@ -171,7 +197,8 @@ def set_env(self, include_user_vars=False):
os.environ.update(self.environment)
if include_user_vars:
os.environ.update(self.user_vars)
os.environ["AWS_PROFILE"] = self.profile
if self.profile:
os.environ["AWS_PROFILE"] = self.profile

def add_user_vars(self, _vars, save=False):
self.user_vars.update(_vars)
Expand All @@ -198,14 +225,30 @@ def exec(self, command, include_user_vars=True, isolated=False):
self.set_env(include_user_vars=include_user_vars)
os.execlp(command[0], *command)

def call(self, command, include_user_vars=True, isolated=False):
if isolated:
env = self.environment.copy()
if include_user_vars:
env.update(self.user_vars)
check_call(command, env=env)
else:
self.set_env(include_user_vars=include_user_vars)
check_call(command)

def get_payload_state(self, payload_id):
from cirrus.lib2.statedb import StateDB

statedb = StateDB(
table_name=self.environment["CIRRUS_STATE_DB"],
session=self.get_session(),
)
state = statedb.get_dbitem(payload_id)

@backoff.on_predicate(backoff.expo, lambda x: x is None, max_time=60)
def _get_payload_item_from_statedb(statedb, payload_id):
return statedb.get_dbitem(payload_id)

state = _get_payload_item_from_statedb(statedb, payload_id)

if not state:
raise exceptions.PayloadNotFoundError(payload_id)
return state
Expand Down Expand Up @@ -265,6 +308,67 @@ def get_execution_by_payload_id(self, payload_id):

return self.get_execution(exec_arn)

def invoke_lambda(self, event, function_name):
aws_lambda = self.get_session().client("lambda")
if function_name not in self.get_lambda_functions():
raise ValueError(
f"lambda named '{function_name}' not found in deployment '{self.name}'"
)
full_name = f"{self.stackname}-{function_name}"
response = aws_lambda.invoke(FunctionName=full_name, Payload=event)
if response["StatusCode"] < 200 or response["StatusCode"] > 299:
raise RuntimeError(response)

return json.load(response["Payload"])

def run_workflow(
self,
payload: dict,
timeout: int = 3600,
poll_interval: int = WORKFLOW_POLL_INTERVAL,
) -> dict:
"""
Args:
deployment (Deployment): where the workflow will be run.
payload (str): payload to pass to the deployment to kick off the workflow.
timeout (Optional[int]): - upper bound on the number of seconds to poll the
deployment before considering the test failed.
poll_interval (Optional[int]): - seconds to delay between checks of the
workflow status.
Returns:
dict containing output payload or error message
"""
payload = ProcessPayload(payload)
wf_id = payload["id"]
logger.info("Submitting %s to %s", wf_id, self.name)
resp = self.process_payload(json.dumps(payload))
logger.debug(resp)

state = "PROCESSING"
end_time = time() + timeout - poll_interval
while state == "PROCESSING" and time() < end_time:
sleep(poll_interval)
resp = self.get_payload_state(wf_id)
state = resp["state_updated"].split("_")[0]
logger.debug({"state": state})

execution = self.get_execution_by_payload_id(wf_id)

if state == "COMPLETED":
output = dict(ProcessPayload.from_event(json.loads(execution["output"])))
elif state == "PROCESSING":
output = {"last_error": "Unkonwn: cirrus-mgmt polling timeout exceeded"}
else:
output = {"last_error": resp.get("last_error", "last error not recorded")}

return output

def template_payload(
self,
payload: str,
Expand Down
12 changes: 8 additions & 4 deletions src/cirrus/plugins/management/utils/boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ def get_mfa_session(**kwargs):
key/value pairs defined in this dictionary will override the
corresponding variables defined in ``SESSION_VARIABLES``.
"""
# Change the cache path from the default of
# ~/.aws/boto/cache to the one used by awscli
working_dir = os.path.join(os.path.expanduser("~"), ".aws/cli/cache")
profile = kwargs.get("profile", None)

# Construct botocore session with cache
session = botocore.session.Session(**kwargs)
provider = session.get_component("credential_provider").get_provider("assume-role")
provider.cache = credentials.JSONFileCache(working_dir)
if profile:
# If ``profile`` is provided, then we need to
# change the cache path from the default of
# ~/.aws/boto/cache to the one used by awscli.
# Without ``profile``, we defer to normal boto operations.
working_dir = os.path.join(os.path.expanduser("~"), ".aws/cli/cache")
provider.cache = credentials.JSONFileCache(working_dir)

return boto3.Session(botocore_session=session)

Expand Down

0 comments on commit 66fabc4

Please sign in to comment.