Skip to content

Commit

Permalink
Merge branch 'master' into brian/concurrent_declarative_source
Browse files Browse the repository at this point in the history
  • Loading branch information
brianjlai committed Oct 17, 2024
2 parents 1242296 + 0d3bb81 commit b74b942
Show file tree
Hide file tree
Showing 26 changed files with 7,163 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ interface MetadataQuerier : AutoCloseable {
/** An implementation might open a connection to build a [MetadataQuerier] instance. */
fun session(config: T): MetadataQuerier
}

fun commonCursorOrNull(cursorColumnID: String): FieldOrMetaField?
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import io.airbyte.cdk.command.StreamInputState
import io.airbyte.cdk.data.AirbyteSchemaType
import io.airbyte.cdk.data.ArrayAirbyteSchemaType
import io.airbyte.cdk.data.LeafAirbyteSchemaType
import io.airbyte.cdk.discover.CommonMetaField
import io.airbyte.cdk.discover.Field
import io.airbyte.cdk.discover.FieldOrMetaField
import io.airbyte.cdk.discover.MetaField
Expand Down Expand Up @@ -198,12 +197,13 @@ class StateManagerFactory(
if (cursorColumnIDComponents.isEmpty()) {
return null
}

val cursorColumnID: String = cursorColumnIDComponents.joinToString(separator = ".")
if (cursorColumnID == CommonMetaField.CDC_LSN.id) {
return CommonMetaField.CDC_LSN
}
return dataColumnOrNull(cursorColumnID)
val maybeCursorField: FieldOrMetaField? =
metadataQuerier.commonCursorOrNull(cursorColumnID)
return maybeCursorField ?: dataColumnOrNull(cursorColumnID)
}

val configuredPrimaryKey: List<Field>? =
configuredStream.primaryKey?.asSequence()?.let { pkOrNull(it.toList()) }
val configuredCursor: FieldOrMetaField? =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class ResourceDrivenMetadataQuerierFactory(
}

override fun extraChecks() {}
override fun commonCursorOrNull(cursorColumnID: String): FieldOrMetaField? {
return when (cursorColumnID) {
CommonMetaField.CDC_LSN.id -> CommonMetaField.CDC_LSN
else -> null
}
}

