Skip to content

Commit

Permalink
Issue 32871/extract trace message creation (#33227)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 authored Dec 11, 2023
1 parent d3f2aa5 commit 0c2d43f
Show file tree
Hide file tree
Showing 22 changed files with 384 additions and 49 deletions.
16 changes: 10 additions & 6 deletions airbyte-cdk/python/airbyte_cdk/exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@

import logging
import sys
from types import TracebackType
from typing import Any, Optional

from airbyte_cdk.utils.traced_exception import AirbyteTracedException


def assemble_uncaught_exception(exception_type: type[BaseException], exception_value: BaseException) -> AirbyteTracedException:
if issubclass(exception_type, AirbyteTracedException):
return exception_value # type: ignore # validated as part of the previous line
return AirbyteTracedException.from_exception(exception_value)


def init_uncaught_exception_handler(logger: logging.Logger) -> None:
"""
Handles uncaught exceptions by emitting an AirbyteTraceMessage and making sure they are not
printed to the console without having secrets removed.
"""

def hook_fn(exception_type, exception_value, traceback_):
def hook_fn(exception_type: type[BaseException], exception_value: BaseException, traceback_: Optional[TracebackType]) -> Any:
# For developer ergonomics, we want to see the stack trace in the logs when we do a ctrl-c
if issubclass(exception_type, KeyboardInterrupt):
sys.__excepthook__(exception_type, exception_value, traceback_)
Expand All @@ -23,11 +31,7 @@ def hook_fn(exception_type, exception_value, traceback_):
logger.fatal(exception_value, exc_info=exception_value)

# emit an AirbyteTraceMessage for any exception that gets to this spot
traced_exc = (
exception_value
if issubclass(exception_type, AirbyteTracedException)
else AirbyteTracedException.from_exception(exception_value)
)
traced_exc = assemble_uncaught_exception(exception_type, exception_value)

traced_exc.emit_message()

Expand Down
29 changes: 29 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/test/catalog_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

from typing import Any, Dict, List

from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode


class CatalogBuilder:
def __init__(self) -> None:
self._streams: List[Dict[str, Any]] = []

def with_stream(self, name: str, sync_mode: SyncMode) -> "CatalogBuilder":
self._streams.append(
{
"stream": {
"name": name,
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_primary_key": [["id"]],
},
"primary_key": [["id"]],
"sync_mode": sync_mode.name,
"destination_sync_mode": "overwrite",
}
)
return self

def build(self) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog.parse_obj({"streams": self._streams})
54 changes: 48 additions & 6 deletions airbyte-cdk/python/airbyte_cdk/test/entrypoint_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,29 @@
import json
import logging
import tempfile
import traceback
from io import StringIO
from pathlib import Path
from typing import Any, List, Mapping, Optional, Union

from airbyte_cdk.entrypoint import AirbyteEntrypoint
from airbyte_cdk.exception_handler import assemble_uncaught_exception
from airbyte_cdk.logger import AirbyteLogFormatter
from airbyte_cdk.sources import Source
from airbyte_protocol.models import AirbyteLogMessage, AirbyteMessage, ConfiguredAirbyteCatalog, Level, TraceType, Type
from airbyte_protocol.models import AirbyteLogMessage, AirbyteMessage, AirbyteStreamStatus, ConfiguredAirbyteCatalog, Level, TraceType, Type
from pydantic.error_wrappers import ValidationError


class EntrypointOutput:
def __init__(self, messages: List[str]):
def __init__(self, messages: List[str], uncaught_exception: Optional[BaseException] = None):
try:
self._messages = [self._parse_message(message) for message in messages]
except ValidationError as exception:
raise ValueError("All messages are expected to be AirbyteMessage") from exception

if uncaught_exception:
self._messages.append(assemble_uncaught_exception(type(uncaught_exception), uncaught_exception).as_airbyte_message())

@staticmethod
def _parse_message(message: str) -> AirbyteMessage:
try:
Expand Down Expand Up @@ -65,15 +70,41 @@ def trace_messages(self) -> List[AirbyteMessage]:

@property
def analytics_messages(self) -> List[AirbyteMessage]:
return [message for message in self._get_message_by_types([Type.TRACE]) if message.trace.type == TraceType.ANALYTICS]
return self._get_trace_message_by_trace_type(TraceType.ANALYTICS)

@property
def errors(self) -> List[AirbyteMessage]:
return self._get_trace_message_by_trace_type(TraceType.ERROR)

def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]:
status_messages = map(
lambda message: message.trace.stream_status.status,
filter(
lambda message: message.trace.stream_status.stream_descriptor.name == stream_name,
self._get_trace_message_by_trace_type(TraceType.STREAM_STATUS),
),
)
return list(status_messages)

