Skip to content

Commit

Permalink
🐛 low-code: Fix incremental substreams (#35471)
Browse files Browse the repository at this point in the history
  • Loading branch information
girarda authored and xiaohansong committed Mar 7, 2024
1 parent b98a947 commit aabd2bd
Show file tree
Hide file tree
Showing 23 changed files with 833 additions and 298 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import datetime as dt
from dataclasses import InitVar, dataclass, field
from typing import Any, Mapping, Union
from typing import Any, Mapping, Optional, Union

from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
Expand Down Expand Up @@ -37,13 +37,13 @@ class MinMaxDatetime:
min_datetime: Union[InterpolatedString, str] = ""
max_datetime: Union[InterpolatedString, str] = ""

def __post_init__(self, parameters: Mapping[str, Any]):
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.datetime = InterpolatedString.create(self.datetime, parameters=parameters or {})
self._parser = DatetimeParser()
self.min_datetime = InterpolatedString.create(self.min_datetime, parameters=parameters) if self.min_datetime else None
self.max_datetime = InterpolatedString.create(self.max_datetime, parameters=parameters) if self.max_datetime else None
self.min_datetime = InterpolatedString.create(self.min_datetime, parameters=parameters) if self.min_datetime else None # type: ignore
self.max_datetime = InterpolatedString.create(self.max_datetime, parameters=parameters) if self.max_datetime else None # type: ignore

def get_datetime(self, config, **additional_parameters) -> dt.datetime:
def get_datetime(self, config: Mapping[str, Any], **additional_parameters: Mapping[str, Any]) -> dt.datetime:
"""
Evaluates and returns the datetime
:param config: The user-provided configuration as specified by the source's spec
Expand All @@ -55,29 +55,44 @@ def get_datetime(self, config, **additional_parameters) -> dt.datetime:
if not datetime_format:
datetime_format = "%Y-%m-%dT%H:%M:%S.%f%z"

time = self._parser.parse(str(self.datetime.eval(config, **additional_parameters)), datetime_format)
time = self._parser.parse(str(self.datetime.eval(config, **additional_parameters)), datetime_format) # type: ignore # datetime is always cast to an interpolated string

if self.min_datetime:
min_time = str(self.min_datetime.eval(config, **additional_parameters))
min_time = str(self.min_datetime.eval(config, **additional_parameters)) # type: ignore # min_datetime is always cast to an interpolated string
if min_time:
min_time = self._parser.parse(min_time, datetime_format)
time = max(time, min_time)
min_datetime = self._parser.parse(min_time, datetime_format) # type: ignore # min_datetime is always cast to an interpolated string
time = max(time, min_datetime)
if self.max_datetime:
max_time = str(self.max_datetime.eval(config, **additional_parameters))
max_time = str(self.max_datetime.eval(config, **additional_parameters)) # type: ignore # max_datetime is always cast to an interpolated string
if max_time:
max_time = self._parser.parse(max_time, datetime_format)
time = min(time, max_time)
max_datetime = self._parser.parse(max_time, datetime_format)
time = min(time, max_datetime)
return time

@property
@property # type: ignore # properties don't play well with dataclasses...
def datetime_format(self) -> str:
"""The format of the string representing the datetime"""
return self._datetime_format

@datetime_format.setter
def datetime_format(self, value: str):
def datetime_format(self, value: str) -> None:
"""Setter for the datetime format"""
# Covers the case where datetime_format is not provided in the constructor, which causes the property object
# to be set which we need to avoid doing
if not isinstance(value, property):
self._datetime_format = value

@classmethod
def create(
cls,
interpolated_string_or_min_max_datetime: Union[InterpolatedString, str, "MinMaxDatetime"],
parameters: Optional[Mapping[str, Any]] = None,
) -> "MinMaxDatetime":
if parameters is None:
parameters = {}
if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance(
interpolated_string_or_min_max_datetime, str
):
return MinMaxDatetime(datetime=interpolated_string_or_min_max_datetime, parameters=parameters)
else:
return interpolated_string_or_min_max_datetime
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
from airbyte_cdk.sources.declarative.types import Config
from airbyte_cdk.sources.declarative.types import Config, StreamSlice
from airbyte_cdk.sources.streams.core import Stream


Expand Down Expand Up @@ -101,6 +101,8 @@ def read_records(
"""
:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.
"""
if not isinstance(stream_slice, StreamSlice):
raise ValueError(f"DeclarativeStream does not support stream_slices that are not StreamSlice. Got {stream_slice}")
yield from self.retriever.read_records(self.get_json_schema(), stream_slice)

def get_json_schema(self) -> Mapping[str, Any]: # type: ignore
Expand All @@ -114,7 +116,7 @@ def get_json_schema(self) -> Mapping[str, Any]: # type: ignore

def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
) -> Iterable[Optional[StreamSlice]]:
"""
Override to define the slices for this stream. See the stream slicing section of the docs for more information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import datetime
from dataclasses import InitVar, dataclass, field
from typing import Any, Iterable, List, Mapping, Optional, Union
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser
Expand Down Expand Up @@ -70,10 +70,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
f"If step is defined, cursor_granularity should be as well and vice-versa. "
f"Right now, step is `{self.step}` and cursor_granularity is `{self.cursor_granularity}`"
)
if not isinstance(self.start_datetime, MinMaxDatetime):
self.start_datetime = MinMaxDatetime(self.start_datetime, parameters)
if self.end_datetime and not isinstance(self.end_datetime, MinMaxDatetime):
self.end_datetime = MinMaxDatetime(self.end_datetime, parameters)
self._start_datetime = MinMaxDatetime.create(self.start_datetime, parameters)
self._end_datetime = None if not self.end_datetime else MinMaxDatetime.create(self.end_datetime, parameters)

