Skip to content

Commit

Permalink
Merge pull request #30976 [YAML] Add the ability to pre-process yaml …
Browse files Browse the repository at this point in the history
…files with jinja2.
  • Loading branch information
robertwb authored May 8, 2024
2 parents faaa68c + 037704b commit 6197657
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@

* Added Feast feature store handler for enrichment transform (Python) ([#30957](https://github.com/apache/beam/issues/30964)).
* BigQuery per-worker metrics are reported by default for Streaming Dataflow Jobs (Java) ([#31015](https://github.com/apache/beam/pull/31015))
* Beam YAML now supports the jinja templating syntax.
Template variables can be passed with the (json-formatted) `--jinja_variables` flag.
* DataFrame API now supports pandas 2.1.x and adds 12 more string functions for Series.([#31185](https://github.com/apache/beam/pull/31185)).

## Breaking Changes
Expand Down
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,13 @@ def get_urn_by_logial_type(self, logical_type):
def get_logical_type_by_language_type(self, representation_type):
return self.by_language_type.get(representation_type, None)

def copy(self):
copy = LogicalTypeRegistry()
copy.by_urn.update(self.by_urn)
copy.by_logical_type.update(self.by_logical_type)
copy.by_language_type.update(self.by_language_type)
return copy


LanguageT = TypeVar('LanguageT')
RepresentationT = TypeVar('RepresentationT')
Expand Down
64 changes: 49 additions & 15 deletions sdks/python/apache_beam/yaml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
#

import argparse
import contextlib
import json

import jinja2
import yaml

import apache_beam as beam
Expand All @@ -25,9 +28,6 @@
from apache_beam.typehints.schemas import MillisInstant
from apache_beam.yaml import yaml_transform

# Workaround for https://github.com/apache/beam/issues/28151.
LogicalType.register_logical_type(MillisInstant)


def _configure_parser(argv):
parser = argparse.ArgumentParser()
Expand All @@ -45,6 +45,12 @@ def _configure_parser(argv):
help='none: do no pipeline validation against the schema; '
'generic: validate the pipeline shape, but not individual transforms; '
'per_transform: also validate the config of known transforms')
parser.add_argument(
'--jinja_variables',
default=None,
type=json.loads,
help='A json dict of variables used when invoking the jinja preprocessor '
'on the provided yaml pipeline.')
return parser.parse_known_args(argv)


Expand All @@ -64,22 +70,50 @@ def _pipeline_spec_from_args(known_args):
return pipeline_yaml


class _BeamFileIOLoader(jinja2.BaseLoader):
def get_source(self, environment, path):
with FileSystems.open(path) as fin:
source = fin.read().decode()
return source, path, lambda: True


@contextlib.contextmanager
def _fix_xlang_instant_coding():
# Scoped workaround for https://github.com/apache/beam/issues/28151.
old_registry = LogicalType._known_logical_types
LogicalType._known_logical_types = old_registry.copy()
try:
LogicalType.register_logical_type(MillisInstant)
yield
finally:
LogicalType._known_logical_types = old_registry


def run(argv=None):
known_args, pipeline_args = _configure_parser(argv)
pipeline_yaml = _pipeline_spec_from_args(known_args)
pipeline_template = _pipeline_spec_from_args(known_args)
pipeline_yaml = ( # keep formatting
jinja2.Environment(
undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
.from_string(pipeline_template)
.render(**known_args.jinja_variables or {}))
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)

with beam.Pipeline( # linebreak for better yapf formatting
options=beam.options.pipeline_options.PipelineOptions(
pipeline_args,
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {}))),
display_data={'yaml': pipeline_yaml}) as p:
print("Building pipeline...")
yaml_transform.expand_pipeline(
p, pipeline_spec, validate_schema=known_args.json_schema_validation)
print("Running pipeline...")
with _fix_xlang_instant_coding():
with beam.Pipeline( # linebreak for better yapf formatting
options=beam.options.pipeline_options.PipelineOptions(
pipeline_args,
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {}))),
display_data={'yaml': pipeline_yaml,
'yaml_jinja_template': pipeline_template,
'yaml_jinja_variables': json.dumps(
known_args.jinja_variables)}) as p:
print("Building pipeline...")
yaml_transform.expand_pipeline(
p, pipeline_spec, validate_schema=known_args.json_schema_validation)
print("Running pipeline...")


if __name__ == '__main__':
Expand Down
76 changes: 76 additions & 0 deletions sdks/python/apache_beam/yaml/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import glob
import logging
import os
import tempfile
import unittest

from apache_beam.yaml import main

TEST_PIPELINE = '''
pipeline:
type: chain
transforms:
- type: Create
config:
elements: [ELEMENT]
# Writing to an actual file here rather than just using AssertThat
# because this is an integration test and above all we want to ensure
# the pipeline actually runs (and asserts may not fail if there's a
# bug in the invocation logic).
- type: WriteToText
config:
path: PATH
'''


class MainTest(unittest.TestCase):
def test_pipeline_spec_from_file(self):
with tempfile.TemporaryDirectory() as tmpdir:
yaml_path = os.path.join(tmpdir, 'test.yaml')
out_path = os.path.join(tmpdir, 'out.txt')
with open(yaml_path, 'wt') as fout:
fout.write(TEST_PIPELINE.replace('PATH', out_path))
main.run(['--yaml_pipeline_file', yaml_path])
with open(glob.glob(out_path + '*')[0], 'rt') as fin:
self.assertEqual(fin.read().strip(), 'ELEMENT')

def test_pipeline_spec_from_flag(self):
with tempfile.TemporaryDirectory() as tmpdir:
out_path = os.path.join(tmpdir, 'out.txt')
main.run(['--yaml_pipeline', TEST_PIPELINE.replace('PATH', out_path)])
with open(glob.glob(out_path + '*')[0], 'rt') as fin:
self.assertEqual(fin.read().strip(), 'ELEMENT')

def test_jinja_variables(self):
with tempfile.TemporaryDirectory() as tmpdir:
out_path = os.path.join(tmpdir, 'out.txt')
main.run([
'--yaml_pipeline',
TEST_PIPELINE.replace('PATH', out_path).replace('ELEMENT', '{{var}}'),
'--jinja_variables',
'{"var": "my_line"}'
])
with open(glob.glob(out_path + '*')[0], 'rt') as fin:
self.assertEqual(fin.read().strip(), 'my_line')


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
2 changes: 2 additions & 0 deletions sdks/python/scripts/run_pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ ISORT_EXCLUDED=(
"process_tfma.py"
"doctests_test.py"
"render_test.py"
"yaml/main.py"
"main_test.py"
)
SKIP_PARAM=""
for file in "${ISORT_EXCLUDED[@]}"; do
Expand Down
1 change: 1 addition & 0 deletions sdks/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def get_portability_package_data():
# BEAM-8840: Do NOT use tests_require or setup_requires.
extras_require={
'docs': [
'jinja2>=3.0,<3.1',
'Sphinx>=1.5.2,<2.0',
'docstring-parser>=0.15,<1.0',
# Pinning docutils as a workaround for Sphinx issue:
Expand Down

0 comments on commit 6197657

Please sign in to comment.