def _get_message_by_types(self, message_types: List[Type]) -> List[AirbyteMessage]:
return [message for message in self._messages if message.type in message_types]

def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[AirbyteMessage]:
return [message for message in self._get_message_by_types([Type.TRACE]) if message.trace.type == trace_type]


def read(source: Source, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog, state: Optional[Any] = None) -> EntrypointOutput:
def read(
source: Source,
config: Mapping[str, Any],
catalog: ConfiguredAirbyteCatalog,
state: Optional[Any] = None,
expecting_exception: bool = False,
) -> EntrypointOutput:
"""
config and state must be json serializable
:param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
provide expecting_exception=True so that the test output logs are cleaner
"""
log_capture_buffer = StringIO()
stream_handler = logging.StreamHandler(log_capture_buffer)
Expand All @@ -100,12 +131,23 @@ def read(source: Source, config: Mapping[str, Any], catalog: ConfiguredAirbyteCa
)
source_entrypoint = AirbyteEntrypoint(source)
parsed_args = source_entrypoint.parse_args(args)
messages = list(source_entrypoint.run(parsed_args))

messages = []
uncaught_exception = None
try:
for message in source_entrypoint.run(parsed_args):
messages.append(message)
except Exception as exception:
if not expecting_exception:
print("Printing unexpected error from entrypoint_wrapper")
print("".join(traceback.format_exception(None, exception, exception.__traceback__)))
uncaught_exception = exception

captured_logs = log_capture_buffer.getvalue().split("\n")[:-1]

parent_logger.removeHandler(stream_handler)

return EntrypointOutput(messages + captured_logs)
return EntrypointOutput(messages + captured_logs, uncaught_exception)


def make_file(path: Path, file_contents: Optional[Union[str, Mapping[str, Any], List[Mapping[str, Any]]]]) -> str:
Expand Down
6 changes: 0 additions & 6 deletions airbyte-cdk/python/airbyte_cdk/test/http/__init__.py

This file was deleted.

6 changes: 6 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/test/mock_http/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from airbyte_cdk.test.mock_http.matcher import HttpRequestMatcher
from airbyte_cdk.test.mock_http.request import HttpRequest
from airbyte_cdk.test.mock_http.response import HttpResponse
from airbyte_cdk.test.mock_http.mocker import HttpMocker

