diff --git a/CHANGES.md b/CHANGES.md index 80b702eb7115..610df5d15e66 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,8 @@ * Added Feast feature store handler for enrichment transform (Python) ([#30957](https://github.com/apache/beam/issues/30964)). * BigQuery per-worker metrics are reported by default for Streaming Dataflow Jobs (Java) ([#31015](https://github.com/apache/beam/pull/31015)) +* Beam YAML now supports the jinja templating syntax. + Template variables can be passed with the (json-formatted) `--jinja_variables` flag. * DataFrame API now supports pandas 2.1.x and adds 12 more string functions for Series.([#31185](https://github.com/apache/beam/pull/31185)). ## Breaking Changes diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index 147a46f0bea5..0a6253e3c23e 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -664,6 +664,13 @@ def get_urn_by_logial_type(self, logical_type): def get_logical_type_by_language_type(self, representation_type): return self.by_language_type.get(representation_type, None) + def copy(self): + copy = LogicalTypeRegistry() + copy.by_urn.update(self.by_urn) + copy.by_logical_type.update(self.by_logical_type) + copy.by_language_type.update(self.by_language_type) + return copy + LanguageT = TypeVar('LanguageT') RepresentationT = TypeVar('RepresentationT') diff --git a/sdks/python/apache_beam/yaml/main.py b/sdks/python/apache_beam/yaml/main.py index 5d1d3f7cea0b..6c87a1ba7e68 100644 --- a/sdks/python/apache_beam/yaml/main.py +++ b/sdks/python/apache_beam/yaml/main.py @@ -16,7 +16,10 @@ # import argparse +import contextlib +import json +import jinja2 import yaml import apache_beam as beam @@ -25,9 +28,6 @@ from apache_beam.typehints.schemas import MillisInstant from apache_beam.yaml import yaml_transform -# Workaround for https://github.com/apache/beam/issues/28151. -LogicalType.register_logical_type(MillisInstant) - def _configure_parser(argv): parser = argparse.ArgumentParser() @@ -45,6 +45,12 @@ def _configure_parser(argv): help='none: do no pipeline validation against the schema; ' 'generic: validate the pipeline shape, but not individual transforms; ' 'per_transform: also validate the config of known transforms') + parser.add_argument( + '--jinja_variables', + default=None, + type=json.loads, + help='A json dict of variables used when invoking the jinja preprocessor ' + 'on the provided yaml pipeline.') return parser.parse_known_args(argv) @@ -64,22 +70,50 @@ def _pipeline_spec_from_args(known_args): return pipeline_yaml +class _BeamFileIOLoader(jinja2.BaseLoader): + def get_source(self, environment, path): + with FileSystems.open(path) as fin: + source = fin.read().decode() + return source, path, lambda: True + + +@contextlib.contextmanager +def _fix_xlang_instant_coding(): + # Scoped workaround for https://github.com/apache/beam/issues/28151. + old_registry = LogicalType._known_logical_types + LogicalType._known_logical_types = old_registry.copy() + try: + LogicalType.register_logical_type(MillisInstant) + yield + finally: + LogicalType._known_logical_types = old_registry + + def run(argv=None): known_args, pipeline_args = _configure_parser(argv) - pipeline_yaml = _pipeline_spec_from_args(known_args) + pipeline_template = _pipeline_spec_from_args(known_args) + pipeline_yaml = ( # keep formatting + jinja2.Environment( + undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader()) + .from_string(pipeline_template) + .render(**known_args.jinja_variables or {})) pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader) - with beam.Pipeline( # linebreak for better yapf formatting - options=beam.options.pipeline_options.PipelineOptions( - pipeline_args, - pickle_library='cloudpickle', - **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( - 'options', {}))), - display_data={'yaml': pipeline_yaml}) as p: - print("Building pipeline...") - yaml_transform.expand_pipeline( - p, pipeline_spec, validate_schema=known_args.json_schema_validation) - print("Running pipeline...") + with _fix_xlang_instant_coding(): + with beam.Pipeline( # linebreak for better yapf formatting + options=beam.options.pipeline_options.PipelineOptions( + pipeline_args, + pickle_library='cloudpickle', + **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( + 'options', {}))), + display_data={'yaml': pipeline_yaml, + 'yaml_jinja_template': pipeline_template, + 'yaml_jinja_variables': json.dumps( + known_args.jinja_variables)}) as p: + print("Building pipeline...") + yaml_transform.expand_pipeline( + p, pipeline_spec, validate_schema=known_args.json_schema_validation) + print("Running pipeline...") if __name__ == '__main__': diff --git a/sdks/python/apache_beam/yaml/main_test.py b/sdks/python/apache_beam/yaml/main_test.py new file mode 100644 index 000000000000..b10c788bccaa --- /dev/null +++ b/sdks/python/apache_beam/yaml/main_test.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import glob +import logging +import os +import tempfile +import unittest + +from apache_beam.yaml import main + +TEST_PIPELINE = ''' +pipeline: + type: chain + transforms: + - type: Create + config: + elements: [ELEMENT] + # Writing to an actual file here rather than just using AssertThat + # because this is an integration test and above all we want to ensure + # the pipeline actually runs (and asserts may not fail if there's a + # bug in the invocation logic). + - type: WriteToText + config: + path: PATH +''' + + +class MainTest(unittest.TestCase): + def test_pipeline_spec_from_file(self): + with tempfile.TemporaryDirectory() as tmpdir: + yaml_path = os.path.join(tmpdir, 'test.yaml') + out_path = os.path.join(tmpdir, 'out.txt') + with open(yaml_path, 'wt') as fout: + fout.write(TEST_PIPELINE.replace('PATH', out_path)) + main.run(['--yaml_pipeline_file', yaml_path]) + with open(glob.glob(out_path + '*')[0], 'rt') as fin: + self.assertEqual(fin.read().strip(), 'ELEMENT') + + def test_pipeline_spec_from_flag(self): + with tempfile.TemporaryDirectory() as tmpdir: + out_path = os.path.join(tmpdir, 'out.txt') + main.run(['--yaml_pipeline', TEST_PIPELINE.replace('PATH', out_path)]) + with open(glob.glob(out_path + '*')[0], 'rt') as fin: + self.assertEqual(fin.read().strip(), 'ELEMENT') + + def test_jinja_variables(self): + with tempfile.TemporaryDirectory() as tmpdir: + out_path = os.path.join(tmpdir, 'out.txt') + main.run([ + '--yaml_pipeline', + TEST_PIPELINE.replace('PATH', out_path).replace('ELEMENT', '{{var}}'), + '--jinja_variables', + '{"var": "my_line"}' + ]) + with open(glob.glob(out_path + '*')[0], 'rt') as fin: + self.assertEqual(fin.read().strip(), 'my_line') + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/scripts/run_pylint.sh b/sdks/python/scripts/run_pylint.sh index 96d4b4885b62..89ea7fe441e4 100755 --- a/sdks/python/scripts/run_pylint.sh +++ b/sdks/python/scripts/run_pylint.sh @@ -99,6 +99,8 @@ ISORT_EXCLUDED=( "process_tfma.py" "doctests_test.py" "render_test.py" + "yaml/main.py" + "main_test.py" ) SKIP_PARAM="" for file in "${ISORT_EXCLUDED[@]}"; do diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 8c4fbf415f3c..53139a9ce289 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -406,6 +406,7 @@ def get_portability_package_data(): # BEAM-8840: Do NOT use tests_require or setup_requires. extras_require={ 'docs': [ + 'jinja2>=3.0,<3.1', 'Sphinx>=1.5.2,<2.0', 'docstring-parser>=0.15,<1.0', # Pinning docutils as a workaround for Sphinx issue: