Skip to content

Commit

Permalink
🐛 🎉 Source Shopify: Add parent state tracking for supported BULK su…
Browse files Browse the repository at this point in the history
…b-streams (#46552)
  • Loading branch information
bazarnov authored Oct 14, 2024
1 parent 75dfdba commit c9c8190
Show file tree
Hide file tree
Showing 11 changed files with 328 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data:
connectorSubtype: api
connectorType: source
definitionId: 9da77001-af33-4bcd-be46-6252bf9342b9
dockerImageTag: 2.5.6
dockerImageTag: 2.5.7
dockerRepository: airbyte/source-shopify
documentationUrl: https://docs.airbyte.com/integrations/sources/shopify
erdUrl: https://dbdocs.io/airbyteio/source-shopify?view=relationships
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [ "poetry-core>=1.0.0",]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
version = "2.5.6"
version = "2.5.7"
name = "source-shopify"
description = "Source CDK implementation for Shopify."
authors = [ "Airbyte <[email protected]>",]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from dataclasses import dataclass, field
from datetime import datetime
from time import sleep, time
Expand All @@ -12,7 +11,7 @@
import requests
from airbyte_cdk.sources.streams.http import HttpClient
from requests.exceptions import JSONDecodeError
from source_shopify.utils import ApiTypeEnum
from source_shopify.utils import LOGGER, ApiTypeEnum
from source_shopify.utils import ShopifyRateLimiter as limiter

from .exceptions import AirbyteTracedException, ShopifyBulkExceptions
Expand All @@ -32,8 +31,8 @@ class ShopifyBulkManager:
job_size: float
job_checkpoint_interval: int

# default logger
logger: Final[logging.Logger] = logging.getLogger("airbyte")
parent_stream_name: Optional[str] = None
parent_stream_cursor: Optional[str] = None

# 10Mb chunk size to save the file
_retrieve_chunk_size: Final[int] = 1024 * 1024 * 10
Expand Down Expand Up @@ -94,7 +93,7 @@ def __post_init__(self) -> None:
# how many records should be collected before we use the checkpoining
self._job_checkpoint_interval = self.job_checkpoint_interval
# define Record Producer instance
self.record_producer: ShopifyBulkRecord = ShopifyBulkRecord(self.query)
self.record_producer: ShopifyBulkRecord = ShopifyBulkRecord(self.query, self.parent_stream_name, self.parent_stream_cursor)

@property
def _tools(self) -> BulkTools:
Expand Down Expand Up @@ -251,9 +250,9 @@ def _log_job_state_with_count(self) -> None:
def _log_state(self, message: Optional[str] = None) -> None:
pattern = f"Stream: `{self.http_client.name}`, the BULK Job: `{self._job_id}` is {self._job_state}"
if message:
self.logger.info(f"{pattern}. {message}.")
LOGGER.info(f"{pattern}. {message}.")
else:
self.logger.info(pattern)
LOGGER.info(pattern)

def _job_get_result(self, response: Optional[requests.Response] = None) -> Optional[str]:
parsed_response = response.json().get("data", {}).get("node", {}) if response else None
Expand Down Expand Up @@ -309,13 +308,13 @@ def _on_canceling_job(self, **kwargs) -> None:
sleep(self._job_check_interval)

def _cancel_on_long_running_job(self) -> None:
self.logger.info(
LOGGER.info(
f"Stream: `{self.http_client.name}` the BULK Job: {self._job_id} runs longer than expected ({self._job_max_elapsed_time} sec). Retry with the reduced `Slice Size` after self-cancelation."
)
self._job_cancel()

def _cancel_on_checkpointing(self) -> None:
self.logger.info(f"Stream: `{self.http_client.name}`, checkpointing after >= `{self._job_checkpoint_interval}` rows collected.")
LOGGER.info(f"Stream: `{self.http_client.name}`, checkpointing after >= `{self._job_checkpoint_interval}` rows collected.")
# set the flag to adjust the next slice from the checkpointed cursor value
self._job_cancel()

Expand Down Expand Up @@ -434,7 +433,7 @@ def _should_switch_shop_name(self, response: requests.Response) -> bool:
return True
return False

@bulk_retry_on_exception(logger)
@bulk_retry_on_exception()
def _job_check_state(self) -> None:
while not self._job_completed():
if self._job_canceled():
Expand All @@ -444,7 +443,7 @@ def _job_check_state(self) -> None:
else:
self._job_track_running()

@bulk_retry_on_exception(logger)
@bulk_retry_on_exception()
def create_job(self, stream_slice: Mapping[str, str], filter_field: str) -> None:
if stream_slice:
query = self.query.get(filter_field, stream_slice["start"], stream_slice["end"])
Expand Down Expand Up @@ -484,7 +483,7 @@ def _job_process_created(self, response: requests.Response) -> None:
self._job_id = bulk_response.get("id")
self._job_created_at = bulk_response.get("createdAt")
self._job_state = ShopifyBulkJobStatus.CREATED.value
self.logger.info(f"Stream: `{self.http_client.name}`, the BULK Job: `{self._job_id}` is {ShopifyBulkJobStatus.CREATED.value}")
LOGGER.info(f"Stream: `{self.http_client.name}`, the BULK Job: `{self._job_id}` is {ShopifyBulkJobStatus.CREATED.value}")

def job_size_normalize(self, start: datetime, end: datetime) -> datetime:
# adjust slice size when it's bigger than the loop point when it should end,
Expand Down Expand Up @@ -522,7 +521,7 @@ def _emit_final_job_message(self, job_current_elapsed_time: int) -> None:
final_message = final_message + lines_collected_message

# emit final Bulk job status message
self.logger.info(f"{final_message}")
LOGGER.info(f"{final_message}")

def _process_bulk_results(self) -> Iterable[Mapping[str, Any]]:
if self._job_result_filename:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,7 @@ def prepare(query: str) -> str:
@dataclass
class ShopifyBulkQuery:
config: Mapping[str, Any]
parent_stream_name: Optional[str] = None
parent_stream_cursor: Optional[str] = None

@property
def has_parent_stream(self) -> bool:
return True if self.parent_stream_name and self.parent_stream_cursor else False

@property
def parent_cursor_key(self) -> Optional[str]:
if self.has_parent_stream:
return f"{self.parent_stream_name}_{self.parent_stream_cursor}"
parent_stream_cursor_alias: Optional[str] = None

@property
def shop_id(self) -> int:
Expand Down Expand Up @@ -143,38 +133,12 @@ def query_nodes(self) -> Optional[Union[List[Field], List[str]]]:
"""
return ["__typename", "id"]

def _inject_parent_cursor_field(self, nodes: List[Field], key: str = "updatedAt", index: int = 2) -> List[Field]:
if self.has_parent_stream:
def inject_parent_cursor_field(self, nodes: List[Field], key: str = "updatedAt", index: int = 2) -> List[Field]:
if self.parent_stream_cursor_alias:
# inject parent cursor key as alias to the `updatedAt` parent cursor field
nodes.insert(index, Field(name="updatedAt", alias=self.parent_cursor_key))

nodes.insert(index, Field(name=key, alias=self.parent_stream_cursor_alias))
return nodes

def _add_parent_record_state(self, record: MutableMapping[str, Any], items: List[dict], to_rfc3339: bool = False) -> List[dict]:
"""
Adds a parent cursor value to each item in the list.
This method iterates over a list of dictionaries and adds a new key-value pair to each dictionary.
The key is the value of `self.query_name`, and the value is another dictionary with a single key "updated_at"
and the provided `parent_cursor_value`.
Args:
items (List[dict]): A list of dictionaries to which the parent cursor value will be added.
parent_cursor_value (str): The value to be set for the "updated_at" key in the nested dictionary.
Returns:
List[dict]: The modified list of dictionaries with the added parent cursor values.
"""

if self.has_parent_stream:
parent_cursor_value: Optional[str] = record.get(self.parent_cursor_key, None)
parent_state = self.tools._datetime_str_to_rfc3339(parent_cursor_value) if to_rfc3339 and parent_cursor_value else None

for item in items:
item[self.parent_stream_name] = {self.parent_stream_cursor: parent_state}

return items

def get(self, filter_field: Optional[str] = None, start: Optional[str] = None, end: Optional[str] = None) -> str:
# define filter query string, if passed
filter_query = f"{filter_field}:>='{start}' AND {filter_field}:<='{end}'" if filter_field else None
Expand Down Expand Up @@ -339,7 +303,7 @@ def query_nodes(self) -> List[Field]:
elif isinstance(self.type.value, str):
nodes = [*nodes, metafield_node]

nodes = self._inject_parent_cursor_field(nodes)
nodes = self.inject_parent_cursor_field(nodes)

return nodes

Expand Down Expand Up @@ -372,9 +336,6 @@ def record_process_components(self, record: MutableMapping[str, Any]) -> Iterabl
else:
metafields = record_components.get("Metafield", [])
if len(metafields) > 0:
if self.has_parent_stream:
# add parent state to each metafield
metafields = self._add_parent_record_state(record, metafields, to_rfc3339=True)
yield from self._process_components(metafields)


Expand Down Expand Up @@ -637,7 +598,7 @@ def query_nodes(self) -> List[Field]:
media_node = self.get_edge_node("media", media_fields)

fields: List[Field] = ["__typename", "id", media_node]
fields = self._inject_parent_cursor_field(fields)
fields = self.inject_parent_cursor_field(fields)

return fields

Expand Down Expand Up @@ -2422,7 +2383,7 @@ class ProductImage(ShopifyBulkQuery):

@property
def query_nodes(self) -> List[Field]:
return self._inject_parent_cursor_field(self.nodes)
return self.inject_parent_cursor_field(self.nodes)

def _process_component(self, entity: List[dict]) -> List[dict]:
for item in entity:
Expand Down Expand Up @@ -2499,8 +2460,6 @@ def record_process_components(self, record: MutableMapping[str, Any]) -> Iterabl

# add the product_id to each `Image`
record["images"] = self._add_product_id(record.get("images", []), record.get("id"))
# add the product cursor to each `Image`
record["images"] = self._add_parent_record_state(record, record.get("images", []), to_rfc3339=True)
record["images"] = self._merge_with_media(record_components)
record.pop("record_components")

Expand Down
Loading

0 comments on commit c9c8190

Please sign in to comment.