Skip to content

Commit

Permalink
Fix issues with dbschema not being handled for complex types (#368)
Browse files Browse the repository at this point in the history
* Added handling dbschema

* Added unit test for schema case
  • Loading branch information
Aryex authored Apr 18, 2022
1 parent dc95d50 commit 2fd2d23
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ final case class TableName(name: String, dbschema: Option[String]) extends Table
}
}

def getTableName : String = EscapeUtils.sqlEscapeAndQuote(name)

def getDbSchema : String = {
dbschema match {
case None => ""
case Some(schema) => EscapeUtils.sqlEscapeAndQuote(schema)
}
}

/**
* The table's name is used as an identifier for the operation.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,23 @@ class SchemaTools extends SchemaToolsInterface {
}
}

case class ColumnInfoQueryData(tableName: String, dbSchema: String, emptyQuery: String)
protected def getColumnInfoQueryData(tableSource: TableSource): ColumnInfoQueryData = tableSource match {
case tb: TableName =>
ColumnInfoQueryData(
tb.getTableName.replace("\"",""),
tb.getDbSchema.replace("\"",""),
// Query for an empty result set from Vertica.
// This is simply so we can load the metadata of the result set
// and use this to retrieve the name and type information of each column
"SELECT * FROM " + tb.getFullTableName + " WHERE 1=0")
case TableQuery(query, _) =>
ColumnInfoQueryData("", "" , "SELECT * FROM (" + query + ") AS x WHERE 1=0")
}

def getColumnInfo(jdbcLayer: JdbcLayerInterface, tableSource: TableSource): ConnectorResult[Seq[ColumnDef]] = {
// Query for an empty result set from Vertica.
// This is simply so we can load the metadata of the result set
// and use this to retrieve the name and type information of each column
val (tableName, query) = tableSource match {
case tb: TableName =>
(tb.getFullTableName, "SELECT * FROM " + tb.getFullTableName + " WHERE 1=0")
case TableQuery(query, _) =>
("", "SELECT * FROM (" + query + ") AS x WHERE 1=0")
}
jdbcLayer.query(query) match {
val tableInfo = getColumnInfoQueryData(tableSource)
jdbcLayer.query(tableInfo.emptyQuery) match {
case Left(err) => Left(JdbcSchemaError(err))
case Right(rs) =>
try {
Expand All @@ -291,7 +297,7 @@ class SchemaTools extends SchemaToolsInterface {
val metadata = new MetadataBuilder().putString(MetadataKey.NAME, columnLabel).build()
val colType = rsmd.getColumnType(idx)
val colDef = ColumnDef(columnLabel, colType, typeName, fieldSize, fieldScale, isSigned, nullable, metadata)
checkForComplexType(colDef, tableName, jdbcLayer)
checkForComplexType(colDef, tableInfo.tableName, tableInfo.dbSchema, jdbcLayer)
}).toList
colDefsOrErrors
.traverse(_.leftMap(err => NonEmptyList.one(err)).toValidated).toEither
Expand All @@ -308,10 +314,10 @@ class SchemaTools extends SchemaToolsInterface {
}
}

private def checkForComplexType(colDef: ColumnDef, tableName: String, jdbcLayer: JdbcLayerInterface): ConnectorResult[ColumnDef] = {
private def checkForComplexType(colDef: ColumnDef, tableName: String, dbSchema: String, jdbcLayer: JdbcLayerInterface): ConnectorResult[ColumnDef] = {
colDef.colType match {
case java.sql.Types.ARRAY |
java.sql.Types.STRUCT => queryColumnDef(colDef, tableName, jdbcLayer)
java.sql.Types.STRUCT => queryColumnDef(colDef, tableName, dbSchema, jdbcLayer)
case _ => Right(colDef)
}
}
Expand All @@ -321,11 +327,9 @@ class SchemaTools extends SchemaToolsInterface {
* Vertica systems tables. This function takes a ColumnDef of a complex type and injects it corresponding element
* ColumnDefs through a series of JDBC queries to Vertica system tables.
* */
private def queryColumnDef(complexTypeColDef: ColumnDef, tableName: String, jdbcLayer: JdbcLayerInterface): ConnectorResult[ColumnDef] = {
val table = tableName.replace("\"", "")
val queryColType = s"SELECT data_type_id, data_type FROM columns WHERE table_name='$table' AND column_name='${complexTypeColDef.label}'"
// We first query from column table for column's Vertica type. Note that data_type_id is Vertica's internal type id, not JDBC.
JdbcUtils.queryAndNext(queryColType, jdbcLayer, (rs) => {
private def queryColumnDef(complexTypeColDef: ColumnDef, tableName: String, dbSchema: String, jdbcLayer: JdbcLayerInterface): ConnectorResult[ColumnDef] = {
def handleColumnExist(rs: ResultSet): ConnectorResult[ColumnDef] = {
// Note that data_type_id is Vertica's internal type id, not JDBC.
val verticaType = rs.getLong("data_type_id")
val typeName = getTypeName(rs.getString("data_type"))
complexTypeColDef.colType match {
Expand All @@ -334,7 +338,12 @@ class SchemaTools extends SchemaToolsInterface {
case java.sql.Types.STRUCT => Right(complexTypeColDef)
case _ => Left(MissingSqlConversionError(complexTypeColDef.colType.toString, typeName))
}
})
}
// We query from Vertica for the column's Vertica type.
val colName = complexTypeColDef.label
val schemaCond = if(dbSchema.nonEmpty) s" AND table_schema='$dbSchema'" else ""
val queryColType = s"SELECT data_type_id, data_type FROM columns WHERE table_name='$tableName'$schemaCond AND column_name='$colName'"
JdbcUtils.queryAndNext(queryColType, jdbcLayer, handleColumnExist)
}

/**
Expand Down Expand Up @@ -747,19 +756,14 @@ class SchemaTools extends SchemaToolsInterface {
class SchemaToolsV10() extends SchemaTools {

override def getColumnInfo(jdbcLayer: JdbcLayerInterface, tableSource: TableSource): ConnectorResult[Seq[ColumnDef]] = {
val tableName = tableSource match {
case tableName: TableName => tableName.getFullTableName.replaceAll("\"","")
case _ => ""
}

// return super.getColumnInfo(jdbcLayer, tableSource)
val tableInfo = getColumnInfoQueryData(tableSource)

super.getColumnInfo(jdbcLayer, tableSource) match {
case Left(err) => Left(err)
case Right(colList) =>
colList.map(col => col.colType match {
case java.sql.Types.VARCHAR =>
checkV10ComplexType(col, tableName, jdbcLayer)
checkV10ComplexType(col, tableInfo.tableName, tableInfo.dbSchema, jdbcLayer)
case _ => Right(col)
}).toList
.traverse(_.leftMap(err => NonEmptyList.one(err)).toValidated).toEither
Expand All @@ -775,7 +779,7 @@ class SchemaToolsV10() extends SchemaTools {
*
* If column is not of complex type in Vertica, then return the ColumnDef as is.
* */
private def checkV10ComplexType(colDef: ColumnDef, tableName: String, jdbcLayer: JdbcLayerInterface): ConnectorResult[ColumnDef] = {
private def checkV10ComplexType(colDef: ColumnDef, tableName: String, dbSchema: String, jdbcLayer: JdbcLayerInterface): ConnectorResult[ColumnDef] = {
def handleVerticaTypeFound(rs: ResultSet): ConnectorResult[ColumnDef] = {
val verticaType = rs.getLong("data_type_id")
if(verticaType > VERTICA_NATIVE_ARRAY_BASE_ID && verticaType < VERTICA_SET_MAX_ID) {
Expand All @@ -789,8 +793,8 @@ class SchemaToolsV10() extends SchemaTools {
JdbcUtils.queryAndNext(queryComplexType, jdbcLayer, handleCTFound, handleCTNotFound)
}
}

val queryColType = s"SELECT data_type_id FROM columns WHERE table_name='$tableName' AND column_name='${colDef.label}'"
val schemaCond = if(dbSchema.nonEmpty) s" AND table_schema='$dbSchema'" else ""
val queryColType = s"SELECT data_type_id FROM columns WHERE table_name='$tableName'$schemaCond AND column_name='${colDef.label}'"
JdbcUtils.queryAndNext(queryColType, jdbcLayer, handleVerticaTypeFound)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,33 @@ class SchemaToolsTests extends AnyFlatSpec with MockFactory with org.scalatest.O
}
}

it should "query columns in a table under a schema space" in {
val schema = "schema"
val tableSource = this.tablename.copy(dbschema = Some(schema))
val (jdbcLayer, mockRs, rsmd) = mockJdbcDeps(tableSource)
val testColDef = TestColumnDef(1, "col1", java.sql.Types.ARRAY, "ARRAY", 0, signed = false, nullable = true)
mockColumnMetadata(rsmd, testColDef)

val verticaArrayType = 1506
mockQueryColumns(tableSource.getTableName.replaceAll("\"", ""),
testColDef.name, verticaArrayType, jdbcLayer, schema)
val mockRs1 = mockQueryTypes(verticaArrayType, hasData = true,jdbcLayer)
(mockRs1.getLong: (String) => Long).expects("jdbc_type").returns(java.sql.Types.BIGINT)
(mockRs1.getString: (String)=>String).expects("type_name").returns("Integer")
mockColumnCount(rsmd, 1)

(new SchemaTools).readSchema(jdbcLayer, tableSource) match {
case Left(error) => fail(error.getFullContext)
case Right(schema) =>
val fields = schema.fields
assert(fields.length == 1)
assert(fields(0).dataType.isInstanceOf[ArrayType])
assert(fields(0).dataType.asInstanceOf[ArrayType]
.elementType.isInstanceOf[LongType])
assert(!fields(0).metadata.getBoolean(MetadataKey.IS_VERTICA_SET))
}
}

it should "correctly detect Vertica Set" in {
val (jdbcLayer, mockRs, rsmd) = mockJdbcDeps(tablename)
val testColDef = TestColumnDef(1, "col1", java.sql.Types.ARRAY, "ARRAY", 0, signed = false, nullable = true)
Expand Down Expand Up @@ -480,7 +507,7 @@ class SchemaToolsTests extends AnyFlatSpec with MockFactory with org.scalatest.O

val verticaArrayType = 1506
val tableName = tablename.getFullTableName.replace("\"", "")
mockQueryColumns(tableName, testColDef.name,verticaArrayType, jdbcLayer)
mockQueryColumns(tableName, testColDef.name, verticaArrayType, jdbcLayer)
mockColumnCount(rsmd, 1)
val verticaElementType = verticaArrayType;
mockQueryTypes(verticaElementType, hasData = false,jdbcLayer)
Expand All @@ -491,9 +518,10 @@ class SchemaToolsTests extends AnyFlatSpec with MockFactory with org.scalatest.O
}
}

private[schema] def mockQueryColumns(tableName: String, colName: String, verticaTypeFound: Long, jdbcLayer: JdbcLayerInterface): Unit = {
private[schema] def mockQueryColumns(tableName: String, colName: String, verticaTypeFound: Long, jdbcLayer: JdbcLayerInterface, schema: String = ""): Unit = {
val mockRs = mock[ResultSet]
val queryColumnDef = s"SELECT data_type_id, data_type FROM columns WHERE table_name='$tableName' AND column_name='$colName'"
val schemaCond = if(schema.nonEmpty) s" AND table_schema='$schema'" else ""
val queryColumnDef = s"SELECT data_type_id, data_type FROM columns WHERE table_name='$tableName'$schemaCond AND column_name='$colName'"
(jdbcLayer.query _)
.expects(queryColumnDef, *)
.returns(Right(mockRs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1588,8 +1588,40 @@ class EndToEndTests(readOpts: Map[String, String], writeOpts: Map[String, String
case e: Exception => fail(e)
}finally {
stmt.close()
TestUtils.dropTable(conn, tableName1)
}
}

it should "read dataframe with 1D array with schema option" in {
val tableName1 = "dftest_array"
val dbschema = "S2VTestSchema"
val nameWithSchema = s"$dbschema.$tableName1"
val n = 1
val stmt = conn.createStatement
TestUtils.createTableBySQL(conn, tableName1, "create table " + nameWithSchema + " (a array[int])")

val insert = "insert into "+ nameWithSchema + " values(array[2])"
TestUtils.populateTableBySQL(stmt, insert, n)
val options = readOpts + ("table" -> tableName1, "dbschema" -> dbschema)
val result = Try{
val df: DataFrame = spark.read.format("com.vertica.spark.datasource.VerticaSource")
.options(options).load()
assert(df.count() == 1)
assert(df.schema.fields(0).dataType.isInstanceOf[ArrayType])
val dataType = df.schema.fields(0).dataType.asInstanceOf[ArrayType]
assert(dataType.elementType.isInstanceOf[LongType])
df.rdd.foreach(row => assert(row.getAs[mutable.WrappedArray[Long]](0)(0) == 2))
}
stmt.close()
TestUtils.dropTable(conn, tableName1)

result match {
case Success(_) => succeed
case Failure(exp) => exp match {
case e: ConnectorException => fail(e.error.getFullContext)
case e: Throwable => fail(s"Unexpected exception: ", e)
}
}
}

it should "read Vertica SET as ARRAY" in {
Expand Down

0 comments on commit 2fd2d23

Please sign in to comment.