Skip to content

Commit

Permalink
CDK: Add schema normalization to declarative stream (#32786)
Browse files Browse the repository at this point in the history
Co-authored-by: Eugene Kulak <[email protected]>
Co-authored-by: Yevhenii Kurochkin <[email protected]>
Co-authored-by: Alexandre Girard <[email protected]>
  • Loading branch information
4 people authored Dec 21, 2023
1 parent 08c2da2 commit 4061f08
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1203,12 +1203,10 @@ definitions:
http_method:
title: HTTP Method
description: The HTTP method used to fetch data from the source (can be GET or POST).
anyOf:
- type: string
- type: string
enum:
- GET
- POST
type: string
enum:
- GET
- POST
default: GET
examples:
- GET
Expand Down Expand Up @@ -1822,9 +1820,22 @@ definitions:
title: Record Filter
description: Responsible for filtering records to be emitted by the Source.
"$ref": "#/definitions/RecordFilter"
schema_normalization:
"$ref": "#/definitions/SchemaNormalization"
default: None
$parameters:
type: object
additionalProperties: true
SchemaNormalization:
title: Schema Normalization
description: Responsible for normalization according to the schema.
type: string
enum:
- None
- Default
examples:
- None
- Default
RemoveFields:
title: Remove Fields
description: A transformation which removes fields from a record. The fields removed are designated using FieldPointers. During transformation, if a field or any of its parents does not exist in the record, no error is thrown.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def read_records(
"""
:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.
"""
yield from self.retriever.read_records(stream_slice)
yield from self.retriever.read_records(self.get_json_schema(), stream_slice)

def get_json_schema(self) -> Mapping[str, Any]: # type: ignore
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ def select_records(
self,
response: requests.Response,
stream_state: StreamState,
records_schema: Mapping[str, Any],
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> List[Record]:
"""
Selects records from the response
:param response: The response to select the records from
:param stream_state: The stream state
:param records_schema: json schema of records to return
:param stream_slice: The stream slice
:param next_page_token: The paginator token
:return: List of Records selected from the response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector
from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
from airbyte_cdk.sources.declarative.models import SchemaNormalization
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.declarative.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer

SCHEMA_TRANSFORMER_TYPE_MAPPING = {
SchemaNormalization.None_: TransformConfig.NoTransform,
SchemaNormalization.Default: TransformConfig.DefaultSchemaNormalization,
}


@dataclass
Expand All @@ -21,13 +28,15 @@ class RecordSelector(HttpSelector):
Attributes:
extractor (RecordExtractor): The record extractor responsible for extracting records from a response
schema_normalization (TypeTransformer): The record normalizer responsible for casting record values to stream schema types
record_filter (RecordFilter): The record filter responsible for filtering extracted records
transformations (List[RecordTransformation]): The transformations to be done on the records
"""

extractor: RecordExtractor
config: Config
parameters: InitVar[Mapping[str, Any]]
schema_normalization: TypeTransformer
record_filter: Optional[RecordFilter] = None
transformations: List[RecordTransformation] = field(default_factory=lambda: [])

Expand All @@ -38,14 +47,31 @@ def select_records(
self,
response: requests.Response,
stream_state: StreamState,
records_schema: Mapping[str, Any],
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> List[Record]:
"""
Selects records from the response
:param response: The response to select the records from
:param stream_state: The stream state
:param records_schema: json schema of records to return
:param stream_slice: The stream slice
:param next_page_token: The paginator token
:return: List of Records selected from the response
"""
all_data = self.extractor.extract_records(response)
filtered_data = self._filter(all_data, stream_state, stream_slice, next_page_token)
self._transform(filtered_data, stream_state, stream_slice)
self._normalize_by_schema(filtered_data, schema=records_schema)
return [Record(data, stream_slice) for data in filtered_data]

def _normalize_by_schema(self, records: List[Mapping[str, Any]], schema: Optional[Mapping[str, Any]]) -> List[Mapping[str, Any]]:
if schema:
# record has type Mapping[str, Any], but dict[str, Any] expected
return [self.schema_normalization.transform(record, schema) for record in records] # type: ignore
return records

def _filter(
self,
records: List[Mapping[str, Any]],
Expand All @@ -67,4 +93,5 @@ def _transform(
) -> None:
for record in records:
for transformation in self.transformations:
transformation.transform(record, config=self.config, stream_state=stream_state, stream_slice=stream_slice)
# record has type Mapping[str, Any], but Record expected
transformation.transform(record, config=self.config, stream_state=stream_state, stream_slice=stream_slice) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ class SessionTokenRequestBearerAuthenticator(BaseModel):
type: Literal['Bearer']


class HttpMethodEnum(Enum):
class HttpMethod(Enum):
GET = 'GET'
POST = 'POST'

Expand Down Expand Up @@ -572,6 +572,11 @@ class RecordFilter(BaseModel):
parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters')


class SchemaNormalization(Enum):
None_ = 'None'
Default = 'Default'


class RemoveFields(BaseModel):
type: Literal['RemoveFields']
field_pointers: List[List[str]] = Field(
Expand Down Expand Up @@ -1019,6 +1024,7 @@ class RecordSelector(BaseModel):
description='Responsible for filtering records to be emitted by the Source.',
title='Record Filter',
)
schema_normalization: Optional[SchemaNormalization] = SchemaNormalization.None_
parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters')


Expand Down Expand Up @@ -1232,8 +1238,8 @@ class HttpRequester(BaseModel):
description='Error handler component that defines how to handle errors.',
title='Error Handler',
)
http_method: Optional[Union[str, HttpMethodEnum]] = Field(
'GET',
http_method: Optional[HttpMethod] = Field(
HttpMethod.GET,
description='The HTTP method used to fetch data from the source (can be GET or POST).',
examples=['GET', 'POST'],
title='HTTP Method',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.decoders import JsonDecoder
from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordFilter, RecordSelector
from airbyte_cdk.sources.declarative.extractors.record_selector import SCHEMA_TRANSFORMER_TYPE_MAPPING
from airbyte_cdk.sources.declarative.incremental import Cursor, CursorFactory, DatetimeBasedCursor, PerPartitionCursor
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
Expand Down Expand Up @@ -107,6 +108,7 @@
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOptionType
from airbyte_cdk.sources.declarative.requesters.request_options import InterpolatedRequestOptionsProvider
from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever, SimpleRetrieverTestReadDecorator
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader, InlineSchemaLoader, JsonFileSchemaLoader
from airbyte_cdk.sources.declarative.spec import Spec
Expand All @@ -115,6 +117,7 @@
from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition
from airbyte_cdk.sources.declarative.types import Config
from airbyte_cdk.sources.message import InMemoryMessageRepository, LogAppenderMessageRepositoryDecorator, MessageRepository
from airbyte_cdk.sources.utils.transform import TypeTransformer
from isodate import parse_duration
from pydantic import BaseModel

Expand Down Expand Up @@ -710,9 +713,8 @@ def create_http_requester(self, model: HttpRequesterModel, config: Config, *, na
parameters=model.parameters or {},
)

model_http_method = (
model.http_method if isinstance(model.http_method, str) else model.http_method.value if model.http_method is not None else "GET"
)
assert model.use_cache is not None # for mypy
assert model.http_method is not None # for mypy

assert model.use_cache is not None # for mypy

Expand All @@ -722,7 +724,7 @@ def create_http_requester(self, model: HttpRequesterModel, config: Config, *, na
path=model.path,
authenticator=authenticator,
error_handler=error_handler,
http_method=model_http_method,
http_method=HttpMethod[model.http_method.value],
request_options_provider=request_options_provider,
config=config,
disable_retries=self._disable_retries,
Expand Down Expand Up @@ -884,16 +886,24 @@ def create_request_option(model: RequestOptionModel, config: Config, **kwargs: A
return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={})

def create_record_selector(
self, model: RecordSelectorModel, config: Config, *, transformations: List[RecordTransformation], **kwargs: Any
self,
model: RecordSelectorModel,
config: Config,
*,
transformations: List[RecordTransformation],
**kwargs: Any,
) -> RecordSelector:
assert model.schema_normalization is not None # for mypy
extractor = self._create_component_from_model(model=model.extractor, config=config)
record_filter = self._create_component_from_model(model.record_filter, config=config) if model.record_filter else None
schema_normalization = TypeTransformer(SCHEMA_TRANSFORMER_TYPE_MAPPING[model.schema_normalization])

return RecordSelector(
extractor=extractor,
config=config,
record_filter=record_filter,
transformations=transformations,
schema_normalization=schema_normalization,
parameters=model.parameters or {},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class HttpRequester(Requester):
config: Config
parameters: InitVar[Mapping[str, Any]]
authenticator: Optional[DeclarativeAuthenticator] = None
http_method: Union[str, HttpMethod] = HttpMethod.GET
http_method: HttpMethod = HttpMethod.GET
request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None
error_handler: Optional[ErrorHandler] = None
disable_retries: bool = False
Expand All @@ -80,7 +80,6 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
else:
self._request_options_provider = self.request_options_provider
self._authenticator = self.authenticator or NoAuth(parameters=parameters)
self._http_method = HttpMethod[self.http_method] if isinstance(self.http_method, str) else self.http_method
self.error_handler = self.error_handler
self._parameters = parameters
self.decoder = JsonDecoder(parameters={})
Expand Down Expand Up @@ -139,7 +138,7 @@ def get_path(
return path.lstrip("/")

def get_method(self) -> HttpMethod:
return self._http_method
return self.http_method

def interpret_response_status(self, response: requests.Response) -> ResponseStatus:
if self.error_handler is None:
Expand Down Expand Up @@ -420,7 +419,7 @@ def _create_prepared_request(
data: Any = None,
) -> requests.PreparedRequest:
url = urljoin(self.get_url_base(), path)
http_method = str(self._http_method.value)
http_method = str(self.http_method.value)
query_params = self.deduplicate_query_params(url, params)
args = {"method": http_method, "url": url, "headers": headers, "params": query_params}
if http_method.upper() in BODY_REQUEST_METHODS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import abstractmethod
from dataclasses import dataclass
from typing import Iterable, Optional
from typing import Any, Iterable, Mapping, Optional

from airbyte_cdk.sources.declarative.types import StreamSlice, StreamState
from airbyte_cdk.sources.streams.core import StreamData
Expand All @@ -19,15 +19,14 @@ class Retriever:
@abstractmethod
def read_records(
self,
records_schema: Mapping[str, Any],
stream_slice: Optional[StreamSlice] = None,
) -> Iterable[StreamData]:
"""
Fetch a stream's records from an HTTP API source
:param sync_mode: Unused but currently necessary for integrating with HttpStream
:param cursor_field: Unused but currently necessary for integrating with HttpStream
:param records_schema: json schema to describe record
:param stream_slice: The stream slice to read data for
:param stream_state: The initial stream state
:return: The records read from the API source
"""

Expand Down
Loading

0 comments on commit 4061f08

Please sign in to comment.