Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into release-2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
marsishandsome committed Aug 21, 2020
2 parents 719d02e + 130488e commit 86b9f19
Show file tree
Hide file tree
Showing 59 changed files with 1,827 additions and 543 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ object TiConfigConst {
val TIKV_REGION_SPLIT_SIZE_IN_MB: String = "spark.tispark.tikv.region_split_size_in_mb"
val ISOLATION_READ_ENGINES: String = "spark.tispark.isolation_read_engines"
val PARTITION_PER_SPLIT: String = "spark.tispark.partition_per_split"
val KV_CLIENT_CONCURRENCY: String = "spark.tispark.kv_client_concurrency"

val SNAPSHOT_ISOLATION_LEVEL: String = "SI"
val READ_COMMITTED_ISOLATION_LEVEL: String = "RC"
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object TiUtil {
}

def isDataFrameEmpty(df: DataFrame): Boolean = {
df.limit(1).count() == 0
df.rdd.isEmpty()
}

def sparkConfToTiConf(conf: SparkConf): TiConfiguration = {
Expand Down Expand Up @@ -130,6 +130,10 @@ object TiUtil {
getIsolationReadEnginesFromString(conf.get(TiConfigConst.ISOLATION_READ_ENGINES)).toList)
}

if (conf.contains(TiConfigConst.KV_CLIENT_CONCURRENCY)) {
tiConf.setKvClientConcurrency(conf.get(TiConfigConst.KV_CLIENT_CONCURRENCY).toInt)
}

tiConf
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ package com.pingcap.tispark.write
import java.util

import com.pingcap.tikv.codec.KeyUtils
import com.pingcap.tikv.key.Key
import com.pingcap.tikv.util.FastByteComparisons

class SerializableKey(val bytes: Array[Byte]) extends Serializable {
class SerializableKey(val bytes: Array[Byte])
extends Comparable[SerializableKey]
with Serializable {
override def toString: String = KeyUtils.formatBytes(bytes)

override def equals(that: Any): Boolean =
Expand All @@ -30,4 +34,12 @@ class SerializableKey(val bytes: Array[Byte]) extends Serializable {

override def hashCode(): Int =
util.Arrays.hashCode(bytes)

override def compareTo(o: SerializableKey): Int = {
FastByteComparisons.compareTo(bytes, o.bytes)
}

def getRowKey: Key = {
Key.toRawKey(bytes)
}
}
151 changes: 143 additions & 8 deletions core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class TiBatchWrite(
private var lockTTLSeconds: Long = _
@transient private var tiDBJDBCClient: TiDBJDBCClient = _
@transient private var tiBatchWriteTables: List[TiBatchWriteTable] = _
@transient private var startMS: Long = _

private def write(): Unit = {
try {
Expand Down Expand Up @@ -116,6 +117,8 @@ class TiBatchWrite(
}

private def doWrite(): Unit = {
startMS = System.currentTimeMillis()

// check if write enable
if (!tiContext.tiConf.isWriteEnable) {
throw new TiBatchWriteException(
Expand All @@ -128,7 +131,9 @@ class TiBatchWrite(
val tikvSupportUpdateTTL = StoreVersion.minTiKVVersion("3.0.5", tiSession.getPDClient)
isTTLUpdate = options.isTTLUpdate(tikvSupportUpdateTTL)
lockTTLSeconds = options.getLockTTLSeconds(tikvSupportUpdateTTL)
tiDBJDBCClient = new TiDBJDBCClient(TiDBUtils.createConnectionFactory(options.url)())
tiDBJDBCClient = new TiDBJDBCClient(
TiDBUtils.createConnectionFactory(options.url)(),
options.writeSplitRegionFinish)

// init tiBatchWriteTables
tiBatchWriteTables = {
Expand Down Expand Up @@ -188,10 +193,16 @@ class TiBatchWrite(
logger.info(s"startTS: $startTs")

// pre calculate
val shuffledRDD: RDD[(SerializableKey, Array[Byte])] = {
var shuffledRDD: RDD[(SerializableKey, Array[Byte])] = {
val rddList = tiBatchWriteTables.map(_.preCalculate(startTimeStamp))
tiContext.sparkSession.sparkContext.union(rddList)
if (rddList.lengthCompare(1) == 0) {
rddList.head
} else {
tiContext.sparkSession.sparkContext.union(rddList)
}
}
val shuffledRDDCount = shuffledRDD.count()
logger.info(s"write kv data count=$shuffledRDDCount")

// take one row as primary key
val (primaryKey: SerializableKey, primaryRow: Array[Byte]) = {
Expand All @@ -211,8 +222,34 @@ class TiBatchWrite(
!keyValue._1.equals(primaryKey)
}

// split region
if (options.enableRegionSplit && "v2".equals(options.regionSplitMethod)) {
val orderedSplitPoints = getRegionSplitPoints(shuffledRDD, shuffledRDDCount)

try {
tiSession.splitRegionAndScatter(
orderedSplitPoints.map(_.bytes).asJava,
options.splitRegionBackoffMS,
options.scatterWaitMS)
} catch {
case e: Throwable => logger.warn("split region and scatter error!", e)
}

// shuffle according to split points
shuffledRDD = shuffledRDD.partitionBy(new TiReginSplitPartitioner(orderedSplitPoints))
}

// driver primary pre-write
val ti2PCClient = new TwoPhaseCommitter(tiConf, startTs, lockTTLSeconds * 1000)
val ti2PCClient =
new TwoPhaseCommitter(
tiConf,
startTs,
lockTTLSeconds * 1000,
options.txnPrewriteBatchSize,
options.txnCommitBatchSize,
options.writeBufferSize,
options.writeThreadPerTask,
options.retryCommitSecondaryKey)
val prewritePrimaryBackoff =
ConcreteBackOffer.newCustomBackOff(BackOffer.BATCH_PREWRITE_BACKOFF)
logger.info("start to prewritePrimaryKey")
Expand All @@ -235,13 +272,24 @@ class TiBatchWrite(
logger.info("start to prewriteSecondaryKeys")
secondaryKeysRDD.foreachPartition { iterator =>
val ti2PCClientOnExecutor =
new TwoPhaseCommitter(tiConf, startTs, lockTTLSeconds * 1000)
new TwoPhaseCommitter(
tiConf,
startTs,
lockTTLSeconds * 1000,
options.txnPrewriteBatchSize,
options.txnCommitBatchSize,
options.writeBufferSize,
options.writeThreadPerTask,
options.retryCommitSecondaryKey)

val pairs = iterator.map { keyValue =>
new BytePairWrapper(keyValue._1.bytes, keyValue._2)
}.asJava

ti2PCClientOnExecutor.prewriteSecondaryKeys(primaryKey.bytes, pairs)
ti2PCClientOnExecutor.prewriteSecondaryKeys(
primaryKey.bytes,
pairs,
options.prewriteBackOfferMS)

try {
ti2PCClientOnExecutor.close()
Expand Down Expand Up @@ -283,6 +331,11 @@ class TiBatchWrite(

logger.info("start to commitPrimaryKey")
ti2PCClient.commitPrimaryKey(commitPrimaryBackoff, primaryKey.bytes, commitTs)
try {
ti2PCClient.close()
} catch {
case _: Throwable =>
}
logger.info("commitPrimaryKey success")

// stop primary key ttl update
Expand All @@ -297,24 +350,106 @@ class TiBatchWrite(
if (!options.skipCommitSecondaryKey) {
logger.info("start to commitSecondaryKeys")
secondaryKeysRDD.foreachPartition { iterator =>
val ti2PCClientOnExecutor = new TwoPhaseCommitter(tiConf, startTs)
val ti2PCClientOnExecutor = new TwoPhaseCommitter(
tiConf,
startTs,
lockTTLSeconds * 1000,
options.txnPrewriteBatchSize,
options.txnCommitBatchSize,
options.writeBufferSize,
options.writeThreadPerTask,
options.retryCommitSecondaryKey)

val keys = iterator.map { keyValue =>
new ByteWrapper(keyValue._1.bytes)
}.asJava

try {
ti2PCClientOnExecutor.commitSecondaryKeys(keys, commitTs)
ti2PCClientOnExecutor.commitSecondaryKeys(keys, commitTs, options.commitBackOfferMS)
} catch {
case e: TiBatchWriteException =>
// ignored
logger.warn(s"commit secondary key error", e)
}

try {
ti2PCClientOnExecutor.close()
} catch {
case _: Throwable =>
}
}
logger.info("commitSecondaryKeys finish")
} else {
logger.info("skipping commit secondary key")
}

val endMS = System.currentTimeMillis()
logger.info(s"batch write cost ${(endMS - startMS) / 1000} seconds")
}

private def getRegionSplitPoints(
rdd: RDD[(SerializableKey, Array[Byte])],
count: Long): List[SerializableKey] = {
if (count < options.regionSplitThreshold) {
return Nil
}

val regionSplitPointNum = if (options.regionSplitNum > 0) {
options.regionSplitNum
} else {
Math.max(
options.minRegionSplitNum,
Math.ceil(count.toDouble / options.regionSplitKeys).toInt)
}
logger.info(s"regionSplitPointNum=$regionSplitPointNum")

val sampleSize = (regionSplitPointNum + 1) * options.sampleSplitFrac
logger.info(s"sampleSize=$sampleSize")

val sampleData = rdd.sample(false, sampleSize.toDouble / count).collect()
logger.info(s"sampleData size=${sampleData.length}")

val finalRegionSplitPointNum = if (options.regionSplitUsingSize) {
val avgSize = getAverageSizeInBytes(sampleData)
logger.info(s"avgSize=$avgSize Bytes")
if (avgSize <= options.bytesPerRegion / options.regionSplitKeys) {
regionSplitPointNum
} else {
Math.min(
Math.floor((count.toDouble / options.bytesPerRegion) * avgSize).toInt,
sampleData.length / 10)
}
} else {
regionSplitPointNum
}
logger.info(s"finalRegionSplitPointNum=$finalRegionSplitPointNum")

val sortedSampleData = sampleData
.map(_._1)
.sorted(new Ordering[SerializableKey] {
override def compare(x: SerializableKey, y: SerializableKey): Int = {
x.compareTo(y)
}
})
val orderedSplitPoints = new Array[SerializableKey](finalRegionSplitPointNum)
val step = Math.floor(sortedSampleData.length.toDouble / (finalRegionSplitPointNum + 1)).toInt
for (i <- 0 until finalRegionSplitPointNum) {
orderedSplitPoints(i) = sortedSampleData((i + 1) * step)
}

logger.info(s"orderedSplitPoints size=${orderedSplitPoints.length}")
orderedSplitPoints.toList
}

private def getAverageSizeInBytes(keyValues: Array[(SerializableKey, Array[Byte])]): Int = {
var avg: Double = 0
var t: Int = 1
keyValues.foreach { keyValue =>
val keySize: Double = keyValue._1.bytes.length + keyValue._2.length
avg = avg + (keySize - avg) / t
t = t + 1
}
Math.ceil(avg).toInt
}

private def getUseTableLock: Boolean = {
Expand Down
Loading

0 comments on commit 86b9f19

Please sign in to comment.