Skip to content

Commit

Permalink
get rid of DatabaseDriver as an enum
Browse files Browse the repository at this point in the history
  • Loading branch information
stephane-airbyte committed Feb 29, 2024
1 parent 96865c4 commit 430bbf7
Show file tree
Hide file tree
Showing 22 changed files with 95 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ public DataSource build() {

final HikariConfig config = new HikariConfig();

config.setDriverClassName(databaseDriver.getDriverClassName());
config.setJdbcUrl(jdbcUrl != null ? jdbcUrl : String.format(databaseDriver.getUrlFormatString(), host, port, database));
config.setDriverClassName(databaseDriver.driverClassName());
config.setJdbcUrl(jdbcUrl != null ? jdbcUrl : String.format(databaseDriver.urlFormatString(), host, port, database));
config.setMaximumPoolSize(maximumPoolSize);
config.setMinimumIdle(minimumPoolSize);
// HikariCP uses milliseconds for all time values:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,41 @@

package io.airbyte.cdk.db.factory;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* Collection of JDBC driver class names and the associated JDBC URL format string.
*/
public enum DatabaseDriver {

CLICKHOUSE("com.clickhouse.jdbc.ClickHouseDriver", "jdbc:clickhouse:%s://%s:%d/%s"),
DATABRICKS("com.databricks.client.jdbc.Driver", "jdbc:databricks://%s:%s;HttpPath=%s;SSL=1;UserAgentEntry=Airbyte"),
DB2("com.ibm.db2.jcc.DB2Driver", "jdbc:db2://%s:%d/%s"),
STARBURST("io.trino.jdbc.TrinoDriver", "jdbc:trino://%s:%s/%s?SSL=true&source=airbyte"),
MARIADB("org.mariadb.jdbc.Driver", "jdbc:mariadb://%s:%d/%s"),
MSSQLSERVER("com.microsoft.sqlserver.jdbc.SQLServerDriver", "jdbc:sqlserver://%s:%d;databaseName=%s"),
MYSQL("com.mysql.cj.jdbc.Driver", "jdbc:mysql://%s:%d/%s"),
ORACLE("oracle.jdbc.OracleDriver", "jdbc:oracle:thin:@%s:%d/%s"),
VERTICA("com.vertica.jdbc.Driver", "jdbc:vertica://%s:%d/%s"),
POSTGRESQL("org.postgresql.Driver", "jdbc:postgresql://%s:%d/%s"),
REDSHIFT("com.amazon.redshift.jdbc.Driver", "jdbc:redshift://%s:%d/%s"),
SNOWFLAKE("net.snowflake.client.jdbc.SnowflakeDriver", "jdbc:snowflake://%s/"),
YUGABYTEDB("com.yugabyte.Driver", "jdbc:yugabytedb://%s:%d/%s"),
EXASOL("com.exasol.jdbc.EXADriver", "jdbc:exa:%s:%d"),
TERADATA("com.teradata.jdbc.TeraDriver", "jdbc:teradata://%s/");

private final String driverClassName;
private final String urlFormatString;

DatabaseDriver(final String driverClassName, final String urlFormatString) {
public record DatabaseDriver(String driverClassName, String urlFormatString) {
public static final DatabaseDriver CLICKHOUSE = new DatabaseDriver("com.clickhouse.jdbc.ClickHouseDriver", "jdbc:clickhouse:%s://%s:%d/%s");
public static final DatabaseDriver DATABRICKS = new DatabaseDriver ("com.databricks.client.jdbc.Driver", "jdbc:databricks://%s:%s;HttpPath=%s;SSL=1;UserAgentEntry=Airbyte");
public static final DatabaseDriver DB2 = new DatabaseDriver ("com.ibm.db2.jcc.DB2Driver", "jdbc:db2://%s:%d/%s");
public static final DatabaseDriver STARBURST = new DatabaseDriver ("io.trino.jdbc.TrinoDriver", "jdbc:trino://%s:%s/%s?SSL=true&source=airbyte");
public static final DatabaseDriver MARIADB = new DatabaseDriver ("org.mariadb.jdbc.Driver", "jdbc:mariadb://%s:%d/%s");
public static final DatabaseDriver MSSQLSERVER = new DatabaseDriver ("com.microsoft.sqlserver.jdbc.SQLServerDriver", "jdbc:sqlserver://%s:%d;databaseName=%s");
public static final DatabaseDriver MYSQL = new DatabaseDriver ("com.mysql.cj.jdbc.Driver", "jdbc:mysql://%s:%d/%s");
public static final DatabaseDriver ORACLE = new DatabaseDriver ("oracle.jdbc.OracleDriver", "jdbc:oracle:thin:@%s:%d/%s");
public static final DatabaseDriver VERTICA = new DatabaseDriver ("com.vertica.jdbc.Driver", "jdbc:vertica://%s:%d/%s");
public static final DatabaseDriver POSTGRESQL = new DatabaseDriver ("org.postgresql.Driver", "jdbc:postgresql://%s:%d/%s");
public static final DatabaseDriver REDSHIFT = new DatabaseDriver ("com.amazon.redshift.jdbc.Driver", "jdbc:redshift://%s:%d/%s");
public static final DatabaseDriver SNOWFLAKE = new DatabaseDriver ("net.snowflake.client.jdbc.SnowflakeDriver", "jdbc:snowflake://%s/");
public static final DatabaseDriver YUGABYTEDB = new DatabaseDriver ("com.yugabyte.Driver", "jdbc:yugabytedb://%s:%d/%s");
public static final DatabaseDriver EXASOL = new DatabaseDriver ("com.exasol.jdbc.EXADriver", "jdbc:exa:%s:%d");
public static final DatabaseDriver TERADATA = new DatabaseDriver ("com.teradata.jdbc.TeraDriver", "jdbc:teradata://%s/");

private static Map<String, DatabaseDriver> DRIVER_BY_CLASS_NAME= new ConcurrentHashMap<>();

public DatabaseDriver(Class<? extends java.sql.Driver> driverClass, String urlFormatString) {
this(driverClass.getCanonicalName(), urlFormatString);
}
public DatabaseDriver(String driverClassName, String urlFormatString) {
this.driverClassName = driverClassName;
this.urlFormatString = urlFormatString;
DRIVER_BY_CLASS_NAME.put(driverClassName, this);
}

public String getDriverClassName() {
return driverClassName;
public static DatabaseDriver findByDriverClassName(String driverClassName) {
return DRIVER_BY_CLASS_NAME.get(driverClassName);
}

public String getUrlFormatString() {
return urlFormatString;
}

/**
* Finds the {@link DatabaseDriver} enumerated value that matches the provided driver class name.
*
* @param driverClassName The driver class name.
* @return The matching {@link DatabaseDriver} enumerated value or {@code null} if no match is
* found.
*/
public static DatabaseDriver findByDriverClassName(final String driverClassName) {
DatabaseDriver selected = null;

for (final DatabaseDriver candidate : values()) {
if (candidate.getDriverClassName().equalsIgnoreCase(driverClassName)) {
selected = candidate;
break;
}
}

return selected;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

package io.airbyte.cdk.integrations;

import static io.airbyte.cdk.db.factory.DatabaseDriver.MSSQLSERVER;
import static io.airbyte.cdk.db.factory.DatabaseDriver.MYSQL;
import static io.airbyte.cdk.db.factory.DatabaseDriver.POSTGRESQL;

import io.airbyte.cdk.db.factory.DatabaseDriver;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
Expand Down Expand Up @@ -44,12 +48,19 @@ protected Duration getConnectionTimeout(final Map<String, String> connectionProp
* @return DataSourceBuilder class used to create dynamic fields for DataSource
*/
public static Duration getConnectionTimeout(final Map<String, String> connectionProperties, String driverClassName) {
final Optional<Duration> parsedConnectionTimeout = switch (DatabaseDriver.findByDriverClassName(driverClassName)) {
case POSTGRESQL -> maybeParseDuration(connectionProperties.get(POSTGRES_CONNECT_TIMEOUT_KEY), ChronoUnit.SECONDS)
final Optional<Duration> parsedConnectionTimeout;
if (driverClassName.equals(POSTGRESQL.driverClassName())) {
parsedConnectionTimeout = maybeParseDuration(
connectionProperties.get(POSTGRES_CONNECT_TIMEOUT_KEY), ChronoUnit.SECONDS)
.or(() -> Optional.of(POSTGRES_CONNECT_TIMEOUT_DEFAULT_DURATION));
case MYSQL -> maybeParseDuration(connectionProperties.get("connectTimeout"), ChronoUnit.MILLIS);
case MSSQLSERVER -> maybeParseDuration(connectionProperties.get("loginTimeout"), ChronoUnit.SECONDS);
default -> maybeParseDuration(connectionProperties.get(CONNECT_TIMEOUT_KEY), ChronoUnit.SECONDS)
} else if (driverClassName.equals(MYSQL.driverClassName())) {
parsedConnectionTimeout = maybeParseDuration(connectionProperties.get("connectTimeout"),
ChronoUnit.MILLIS);
} else if (driverClassName.equals(MSSQLSERVER.driverClassName())) {
parsedConnectionTimeout = maybeParseDuration(connectionProperties.get("loginTimeout"),
ChronoUnit.SECONDS);
} else {
parsedConnectionTimeout = maybeParseDuration(connectionProperties.get(CONNECT_TIMEOUT_KEY), ChronoUnit.SECONDS)
// Enforce minimum timeout duration for unspecified data sources.
.filter(d -> d.compareTo(CONNECT_TIMEOUT_DEFAULT) >= 0);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ private DataSource getDataSourceFromConfig(final JsonNode config) {
return DataSourceFactory.create(
config.get(JdbcUtils.USERNAME_KEY).asText(),
config.get(JdbcUtils.PASSWORD_KEY).asText(),
DatabaseDriver.POSTGRESQL.getDriverClassName(),
String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
DatabaseDriver.POSTGRESQL.driverClassName(),
String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ void setup() throws Exception {
dataSource = DataSourceFactory.create(
config.get(JdbcUtils.USERNAME_KEY).asText(),
config.get(JdbcUtils.PASSWORD_KEY).asText(),
DatabaseDriver.POSTGRESQL.getDriverClassName(),
String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
DatabaseDriver.POSTGRESQL.driverClassName(),
String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ void setup() {
final DataSource connectionPool = DataSourceFactory.create(
config.get(JdbcUtils.USERNAME_KEY).asText(),
config.get(JdbcUtils.PASSWORD_KEY).asText(),
DatabaseDriver.POSTGRESQL.getDriverClassName(),
String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
DatabaseDriver.POSTGRESQL.driverClassName(),
String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ final public T initialized() {
this.dataSource = DataSourceFactory.create(
getUserName(),
getPassword(),
getDatabaseDriver().getDriverClassName(),
getDatabaseDriver().driverClassName(),
getJdbcUrl(),
connectionProperties,
JdbcConnector.getConnectionTimeout(connectionProperties, getDatabaseDriver().getDriverClassName()));
JdbcConnector.getConnectionTimeout(connectionProperties, getDatabaseDriver().driverClassName()));
this.dslContext = DSLContextFactory.create(dataSource, getSqlDialect());
return self();
}
Expand Down Expand Up @@ -179,7 +179,7 @@ final public DSLContext getDslContext() {

public String getJdbcUrl() {
return String.format(
getDatabaseDriver().getUrlFormatString(),
getDatabaseDriver().urlFormatString(),
getContainer().getHost(),
getContainer().getFirstMappedPort(),
getDatabaseName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void setup() throws Exception {
dataSource = DataSourceFactory.create(
config.get(JdbcUtils.USERNAME_KEY).asText(),
config.get(JdbcUtils.PASSWORD_KEY).asText(),
DatabaseDriver.POSTGRESQL.getDriverClassName(),
String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
DatabaseDriver.POSTGRESQL.driverClassName(),
String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class JdbcSource extends AbstractJdbcSource<JDBCType> implements Source {
private static final Logger LOGGER = LoggerFactory.getLogger(JdbcSource.class);

public JdbcSource() {
super(DatabaseDriver.POSTGRESQL.getDriverClassName(), AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
super(DatabaseDriver.POSTGRESQL.driverClassName(), AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
}

// no-op for JdbcSource since the config it receives is designed to be use for JDBC.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public static class PostgresTestSource extends AbstractJdbcSource<JDBCType> impl

private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class);

static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.getDriverClassName();
static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName();

public PostgresTestSource() {
super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
Expand All @@ -101,7 +101,7 @@ public PostgresTestSource() {
public JsonNode toDatabaseConfig(final JsonNode config) {
final ImmutableMap.Builder<Object, Object> configBuilder = ImmutableMap.builder()
.put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText())
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private static class PostgresTestSource extends AbstractJdbcSource<JDBCType> imp

private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class);

static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.getDriverClassName();
static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName();

public PostgresTestSource() {
super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
Expand All @@ -110,7 +110,7 @@ public PostgresTestSource() {
public JsonNode toDatabaseConfig(final JsonNode config) {
final ImmutableMap.Builder<Object, Object> configBuilder = ImmutableMap.builder()
.put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText())
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private static class PostgresTestSource extends AbstractJdbcSource<JDBCType> imp

private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class);

static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.getDriverClassName();
static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName();

public PostgresTestSource() {
super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
Expand All @@ -104,7 +104,7 @@ public PostgresTestSource() {
public JsonNode toDatabaseConfig(final JsonNode config) {
final ImmutableMap.Builder<Object, Object> configBuilder = ImmutableMap.builder()
.put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText())
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

package io.airbyte.cdk.integrations.source.jdbc.test;

import static io.airbyte.cdk.db.factory.DatabaseDriver.CLICKHOUSE;
import static io.airbyte.cdk.db.factory.DatabaseDriver.MYSQL;
import static io.airbyte.cdk.db.factory.DatabaseDriver.ORACLE;
import static io.airbyte.cdk.db.factory.DatabaseDriver.SNOWFLAKE;
import static io.airbyte.cdk.db.factory.DatabaseDriver.TERADATA;
import static io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils.enquoteIdentifier;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -290,9 +295,8 @@ protected AirbyteCatalog filterOutOtherSchemas(final AirbyteCatalog catalog) {
@Test
protected void testDiscoverWithMultipleSchemas() throws Exception {
// clickhouse and mysql do not have a concept of schemas, so this test does not make sense for them.
switch (testdb.getDatabaseDriver()) {
case MYSQL, CLICKHOUSE, TERADATA:
return;
if (testdb.getDatabaseDriver() == MYSQL || testdb.getDatabaseDriver() == CLICKHOUSE || testdb.getDatabaseDriver() == TERADATA) {
return;
}

// add table and data to a separate schema.
Expand Down Expand Up @@ -750,7 +754,7 @@ public void testIncrementalWithConcurrentInsertion() throws Exception {
.map(r -> r.getRecord().getData().get(COL_NAME).asText())
.toList();
// some databases don't make insertion order guarantee when equal ordering value
if (testdb.getDatabaseDriver().equals(DatabaseDriver.TERADATA) || testdb.getDatabaseDriver().equals(DatabaseDriver.ORACLE)) {
if (testdb.getDatabaseDriver().equals(TERADATA) || testdb.getDatabaseDriver().equals(DatabaseDriver.ORACLE)) {
assertThat(List.of("a", "b"), Matchers.containsInAnyOrder(firstSyncNames.toArray()));
} else {
assertEquals(List.of("a", "b"), firstSyncNames);
Expand Down Expand Up @@ -802,7 +806,7 @@ public void testIncrementalWithConcurrentInsertion() throws Exception {
.toList();

// teradata doesn't make insertion order guarantee when equal ordering value
if (testdb.getDatabaseDriver().equals(DatabaseDriver.TERADATA)) {
if (testdb.getDatabaseDriver().equals(TERADATA)) {
assertThat(List.of("c", "d", "e", "f"), Matchers.containsInAnyOrder(thirdSyncExpectedNames.toArray()));
} else {
assertEquals(List.of("c", "d", "e", "f"), thirdSyncExpectedNames);
Expand Down Expand Up @@ -1009,22 +1013,22 @@ protected void createSchemas() {
}

private JsonNode convertIdBasedOnDatabase(final int idValue) {
return switch (testdb.getDatabaseDriver()) {
case ORACLE, SNOWFLAKE -> Jsons.jsonNode(BigDecimal.valueOf(idValue));
default -> Jsons.jsonNode(idValue);
};
if (testdb.getDatabaseDriver() == ORACLE || testdb.getDatabaseDriver() == SNOWFLAKE) {
return Jsons.jsonNode(BigDecimal.valueOf(idValue));
}
return Jsons.jsonNode(idValue);
}

private String getDefaultSchemaName() {
return supportsSchemas() ? SCHEMA_NAME : null;
}

protected String getDefaultNamespace() {
return switch (testdb.getDatabaseDriver()) {
if (testdb.getDatabaseDriver() == MYSQL || testdb.getDatabaseDriver() == CLICKHOUSE || testdb.getDatabaseDriver() == TERADATA) {
// mysql does not support schemas, it namespaces using database names instead.
case MYSQL, CLICKHOUSE, TERADATA -> testdb.getDatabaseName();
default -> SCHEMA_NAME;
};
return testdb.getDatabaseName();
}
return SCHEMA_NAME;
}

protected static void setEmittedAtToNull(final Iterable<AirbyteMessage> messages) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public class MssqlSource extends AbstractJdbcSource<JDBCType> implements Source
"""
SELECT CASE WHEN (SELECT TOP 1 1 FROM "%s"."%s" WHERE "%s" IS NULL)=1 then 1 else 0 end as %s
""";
public static final String DRIVER_CLASS = DatabaseDriver.MSSQLSERVER.getDriverClassName();
public static final String DRIVER_CLASS = DatabaseDriver.MSSQLSERVER.driverClassName();
public static final String MSSQL_CDC_OFFSET = "mssql_cdc_offset";
public static final String MSSQL_DB_HISTORY = "mssql_db_history";
public static final String IS_COMPRESSED = "is_compressed";
Expand Down
Loading

0 comments on commit 430bbf7

Please sign in to comment.