self._timezone = datetime.timezone.utc
self._interpolation = JinjaInterpolation()
Expand All @@ -84,23 +82,23 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
else datetime.timedelta.max
)
self._cursor_granularity = self._parse_timedelta(self.cursor_granularity)
self.cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters)
self.lookback_window = InterpolatedString.create(self.lookback_window, parameters=parameters)
self.partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters)
self.partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters)
self._cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters)
self._lookback_window = InterpolatedString.create(self.lookback_window, parameters=parameters) if self.lookback_window else None
self._partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters)
self._partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters)
self._parser = DatetimeParser()

# If datetime format is not specified then start/end datetime should inherit it from the stream slicer
if not self.start_datetime.datetime_format:
self.start_datetime.datetime_format = self.datetime_format
if self.end_datetime and not self.end_datetime.datetime_format:
self.end_datetime.datetime_format = self.datetime_format
if not self._start_datetime.datetime_format:
self._start_datetime.datetime_format = self.datetime_format
if self._end_datetime and not self._end_datetime.datetime_format:
self._end_datetime.datetime_format = self.datetime_format

if not self.cursor_datetime_formats:
self.cursor_datetime_formats = [self.datetime_format]

def get_stream_state(self) -> StreamState:
return {self.cursor_field.eval(self.config): self._cursor} if self._cursor else {}
return {self._cursor_field.eval(self.config): self._cursor} if self._cursor else {}

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Expand All @@ -109,17 +107,22 @@ def set_initial_state(self, stream_state: StreamState) -> None:
:param stream_state: The state of the stream as returned by get_stream_state
"""
self._cursor = stream_state.get(self.cursor_field.eval(self.config)) if stream_state else None
self._cursor = stream_state.get(self._cursor_field.eval(self.config)) if stream_state else None

def close_slice(self, stream_slice: StreamSlice, most_recent_record: Optional[Record]) -> None:
last_record_cursor_value = most_recent_record.get(self.cursor_field.eval(self.config)) if most_recent_record else None
stream_slice_value_end = stream_slice.get(self.partition_field_end.eval(self.config))
if stream_slice.partition:
raise ValueError(f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}.")
last_record_cursor_value = most_recent_record.get(self._cursor_field.eval(self.config)) if most_recent_record else None
stream_slice_value_end = stream_slice.get(self._partition_field_end.eval(self.config))
potential_cursor_values = [
cursor_value for cursor_value in [self._cursor, last_record_cursor_value, stream_slice_value_end] if cursor_value
]
cursor_value_str_by_cursor_value_datetime = dict(
map(
# we need to ensure the cursor value is preserved as is in the state else the CATs might complain of something like
# 2023-01-04T17:30:19.000Z' <= '2023-01-04T17:30:19.000000Z'
lambda datetime_str: (self.parse_date(datetime_str), datetime_str),
filter(lambda item: item, [self._cursor, last_record_cursor_value, stream_slice_value_end]),
potential_cursor_values,
)
)
self._cursor = (
Expand All @@ -142,37 +145,43 @@ def stream_slices(self) -> Iterable[StreamSlice]:
return self._partition_daterange(start_datetime, end_datetime, self._step)

def _calculate_earliest_possible_value(self, end_datetime: datetime.datetime) -> datetime.datetime:
lookback_delta = self._parse_timedelta(self.lookback_window.eval(self.config) if self.lookback_window else "P0D")
earliest_possible_start_datetime = min(self.start_datetime.get_datetime(self.config), end_datetime)
lookback_delta = self._parse_timedelta(self._lookback_window.eval(self.config) if self.lookback_window else "P0D")
earliest_possible_start_datetime = min(self._start_datetime.get_datetime(self.config), end_datetime)
cursor_datetime = self._calculate_cursor_datetime_from_state(self.get_stream_state())
return max(earliest_possible_start_datetime, cursor_datetime) - lookback_delta

def _select_best_end_datetime(self) -> datetime.datetime:
now = datetime.datetime.now(tz=self._timezone)
if not self.end_datetime:
if not self._end_datetime:
return now
return min(self.end_datetime.get_datetime(self.config), now)
return min(self._end_datetime.get_datetime(self.config), now)

def _calculate_cursor_datetime_from_state(self, stream_state: Mapping[str, Any]) -> datetime.datetime:
if self.cursor_field.eval(self.config, stream_state=stream_state) in stream_state:
return self.parse_date(stream_state[self.cursor_field.eval(self.config)])
if self._cursor_field.eval(self.config, stream_state=stream_state) in stream_state:
return self.parse_date(stream_state[self._cursor_field.eval(self.config)])
return datetime.datetime.min.replace(tzinfo=datetime.timezone.utc)

def _format_datetime(self, dt: datetime.datetime) -> str:
return self._parser.format(dt, self.datetime_format)

def _partition_daterange(self, start: datetime.datetime, end: datetime.datetime, step: Union[datetime.timedelta, Duration]):
start_field = self.partition_field_start.eval(self.config)
end_field = self.partition_field_end.eval(self.config)
def _partition_daterange(
self, start: datetime.datetime, end: datetime.datetime, step: Union[datetime.timedelta, Duration]
) -> List[StreamSlice]:
start_field = self._partition_field_start.eval(self.config)
end_field = self._partition_field_end.eval(self.config)
dates = []
while start <= end:
next_start = self._evaluate_next_start_date_safely(start, step)
end_date = self._get_date(next_start - self._cursor_granularity, end, min)
dates.append({start_field: self._format_datetime(start), end_field: self._format_datetime(end_date)})
dates.append(
StreamSlice(
partition={}, cursor_slice={start_field: self._format_datetime(start), end_field: self._format_datetime(end_date)}
)
)
start = next_start
return dates

def _evaluate_next_start_date_safely(self, start, step):
def _evaluate_next_start_date_safely(self, start: datetime.datetime, step: datetime.timedelta) -> datetime.datetime:
"""
Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date
This method assumes that users would never enter a step that would generate an overflow. Given that would be the case, the code
Expand All @@ -183,7 +192,12 @@ def _evaluate_next_start_date_safely(self, start, step):
except OverflowError:
return datetime.datetime.max.replace(tzinfo=datetime.timezone.utc)

