diff --git a/CHANGES.md b/CHANGES.md
index c70a1bee9bd..47ee816837c 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -16,6 +16,11 @@ ones in. -->
### Enhancements
+[#5187](https://github.com/cylc/cylc-flow/pull/5189) - Allow
+`cylc validate --revalidate` to use template variables collected from
+the workflow database. Also applied to `cylc graph`, `cylc view` and
+`cylc config`.
+
[#5032](https://github.com/cylc/cylc-flow/pull/5032) - set a default limit of
100 for the "default" queue.
diff --git a/cylc/flow/option_parsers.py b/cylc/flow/option_parsers.py
index a0609e0e172..37a04c52cb2 100644
--- a/cylc/flow/option_parsers.py
+++ b/cylc/flow/option_parsers.py
@@ -36,6 +36,8 @@
from typing import Any, Dict, Optional, List, Tuple
from cylc.flow import LOG
+from cylc.flow.exceptions import WorkflowConfigError
+from cylc.flow.pathutil import is_in_a_rundir
from cylc.flow.terminal import supports_color, DIM
import cylc.flow.flags
from cylc.flow.loggingutil import (
@@ -289,6 +291,7 @@ def __init__(
argdoc: Optional[List[Tuple[str, str]]] = None,
comms: bool = False,
jset: bool = False,
+ revalidate: bool = False,
multitask: bool = False,
multiworkflow: bool = False,
auto_add: bool = True,
@@ -303,6 +306,7 @@ def __init__(
instructions. Optional list of tuples of (name, description).
comms: If True, allow the --comms-timeout option.
jset: If True, allow the Jinja2 --set option.
+ revalidate: If True, allow the --revalidate option.
multitask: If True, insert the multitask text into the
usage instructions.
multiworkflow: If True, insert the multiworkflow text into the
@@ -327,6 +331,7 @@ def __init__(
self.unlimited_args = False
self.comms = comms
self.jset = jset
+ self.revalidate = revalidate
self.color = color
# Whether to log messages that are below warning level to stdout
# instead of stderr:
@@ -440,6 +445,13 @@ def add_std_options(self):
),
action="store", default=None, dest="templatevars_file")
+ if self.revalidate:
+ self.add_std_option(
+ '--revalidate',
+ help="Get template variables from prevous workflow run.",
+ action='store_true', default=False
+ )
+
def add_cylc_rose_options(self) -> None:
"""Add extra options for cylc-rose plugin if it is installed."""
try:
@@ -607,3 +619,11 @@ def __call__(self, **kwargs) -> Values:
setattr(opts, key, value)
return opts
+
+
+def can_revalidate(flow_file, opts):
+ if not is_in_a_rundir(flow_file) and opts.revalidate:
+ raise WorkflowConfigError(
+ 'Revalidation only works with installed workflows.'
+ )
+ return True
diff --git a/cylc/flow/parsec/fileparse.py b/cylc/flow/parsec/fileparse.py
index 510a46d6322..8ed192ebd2d 100644
--- a/cylc/flow/parsec/fileparse.py
+++ b/cylc/flow/parsec/fileparse.py
@@ -251,7 +251,7 @@ def process_plugins(fpath, opts):
# If you want it to work on sourcedirs you need to get the options
# to here.
plugin_result = entry_point.resolve()(
- srcdir=fpath, opts=opts
+ fpath, opts=opts
)
except Exception as exc:
# NOTE: except Exception (purposefully vague)
diff --git a/cylc/flow/pathutil.py b/cylc/flow/pathutil.py
index 0d62ef37ad4..10e1a665665 100644
--- a/cylc/flow/pathutil.py
+++ b/cylc/flow/pathutil.py
@@ -456,3 +456,8 @@ def get_workflow_name_from_id(workflow_id: str) -> str:
name_path = id_path
return str(name_path.relative_to(cylc_run_dir))
+
+
+def is_in_a_rundir(path_):
+ """Is this path in a run directory"""
+ return is_relative_to(path_, Path(get_cylc_run_dir()))
diff --git a/cylc/flow/pre_configure/get_old_tvars.py b/cylc/flow/pre_configure/get_old_tvars.py
new file mode 100644
index 00000000000..de004a33783
--- /dev/null
+++ b/cylc/flow/pre_configure/get_old_tvars.py
@@ -0,0 +1,70 @@
+# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
+# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+"""Retrieve template variables stored in a workflow database.
+"""
+
+from cylc.flow.rundb import CylcWorkflowDAO
+from cylc.flow.templatevars import eval_var
+from optparse import Values
+from pathlib import Path
+from typing import Union
+
+
+class OldTemplateVars:
+ """Gets template variables stored in workflow database.
+
+ Mirrors the interface used in scheduler.py to get db nfo on restart.
+ """
+ DB = 'log/db'
+
+ def __init__(self, run_dir):
+ self.template_vars = {}
+ self._get_db_template_vars(Path(run_dir))
+
+ def _callback(self, _, row):
+ """Extract key and value and run eval_var on them assigning
+ them to self.template_vars.
+ """
+ self.template_vars[row[0]] = eval_var(row[1])
+
+ def _get_db_template_vars(self, run_dir):
+ dao = CylcWorkflowDAO(str(run_dir / self.DB))
+ dao.select_workflow_template_vars(self._callback)
+
+
+# Entry point:
+def main(srcdir: Union[Path, str], opts: 'Values') -> dict:
+ # We can calculate the source directory here!
+ """Get options from a previously installed run.
+
+ These options are stored in the database.
+ Calculate the templating language used from the shebang line.
+
+ N.B. The srcdir for this plugin to operate on is a workflow run dir.
+
+ Args:
+ srcdir: The directory of a previously run workflow.
+ opts: Options Object
+ """
+ if not hasattr(opts, 'revalidate') or not opts.revalidate:
+ return {}
+ else:
+ return {
+ 'template_variables':
+ OldTemplateVars(srcdir).template_vars,
+ 'templating_detected':
+ 'template variables'
+ }
diff --git a/cylc/flow/scripts/config.py b/cylc/flow/scripts/config.py
index 4907ae77176..ec7b3921b41 100755
--- a/cylc/flow/scripts/config.py
+++ b/cylc/flow/scripts/config.py
@@ -49,19 +49,20 @@
$ cylc config --initial-cycle-point=now myflow
"""
+import asyncio
import os.path
from typing import List, Optional, TYPE_CHECKING
from cylc.flow.cfgspec.glbl_cfg import glbl_cfg
from cylc.flow.config import WorkflowConfig
-from cylc.flow.id_cli import parse_id
-from cylc.flow.exceptions import InputError
+from cylc.flow.id_cli import parse_id_async
+from cylc.flow.exceptions import InputError, WorkflowConfigError
from cylc.flow.option_parsers import (
WORKFLOW_ID_OR_PATH_ARG_DOC,
CylcOptionParser as COP,
icp_option,
)
-from cylc.flow.pathutil import get_workflow_run_dir
+from cylc.flow.pathutil import get_workflow_run_dir, is_in_a_rundir
from cylc.flow.templatevars import get_template_vars
from cylc.flow.terminal import cli_function
from cylc.flow.workflow_files import WorkflowFiles
@@ -75,6 +76,7 @@ def get_option_parser() -> COP:
__doc__,
argdoc=[COP.optional(WORKFLOW_ID_OR_PATH_ARG_DOC)],
jset=True,
+ revalidate=True,
)
parser.add_option(
@@ -149,6 +151,14 @@ def main(
options: 'Values',
*ids,
) -> None:
+ asyncio.run(_main(parser, options, *ids))
+
+
+async def _main(
+ parser: COP,
+ options: 'Values',
+ *ids,
+) -> None:
if options.print_platform_names and options.print_platforms:
options.print_platform_names = False
@@ -178,12 +188,17 @@ def main(
)
return
- workflow_id, _, flow_file = parse_id(
+ workflow_id, _, flow_file = await parse_id_async(
*ids,
src=True,
constraint='workflows',
)
+ if not is_in_a_rundir(flow_file) and options.revalidate:
+ raise WorkflowConfigError(
+ 'Revalidation only works with installed workflows.'
+ )
+
if options.print_hierarchy:
print("\n".join(get_config_file_hierarchy(workflow_id)))
return
diff --git a/cylc/flow/scripts/graph.py b/cylc/flow/scripts/graph.py
index 3488177ac17..3b7d32fba48 100644
--- a/cylc/flow/scripts/graph.py
+++ b/cylc/flow/scripts/graph.py
@@ -35,6 +35,7 @@
$ cylc graph one -o 'one.svg'
"""
+import asyncio
from difflib import unified_diff
from shutil import which
from subprocess import Popen, PIPE
@@ -45,11 +46,12 @@
from cylc.flow.config import WorkflowConfig
from cylc.flow.exceptions import InputError, CylcError
from cylc.flow.id import Tokens
-from cylc.flow.id_cli import parse_id
+from cylc.flow.id_cli import parse_id_async
from cylc.flow.option_parsers import (
WORKFLOW_ID_OR_PATH_ARG_DOC,
CylcOptionParser as COP,
icp_option,
+ can_revalidate,
)
from cylc.flow.templatevars import get_template_vars
from cylc.flow.terminal import cli_function
@@ -108,9 +110,10 @@ def get_nodes_and_edges(
workflow_id,
start,
stop,
+ flow_file,
) -> Tuple[List[Node], List[Edge]]:
"""Return graph sorted nodes and edges."""
- config = get_config(workflow_id, opts)
+ config = get_config(workflow_id, opts, flow_file)
if opts.namespaces:
nodes, edges = _get_inheritance_nodes_and_edges(config)
else:
@@ -194,13 +197,8 @@ def _get_inheritance_nodes_and_edges(
return sorted(nodes), sorted(edges)
-def get_config(workflow_id: str, opts: 'Values') -> WorkflowConfig:
+def get_config(workflow_id: str, opts: 'Values', flow_file) -> WorkflowConfig:
"""Return a WorkflowConfig object for the provided reg / path."""
- workflow_id, _, flow_file = parse_id(
- workflow_id,
- src=True,
- constraint='workflows',
- )
template_vars = get_template_vars(opts)
return WorkflowConfig(
workflow_id, flow_file, opts, template_vars=template_vars
@@ -334,7 +332,7 @@ def open_image(filename):
img.show()
-def graph_render(opts, workflow_id, start, stop) -> int:
+def graph_render(opts, workflow_id, start, stop, flow_file) -> int:
"""Render the workflow graph to the specified format.
Graph is rendered to the specified format. The Graphviz "dot" format
@@ -349,6 +347,7 @@ def graph_render(opts, workflow_id, start, stop) -> int:
workflow_id,
start,
stop,
+ flow_file
)
# format the graph in graphviz-dot format
@@ -382,7 +381,9 @@ def graph_render(opts, workflow_id, start, stop) -> int:
return 0
-def graph_reference(opts, workflow_id, start, stop, write=print) -> int:
+def graph_reference(
+ opts, workflow_id, start, stop, flow_file, write=print,
+) -> int:
"""Format the workflow graph using the cylc reference format."""
# get nodes and edges
nodes, edges = get_nodes_and_edges(
@@ -390,6 +391,7 @@ def graph_reference(opts, workflow_id, start, stop, write=print) -> int:
workflow_id,
start,
stop,
+ flow_file
)
for line in format_cylc_reference(opts, nodes, edges):
write(line)
@@ -397,13 +399,22 @@ def graph_reference(opts, workflow_id, start, stop, write=print) -> int:
return 0
-def graph_diff(opts, workflow_a, workflow_b, start, stop) -> int:
+async def graph_diff(opts, workflow_a, workflow_b, start, stop, flow_file) -> int:
"""Difference the workflow graphs using the cylc reference format."""
+
+ workflow_b, _, flow_file_b = await parse_id_async(
+ workflow_b,
+ src=True,
+ constraint='workflows',
+ )
+
# load graphs
graph_a: List[str] = []
graph_b: List[str] = []
- graph_reference(opts, workflow_a, start, stop, write=graph_a.append),
- graph_reference(opts, workflow_b, start, stop, write=graph_b.append),
+ graph_reference(
+ opts, workflow_a, start, stop, flow_file, write=graph_a.append),
+ graph_reference(
+ opts, workflow_b, start, stop, flow_file_b, write=graph_b.append),
# compare graphs
diff_lines = list(
@@ -427,6 +438,7 @@ def get_option_parser() -> COP:
parser = COP(
__doc__,
jset=True,
+ revalidate=True,
argdoc=[
WORKFLOW_ID_OR_PATH_ARG_DOC,
COP.optional(
@@ -507,20 +519,36 @@ def main(
start: Optional[str] = None,
stop: Optional[str] = None
) -> None:
+ result = asyncio.run(_main(parser, opts, workflow_id, start, stop))
+ sys.exit(result)
+
+
+async def _main(
+ parser: COP,
+ opts: 'Values',
+ workflow_id: str,
+ start: Optional[str] = None,
+ stop: Optional[str] = None
+) -> int:
"""Implement ``cylc graph``."""
if opts.grouping and opts.namespaces:
raise InputError('Cannot combine --group and --namespaces.')
if opts.cycles and opts.namespaces:
raise InputError('Cannot combine --cycles and --namespaces.')
+ workflow_id, _, flow_file = await parse_id_async(
+ workflow_id,
+ src=True,
+ constraint='workflows',
+ )
+
+ can_revalidate(flow_file, opts)
+
if opts.diff:
- sys.exit(
- graph_diff(opts, workflow_id, opts.diff, start, stop)
- )
+ return await graph_diff(
+ opts, workflow_id, opts.diff, start, stop, flow_file)
if opts.reference:
- sys.exit(
- graph_reference(opts, workflow_id, start, stop)
- )
- sys.exit(
- graph_render(opts, workflow_id, start, stop)
- )
+ return graph_reference(
+ opts, workflow_id, start, stop, flow_file)
+
+ return graph_render(opts, workflow_id, start, stop, flow_file)
diff --git a/cylc/flow/scripts/validate.py b/cylc/flow/scripts/validate.py
index 6e5031c91ee..49affaec0f1 100755
--- a/cylc/flow/scripts/validate.py
+++ b/cylc/flow/scripts/validate.py
@@ -25,6 +25,7 @@
use 'cylc view -i,--inline WORKFLOW' for comparison.
"""
+import asyncio
from ansimarkup import parse as cparse
from optparse import Values
import sys
@@ -32,18 +33,19 @@
from cylc.flow import LOG, __version__ as CYLC_VERSION
from cylc.flow.config import WorkflowConfig
from cylc.flow.exceptions import (
- WorkflowConfigError,
TaskProxySequenceBoundsError,
- TriggerExpressionError
+ TriggerExpressionError,
+ WorkflowConfigError
)
import cylc.flow.flags
-from cylc.flow.id_cli import parse_id
+from cylc.flow.id_cli import parse_id_async
from cylc.flow.loggingutil import disable_timestamps
from cylc.flow.option_parsers import (
WORKFLOW_ID_OR_PATH_ARG_DOC,
CylcOptionParser as COP,
Options,
icp_option,
+ can_revalidate,
)
from cylc.flow.profiler import Profiler
from cylc.flow.task_proxy import TaskProxy
@@ -55,6 +57,7 @@ def get_option_parser():
parser = COP(
__doc__,
jset=True,
+ revalidate=True,
argdoc=[WORKFLOW_ID_OR_PATH_ARG_DOC],
)
@@ -80,6 +83,11 @@ def get_option_parser():
default="live", dest="run_mode",
choices=['live', 'dummy', 'simulation'])
+ parser.add_option(
+ '--revalidate', help="Validate as if for re-install",
+ default=False, dest="revalidate", action="store_true"
+ )
+
parser.add_option(icp_option)
parser.add_cylc_rose_options()
@@ -102,6 +110,11 @@ def get_option_parser():
@cli_function(get_option_parser)
def main(parser: COP, options: 'Values', workflow_id: str) -> None:
+ """cylc validate CLI."""
+ asyncio.run(_main(parser, options, workflow_id))
+
+
+async def _main(parser: COP, options: 'Values', workflow_id: str) -> None:
"""cylc validate CLI."""
profiler = Profiler(None, options.profile_mode)
profiler.start()
@@ -109,11 +122,13 @@ def main(parser: COP, options: 'Values', workflow_id: str) -> None:
if cylc.flow.flags.verbosity < 2:
disable_timestamps(LOG)
- workflow_id, _, flow_file = parse_id(
+ workflow_id, _, flow_file = await parse_id_async(
workflow_id,
src=True,
constraint='workflows',
)
+ can_revalidate(flow_file, options)
+
cfg = WorkflowConfig(
workflow_id,
flow_file,
diff --git a/cylc/flow/scripts/view.py b/cylc/flow/scripts/view.py
index f5afbf0bb9e..1768285d54a 100755
--- a/cylc/flow/scripts/view.py
+++ b/cylc/flow/scripts/view.py
@@ -25,14 +25,16 @@
configuration (as Cylc would see it).
"""
+import asyncio
from typing import TYPE_CHECKING
-from cylc.flow.id_cli import parse_id
+from cylc.flow.id_cli import parse_id_async
from cylc.flow.option_parsers import (
WORKFLOW_ID_OR_PATH_ARG_DOC,
CylcOptionParser as COP,
)
from cylc.flow.parsec.fileparse import read_and_proc
+from cylc.flow.option_parsers import can_revalidate
from cylc.flow.templatevars import load_template_vars
from cylc.flow.terminal import cli_function
@@ -44,6 +46,7 @@ def get_option_parser():
parser = COP(
__doc__,
jset=True,
+ revalidate=True,
argdoc=[WORKFLOW_ID_OR_PATH_ARG_DOC],
)
@@ -98,12 +101,18 @@ def get_option_parser():
@cli_function(get_option_parser)
def main(parser: COP, options: 'Values', workflow_id: str) -> None:
- workflow_id, _, flow_file = parse_id(
+ asyncio.run(_main(parser, options, workflow_id))
+
+
+async def _main(parser: COP, options: 'Values', workflow_id: str) -> None:
+ workflow_id, _, flow_file = await parse_id_async(
workflow_id,
src=True,
constraint='workflows',
)
+ can_revalidate(flow_file, options)
+
# read in the flow.cylc file
viewcfg = {
'mark': options.mark,
diff --git a/setup.cfg b/setup.cfg
index 7566519f0bd..a8e585e3851 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -210,6 +210,7 @@ cylc.main_loop =
# NOTE: all entry points should be listed here even if Cylc Flow does not
# provide any implementations, to make entry point scraping easier
cylc.pre_configure =
+ get_old_tvars = cylc.flow.pre_configure.get_old_tvars:main
cylc.post_install =
log_vc_info = cylc.flow.install_plugins.log_vc_info:main
diff --git a/tests/integration/plugins/test_get_old_tvars.py b/tests/integration/plugins/test_get_old_tvars.py
new file mode 100644
index 00000000000..a04faf8a929
--- /dev/null
+++ b/tests/integration/plugins/test_get_old_tvars.py
@@ -0,0 +1,132 @@
+# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
+# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from cylc.flow.scheduler_cli import get_option_parser
+from cylc.flow.parsec.exceptions import Jinja2Error
+from cylc.flow.pre_configure.get_old_tvars import main as get_old_tvars
+import pytest
+from pytest import param
+from types import SimpleNamespace
+
+from cylc.flow.scripts.validate import (
+ _main as validate,
+ get_option_parser as validate_gop
+)
+from cylc.flow.scripts.view import (
+ _main as view,
+ get_option_parser as view_gop
+)
+from cylc.flow.scripts.graph import (
+ _main as graph,
+ get_option_parser as graph_gop
+)
+from cylc.flow.scripts.config import (
+ _main as config,
+ get_option_parser as config_gop
+)
+
+
+@pytest.fixture(scope='module')
+def create_workflow(mod_one_conf, mod_flow, mod_scheduler):
+ # Set up opts and parser
+ parser = get_option_parser()
+ opts = SimpleNamespace(**parser.get_default_values().__dict__)
+ opts.templatevars = ['FOO="From cylc template variables"']
+ opts.templatevars_file = []
+
+ conf = mod_one_conf
+ # Set up scheduler
+ schd = mod_scheduler(mod_flow(conf), templatevars=['FOO="bar"'])
+
+ yield SimpleNamespace(schd=schd, opts=opts)
+
+
+@pytest.mark.parametrize(
+ 'revalidate, expect',
+ [
+ (False, {}),
+ (True, 'bar')
+ ]
+)
+async def test_basic(create_workflow, mod_start, revalidate, expect):
+ """It returns a pre-existing configuration if opts.revalidate is True"""
+ opts = create_workflow.opts
+ opts.revalidate = revalidate
+
+ async with mod_start(create_workflow.schd):
+ result = get_old_tvars(create_workflow.schd.workflow_run_dir, opts)
+ if expect:
+ assert result['template_variables']['FOO'] == expect
+ else:
+ assert result == expect
+
+
+@pytest.fixture(scope='module')
+def _setup(mod_scheduler, mod_flow):
+ """Provide an installed flow with a database to try assorted
+ simple Cylc scripts against.
+ """
+ conf = {
+ '#!jinja2': '',
+ 'scheduler': {
+ 'allow implicit tasks': True
+ },
+ 'scheduling': {
+ 'graph': {
+ 'R1': r'{{FOO}}'
+ }
+ }
+ }
+ schd = mod_scheduler(mod_flow(conf), templatevars=['FOO="bar"'])
+
+ yield schd
+
+
+@pytest.mark.parametrize(
+ 'function, parser, expect',
+ (
+ param(validate, validate_gop, 'Valid for', id="validate"),
+ param(view, view_gop, 'FOO', id="view"),
+ param(graph, graph_gop, '1/bar', id='graph'),
+ param(config, config_gop, 'R1 = bar', id='config')
+ )
+)
+@pytest.mark.parametrize(
+ 'revalidate',
+ [
+ (False),
+ (True)
+ ]
+)
+async def test_revalidate_validate(
+ _setup, mod_start, capsys, function, parser, revalidate, expect,
+):
+ """It validates with Cylc Validate."""
+ parser = parser()
+ opts = SimpleNamespace(**parser.get_default_values().__dict__)
+ opts.templatevars = []
+ opts.templatevars_file = []
+ opts.revalidate = revalidate
+ if function == graph:
+ opts.reference = True
+
+ async with mod_start(_setup):
+ if revalidate or expect == 'FOO':
+ await function(parser, opts, _setup.workflow_name)
+ assert expect in capsys.readouterr().out
+ else:
+ with pytest.raises(Jinja2Error, match="'FOO' is undefined"):
+ await function(parser, opts, _setup.workflow_name)
diff --git a/tests/unit/plugins/test_get_old_tvars.py b/tests/unit/plugins/test_get_old_tvars.py
new file mode 100644
index 00000000000..4d869575905
--- /dev/null
+++ b/tests/unit/plugins/test_get_old_tvars.py
@@ -0,0 +1,65 @@
+# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
+# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from cylc.flow.pre_configure.get_old_tvars import OldTemplateVars
+import sqlite3
+import pytest
+
+
+@pytest.fixture(scope='module')
+def _setup_db(tmp_path_factory):
+ tmp_path = tmp_path_factory.mktemp('test_get_old_tvars')
+ logfolder = tmp_path / "log/"
+ logfolder.mkdir()
+ db_path = logfolder / 'db'
+ conn = sqlite3.connect(db_path)
+ conn.execute(
+ r'''
+ CREATE TABLE workflow_template_vars (
+ key,
+ value
+ )
+ '''
+ )
+ conn.execute(
+ r'''
+ INSERT INTO workflow_template_vars
+ VALUES
+ ("FOO", "42"),
+ ("BAR", "'hello world'"),
+ ("BAZ", "'foo', 'bar', 48"),
+ ("QUX", "['foo', 'bar', 21]")
+ '''
+ )
+ conn.commit()
+ conn.close()
+ yield OldTemplateVars(tmp_path)
+
+
+@pytest.mark.parametrize(
+ 'key, expect',
+ (
+ ('FOO', 42),
+ ('BAR', 'hello world'),
+ ('BAZ', ('foo', 'bar', 48)),
+ ('QUX', ['foo', 'bar', 21])
+ )
+)
+def test_OldTemplateVars(key, expect, _setup_db):
+ """It can extract a variety of items from a workflow database.
+ """
+ assert _setup_db.template_vars[key] == expect
diff --git a/tests/unit/scripts/test_graph.py b/tests/unit/scripts/test_graph.py
index aa64002fca7..e0638734e85 100644
--- a/tests/unit/scripts/test_graph.py
+++ b/tests/unit/scripts/test_graph.py
@@ -89,7 +89,7 @@ def _get_parents_lists(*args, **kwargs):
monkeypatch.setattr(
'cylc.flow.scripts.graph.get_config',
- lambda x, y: config
+ lambda x, y, z: config
)
@@ -295,11 +295,11 @@ def test_null(null_config):
grouping=False,
show_suicide=False
)
- assert get_nodes_and_edges(opts, None, 1, 2) == ([], [])
+ assert get_nodes_and_edges(opts, None, 1, 2, 'foo') == ([], [])
opts = SimpleNamespace(
namespaces=True,
grouping=False,
show_suicide=False
)
- assert get_nodes_and_edges(opts, None, 1, 2) == ([], [])
+ assert get_nodes_and_edges(opts, None, 1, 2, 'foo') == ([], [])
diff --git a/tests/unit/test_option_parsers.py b/tests/unit/test_option_parsers.py
index dcd2f6e8977..9694d454e75 100644
--- a/tests/unit/test_option_parsers.py
+++ b/tests/unit/test_option_parsers.py
@@ -15,13 +15,20 @@
# along with this program. If not, see .
import pytest
+from pytest import param
from typing import List
import sys
import io
from contextlib import redirect_stdout
+from cylc.flow.exceptions import WorkflowConfigError
import cylc.flow.flags
-from cylc.flow.option_parsers import CylcOptionParser as COP, Options
+from cylc.flow.option_parsers import (
+ CylcOptionParser as COP,
+ Options,
+ can_revalidate,
+)
+from types import SimpleNamespace
USAGE_WITH_COMMENT = "usage \n # comment"
@@ -93,3 +100,39 @@ def test_Options_std_opts():
MyOptions = Options(parser)
MyValues = MyOptions(verbosity=1)
assert MyValues.verbosity == 1
+
+
+@pytest.mark.parametrize(
+ 'rundir, revalidate, expect',
+ (
+ (True, True, ''),
+ (True, False, ''),
+ (False, True, False),
+ (False, False, ''),
+ )
+)
+def test_can_revalidate(monkeypatch, tmp_path, rundir, revalidate, expect):
+ """It raises an error if revalidation isn't allowed and the user
+ has asked for revalidation.
+ """
+ is_ = tmp_path / 'is'
+ not_ = tmp_path / 'not'
+ monkeypatch.setattr(
+ 'cylc.flow.pathutil.get_cylc_run_dir', lambda: is_)
+
+ flow_file = is_ if rundir else not_
+ flow_file = flow_file / 'foo/bar/baz/flow.cylc'
+ if expect is False:
+ with pytest.raises(WorkflowConfigError):
+ can_revalidate(
+ flow_file,
+ SimpleNamespace(**{'revalidate': revalidate})
+ )
+ else:
+ assert (
+ can_revalidate(
+ flow_file,
+ SimpleNamespace(**{'revalidate': revalidate})
+ )
+ is True
+ )
diff --git a/tests/unit/test_pathutil.py b/tests/unit/test_pathutil.py
index 67a779e6b7e..5a8baba8174 100644
--- a/tests/unit/test_pathutil.py
+++ b/tests/unit/test_pathutil.py
@@ -40,6 +40,7 @@
get_workflow_run_share_dir,
get_workflow_run_work_dir,
get_workflow_test_log_path,
+ is_in_a_rundir,
make_localhost_symlinks,
make_workflow_run_tree,
parse_rm_dirs,
@@ -574,3 +575,12 @@ def test_get_workflow_name_from_id(
result = get_workflow_name_from_id(id_)
assert result == name
+
+
+def test_is_in_a_rundir(monkeypatch, tmp_path):
+ is_ = tmp_path / 'is'
+ not_ = tmp_path / 'not'
+ monkeypatch.setattr(
+ 'cylc.flow.pathutil.get_cylc_run_dir', lambda: is_)
+ assert is_in_a_rundir(is_ / 'foo/bar/baz')
+ assert not is_in_a_rundir(not_ / 'foo/bar/baz')