Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get rid of DatabaseDriver as an enum #35710

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ public DataSource build() {

final HikariConfig config = new HikariConfig();

config.setDriverClassName(databaseDriver.getDriverClassName());
config.setJdbcUrl(jdbcUrl != null ? jdbcUrl : String.format(databaseDriver.getUrlFormatString(), host, port, database));
config.setDriverClassName(databaseDriver.driverClassName());
config.setJdbcUrl(jdbcUrl != null ? jdbcUrl : String.format(databaseDriver.urlFormatString(), host, port, database));
config.setMaximumPoolSize(maximumPoolSize);
config.setMinimumIdle(minimumPoolSize);
// HikariCP uses milliseconds for all time values:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,41 @@

package io.airbyte.cdk.db.factory;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* Collection of JDBC driver class names and the associated JDBC URL format string.
*/
public enum DatabaseDriver {

CLICKHOUSE("com.clickhouse.jdbc.ClickHouseDriver", "jdbc:clickhouse:%s://%s:%d/%s"),
DATABRICKS("com.databricks.client.jdbc.Driver", "jdbc:databricks://%s:%s;HttpPath=%s;SSL=1;UserAgentEntry=Airbyte"),
DB2("com.ibm.db2.jcc.DB2Driver", "jdbc:db2://%s:%d/%s"),
STARBURST("io.trino.jdbc.TrinoDriver", "jdbc:trino://%s:%s/%s?SSL=true&source=airbyte"),
MARIADB("org.mariadb.jdbc.Driver", "jdbc:mariadb://%s:%d/%s"),
MSSQLSERVER("com.microsoft.sqlserver.jdbc.SQLServerDriver", "jdbc:sqlserver://%s:%d;databaseName=%s"),
MYSQL("com.mysql.cj.jdbc.Driver", "jdbc:mysql://%s:%d/%s"),
ORACLE("oracle.jdbc.OracleDriver", "jdbc:oracle:thin:@%s:%d/%s"),
VERTICA("com.vertica.jdbc.Driver", "jdbc:vertica://%s:%d/%s"),
POSTGRESQL("org.postgresql.Driver", "jdbc:postgresql://%s:%d/%s"),
REDSHIFT("com.amazon.redshift.jdbc.Driver", "jdbc:redshift://%s:%d/%s"),
SNOWFLAKE("net.snowflake.client.jdbc.SnowflakeDriver", "jdbc:snowflake://%s/"),
YUGABYTEDB("com.yugabyte.Driver", "jdbc:yugabytedb://%s:%d/%s"),
EXASOL("com.exasol.jdbc.EXADriver", "jdbc:exa:%s:%d"),
TERADATA("com.teradata.jdbc.TeraDriver", "jdbc:teradata://%s/");

private final String driverClassName;
private final String urlFormatString;

DatabaseDriver(final String driverClassName, final String urlFormatString) {
public record DatabaseDriver(String driverClassName, String urlFormatString) {
public static final DatabaseDriver CLICKHOUSE = new DatabaseDriver("com.clickhouse.jdbc.ClickHouseDriver", "jdbc:clickhouse:%s://%s:%d/%s");
public static final DatabaseDriver DATABRICKS = new DatabaseDriver ("com.databricks.client.jdbc.Driver", "jdbc:databricks://%s:%s;HttpPath=%s;SSL=1;UserAgentEntry=Airbyte");
public static final DatabaseDriver DB2 = new DatabaseDriver ("com.ibm.db2.jcc.DB2Driver", "jdbc:db2://%s:%d/%s");
public static final DatabaseDriver STARBURST = new DatabaseDriver ("io.trino.jdbc.TrinoDriver", "jdbc:trino://%s:%s/%s?SSL=true&source=airbyte");
public static final DatabaseDriver MARIADB = new DatabaseDriver ("org.mariadb.jdbc.Driver", "jdbc:mariadb://%s:%d/%s");
public static final DatabaseDriver MSSQLSERVER = new DatabaseDriver ("com.microsoft.sqlserver.jdbc.SQLServerDriver", "jdbc:sqlserver://%s:%d;databaseName=%s");
public static final DatabaseDriver MYSQL = new DatabaseDriver ("com.mysql.cj.jdbc.Driver", "jdbc:mysql://%s:%d/%s");
public static final DatabaseDriver ORACLE = new DatabaseDriver ("oracle.jdbc.OracleDriver", "jdbc:oracle:thin:@%s:%d/%s");
public static final DatabaseDriver VERTICA = new DatabaseDriver ("com.vertica.jdbc.Driver", "jdbc:vertica://%s:%d/%s");
public static final DatabaseDriver POSTGRESQL = new DatabaseDriver ("org.postgresql.Driver", "jdbc:postgresql://%s:%d/%s");
public static final DatabaseDriver REDSHIFT = new DatabaseDriver ("com.amazon.redshift.jdbc.Driver", "jdbc:redshift://%s:%d/%s");
public static final DatabaseDriver SNOWFLAKE = new DatabaseDriver ("net.snowflake.client.jdbc.SnowflakeDriver", "jdbc:snowflake://%s/");
public static final DatabaseDriver YUGABYTEDB = new DatabaseDriver ("com.yugabyte.Driver", "jdbc:yugabytedb://%s:%d/%s");
public static final DatabaseDriver EXASOL = new DatabaseDriver ("com.exasol.jdbc.EXADriver", "jdbc:exa:%s:%d");
public static final DatabaseDriver TERADATA = new DatabaseDriver ("com.teradata.jdbc.TeraDriver", "jdbc:teradata://%s/");

private static Map<String, DatabaseDriver> DRIVER_BY_CLASS_NAME= new ConcurrentHashMap<>();

public DatabaseDriver(Class<? extends java.sql.Driver> driverClass, String urlFormatString) {
this(driverClass.getCanonicalName(), urlFormatString);
}
public DatabaseDriver(String driverClassName, String urlFormatString) {
this.driverClassName = driverClassName;
this.urlFormatString = urlFormatString;
DRIVER_BY_CLASS_NAME.put(driverClassName, this);
}

public String getDriverClassName() {
return driverClassName;
public static DatabaseDriver findByDriverClassName(String driverClassName) {
return DRIVER_BY_CLASS_NAME.get(driverClassName);
}

public String getUrlFormatString() {
return urlFormatString;
}

/**
* Finds the {@link DatabaseDriver} enumerated value that matches the provided driver class name.
*
* @param driverClassName The driver class name.
* @return The matching {@link DatabaseDriver} enumerated value or {@code null} if no match is
* found.
*/
public static DatabaseDriver findByDriverClassName(final String driverClassName) {
DatabaseDriver selected = null;

for (final DatabaseDriver candidate : values()) {
if (candidate.getDriverClassName().equalsIgnoreCase(driverClassName)) {
selected = candidate;
break;
}
}

return selected;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

package io.airbyte.cdk.integrations;

import static io.airbyte.cdk.db.factory.DatabaseDriver.MSSQLSERVER;
import static io.airbyte.cdk.db.factory.DatabaseDriver.MYSQL;
import static io.airbyte.cdk.db.factory.DatabaseDriver.POSTGRESQL;

import io.airbyte.cdk.db.factory.DatabaseDriver;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
Expand Down Expand Up @@ -44,12 +48,19 @@ protected Duration getConnectionTimeout(final Map<String, String> connectionProp
* @return DataSourceBuilder class used to create dynamic fields for DataSource
*/
public static Duration getConnectionTimeout(final Map<String, String> connectionProperties, String driverClassName) {
final Optional<Duration> parsedConnectionTimeout = switch (DatabaseDriver.findByDriverClassName(driverClassName)) {
case POSTGRESQL -> maybeParseDuration(connectionProperties.get(POSTGRES_CONNECT_TIMEOUT_KEY), ChronoUnit.SECONDS)
final Optional<Duration> parsedConnectionTimeout;
if (driverClassName.equals(POSTGRESQL.driverClassName())) {
parsedConnectionTimeout = maybeParseDuration(
connectionProperties.get(POSTGRES_CONNECT_TIMEOUT_KEY), ChronoUnit.SECONDS)
.or(() -> Optional.of(POSTGRES_CONNECT_TIMEOUT_DEFAULT_DURATION));
case MYSQL -> maybeParseDuration(connectionProperties.get("connectTimeout"), ChronoUnit.MILLIS);
case MSSQLSERVER -> maybeParseDuration(connectionProperties.get("loginTimeout"), ChronoUnit.SECONDS);
default -> maybeParseDuration(connectionProperties.get(CONNECT_TIMEOUT_KEY), ChronoUnit.SECONDS)
} else if (driverClassName.equals(MYSQL.driverClassName())) {
parsedConnectionTimeout = maybeParseDuration(connectionProperties.get("connectTimeout"),
ChronoUnit.MILLIS);
} else if (driverClassName.equals(MSSQLSERVER.driverClassName())) {
parsedConnectionTimeout = maybeParseDuration(connectionProperties.get("loginTimeout"),
ChronoUnit.SECONDS);
} else {
parsedConnectionTimeout = maybeParseDuration(connectionProperties.get(CONNECT_TIMEOUT_KEY), ChronoUnit.SECONDS)
// Enforce minimum timeout duration for unspecified data sources.
.filter(d -> d.compareTo(CONNECT_TIMEOUT_DEFAULT) >= 0);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ private DataSource getDataSourceFromConfig(final JsonNode config) {
return DataSourceFactory.create(
config.get(JdbcUtils.USERNAME_KEY).asText(),
config.get(JdbcUtils.PASSWORD_KEY).asText(),
DatabaseDriver.POSTGRESQL.getDriverClassName(),
String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
DatabaseDriver.POSTGRESQL.driverClassName(),
String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ void setup() throws Exception {
dataSource = DataSourceFactory.create(
config.get(JdbcUtils.USERNAME_KEY).asText(),
config.get(JdbcUtils.PASSWORD_KEY).asText(),
DatabaseDriver.POSTGRESQL.getDriverClassName(),
String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
DatabaseDriver.POSTGRESQL.driverClassName(),
String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ void setup() {
final DataSource connectionPool = DataSourceFactory.create(
config.get(JdbcUtils.USERNAME_KEY).asText(),
config.get(JdbcUtils.PASSWORD_KEY).asText(),
DatabaseDriver.POSTGRESQL.getDriverClassName(),
String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
DatabaseDriver.POSTGRESQL.driverClassName(),
String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ final public T initialized() {
this.dataSource = DataSourceFactory.create(
getUserName(),
getPassword(),
getDatabaseDriver().getDriverClassName(),
getDatabaseDriver().driverClassName(),
getJdbcUrl(),
connectionProperties,
JdbcConnector.getConnectionTimeout(connectionProperties, getDatabaseDriver().getDriverClassName()));
JdbcConnector.getConnectionTimeout(connectionProperties, getDatabaseDriver().driverClassName()));
this.dslContext = DSLContextFactory.create(dataSource, getSqlDialect());
return self();
}
Expand Down Expand Up @@ -179,7 +179,7 @@ final public DSLContext getDslContext() {

public String getJdbcUrl() {
return String.format(
getDatabaseDriver().getUrlFormatString(),
getDatabaseDriver().urlFormatString(),
getContainer().getHost(),
getContainer().getFirstMappedPort(),
getDatabaseName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void setup() throws Exception {
dataSource = DataSourceFactory.create(
config.get(JdbcUtils.USERNAME_KEY).asText(),
config.get(JdbcUtils.PASSWORD_KEY).asText(),
DatabaseDriver.POSTGRESQL.getDriverClassName(),
String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
DatabaseDriver.POSTGRESQL.driverClassName(),
String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class JdbcSource extends AbstractJdbcSource<JDBCType> implements Source {
private static final Logger LOGGER = LoggerFactory.getLogger(JdbcSource.class);

public JdbcSource() {
super(DatabaseDriver.POSTGRESQL.getDriverClassName(), AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
super(DatabaseDriver.POSTGRESQL.driverClassName(), AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
}

// no-op for JdbcSource since the config it receives is designed to be use for JDBC.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public static class PostgresTestSource extends AbstractJdbcSource<JDBCType> impl

private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class);

static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.getDriverClassName();
static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName();

public PostgresTestSource() {
super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
Expand All @@ -101,7 +101,7 @@ public PostgresTestSource() {
public JsonNode toDatabaseConfig(final JsonNode config) {
final ImmutableMap.Builder<Object, Object> configBuilder = ImmutableMap.builder()
.put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText())
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private static class PostgresTestSource extends AbstractJdbcSource<JDBCType> imp

private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class);

static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.getDriverClassName();
static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName();

public PostgresTestSource() {
super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
Expand All @@ -110,7 +110,7 @@ public PostgresTestSource() {
public JsonNode toDatabaseConfig(final JsonNode config) {
final ImmutableMap.Builder<Object, Object> configBuilder = ImmutableMap.builder()
.put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText())
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private static class PostgresTestSource extends AbstractJdbcSource<JDBCType> imp

private static final Logger LOGGER = LoggerFactory.getLogger(PostgresTestSource.class);

static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.getDriverClassName();
static final String DRIVER_CLASS = DatabaseDriver.POSTGRESQL.driverClassName();

public PostgresTestSource() {
super(DRIVER_CLASS, AdaptiveStreamingQueryConfig::new, JdbcUtils.getDefaultSourceOperations());
Expand All @@ -104,7 +104,7 @@ public PostgresTestSource() {
public JsonNode toDatabaseConfig(final JsonNode config) {
final ImmutableMap.Builder<Object, Object> configBuilder = ImmutableMap.builder()
.put(JdbcUtils.USERNAME_KEY, config.get(JdbcUtils.USERNAME_KEY).asText())
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.getUrlFormatString(),
.put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString(),
config.get(JdbcUtils.HOST_KEY).asText(),
config.get(JdbcUtils.PORT_KEY).asInt(),
config.get(JdbcUtils.DATABASE_KEY).asText()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

package io.airbyte.cdk.integrations.source.jdbc.test;

import static io.airbyte.cdk.db.factory.DatabaseDriver.CLICKHOUSE;
import static io.airbyte.cdk.db.factory.DatabaseDriver.MYSQL;
import static io.airbyte.cdk.db.factory.DatabaseDriver.ORACLE;
import static io.airbyte.cdk.db.factory.DatabaseDriver.SNOWFLAKE;
import static io.airbyte.cdk.db.factory.DatabaseDriver.TERADATA;
import static io.airbyte.cdk.integrations.source.relationaldb.RelationalDbQueryUtils.enquoteIdentifier;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -290,9 +295,8 @@ protected AirbyteCatalog filterOutOtherSchemas(final AirbyteCatalog catalog) {
@Test
protected void testDiscoverWithMultipleSchemas() throws Exception {
// clickhouse and mysql do not have a concept of schemas, so this test does not make sense for them.
switch (testdb.getDatabaseDriver()) {
case MYSQL, CLICKHOUSE, TERADATA:
return;
if (testdb.getDatabaseDriver() == MYSQL || testdb.getDatabaseDriver() == CLICKHOUSE || testdb.getDatabaseDriver() == TERADATA) {
return;
}

// add table and data to a separate schema.
Expand Down Expand Up @@ -750,7 +754,7 @@ public void testIncrementalWithConcurrentInsertion() throws Exception {
.map(r -> r.getRecord().getData().get(COL_NAME).asText())
.toList();
// some databases don't make insertion order guarantee when equal ordering value
if (testdb.getDatabaseDriver().equals(DatabaseDriver.TERADATA) || testdb.getDatabaseDriver().equals(DatabaseDriver.ORACLE)) {
if (testdb.getDatabaseDriver().equals(TERADATA) || testdb.getDatabaseDriver().equals(DatabaseDriver.ORACLE)) {
assertThat(List.of("a", "b"), Matchers.containsInAnyOrder(firstSyncNames.toArray()));
} else {
assertEquals(List.of("a", "b"), firstSyncNames);
Expand Down Expand Up @@ -802,7 +806,7 @@ public void testIncrementalWithConcurrentInsertion() throws Exception {
.toList();

// teradata doesn't make insertion order guarantee when equal ordering value
if (testdb.getDatabaseDriver().equals(DatabaseDriver.TERADATA)) {
if (testdb.getDatabaseDriver().equals(TERADATA)) {
assertThat(List.of("c", "d", "e", "f"), Matchers.containsInAnyOrder(thirdSyncExpectedNames.toArray()));
} else {
assertEquals(List.of("c", "d", "e", "f"), thirdSyncExpectedNames);
Expand Down Expand Up @@ -1009,22 +1013,22 @@ protected void createSchemas() {
}

private JsonNode convertIdBasedOnDatabase(final int idValue) {
return switch (testdb.getDatabaseDriver()) {
case ORACLE, SNOWFLAKE -> Jsons.jsonNode(BigDecimal.valueOf(idValue));
default -> Jsons.jsonNode(idValue);
};
if (testdb.getDatabaseDriver() == ORACLE || testdb.getDatabaseDriver() == SNOWFLAKE) {
return Jsons.jsonNode(BigDecimal.valueOf(idValue));
}
return Jsons.jsonNode(idValue);
}

private String getDefaultSchemaName() {
return supportsSchemas() ? SCHEMA_NAME : null;
}

protected String getDefaultNamespace() {
return switch (testdb.getDatabaseDriver()) {
if (testdb.getDatabaseDriver() == MYSQL || testdb.getDatabaseDriver() == CLICKHOUSE || testdb.getDatabaseDriver() == TERADATA) {
// mysql does not support schemas, it namespaces using database names instead.
case MYSQL, CLICKHOUSE, TERADATA -> testdb.getDatabaseName();
default -> SCHEMA_NAME;
};
return testdb.getDatabaseName();
}
return SCHEMA_NAME;
}

protected static void setEmittedAtToNull(final Iterable<AirbyteMessage> messages) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public class MssqlSource extends AbstractJdbcSource<JDBCType> implements Source
"""
SELECT CASE WHEN (SELECT TOP 1 1 FROM "%s"."%s" WHERE "%s" IS NULL)=1 then 1 else 0 end as %s
""";
public static final String DRIVER_CLASS = DatabaseDriver.MSSQLSERVER.getDriverClassName();
public static final String DRIVER_CLASS = DatabaseDriver.MSSQLSERVER.driverClassName();
public static final String MSSQL_CDC_OFFSET = "mssql_cdc_offset";
public static final String MSSQL_DB_HISTORY = "mssql_db_history";
public static final String IS_COMPRESSED = "is_compressed";
Expand Down
Loading
Loading