Skip to content

Commit

Permalink
bug(cdk) Always return a connection status even if an exception was r…
Browse files Browse the repository at this point in the history
…aised (#45205)
  • Loading branch information
girarda authored Sep 27, 2024
1 parent 65622a9 commit f01a43c
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 14 deletions.
19 changes: 18 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
from airbyte_cdk.logger import init_logger
from airbyte_cdk.models import ( # type: ignore [attr-defined]
AirbyteConnectionStatus,
AirbyteMessage,
AirbyteMessageSerializer,
AirbyteStateStats,
Expand Down Expand Up @@ -139,12 +140,28 @@ def check(self, source_spec: ConnectorSpecification, config: TConfig) -> Iterabl
self.validate_connection(source_spec, config)
except AirbyteTracedException as traced_exc:
connection_status = traced_exc.as_connection_status_message()
# The platform uses the exit code to surface unexpected failures so we raise the exception if the failure type not a config error
# If the failure is not exceptional, we'll emit a failed connection status message and return
if traced_exc.failure_type != FailureType.config_error:
raise traced_exc
if connection_status:
yield from self._emit_queued_messages(self.source)
yield connection_status
return

check_result = self.source.check(self.logger, config)
try:
check_result = self.source.check(self.logger, config)
except AirbyteTracedException as traced_exc:
yield traced_exc.as_airbyte_message()
# The platform uses the exit code to surface unexpected failures so we raise the exception if the failure type not a config error
# If the failure is not exceptional, we'll emit a failed connection status message and return
if traced_exc.failure_type != FailureType.config_error:
raise traced_exc
else:
yield AirbyteMessage(
type=Type.CONNECTION_STATUS, connectionStatus=AirbyteConnectionStatus(status=Status.FAILED, message=traced_exc.message)
)
return
if check_result.status == Status.SUCCEEDED:
self.logger.info("Check succeeded")
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
from airbyte_cdk import AirbyteTracedException

from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError
from unit_tests.sources.file_based.helpers import (
FailingSchemaValidationPolicy,
Expand Down Expand Up @@ -130,7 +130,7 @@
_base_failure_scenario.copy()
.set_name("error_empty_stream_scenario")
.set_source_builder(_base_failure_scenario.copy().source_builder.copy().set_files({}))
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.EMPTY_STREAM.value)
.set_expected_check_error(None, FileBasedSourceError.EMPTY_STREAM.value)
).build()


Expand All @@ -142,7 +142,7 @@
TestErrorListMatchingFilesInMemoryFilesStreamReader(files=_base_failure_scenario.source_builder._files, file_type="csv")
)
)
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.ERROR_LISTING_FILES.value)
.set_expected_check_error(None, FileBasedSourceError.ERROR_LISTING_FILES.value)
).build()


Expand All @@ -154,7 +154,7 @@
TestErrorOpenFileInMemoryFilesStreamReader(files=_base_failure_scenario.source_builder._files, file_type="csv")
)
)
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.ERROR_READING_FILE.value)
.set_expected_check_error(None, FileBasedSourceError.ERROR_READING_FILE.value)
).build()


Expand Down Expand Up @@ -216,5 +216,5 @@
],
}
)
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.ERROR_READING_FILE.value)
.set_expected_check_error(None, FileBasedSourceError.ERROR_READING_FILE.value)
).build()
Original file line number Diff line number Diff line change
Expand Up @@ -1940,7 +1940,7 @@
}
)
.set_expected_check_status("FAILED")
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_check_error(None, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
).build()
Expand Down Expand Up @@ -2030,7 +2030,7 @@
}
)
.set_expected_check_status("FAILED")
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_check_error(None, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
).build()
Expand Down Expand Up @@ -3240,7 +3240,6 @@
}
)
.set_expected_records(None)
.set_expected_check_error(AirbyteTracedException, None)
).build()

