Skip to content

Commit

Permalink
[yaml] Fix examples catalog tests (#33027)
Browse files Browse the repository at this point in the history
  • Loading branch information
Polber authored Nov 13, 2024
1 parent bff3eac commit ab5c069
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 110 deletions.
105 changes: 95 additions & 10 deletions sdks/python/apache_beam/yaml/examples/testing/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@


def check_output(expected: List[str]):
def _check_inner(actual: PCollection[str]):
formatted_actual = actual | beam.Map(
def _check_inner(actual: List[PCollection[str]]):
formatted_actual = actual | beam.Flatten() | beam.Map(
lambda row: str(beam.Row(**row._asdict())))
assert_matches_stdout(formatted_actual, expected)

Expand All @@ -59,6 +59,57 @@ def products_csv():
])


def spanner_data():
return [{
'shipment_id': 'S1',
'customer_id': 'C1',
'shipment_date': '2023-05-01',
'shipment_cost': 150.0,
'customer_name': 'Alice',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S2',
'customer_id': 'C2',
'shipment_date': '2023-06-12',
'shipment_cost': 300.0,
'customer_name': 'Bob',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S3',
'customer_id': 'C1',
'shipment_date': '2023-05-10',
'shipment_cost': 20.0,
'customer_name': 'Alice',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S4',
'customer_id': 'C4',
'shipment_date': '2024-07-01',
'shipment_cost': 150.0,
'customer_name': 'Derek',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S5',
'customer_id': 'C5',
'shipment_date': '2023-05-09',
'shipment_cost': 300.0,
'customer_name': 'Erin',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S6',
'customer_id': 'C4',
'shipment_date': '2024-07-02',
'shipment_cost': 150.0,
'customer_name': 'Derek',
'customer_email': '[email protected]'
}]


def create_test_method(
pipeline_spec_file: str,
custom_preprocessors: List[Callable[..., Union[Dict, List]]]):
Expand All @@ -84,9 +135,12 @@ def test_yaml_example(self):
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {})))) as p:
actual = yaml_transform.expand_pipeline(p, pipeline_spec)
if not actual:
actual = p.transforms_stack[0].parts[-1].outputs[None]
actual = [yaml_transform.expand_pipeline(p, pipeline_spec)]
if not actual[0]:
actual = list(p.transforms_stack[0].parts[-1].outputs.values())
for transform in p.transforms_stack[0].parts[:-1]:
if transform.transform.label == 'log_for_testing':
actual += list(transform.outputs.values())
check_output(expected)(actual)

return test_yaml_example
Expand Down Expand Up @@ -155,9 +209,13 @@ def _wordcount_test_preprocessor(
env.input_file('kinglear.txt', '\n'.join(lines)))


@YamlExamplesTestSuite.register_test_preprocessor(
['test_simple_filter_yaml', 'test_simple_filter_and_combine_yaml'])
def _file_io_write_test_preprocessor(
@YamlExamplesTestSuite.register_test_preprocessor([
'test_simple_filter_yaml',
'test_simple_filter_and_combine_yaml',
'test_spanner_read_yaml',
'test_spanner_write_yaml'
])
def _io_write_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):

if pipeline := test_spec.get('pipeline', None):
Expand All @@ -166,8 +224,8 @@ def _file_io_write_test_preprocessor(
transform['type'] = 'LogForTesting'
transform['config'] = {
k: v
for k,
v in transform.get('config', {}).items() if k.startswith('__')
for (k, v) in transform.get('config', {}).items()
if (k.startswith('__') or k == 'error_handling')
}

return test_spec
Expand All @@ -191,7 +249,30 @@ def _file_io_read_test_preprocessor(
return test_spec


@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml'])
def _spanner_io_read_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):

if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '').startswith('ReadFromSpanner'):
config = transform['config']
instance, database = config['instance_id'], config['database_id']
if table := config.get('table', None) is None:
table = config.get('query', '').split('FROM')[-1].strip()
transform['type'] = 'Create'
transform['config'] = {
k: v
for k, v in config.items() if k.startswith('__')
}
transform['config']['elements'] = INPUT_TABLES[(
str(instance), str(database), str(table))]

return test_spec


INPUT_FILES = {'products.csv': products_csv()}
INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()}

YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__))
ExamplesTest = YamlExamplesTestSuite(
Expand All @@ -205,6 +286,10 @@ def _file_io_read_test_preprocessor(
'AggregationExamplesTest',
os.path.join(YAML_DOCS_DIR, '../transforms/aggregation/*.yaml')).run()

IOTest = YamlExamplesTestSuite(
'IOExamplesTest', os.path.join(YAML_DOCS_DIR,
'../transforms/io/*.yaml')).run()

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
pipeline:
transforms:

# Reading data from a Spanner database. The table used here has the following columns:
# shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String)
# ReadFromSpanner transform is called using project_id, instance_id, database_id and a query
# A table with a list of columns can also be specified instead of a query
# Reading data from a Spanner database. The table used here has the following columns:
# shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String)
# ReadFromSpanner transform is called using project_id, instance_id, database_id and a query
# A table with a list of columns can also be specified instead of a query
- type: ReadFromSpanner
name: ReadShipments
config:
Expand All @@ -30,18 +30,18 @@ pipeline:
database_id: 'shipment'
query: 'SELECT * FROM shipments'

# Filtering the data based on a specific condition
# Here, the condition is used to keep only the rows where the customer_id is 'C1'
# Filtering the data based on a specific condition
# Here, the condition is used to keep only the rows where the customer_id is 'C1'
- type: Filter
name: FilterShipments
input: ReadShipments
config:
language: python
keep: "customer_id == 'C1'"

# Mapping the data fields and applying transformations
# A new field 'shipment_cost_category' is added with a custom transformation
# A callable is defined to categorize shipment cost
# Mapping the data fields and applying transformations
# A new field 'shipment_cost_category' is added with a custom transformation
# A callable is defined to categorize shipment cost
- type: MapToFields
name: MapFieldsForSpanner
input: FilterShipments
Expand All @@ -65,16 +65,15 @@ pipeline:
else:
return 'High Cost'
# Writing the transformed data to a CSV file
# Writing the transformed data to a CSV file
- type: WriteToCsv
name: WriteBig
input: MapFieldsForSpanner
config:
path: shipments.csv


# On executing the above pipeline, a new CSV file is created with the following records

# On executing the above pipeline, a new CSV file is created with the following records
# Expected:
# Row(shipment_id='S1', customer_id='C1', shipment_date='2023-05-01', shipment_cost=150.0, customer_name='Alice', customer_email='[email protected]', shipment_cost_category='Medium Cost')
# Row(shipment_id='S3', customer_id='C1', shipment_date='2023-05-10', shipment_cost=20.0, customer_name='Alice', customer_email='[email protected]', shipment_cost_category='Low Cost')
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
pipeline:
transforms:

# Step 1: Creating rows to be written to Spanner
# The element names correspond to the column names in the Spanner table
# Step 1: Creating rows to be written to Spanner
# The element names correspond to the column names in the Spanner table
- type: Create
name: CreateRows
config:
Expand All @@ -31,10 +31,10 @@ pipeline:
customer_name: "Erin"
customer_email: "[email protected]"

# Step 2: Writing the created rows to a Spanner database
# We require the project ID, instance ID, database ID and table ID to connect to Spanner
# Error handling can be specified optionally to ensure any failed operations aren't lost
# The failed data is passed on in the pipeline and can be handled
# Step 2: Writing the created rows to a Spanner database
# We require the project ID, instance ID, database ID and table ID to connect to Spanner
# Error handling can be specified optionally to ensure any failed operations aren't lost
# The failed data is passed on in the pipeline and can be handled
- type: WriteToSpanner
name: WriteSpanner
input: CreateRows
Expand All @@ -46,8 +46,11 @@ pipeline:
error_handling:
output: my_error_output

# Step 3: Writing the failed records to a JSON file
# Step 3: Writing the failed records to a JSON file
- type: WriteToJson
input: WriteSpanner.my_error_output
config:
path: errors.json

# Expected:
# Row(shipment_id='S5', customer_id='C5', shipment_date='2023-05-09', shipment_cost=300.0, customer_name='Erin', customer_email='[email protected]')
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/yaml/generate_yaml_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from apache_beam.version import __version__ as beam_version
from apache_beam.yaml import json_utils
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_mapping import ErrorHandlingConfig
from apache_beam.yaml.yaml_errors import ErrorHandlingConfig


def _singular(name):
Expand Down
88 changes: 88 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import functools
import inspect
from typing import NamedTuple

import apache_beam as beam
from apache_beam.typehints.row_type import RowTypeConstraint


class ErrorHandlingConfig(NamedTuple):
"""Class to define Error Handling parameters.
Args:
output (str): Name to use for the output error collection
"""
output: str
# TODO: Other parameters are valid here too, but not common to Java.


def exception_handling_args(error_handling_spec):
if error_handling_spec:
return {
'dead_letter_tag' if k == 'output' else k: v
for (k, v) in error_handling_spec.items()
}
else:
return None


def map_errors_to_standard_format(input_type):
# TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple.

return beam.Map(
lambda x: beam.Row(
element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2]))
).with_output_types(
RowTypeConstraint.from_fields([("element", input_type), ("msg", str),
("stack", str)]))


def maybe_with_exception_handling(inner_expand):
def expand(self, pcoll):
wrapped_pcoll = beam.core._MaybePValueWithErrors(
pcoll, self._exception_handling_args)
return inner_expand(self, wrapped_pcoll).as_result(
map_errors_to_standard_format(pcoll.element_type))

return expand


def maybe_with_exception_handling_transform_fn(transform_fn):
@functools.wraps(transform_fn)
def expand(pcoll, error_handling=None, **kwargs):
wrapped_pcoll = beam.core._MaybePValueWithErrors(
pcoll, exception_handling_args(error_handling))
return transform_fn(wrapped_pcoll, **kwargs).as_result(
map_errors_to_standard_format(pcoll.element_type))

original_signature = inspect.signature(transform_fn)
new_parameters = list(original_signature.parameters.values())
error_handling_param = inspect.Parameter(
'error_handling',
inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=ErrorHandlingConfig)
if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD:
new_parameters.insert(-1, error_handling_param)
else:
new_parameters.append(error_handling_param)
expand.__signature__ = original_signature.replace(parameters=new_parameters)

return expand
6 changes: 3 additions & 3 deletions sdks/python/apache_beam/yaml/yaml_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints import schemas
from apache_beam.yaml import json_utils
from apache_beam.yaml import yaml_mapping
from apache_beam.yaml import yaml_errors
from apache_beam.yaml import yaml_provider


Expand Down Expand Up @@ -289,7 +289,7 @@ def formatter(row):


@beam.ptransform_fn
@yaml_mapping.maybe_with_exception_handling_transform_fn
@yaml_errors.maybe_with_exception_handling_transform_fn
def read_from_pubsub(
root,
*,
Expand Down Expand Up @@ -393,7 +393,7 @@ def mapper(msg):


@beam.ptransform_fn
@yaml_mapping.maybe_with_exception_handling_transform_fn
@yaml_errors.maybe_with_exception_handling_transform_fn
def write_to_pubsub(
pcoll,
*,
Expand Down
Loading

0 comments on commit ab5c069

Please sign in to comment.