Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Enrichment Handlers to PEP 585 typing #33087

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
from collections.abc import Mapping
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Union

Expand All @@ -30,7 +28,7 @@
from apache_beam.transforms.enrichment import EnrichmentSourceHandler

QueryFn = Callable[[beam.Row], str]
ConditionValueFn = Callable[[beam.Row], List[Any]]
ConditionValueFn = Callable[[beam.Row], list[Any]]


def _validate_bigquery_metadata(
Expand All @@ -54,8 +52,8 @@ def _validate_bigquery_metadata(
"`condition_value_fn`")


class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]],
Union[Row, List[Row]]]):
class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]],
Union[Row, list[Row]]]):
"""Enrichment handler for Google Cloud BigQuery.

Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment`
Expand Down Expand Up @@ -83,8 +81,8 @@ def __init__(
*,
table_name: str = "",
row_restriction_template: str = "",
fields: Optional[List[str]] = None,
column_names: Optional[List[str]] = None,
fields: Optional[list[str]] = None,
column_names: Optional[list[str]] = None,
condition_value_fn: Optional[ConditionValueFn] = None,
query_fn: Optional[QueryFn] = None,
min_batch_size: int = 1,
Expand All @@ -107,10 +105,10 @@ def __init__(
row_restriction_template (str): A template string for the `WHERE` clause
in the BigQuery query with placeholders (`{}`) to dynamically filter
rows based on input data.
fields: (Optional[List[str]]) List of field names present in the input
fields: (Optional[list[str]]) List of field names present in the input
`beam.Row`. These are used to construct the WHERE clause
(if `condition_value_fn` is not provided).
column_names: (Optional[List[str]]) Names of columns to select from the
column_names: (Optional[list[str]]) Names of columns to select from the
BigQuery table. If not provided, all columns (`*`) are selected.
condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function
that takes a `beam.Row` and returns a list of value to populate in the
Expand Down Expand Up @@ -179,11 +177,11 @@ def create_row_key(self, row: beam.Row):
return (tuple(row_dict[field] for field in self.fields))
raise ValueError("Either fields or condition_value_fn must be specified")

def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs):
if isinstance(request, List):
def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
if isinstance(request, list):
values = []
responses = []
requests_map: Dict[Any, Any] = {}
requests_map: dict[Any, Any] = {}
batch_size = len(request)
raw_query = self.query_template
if batch_size > 1:
Expand Down Expand Up @@ -230,8 +228,8 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs):
def __exit__(self, exc_type, exc_val, exc_tb):
self.client.close()

def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]):
if isinstance(request, List):
def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]):
if isinstance(request, list):
cache_keys = []
for req in request:
req_dict = req._asdict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
# limitations under the License.
#
import logging
from collections.abc import Callable
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional

from google.api_core.exceptions import NotFound
Expand Down Expand Up @@ -115,7 +114,7 @@ def __call__(self, request: beam.Row, *args, **kwargs):
Args:
request: the input `beam.Row` to enrich.
"""
response_dict: Dict[str, Any] = {}
response_dict: dict[str, Any] = {}
row_key_str: str = ""
try:
if self._row_key_fn:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
import datetime
import logging
import unittest
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Tuple
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -57,8 +54,8 @@ class ValidateResponse(beam.DoFn):
def __init__(
self,
n_fields: int,
fields: List[str],
enriched_fields: Dict[str, List[str]],
fields: list[str],
enriched_fields: dict[str, list[str]],
include_timestamp: bool = False,
):
self.n_fields = n_fields
Expand Down Expand Up @@ -88,7 +85,7 @@ def process(self, element: beam.Row, *args, **kwargs):
"Response from bigtable should contain a %s column_family with "
"%s columns." % (column_family, columns))
if (self._include_timestamp and
not isinstance(element_dict[column_family][key][0], Tuple)): # type: ignore[arg-type]
not isinstance(element_dict[column_family][key][0], tuple)):
raise BeamAssertException(
"Response from bigtable should contain timestamp associated with "
"its value.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
#
import logging
import tempfile
from collections.abc import Callable
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from typing import Callable
from typing import List
from typing import Mapping
from typing import Optional

import apache_beam as beam
Expand Down Expand Up @@ -95,7 +94,7 @@ class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row,
def __init__(
self,
feature_store_yaml_path: str,
feature_names: Optional[List[str]] = None,
feature_names: Optional[list[str]] = None,
feature_service_name: Optional[str] = "",
full_feature_names: Optional[bool] = False,
entity_id: str = "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"""

import unittest
from collections.abc import Mapping
from typing import Any
from typing import Mapping

import pytest

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#
import logging
from typing import List

import proto
from google.api_core.exceptions import NotFound
Expand Down Expand Up @@ -209,7 +208,7 @@ def __init__(
api_endpoint: str,
feature_store_id: str,
entity_type_id: str,
feature_ids: List[str],
feature_ids: list[str],
row_key: str,
*,
exception_level: ExceptionLevel = ExceptionLevel.WARN,
Expand All @@ -224,7 +223,7 @@ def __init__(
Vertex AI Feature Store (Legacy).
feature_store_id (str): The id of the Vertex AI Feature Store (Legacy).
entity_type_id (str): The entity type of the feature store.
feature_ids (List[str]): A list of feature-ids to fetch
feature_ids (list[str]): A list of feature-ids to fetch
from the Feature Store.
row_key (str): The row key field name containing the entity id
for the feature values.
Expand Down
Loading