override fun close() {
isClosed = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import io.airbyte.cdk.load.task.implementor.FailSyncTaskFactory
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicReference
import kotlinx.coroutines.CancellationException

/**
Expand Down Expand Up @@ -64,6 +65,8 @@ T : LeveledTask,
T : ScopedTask {
val log = KotlinLogging.logger {}

val onException = AtomicReference(suspend {})

inner class SyncTaskWrapper(
private val syncManager: SyncManager,
override val innerTask: ScopedTask,
Expand Down Expand Up @@ -142,7 +145,11 @@ T : ScopedTask {
}
}

override fun withExceptionHandling(task: T): WrappedTask<ScopedTask> {
override suspend fun setCallback(callback: suspend () -> Unit) {
onException.set(callback)
}

override suspend fun withExceptionHandling(task: T): WrappedTask<ScopedTask> {
return when (task) {
is SyncLevel -> SyncTaskWrapper(syncManager, task)
is StreamLevel -> StreamTaskWrapper(task.stream, syncManager, task)
Expand All @@ -167,6 +174,6 @@ T : ScopedTask {
}

override suspend fun handleSyncFailed() {
taskScopeProvider.kill()
onException.get().invoke()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import io.airbyte.cdk.load.task.internal.UpdateCheckpointsTask
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock

Expand Down Expand Up @@ -105,12 +106,15 @@ class DefaultDestinationTaskLauncher(
private val log = KotlinLogging.logger {}

private val batchUpdateLock = Mutex()
private val succeeded = Channel<Boolean>(Channel.UNLIMITED)

private suspend fun enqueue(task: LeveledTask) {
taskScopeProvider.launch(exceptionHandler.withExceptionHandling(task))
}

override suspend fun start() {
override suspend fun run() {
exceptionHandler.setCallback { succeeded.send(false) }

// Start the input consumer ASAP
log.info { "Starting input consumer task" }
enqueue(inputConsumerTask)
Expand All @@ -127,11 +131,19 @@ class DefaultDestinationTaskLauncher(
enqueue(spillTask)
}

// Start the checkpoint management tasks
log.info { "Starting timed flush task" }
enqueue(timedFlushTask)

log.info { "Starting checkpoint update task" }
enqueue(updateCheckpointsTask)

// Await completion
if (succeeded.receive()) {
taskScopeProvider.close()
} else {
taskScopeProvider.kill()
}
}

/** Called when the initial destination setup completes. */
Expand Down Expand Up @@ -207,6 +219,6 @@ class DefaultDestinationTaskLauncher(

/** Called exactly once when all streams are closed. */
override suspend fun handleTeardownComplete() {
taskScopeProvider.close()
succeeded.send(true)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.AtomicReference
import kotlin.system.measureTimeMillis
import kotlinx.coroutines.CompletableJob
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
Expand All @@ -33,21 +35,13 @@ import kotlinx.coroutines.withTimeoutOrNull
* - should not block internal tasks (esp reading from stdin)
* - should complete if possible even when failing the sync
* ```
* - [ShutdownScope]: special case of [ImplementorScope]
* ```
* - tasks that should run during shutdown
* - handles canceling/joining other tasks
* - (and so should not cancel themselves)
* ```
*/
sealed interface ScopedTask : Task

interface InternalScope : ScopedTask

interface ImplementorScope : ScopedTask

interface ShutdownScope : ScopedTask

@Singleton
@Secondary
class DestinationTaskScopeProvider(config: DestinationConfiguration) :
Expand All @@ -57,42 +51,65 @@ class DestinationTaskScopeProvider(config: DestinationConfiguration) :
private val timeoutMs = config.gracefulCancellationTimeoutMs

data class ControlScope(
val job: Job,
val dispatcher: CoroutineDispatcher,
val name: String,
val job: CompletableJob,
val dispatcher: CoroutineDispatcher
) {
val scope: CoroutineScope = CoroutineScope(dispatcher + job)
)
val runningJobs: AtomicLong = AtomicLong(0)
}

private val internalScope = ControlScope(Job(), Dispatchers.IO)
private val internalScope = ControlScope("internal", Job(), Dispatchers.IO)

private val implementorScope =
ControlScope(
SupervisorJob(),
"implementor",
Job(),
Executors.newFixedThreadPool(config.maxNumImplementorTaskThreads)
.asCoroutineDispatcher()
)

override suspend fun launch(task: WrappedTask<ScopedTask>) {
when (task.innerTask) {
is InternalScope -> internalScope.scope.launch { execute(task, "internal") }
is ImplementorScope -> implementorScope.scope.launch { execute(task, "implementor") }
is ShutdownScope -> implementorScope.scope.launch { execute(task, "shutdown") }
val scope =
when (task.innerTask) {
is InternalScope -> internalScope
is ImplementorScope -> implementorScope
}
scope.scope.launch {
var nJobs = scope.runningJobs.incrementAndGet()
log.info { "Launching task $task in scope ${scope.name} ($nJobs now running)" }
val elapsed = measureTimeMillis { task.execute() }
nJobs = scope.runningJobs.decrementAndGet()
log.info { "Task $task completed in $elapsed ms ($nJobs now running)" }
}
}

private suspend fun execute(task: WrappedTask<ScopedTask>, scope: String) {
log.info { "Launching task $task in scope $scope" }
val elapsed = measureTimeMillis { task.execute() }
log.info { "Task $task completed in $elapsed ms" }
}

override suspend fun close() {
log.info { "Closing task scopes" }
// Under normal operation, all tasks should be complete
// (except things like force flush, which loop). So
// - it's safe to force cancel the internal tasks
// - implementor scope should join immediately
log.info { "Closing task scopes (${implementorScope.runningJobs.get()} remaining)" }
val uncaughtExceptions = AtomicReference<Throwable>()
implementorScope.job.children.forEach {
it.invokeOnCompletion { cause ->
if (cause != null) {
log.error { "Uncaught exception in implementor task: $cause" }
uncaughtExceptions.set(cause)
}
}
}
implementorScope.job.complete()
implementorScope.job.join()
log.info { "Implementor tasks completed, cancelling internal tasks." }
if (uncaughtExceptions.get() != null) {
throw IllegalStateException(
"Uncaught exceptions in implementor tasks",
uncaughtExceptions.get()
)
}
log.info {
"Implementor tasks completed, cancelling internal tasks (${internalScope.runningJobs.get()} remaining)."
}
internalScope.job.cancel()
}

Expand All @@ -104,6 +121,7 @@ class DestinationTaskScopeProvider(config: DestinationConfiguration) :
log.info {
"Cancelled internal tasks, waiting ${timeoutMs}ms for implementor tasks to complete"
}
implementorScope.job.complete()
implementorScope.job.join()
log.info { "Implementor tasks completed" }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,30 @@ interface Task {
* transitions between tasks.
*/
interface TaskLauncher {
suspend fun start()
/**
* Execute the task workflow. Should dispatch tasks asynchronously and suspend until the
* workflow is complete.
*/
suspend fun run()
}

/**
* Wraps tasks with exception handling. It should provide an exception handling workflow and take
* responsibility for closing scopes, etc.
* Wraps tasks with exception handling. It should perform all necessary exception handling, then
* execute the provided callback.
*/
interface TaskExceptionHandler<T : Task, U : Task> {
fun withExceptionHandling(task: T): U
// Wrap a task with exception handling.
suspend fun withExceptionHandling(task: T): U

// Set a callback that will be invoked when any exception handling is done.
suspend fun setCallback(callback: suspend () -> Unit)
}

/** Provides the scope(s) in which tasks run. */
interface TaskScopeProvider<T : Task> : CloseableCoroutine {
/** Launch a task in the correct scope. */
suspend fun launch(task: T)

/** Unliked close, may attempt to fail gracefully, but should guarantee return. */
/** Unliked [close], may attempt to fail gracefully, but should guarantee return. */
suspend fun kill()
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.state.StreamIncompleteResult
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskExceptionHandler
import io.airbyte.cdk.load.task.ShutdownScope
import io.airbyte.cdk.load.task.ImplementorScope
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton

interface FailStreamTask : ShutdownScope
interface FailStreamTask : ImplementorScope

/**
* FailStreamTask is a task that is executed when a stream fails. It is responsible for cleaning up
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ package io.airbyte.cdk.load.task.implementor
import io.airbyte.cdk.load.state.CheckpointManager
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskExceptionHandler
import io.airbyte.cdk.load.task.ShutdownScope
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.util.setOnce
import io.airbyte.cdk.load.write.DestinationWriter
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicBoolean

interface FailSyncTask : ShutdownScope
interface FailSyncTask : ImplementorScope

/**
* FailSyncTask is a task that is executed when a sync fails. It is responsible for cleaning up
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package io.airbyte.cdk.load.task.implementor
import io.airbyte.cdk.load.state.CheckpointManager
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ShutdownScope
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.task.SyncLevel
import io.airbyte.cdk.load.util.setOnce
import io.airbyte.cdk.load.write.DestinationWriter
Expand All @@ -16,7 +16,7 @@ import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.atomic.AtomicBoolean

interface TeardownTask : SyncLevel, ShutdownScope
interface TeardownTask : SyncLevel, ImplementorScope

/**
* Wraps @[DestinationWriter.teardown] and stops the task launcher.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class WriteOperation(
val log = KotlinLogging.logger {}

override fun execute() = runBlocking {
taskLauncher.start()
taskLauncher.run()

when (val result = syncManager.awaitSyncResult()) {
is SyncSuccess -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,10 @@ class DestinationTaskExceptionHandlerTest<T> where T : LeveledTask, T : ScopedTa
}

@Test
fun testHandleSyncFailed(scopeProvider: MockScopeProvider) = runTest {
fun testHandleSyncFailed() = runTest {
val wasHandled = Channel<Boolean>(Channel.UNLIMITED)
exceptionHandler.setCallback { wasHandled.send(true) }
exceptionHandler.handleSyncFailed()
Assertions.assertTrue(scopeProvider.didKill)
Assertions.assertTrue(wasHandled.receive())
}
}
Loading

0 comments on commit b74b942

Please sign in to comment.