csv_no_records_scenario: TestScenario[InMemoryFilesSource] = (
Expand Down Expand Up @@ -3340,5 +3339,4 @@
}
)
.set_expected_records(None)
.set_expected_check_error(AirbyteTracedException, None)
).build()
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,15 @@ def check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenar
["check", "--config", make_file(tmp_path / "config.json", scenario.config)],
)
captured = capsys.readouterr()
return json.loads(captured.out.splitlines()[0])["connectionStatus"] # type: ignore
return _find_connection_status(captured.out.splitlines())


def _find_connection_status(output: List[str]) -> Mapping[str, Any]:
for line in output:
json_line = json.loads(line)
if "connectionStatus" in json_line:
return json_line["connectionStatus"]
raise ValueError("No valid connectionStatus found in output")


def discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> Dict[str, Any]:
Expand Down
56 changes: 53 additions & 3 deletions airbyte-cdk/python/unit_tests/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from unittest import mock
from unittest.mock import MagicMock, patch

import freezegun
import pytest
import requests
from airbyte_cdk import AirbyteEntrypoint
Expand All @@ -32,6 +33,7 @@
AirbyteStreamStatusTraceMessage,
AirbyteTraceMessage,
ConnectorSpecification,
FailureType,
OrchestratorType,
Status,
StreamDescriptor,
Expand Down Expand Up @@ -151,6 +153,8 @@ def _wrap_message(submessage: Union[AirbyteConnectionStatus, ConnectorSpecificat
message = AirbyteMessage(type=Type.CATALOG, catalog=submessage)
elif isinstance(submessage, AirbyteRecordMessage):
message = AirbyteMessage(type=Type.RECORD, record=submessage)
elif isinstance(submessage, AirbyteTraceMessage):
message = AirbyteMessage(type=Type.TRACE, trace=submessage)
else:
raise Exception(f"Unknown message type: {submessage}")

Expand Down Expand Up @@ -219,13 +223,59 @@ def test_run_check(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock
assert spec_mock.called


@freezegun.freeze_time("1970-01-01T00:00:00.001Z")
def test_run_check_with_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
exception = ValueError("Any error")
parsed_args = Namespace(command="check", config="config_path")
mocker.patch.object(MockSource, "check", side_effect=ValueError("Any error"))
mocker.patch.object(MockSource, "check", side_effect=exception)

with pytest.raises(ValueError):
messages = list(entrypoint.run(parsed_args))
assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode()] == messages
list(entrypoint.run(parsed_args))


@freezegun.freeze_time("1970-01-01T00:00:00.001Z")
def test_run_check_with_traced_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
exception = AirbyteTracedException.from_exception(ValueError("Any error"))
parsed_args = Namespace(command="check", config="config_path")
mocker.patch.object(MockSource, "check", side_effect=exception)

with pytest.raises(AirbyteTracedException):
list(entrypoint.run(parsed_args))


@freezegun.freeze_time("1970-01-01T00:00:00.001Z")
def test_run_check_with_config_error(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
exception = AirbyteTracedException.from_exception(ValueError("Any error"))
exception.failure_type = FailureType.config_error
parsed_args = Namespace(command="check", config="config_path")
mocker.patch.object(MockSource, "check", side_effect=exception)

messages = list(entrypoint.run(parsed_args))
expected_trace = exception.as_airbyte_message()
expected_trace.emitted_at = 1
expected_trace.trace.emitted_at = 1
expected_messages = [
orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(),
orjson.dumps(AirbyteMessageSerializer.dump(expected_trace)).decode(),
_wrap_message(
AirbyteConnectionStatus(
status=Status.FAILED,
message=AirbyteTracedException.from_exception(exception).message
)
),
]
assert messages == expected_messages


@freezegun.freeze_time("1970-01-01T00:00:00.001Z")
def test_run_check_with_transient_error(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
exception = AirbyteTracedException.from_exception(ValueError("Any error"))
exception.failure_type = FailureType.transient_error
parsed_args = Namespace(command="check", config="config_path")
mocker.patch.object(MockSource, "check", side_effect=exception)

with pytest.raises(AirbyteTracedException):
list(entrypoint.run(parsed_args))


def test_run_discover(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
Expand Down

0 comments on commit f01a43c

Please sign in to comment.