Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[yaml] Fix examples catalog tests #33027

Merged
merged 6 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 95 additions & 10 deletions sdks/python/apache_beam/yaml/examples/testing/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@


def check_output(expected: List[str]):
def _check_inner(actual: PCollection[str]):
formatted_actual = actual | beam.Map(
def _check_inner(actual: List[PCollection[str]]):
formatted_actual = actual | beam.Flatten() | beam.Map(
lambda row: str(beam.Row(**row._asdict())))
assert_matches_stdout(formatted_actual, expected)

Expand All @@ -59,6 +59,57 @@ def products_csv():
])


def spanner_data():
return [{
'shipment_id': 'S1',
'customer_id': 'C1',
'shipment_date': '2023-05-01',
'shipment_cost': 150.0,
'customer_name': 'Alice',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S2',
'customer_id': 'C2',
'shipment_date': '2023-06-12',
'shipment_cost': 300.0,
'customer_name': 'Bob',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S3',
'customer_id': 'C1',
'shipment_date': '2023-05-10',
'shipment_cost': 20.0,
'customer_name': 'Alice',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S4',
'customer_id': 'C4',
'shipment_date': '2024-07-01',
'shipment_cost': 150.0,
'customer_name': 'Derek',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S5',
'customer_id': 'C5',
'shipment_date': '2023-05-09',
'shipment_cost': 300.0,
'customer_name': 'Erin',
'customer_email': '[email protected]'
},
{
'shipment_id': 'S6',
'customer_id': 'C4',
'shipment_date': '2024-07-02',
'shipment_cost': 150.0,
'customer_name': 'Derek',
'customer_email': '[email protected]'
}]


def create_test_method(
pipeline_spec_file: str,
custom_preprocessors: List[Callable[..., Union[Dict, List]]]):
Expand All @@ -84,9 +135,12 @@ def test_yaml_example(self):
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {})))) as p:
actual = yaml_transform.expand_pipeline(p, pipeline_spec)
if not actual:
actual = p.transforms_stack[0].parts[-1].outputs[None]
actual = [yaml_transform.expand_pipeline(p, pipeline_spec)]
if not actual[0]:
actual = list(p.transforms_stack[0].parts[-1].outputs.values())
for transform in p.transforms_stack[0].parts[:-1]:
if transform.transform.label == 'log_for_testing':
actual += list(transform.outputs.values())
check_output(expected)(actual)

return test_yaml_example
Expand Down Expand Up @@ -155,9 +209,13 @@ def _wordcount_test_preprocessor(
env.input_file('kinglear.txt', '\n'.join(lines)))


@YamlExamplesTestSuite.register_test_preprocessor(
['test_simple_filter_yaml', 'test_simple_filter_and_combine_yaml'])
def _file_io_write_test_preprocessor(
@YamlExamplesTestSuite.register_test_preprocessor([
'test_simple_filter_yaml',
'test_simple_filter_and_combine_yaml',
'test_spanner_read_yaml',
'test_spanner_write_yaml'
])
def _io_write_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):

if pipeline := test_spec.get('pipeline', None):
Expand All @@ -166,8 +224,8 @@ def _file_io_write_test_preprocessor(
transform['type'] = 'LogForTesting'
transform['config'] = {
k: v
for k,
v in transform.get('config', {}).items() if k.startswith('__')
for (k, v) in transform.get('config', {}).items()
if (k.startswith('__') or k == 'error_handling')
}

return test_spec
Expand All @@ -191,7 +249,30 @@ def _file_io_read_test_preprocessor(
return test_spec


@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml'])
def _spanner_io_read_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):

if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '').startswith('ReadFromSpanner'):
config = transform['config']
instance, database = config['instance_id'], config['database_id']
if table := config.get('table', None) is None:
table = config.get('query', '').split('FROM')[-1].strip()
transform['type'] = 'Create'
transform['config'] = {
k: v
for k, v in config.items() if k.startswith('__')
}
transform['config']['elements'] = INPUT_TABLES[(
str(instance), str(database), str(table))]

return test_spec


INPUT_FILES = {'products.csv': products_csv()}
INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()}

YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__))
ExamplesTest = YamlExamplesTestSuite(
Expand All @@ -205,6 +286,10 @@ def _file_io_read_test_preprocessor(
'AggregationExamplesTest',
os.path.join(YAML_DOCS_DIR, '../transforms/aggregation/*.yaml')).run()

IOTest = YamlExamplesTestSuite(
'IOExamplesTest', os.path.join(YAML_DOCS_DIR,
'../transforms/io/*.yaml')).run()

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
pipeline:
transforms:

# Reading data from a Spanner database. The table used here has the following columns:
# shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String)
# ReadFromSpanner transform is called using project_id, instance_id, database_id and a query
# A table with a list of columns can also be specified instead of a query
# Reading data from a Spanner database. The table used here has the following columns:
# shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String)
# ReadFromSpanner transform is called using project_id, instance_id, database_id and a query
# A table with a list of columns can also be specified instead of a query
- type: ReadFromSpanner
name: ReadShipments
config:
Expand All @@ -30,18 +30,18 @@ pipeline:
database_id: 'shipment'
query: 'SELECT * FROM shipments'

# Filtering the data based on a specific condition
# Here, the condition is used to keep only the rows where the customer_id is 'C1'
# Filtering the data based on a specific condition
# Here, the condition is used to keep only the rows where the customer_id is 'C1'
- type: Filter
name: FilterShipments
input: ReadShipments
config:
language: python
keep: "customer_id == 'C1'"

# Mapping the data fields and applying transformations
# A new field 'shipment_cost_category' is added with a custom transformation
# A callable is defined to categorize shipment cost
# Mapping the data fields and applying transformations
# A new field 'shipment_cost_category' is added with a custom transformation
# A callable is defined to categorize shipment cost
- type: MapToFields
name: MapFieldsForSpanner
input: FilterShipments
Expand All @@ -65,16 +65,15 @@ pipeline:
else:
return 'High Cost'

# Writing the transformed data to a CSV file
# Writing the transformed data to a CSV file
- type: WriteToCsv
name: WriteBig
input: MapFieldsForSpanner
config:
path: shipments.csv


# On executing the above pipeline, a new CSV file is created with the following records

# On executing the above pipeline, a new CSV file is created with the following records
# Expected:
# Row(shipment_id='S1', customer_id='C1', shipment_date='2023-05-01', shipment_cost=150.0, customer_name='Alice', customer_email='[email protected]', shipment_cost_category='Medium Cost')
# Row(shipment_id='S3', customer_id='C1', shipment_date='2023-05-10', shipment_cost=20.0, customer_name='Alice', customer_email='[email protected]', shipment_cost_category='Low Cost')
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
pipeline:
transforms:

# Step 1: Creating rows to be written to Spanner
# The element names correspond to the column names in the Spanner table
# Step 1: Creating rows to be written to Spanner
# The element names correspond to the column names in the Spanner table
- type: Create
name: CreateRows
config:
Expand All @@ -31,10 +31,10 @@ pipeline:
customer_name: "Erin"
customer_email: "[email protected]"

# Step 2: Writing the created rows to a Spanner database
# We require the project ID, instance ID, database ID and table ID to connect to Spanner
# Error handling can be specified optionally to ensure any failed operations aren't lost
# The failed data is passed on in the pipeline and can be handled
# Step 2: Writing the created rows to a Spanner database
# We require the project ID, instance ID, database ID and table ID to connect to Spanner
# Error handling can be specified optionally to ensure any failed operations aren't lost
# The failed data is passed on in the pipeline and can be handled
- type: WriteToSpanner
name: WriteSpanner
input: CreateRows
Expand All @@ -46,8 +46,11 @@ pipeline:
error_handling:
output: my_error_output

# Step 3: Writing the failed records to a JSON file
# Step 3: Writing the failed records to a JSON file
- type: WriteToJson
input: WriteSpanner.my_error_output
config:
path: errors.json

# Expected:
# Row(shipment_id='S5', customer_id='C5', shipment_date='2023-05-09', shipment_cost=300.0, customer_name='Erin', customer_email='[email protected]')
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/yaml/generate_yaml_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from apache_beam.version import __version__ as beam_version
from apache_beam.yaml import json_utils
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_mapping import ErrorHandlingConfig
from apache_beam.yaml.yaml_errors import ErrorHandlingConfig


def _singular(name):
Expand Down
88 changes: 88 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import functools
import inspect
from typing import NamedTuple

import apache_beam as beam
from apache_beam.typehints.row_type import RowTypeConstraint


class ErrorHandlingConfig(NamedTuple):
"""Class to define Error Handling parameters.

Args:
output (str): Name to use for the output error collection
"""
output: str
# TODO: Other parameters are valid here too, but not common to Java.


def exception_handling_args(error_handling_spec):
if error_handling_spec:
return {
'dead_letter_tag' if k == 'output' else k: v
for (k, v) in error_handling_spec.items()
}
else:
return None


def map_errors_to_standard_format(input_type):
# TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple.

return beam.Map(
lambda x: beam.Row(
element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2]))
).with_output_types(
RowTypeConstraint.from_fields([("element", input_type), ("msg", str),
("stack", str)]))


def maybe_with_exception_handling(inner_expand):
def expand(self, pcoll):
wrapped_pcoll = beam.core._MaybePValueWithErrors(
pcoll, self._exception_handling_args)
return inner_expand(self, wrapped_pcoll).as_result(
map_errors_to_standard_format(pcoll.element_type))

return expand


def maybe_with_exception_handling_transform_fn(transform_fn):
@functools.wraps(transform_fn)
def expand(pcoll, error_handling=None, **kwargs):
wrapped_pcoll = beam.core._MaybePValueWithErrors(
pcoll, exception_handling_args(error_handling))
return transform_fn(wrapped_pcoll, **kwargs).as_result(
map_errors_to_standard_format(pcoll.element_type))

original_signature = inspect.signature(transform_fn)
new_parameters = list(original_signature.parameters.values())
error_handling_param = inspect.Parameter(
'error_handling',
inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=ErrorHandlingConfig)
if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD:
new_parameters.insert(-1, error_handling_param)
else:
new_parameters.append(error_handling_param)
expand.__signature__ = original_signature.replace(parameters=new_parameters)

return expand
6 changes: 3 additions & 3 deletions sdks/python/apache_beam/yaml/yaml_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from apache_beam.portability.api import schema_pb2
from apache_beam.typehints import schemas
from apache_beam.yaml import json_utils
from apache_beam.yaml import yaml_mapping
from apache_beam.yaml import yaml_errors
from apache_beam.yaml import yaml_provider


Expand Down Expand Up @@ -289,7 +289,7 @@ def formatter(row):


@beam.ptransform_fn
@yaml_mapping.maybe_with_exception_handling_transform_fn
@yaml_errors.maybe_with_exception_handling_transform_fn
def read_from_pubsub(
root,
*,
Expand Down Expand Up @@ -393,7 +393,7 @@ def mapper(msg):


@beam.ptransform_fn
@yaml_mapping.maybe_with_exception_handling_transform_fn
@yaml_errors.maybe_with_exception_handling_transform_fn
def write_to_pubsub(
pcoll,
*,
Expand Down
Loading
Loading