Skip to content

Commit

Permalink
convert destination-snowflake to Kotlin CDK (#36910)
Browse files Browse the repository at this point in the history
not only bringing snowflake to the latest CDK but also:
1) Bringing the `SourceOperation` into production code from the test code. There's really no reason those improvements should stay out of production (and they're present in the source-snowflake)
2) adding `putTimestamp` into the `SourceOperation`, so that snowflake doesn't throw an exception at every call, which implies it also creates a new thread
3) make use of the newly added ability to filter orphan thread on shutdown. We filter all the threads created during calls to `SFStatement.close()`
4) don't always take a lock when deleting destinationStates. We now check if there's any states to delete by doing a `SELECT` (and not taking any table lock) before issuing the `DELETE` (the old behavior was causing test contention, and it's a bad idea in general)
5) only execute `airbyte_internal._airbyte_destination_state`
  • Loading branch information
stephane-airbyte authored May 3, 2024
1 parent f23c2e6 commit 7c0a6c5
Show file tree
Hide file tree
Showing 31 changed files with 338 additions and 228 deletions.
4 changes: 3 additions & 1 deletion airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ corresponds to that version.
### Java CDK

| Version | Date | Pull Request | Subject |
|:--------|:-----------| :--------------------------------------------------------- |:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| :------ | :--------- | :--------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| 0.31.7 | 2024-05-02 | [\#36910](https://github.com/airbytehq/airbyte/pull/36910) | changes for destination-snowflake |
| 0.31.6 | 2024-05-02 | [\#37746](https://github.com/airbytehq/airbyte/pull/37746) | debuggability improvements. |
| 0.31.5 | 2024-04-30 | [\#37758](https://github.com/airbytehq/airbyte/pull/37758) | Set debezium max retries to zero |
| 0.31.4 | 2024-04-30 | [\#37754](https://github.com/airbytehq/airbyte/pull/37754) | Add DebeziumEngine notification log |
| 0.31.3 | 2024-04-30 | [\#37726](https://github.com/airbytehq/airbyte/pull/37726) | Remove debezium retries |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import java.util.function.Consumer
import java.util.function.Function
import java.util.stream.Stream
import java.util.stream.StreamSupport
import org.slf4j.Logger
import org.slf4j.LoggerFactory

/** Database object for interacting with a JDBC connection. */
abstract class JdbcDatabase(protected val sourceOperations: JdbcCompatibleSourceOperations<*>?) :
Expand Down Expand Up @@ -211,6 +213,7 @@ abstract class JdbcDatabase(protected val sourceOperations: JdbcCompatibleSource
abstract fun <T> executeMetadataQuery(query: Function<DatabaseMetaData?, T>): T

companion object {
private val LOGGER: Logger = LoggerFactory.getLogger(JdbcDatabase::class.java)
/**
* Map records returned in a result set. It is an "unsafe" stream because the stream must be
* manually closed. Otherwise, there will be a database connection leak.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ internal constructor(
}

@JvmStatic
fun getThreadCreationInfo(thread: Thread): ThreadCreationInfo {
return getMethod.invoke(threadCreationInfo, thread) as ThreadCreationInfo
fun getThreadCreationInfo(thread: Thread): ThreadCreationInfo? {
return getMethod.invoke(threadCreationInfo, thread) as ThreadCreationInfo?
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.31.6
version=0.31.7
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ abstract class JdbcSqlOperations : SqlOperations {
}
}

fun dropTableIfExistsQuery(schemaName: String?, tableName: String?): String {
open fun dropTableIfExistsQuery(schemaName: String?, tableName: String?): String {
return String.format("DROP TABLE IF EXISTS %s.%s;\n", schemaName, tableName)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ abstract class JdbcDestinationHandler<DestinationState>(
existingTable.columns[JavaBaseConstants.COLUMN_NAME_AB_META]!!.type
}

private fun existingSchemaMatchesStreamConfig(
open protected fun existingSchemaMatchesStreamConfig(
stream: StreamConfig?,
existingTable: TableDefinition
): Boolean {
Expand Down Expand Up @@ -400,6 +400,29 @@ abstract class JdbcDestinationHandler<DestinationState>(
return actualColumns == intendedColumns
}

protected open fun getDeleteStatesSql(
destinationStates: Map<StreamId, DestinationState>
): String {
return dslContext
.deleteFrom(table(quotedName(rawTableSchemaName, DESTINATION_STATE_TABLE_NAME)))
.where(
destinationStates.keys
.stream()
.map { streamId: StreamId ->
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAME))
.eq(streamId.originalName)
.and(
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAMESPACE))
.eq(streamId.originalNamespace)
)
}
.reduce(DSL.falseCondition()) { obj: Condition, arg2: Condition? ->
obj.or(arg2)
}
)
.getSQL(ParamType.INLINED)
}

@Throws(Exception::class)
override fun commitDestinationStates(destinationStates: Map<StreamId, DestinationState>) {
try {
Expand All @@ -408,25 +431,7 @@ abstract class JdbcDestinationHandler<DestinationState>(
}

// Delete all state records where the stream name+namespace match one of our states
val deleteStates =
dslContext
.deleteFrom(table(quotedName(rawTableSchemaName, DESTINATION_STATE_TABLE_NAME)))
.where(
destinationStates.keys
.stream()
.map { streamId: StreamId ->
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAME))
.eq(streamId.originalName)
.and(
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAMESPACE))
.eq(streamId.originalNamespace)
)
}
.reduce(DSL.falseCondition()) { obj: Condition, arg2: Condition? ->
obj.or(arg2)
}
)
.getSQL(ParamType.INLINED)
var deleteStates = getDeleteStatesSql(destinationStates)

// Reinsert all of our states
var insertStatesStep =
Expand Down Expand Up @@ -461,12 +466,17 @@ abstract class JdbcDestinationHandler<DestinationState>(
}
val insertStates = insertStatesStep.getSQL(ParamType.INLINED)

jdbcDatabase.executeWithinTransaction(listOf(deleteStates, insertStates))
executeWithinTransaction(listOf(deleteStates, insertStates))
} catch (e: Exception) {
LOGGER.warn("Failed to commit destination states", e)
}
}

@Throws(Exception::class)
protected open fun executeWithinTransaction(statements: List<String>) {
jdbcDatabase.executeWithinTransaction(statements)
}

/**
* Convert to the TYPE_NAME retrieved from [java.sql.DatabaseMetaData.getColumns]
*
Expand All @@ -479,9 +489,9 @@ abstract class JdbcDestinationHandler<DestinationState>(

companion object {
private val LOGGER: Logger = LoggerFactory.getLogger(JdbcDestinationHandler::class.java)
private const val DESTINATION_STATE_TABLE_NAME = "_airbyte_destination_state"
private const val DESTINATION_STATE_TABLE_COLUMN_NAME = "name"
private const val DESTINATION_STATE_TABLE_COLUMN_NAMESPACE = "namespace"
protected const val DESTINATION_STATE_TABLE_NAME = "_airbyte_destination_state"
protected const val DESTINATION_STATE_TABLE_COLUMN_NAME = "name"
protected const val DESTINATION_STATE_TABLE_COLUMN_NAMESPACE = "namespace"
private const val DESTINATION_STATE_TABLE_COLUMN_STATE = "destination_state"
private const val DESTINATION_STATE_TABLE_COLUMN_UPDATED_AT = "updated_at"

Expand Down Expand Up @@ -542,6 +552,7 @@ abstract class JdbcDestinationHandler<DestinationState>(
return Optional.of(TableDefinition(retrievedColumnDefns))
}

@JvmStatic
fun fromIsNullableIsoString(isNullable: String?): Boolean {
return "YES".equals(isNullable, ignoreCase = true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class LocalAirbyteDestination(private val dest: Destination) : AirbyteDestinatio
return isClosed
}

override val exitValue = 0
override var exitValue = 0

override fun attemptRead(): Optional<io.airbyte.protocol.models.AirbyteMessage> {
return Optional.empty()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ interface AirbyteDestination : CheckedConsumer<AirbyteMessage, Exception>, AutoC
* @return exit code of the destination process
* @throws IllegalStateException if the destination process has not exited
*/
abstract val exitValue: Int
val exitValue: Int

/**
* Attempts to read an AirbyteMessage from the Destination.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ abstract class BaseDestinationV1V2Migrator<DialectTableDefinition> : Destination
* @return whether it exists and is in the correct format
*/
@Throws(Exception::class)
protected fun doesValidV1RawTableExist(namespace: String?, tableName: String?): Boolean {
protected open fun doesValidV1RawTableExist(namespace: String?, tableName: String?): Boolean {
val existingV1RawTable = getTableIfExists(namespace, tableName)
return existingV1RawTable.isPresent &&
doesV1RawTableMatchExpectedSchema(existingV1RawTable.get())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
* Subclasses should override this method if they need to make changes to the stream ID. For
* example, you could upcase the final table name here.
*/
protected fun buildStreamId(
open protected fun buildStreamId(
namespace: String,
finalTableName: String,
rawTableName: String
Expand Down Expand Up @@ -149,7 +149,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
/** Identical to [BaseTypingDedupingTest.getRawMetadataColumnNames]. */
get() = HashMap()

protected val finalMetadataColumnNames: Map<String, String>
open protected val finalMetadataColumnNames: Map<String, String>
/** Identical to [BaseTypingDedupingTest.getFinalMetadataColumnNames]. */
get() = HashMap()

Expand Down Expand Up @@ -728,7 +728,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
*/
@Test
@Throws(Exception::class)
fun ignoreOldRawRecords() {
open fun ignoreOldRawRecords() {
createRawTable(streamId)
createFinalTable(incrementalAppendStream, "")
insertRawTableRecords(
Expand Down Expand Up @@ -1519,7 +1519,10 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
executeSoftReset(generator, destinationHandler, incrementalAppendStream)
}

protected fun migrationAssertions(v1RawRecords: List<JsonNode>, v2RawRecords: List<JsonNode>) {
protected open fun migrationAssertions(
v1RawRecords: List<JsonNode>,
v2RawRecords: List<JsonNode>
) {
val v2RecordMap =
v2RawRecords
.stream()
Expand Down Expand Up @@ -1570,7 +1573,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
}

@Throws(Exception::class)
protected fun dumpV1RawTableRecords(streamId: StreamId): List<JsonNode> {
open protected fun dumpV1RawTableRecords(streamId: StreamId): List<JsonNode> {
return dumpRawTableRecords(streamId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins {
}

airbyteJavaConnector {
cdkVersionRequired = '0.27.7'
cdkVersionRequired = '0.31.7'
features = ['db-destinations', 's3-destinations', 'typing-deduping']
useLocalCdk = false
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# currently limit the number of parallel threads until further investigation into the issues \
# where Snowflake will fail to login using config credentials
testExecutionConcurrency=4
testExecutionConcurrency=-1
JunitMethodExecutionTimeout=15 m
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
dockerImageTag: 3.7.0
dockerImageTag: 3.7.1
dockerRepository: airbyte/destination-snowflake
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
githubIssueLabel: destination-snowflake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ private static String getAccessTokenUsingRefreshToken(final String hostName,
}

public static JdbcDatabase getDatabase(final DataSource dataSource) {
return new DefaultJdbcDatabase(dataSource);
return new DefaultJdbcDatabase(dataSource, new SnowflakeSourceOperations());
}

private static Runnable getRefreshTokenTask(final HikariDataSource dataSource) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,26 @@
import static io.airbyte.integrations.destination.snowflake.SnowflakeDestination.SCHEDULED_EXECUTOR_SERVICE;

import io.airbyte.cdk.integrations.base.AirbyteExceptionHandler;
import io.airbyte.cdk.integrations.base.IntegrationRunner;
import io.airbyte.cdk.integrations.base.adaptive.AdaptiveDestinationRunner;
import net.snowflake.client.core.SFSession;
import net.snowflake.client.core.SFStatement;
import net.snowflake.client.jdbc.SnowflakeSQLException;

public class SnowflakeDestinationRunner {

public static void main(final String[] args) throws Exception {
IntegrationRunner.addOrphanedThreadFilter((Thread t) -> {
for (StackTraceElement stackTraceElement : IntegrationRunner.getThreadCreationInfo(t).getStack()) {
String stackClassName = stackTraceElement.getClassName();
String stackMethodName = stackTraceElement.getMethodName();
if (SFStatement.class.getCanonicalName().equals(stackClassName) && "close".equals(stackMethodName) ||
SFSession.class.getCanonicalName().equals(stackClassName) && "callHeartBeatWithQueryTimeout".equals(stackMethodName)) {
return false;
}
}
return true;
});
AirbyteExceptionHandler.addThrowableForDeinterpolation(SnowflakeSQLException.class);
AdaptiveDestinationRunner.baseOnEnv()
.withOssDestination(() -> new SnowflakeDestination(OssCloudEnvVarConsts.AIRBYTE_OSS))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.airbyte.cdk.db.jdbc.JdbcUtils;
import io.airbyte.cdk.integrations.base.Destination;
import io.airbyte.cdk.integrations.base.JavaBaseConstants;
import io.airbyte.cdk.integrations.base.JavaBaseConstants.DestinationColumns;
import io.airbyte.cdk.integrations.base.SerializedAirbyteMessageConsumer;
import io.airbyte.cdk.integrations.base.TypingAndDedupingFlag;
import io.airbyte.cdk.integrations.destination.NamingConventionTransformer;
Expand Down Expand Up @@ -132,7 +133,7 @@ public JsonNode toJdbcConfig(final JsonNode config) {
}

@Override
protected JdbcSqlGenerator getSqlGenerator() {
protected JdbcSqlGenerator getSqlGenerator(final JsonNode config) {
throw new UnsupportedOperationException("Snowflake does not yet use the native JDBC DV2 interface");
}

Expand Down Expand Up @@ -209,7 +210,7 @@ public SerializedAirbyteMessageConsumer getSerializedMessageConsumer(final JsonN
typerDeduper,
parsedCatalog,
defaultNamespace,
true)
DestinationColumns.V2_WITHOUT_META)
.setBufferMemoryLimit(Optional.of(getSnowflakeBufferMemoryLimit()))
.setOptimalBatchSizeBytes(
// The per stream size limit is following recommendations from:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,26 @@
import static io.airbyte.cdk.db.jdbc.DateTimeConverter.putJavaSQLTime;

import com.fasterxml.jackson.databind.node.ObjectNode;
import io.airbyte.cdk.db.DataTypeUtils;
import io.airbyte.cdk.db.jdbc.JdbcSourceOperations;
import io.airbyte.commons.json.Jsons;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.OffsetDateTime;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeFormatterBuilder;

public class SnowflakeTestSourceOperations extends JdbcSourceOperations {
public class SnowflakeSourceOperations extends JdbcSourceOperations {

private static final DateTimeFormatter SNOWFLAKE_TIMESTAMPTZ_FORMATTER = new DateTimeFormatterBuilder()
.parseCaseInsensitive()
.append(DateTimeFormatter.ISO_LOCAL_DATE)
.appendLiteral(' ')
.append(DateTimeFormatter.ISO_LOCAL_TIME)
.optionalStart()
.appendLiteral(' ')
.append(DateTimeFormatter.ofPattern("XX"))
.toFormatter();

@Override
public void copyToJsonField(final ResultSet resultSet, final int colIndex, final ObjectNode json) throws SQLException {
Expand Down Expand Up @@ -45,4 +59,18 @@ protected void putTime(final ObjectNode node,
putJavaSQLTime(node, columnName, resultSet, index);
}

@Override
protected void putTimestampWithTimezone(final ObjectNode node, final String columnName, final ResultSet resultSet, final int index)
throws SQLException {
final String timestampAsString = resultSet.getString(index);
OffsetDateTime timestampWithOffset = OffsetDateTime.parse(timestampAsString, SNOWFLAKE_TIMESTAMPTZ_FORMATTER);
node.put(columnName, timestampWithOffset.format(DataTypeUtils.TIMESTAMPTZ_FORMATTER));
}

protected void putTimestamp(final ObjectNode node, final String columnName, final ResultSet resultSet, final int index) throws SQLException {
// for backward compatibility
var instant = resultSet.getTimestamp(index).toInstant();
node.put(columnName, DataTypeUtils.toISO8601StringWithMicroseconds(instant));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.airbyte.cdk.db.jdbc.JdbcDatabase;
import io.airbyte.cdk.integrations.base.DestinationConfig;
import io.airbyte.cdk.integrations.base.JavaBaseConstants;
import io.airbyte.cdk.integrations.destination.async.partial_messages.PartialAirbyteMessage;
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage;
import io.airbyte.cdk.integrations.destination.jdbc.JdbcSqlOperations;
import io.airbyte.cdk.integrations.destination.jdbc.SqlOperations;
import io.airbyte.cdk.integrations.destination.jdbc.SqlOperationsUtils;
Expand Down Expand Up @@ -37,10 +37,10 @@ public class SnowflakeSqlOperations extends JdbcSqlOperations implements SqlOper
@Override
public void createSchemaIfNotExists(final JdbcDatabase database, final String schemaName) throws Exception {
try {
if (!schemaSet.contains(schemaName) && !isSchemaExists(database, schemaName)) {
if (!getSchemaSet().contains(schemaName) && !isSchemaExists(database, schemaName)) {
// 1s1t is assuming a lowercase airbyte_internal schema name, so we need to quote it
database.execute(String.format("CREATE SCHEMA IF NOT EXISTS \"%s\";", schemaName));
schemaSet.add(schemaName);
getSchemaSet().add(schemaName);
}
} catch (final Exception e) {
throw checkForKnownConfigExceptions(e).orElseThrow(() -> e);
Expand Down
Loading

0 comments on commit 7c0a6c5

Please sign in to comment.