def _get_date(self, cursor_value, default_date: datetime.datetime, comparator) -> datetime.datetime:
def _get_date(
self,
cursor_value: datetime.datetime,
default_date: datetime.datetime,
comparator: Callable[[datetime.datetime, datetime.datetime], datetime.datetime],
) -> datetime.datetime:
cursor_date = cursor_value or default_date
return comparator(cursor_date, default_date)

Expand All @@ -196,7 +210,7 @@ def parse_date(self, date: str) -> datetime.datetime:
raise ValueError(f"No format in {self.cursor_datetime_formats} matching {date}")

@classmethod
def _parse_timedelta(cls, time_str) -> Union[datetime.timedelta, Duration]:
def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta, Duration]:
"""
:return Parses an ISO 8601 durations into datetime.timedelta or Duration objects.
"""
Expand Down Expand Up @@ -244,18 +258,20 @@ def request_kwargs(self) -> Mapping[str, Any]:
# Never update kwargs
return {}

def _get_request_options(self, option_type: RequestOptionType, stream_slice: StreamSlice):
options = {}
def _get_request_options(self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]:
options: MutableMapping[str, Any] = {}
if not stream_slice:
return options
if self.start_time_option and self.start_time_option.inject_into == option_type:
options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get(
self.partition_field_start.eval(self.config)
options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string
self._partition_field_start.eval(self.config)
)
if self.end_time_option and self.end_time_option.inject_into == option_type:
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(self.partition_field_end.eval(self.config))
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(self._partition_field_end.eval(self.config)) # type: ignore # field_name is always casted to an interpolated string
return options

def should_be_synced(self, record: Record) -> bool:
cursor_field = self.cursor_field.eval(self.config)
cursor_field = self._cursor_field.eval(self.config)
record_cursor_value = record.get(cursor_field)
if not record_cursor_value:
self._send_log(
Expand All @@ -278,7 +294,7 @@ def _send_log(self, level: Level, message: str) -> None:
)

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
cursor_field = self.cursor_field.eval(self.config)
cursor_field = self._cursor_field.eval(self.config)
first_cursor_value = first.get(cursor_field)
second_cursor_value = second.get(cursor_field)
if first_cursor_value and second_cursor_value:
Expand Down
Loading

0 comments on commit aabd2bd

Please sign in to comment.