Skip to content

Commit

Permalink
File CDK: Don't fetch full file list for availability check (#31651)
Browse files Browse the repository at this point in the history
Co-authored-by: flash1293 <[email protected]>
  • Loading branch information
Joe Reuter and flash1293 authored Oct 23, 2023
1 parent 1f564ff commit d474827
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import traceback
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

from airbyte_cdk.sources import Source
from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy
Expand Down Expand Up @@ -55,28 +55,32 @@ def check_availability_and_parsability(
"""
parser = stream.get_parser()
try:
files = self._check_list_files(stream)
file = self._check_list_files(stream)
if not parser.parser_max_n_files_for_parsability == 0:
self._check_parse_record(stream, files[0], logger)
self._check_parse_record(stream, file, logger)
else:
# If the parser is set to not check parsability, we still want to check that we can open the file.
handle = stream.stream_reader.open_file(files[0], parser.file_read_mode, None, logger)
handle = stream.stream_reader.open_file(file, parser.file_read_mode, None, logger)
handle.close()
except CheckAvailabilityError:
return False, "".join(traceback.format_exc())

return True, None

def _check_list_files(self, stream: "AbstractFileBasedStream") -> List[RemoteFile]:
def _check_list_files(self, stream: "AbstractFileBasedStream") -> RemoteFile:
"""
Check that we can list files from the stream.
Returns the first file if successful, otherwise raises a CheckAvailabilityError.
"""
try:
files = stream.list_files()
file = next(iter(stream.get_files()))
except StopIteration:
raise CheckAvailabilityError(FileBasedSourceError.EMPTY_STREAM, stream=stream.name)
except Exception as exc:
raise CheckAvailabilityError(FileBasedSourceError.ERROR_LISTING_FILES, stream=stream.name) from exc

if not files:
raise CheckAvailabilityError(FileBasedSourceError.EMPTY_STREAM, stream=stream.name)

return files
return file

def _check_parse_record(self, stream: "AbstractFileBasedStream", file: RemoteFile, logger: logging.Logger) -> None:
parser = stream.get_parser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#

from abc import abstractmethod
from functools import cached_property, lru_cache
from functools import cache, cached_property, lru_cache
from typing import Any, Dict, Iterable, List, Mapping, Optional, Type

from airbyte_cdk.models import SyncMode
Expand Down Expand Up @@ -59,10 +59,21 @@ def __init__(
def primary_key(self) -> PrimaryKeyType:
...

@abstractmethod
@cache
def list_files(self) -> List[RemoteFile]:
"""
List all files that belong to the stream.
The output of this method is cached so we don't need to list the files more than once.
This means we won't pick up changes to the files during a sync. This meethod uses the
get_files method which is implemented by the concrete stream class.
"""
return list(self.get_files())

@abstractmethod
def get_files(self) -> Iterable[RemoteFile]:
"""
List all files that belong to the stream as defined by the stream's globs.
"""
...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,11 @@ def _get_raw_json_schema(self) -> JsonSchema:

return schema

@cache
def list_files(self) -> List[RemoteFile]:
def get_files(self) -> Iterable[RemoteFile]:
"""
List all files that belong to the stream as defined by the stream's globs.
The output of this method is cached so we don't need to list the files more than once.
This means we won't pick up changes to the files during a sync.
Return all files that belong to the stream as defined by the stream's globs.
"""
return list(self.stream_reader.get_matching_files(self.config.globs or [], self.config.legacy_prefix, self.logger))
return self.stream_reader.get_matching_files(self.config.globs or [], self.config.legacy_prefix, self.logger)

def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
loop = asyncio.get_event_loop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,30 @@ def test_given_file_extension_does_not_match_when_check_availability_and_parsabi
example we've seen was for JSONL parser but the file extension was just `.json`. Note that there we more than one record extracted
from this stream so it's not just that the file is one JSON object
"""
self._stream.list_files.return_value = [_FILE_WITH_UNKNOWN_EXTENSION]
self._stream.get_files.return_value = [_FILE_WITH_UNKNOWN_EXTENSION]
self._parser.parse_records.return_value = [{"a record": 1}]

is_available, reason = self._strategy.check_availability_and_parsability(self._stream, Mock(), Mock())

assert is_available

def test_not_available_given_no_files(self) -> None:
"""
If no files are returned, then the stream is not available.
"""
self._stream.get_files.return_value = []

is_available, reason = self._strategy.check_availability_and_parsability(self._stream, Mock(), Mock())

assert not is_available
assert "No files were identified in the stream" in reason

def test_parse_records_is_not_called_with_parser_max_n_files_for_parsability_set(self) -> None:
"""
If the stream parser sets parser_max_n_files_for_parsability to 0, then we should not call parse_records on it
"""
self._parser.parser_max_n_files_for_parsability = 0
self._stream.list_files.return_value = [_FILE_WITH_UNKNOWN_EXTENSION]
self._stream.get_files.return_value = [_FILE_WITH_UNKNOWN_EXTENSION]

is_available, reason = self._strategy.check_availability_and_parsability(self._stream, Mock(), Mock())

Expand Down

0 comments on commit d474827

Please sign in to comment.