__all__ = ["HttpMocker", "HttpRequest", "HttpRequestMatcher", "HttpResponse"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

from airbyte_cdk.test.http.request import HttpRequest
from airbyte_cdk.test.mock_http.request import HttpRequest


class HttpRequestMatcher:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Callable, List, Optional, Union

import requests_mock
from airbyte_cdk.test.http import HttpRequest, HttpRequestMatcher, HttpResponse
from airbyte_cdk.test.mock_http import HttpRequest, HttpRequestMatcher, HttpResponse


class HttpMocker(contextlib.ContextDecorator):
Expand Down Expand Up @@ -75,7 +75,7 @@ def wrapper(*args, **kwargs): # type: ignore # this is a very generic wrapper
except requests_mock.NoMockAddress as no_mock_exception:
matchers_as_string = "\n\t".join(map(lambda matcher: str(matcher.request), self._matchers))
raise ValueError(
f"No matcher matches {no_mock_exception.args[0]}. Matchers currently configured are:\n\t{matchers_as_string}"
f"No matcher matches {no_mock_exception.args[0]} with headers `{no_mock_exception.request.headers}`. Matchers currently configured are:\n\t{matchers_as_string}"
) from no_mock_exception
except AssertionError as test_assertion:
assertion_error = test_assertion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path as FilePath
from typing import Any, Dict, List, Optional, Tuple, Union

from airbyte_cdk.test.http import HttpResponse
from airbyte_cdk.test.mock_http import HttpResponse


def _extract(path: List[str], response_template: Dict[str, Any]) -> Any:
Expand Down
13 changes: 7 additions & 6 deletions airbyte-cdk/python/airbyte_cdk/utils/traced_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import traceback
from datetime import datetime
from typing import Optional

from airbyte_cdk.models import (
AirbyteConnectionStatus,
Expand All @@ -25,10 +26,10 @@ class AirbyteTracedException(Exception):

def __init__(
self,
internal_message: str = None,
message: str = None,
internal_message: Optional[str] = None,
message: Optional[str] = None,
failure_type: FailureType = FailureType.system_error,
exception: BaseException = None,
exception: Optional[BaseException] = None,
):
"""
:param internal_message: the internal error that caused the failure
Expand Down Expand Up @@ -71,7 +72,7 @@ def as_connection_status_message(self) -> AirbyteMessage:
)
return output_message

def emit_message(self):
def emit_message(self) -> None:
"""
Prints the exception as an AirbyteTraceMessage.
Note that this will be called automatically on uncaught exceptions when using the airbyte_cdk entrypoint.
Expand All @@ -81,9 +82,9 @@ def emit_message(self):
print(filtered_message)

@classmethod
def from_exception(cls, exc: Exception, *args, **kwargs) -> "AirbyteTracedException":
def from_exception(cls, exc: BaseException, *args, **kwargs) -> "AirbyteTracedException": # type: ignore # ignoring because of args and kwargs
"""
Helper to create an AirbyteTracedException from an existing exception
:param exc: the exception that caused the error
"""
return cls(internal_message=str(exc), exception=exc, *args, **kwargs)
return cls(internal_message=str(exc), exception=exc, *args, **kwargs) # type: ignore # ignoring because of args and kwargs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from airbyte_cdk.models import AirbyteAnalyticsTraceMessage
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat
from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError
from airbyte_cdk.test.catalog_builder import CatalogBuilder
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from airbyte_protocol.models import SyncMode
from unit_tests.sources.file_based.helpers import EmptySchemaParser, LowInferenceLimitDiscoveryPolicy
from unit_tests.sources.file_based.in_memory_files_source import InMemoryFilesSource
from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder
Expand Down Expand Up @@ -1639,6 +1641,7 @@
)
.set_file_type("csv")
)
.set_catalog(CatalogBuilder().with_stream("stream1", SyncMode.full_refresh).build())
.set_expected_catalog(
{
"streams": [
Expand Down Expand Up @@ -1712,6 +1715,7 @@
)
.set_file_type("csv")
)
.set_catalog(CatalogBuilder().with_stream("stream1", SyncMode.full_refresh).with_stream("stream2", SyncMode.full_refresh).build())
.set_expected_catalog(
{
"streams": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from airbyte_cdk.models import AirbyteAnalyticsTraceMessage, SyncMode
from airbyte_cdk.sources import AbstractSource
from airbyte_protocol.models import ConfiguredAirbyteCatalog


@dataclass
Expand Down Expand Up @@ -46,11 +47,13 @@ def __init__(
incremental_scenario_config: Optional[IncrementalScenarioConfig],
expected_analytics: Optional[List[AirbyteAnalyticsTraceMessage]] = None,
log_levels: Optional[Set[str]] = None,
catalog: Optional[ConfiguredAirbyteCatalog] = None,
):
if log_levels is None:
log_levels = {"ERROR", "WARN", "WARNING"}
self.name = name
self.config = config
self.catalog = catalog
self.source = source
self.expected_spec = expected_spec
self.expected_check_status = expected_check_status
Expand All @@ -67,16 +70,15 @@ def __init__(

def validate(self) -> None:
assert self.name
if not self.expected_catalog:
return
if self.expected_read_error or self.expected_check_error:
return
# Only verify the streams if no errors are expected
streams = set([s.name for s in self.source.streams(self.config)])
expected_streams = {s["name"] for s in self.expected_catalog["streams"]}
assert expected_streams <= streams

def configured_catalog(self, sync_mode: SyncMode) -> Optional[Mapping[str, Any]]:
# The preferred way of returning the catalog for the TestScenario is by providing it at the initialization. The previous solution
# relied on `self.source.streams` which might raise an exception hence screwing the tests results as the user might expect the
# exception to be raised as part of the actual check/discover/read commands
# Note that to avoid a breaking change, we still attempt to automatically generate the catalog based on the streams
if self.catalog:
return self.catalog.dict() # type: ignore # dict() is not typed

catalog: Mapping[str, Any] = {"streams": []}
for stream in self.source.streams(self.config):
catalog["streams"].append(
Expand Down Expand Up @@ -108,6 +110,7 @@ class TestScenarioBuilder(Generic[SourceType]):
def __init__(self) -> None:
self._name = ""
self._config: Mapping[str, Any] = {}
self._catalog: Optional[ConfiguredAirbyteCatalog] = None
self._expected_spec: Optional[Mapping[str, Any]] = None
self._expected_check_status: Optional[str] = None
self._expected_catalog: Mapping[str, Any] = {}
Expand All @@ -133,6 +136,10 @@ def set_expected_spec(self, expected_spec: Mapping[str, Any]) -> "TestScenarioBu
self._expected_spec = expected_spec
return self

def set_catalog(self, catalog: ConfiguredAirbyteCatalog) -> "TestScenarioBuilder[SourceType]":
self._catalog = catalog
return self

def set_expected_check_status(self, expected_check_status: str) -> "TestScenarioBuilder[SourceType]":
self._expected_check_status = expected_check_status
return self
Expand Down Expand Up @@ -201,6 +208,7 @@ def build(self) -> "TestScenario[SourceType]":
self._incremental_scenario_config,
self._expected_analytics,
self._log_levels,
self._catalog,
)

def _configured_catalog(self, sync_mode: SyncMode) -> Optional[Mapping[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


from airbyte_cdk.sources.file_based.exceptions import ConfigValidationError, FileBasedSourceError
from airbyte_cdk.test.catalog_builder import CatalogBuilder
from airbyte_protocol.models import SyncMode
from unit_tests.sources.file_based.scenarios.file_based_source_builder import FileBasedSourceBuilder
from unit_tests.sources.file_based.scenarios.scenario_builder import TestScenarioBuilder

Expand Down Expand Up @@ -116,6 +118,7 @@
]
}
)
.set_catalog(CatalogBuilder().with_stream("stream1", SyncMode.full_refresh).build())
.set_expected_check_status("FAILED")
.set_expected_check_error(None, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value)
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value)
Expand Down Expand Up @@ -439,6 +442,7 @@
]
}
)
.set_catalog(CatalogBuilder().with_stream("stream1", SyncMode.full_refresh).with_stream("stream2", SyncMode.full_refresh).with_stream("stream3", SyncMode.full_refresh).build())
.set_expected_check_status("FAILED")
.set_expected_check_error(None, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value)
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.ERROR_PARSING_USER_PROVIDED_SCHEMA.value)
Expand Down
Loading

0 comments on commit 0c2d43f

Please sign in to comment.