From 0534de5118b3ca82f572c3393a477c62b57646e8 Mon Sep 17 00:00:00 2001 From: Liangliang Gu Date: Wed, 19 Aug 2020 16:27:21 +0800 Subject: [PATCH 1/7] fix batch write bug (#1562) * set enable region split default value to true * fix txn heartbeat retry not invalidating region cache * BatchWrite: add parameter taskNumPerRegion * try to solve TTLManager TxnLockNotFound problem * hack: use tispark to resolve locks * Revert "try to solve TTLManager TxnLockNotFound problem" This reverts commit 6c8d7c3bca75ab3d28fdc7686a86b4504bf50436. * fix ordering null point exception * Revert "hack: use tispark to resolve locks" This reverts commit 6cc1eb57aa56bdfdfc08fb83f4cd117b27fe5446. * 1. set txn_size in precommit request, 2. use sample to split index region * fix oom - kvclient not closed after batch get * fmt code * fix NullPointerException: appendBatchBySize keys=null * add parameter: spark.tispark.shuffleKeyToSameRegion, default=true * add parameter: prewriteBackOfferMS default=240000 * add repartition * fix escape char in jdbc url * add retry for commit secondary keys * add configuration to control whether to retry commit secondary keys * fix TiRegionPartitioner if writeConcurrency is set * check parameter taskNumPerRegion * fix index split syntax * fmt * fix batch get resolve lock bug * fix key not found bug (#1531) * Fix inconsistent index in batch write (#1532) * fix inconsistent index * update test * fix null unique index key encode error (#1529) * fix null unique index key encode error * fix bug * fix bug * fix bug * fmt * do not throw exception when split index failed * Fix incorrect usage of LinkedList in GroupByKeys (#1530) * fix split float/double index region (#1533) * add check for region split when minVal = maxVal (#1537) * use startTs's previous timestamp to read (#1536) * fix toString in index split region * add argument: txnCommitBatchSize & writeTaskNumber & writeBufferSize (#1538) * fix resolve lock npe (#1539) * fix resolve lock npe * continue * ignore WriteReadSuite test * set snapshotBatchGetSize default value to 20480 * change TwoPhaseCommitter log level to info * Revert "set snapshotBatchGetSize default value to 20480" This reverts commit 2cd48ff1bea5109087ee34e72cf0daa325b28988. * add invalidate region for batchGet * add column name in error message (#1544) * add argument: writeThreadPerTask (#1545) * support commit concurrency (#1546) * increase getRegionById backoffer * fix Store Not Match error * fix Store Not Match error in LockResolver * Revert "fix Store Not Match error" This reverts commit 1cfff825355ecb357764baa3fa902d06c302abd8. * refactor ThreadPool and parameter (#1548) * fix BatchGet stuck bug (#1549) * do not use getRegionById in retry logic (#1550) * add more log for TwoPhaseCommitter (#1551) * add argument commitBackOfferMS (#1552) * continue run when meet exception during commit secondary key (#1553) * Batch Write optimization (#1535) * fmt * delete unused PREWRITE_CONCURRENCY * add variable: tidb write split region finish * set commitBackOfferMS default value from 60s to 20s * refactor tidb_wait_split_region_finish & fix sql exec bug * add more log * update writeSplitRegionFinish * add more log * Revert "continue run when meet exception during commit secondary key (#1553)" This reverts commit 841adf7e9f7483f13eaab5336f399c94a6f1abdd. * region split version2 (#1558) * add parameter: txnPrewriteBatchSize & txnCommitBatchSize (#1560) * fix columnar batch (#1559) * set TIDB_REGION_SPLIT_METHOD default to v2 Co-authored-by: xufei Co-authored-by: birdstorm Co-authored-by: xufei --- .../tikv/columnar/TiColumnVectorAdapter.java | 14 +- .../pingcap/tikv/datatype/TypeMapping.java | 5 + .../com/pingcap/tispark/TiConfigConst.scala | 1 + .../com/pingcap/tispark/utils/TiUtil.scala | 6 +- .../tispark/write/SerializableKey.scala | 14 +- .../pingcap/tispark/write/TiBatchWrite.scala | 129 +++- .../tispark/write/TiBatchWriteTable.scala | 653 ++++++++++++------ .../pingcap/tispark/write/TiDBOptions.scala | 102 ++- .../write/TiReginSplitPartitioner.scala | 46 ++ .../tispark/write/TiRegionPartitioner.scala | 20 +- .../spark/sql/execution/CoprocessorRDD.scala | 16 +- .../spark/sql/tispark/TiHandleRDD.scala | 85 ++- .../tispark/BatchWriteIssueSuite.scala | 47 +- .../tispark/concurrency/WriteReadSuite.scala | 12 +- .../datasource/BaseDataSourceTest.scala | 22 +- .../datasource/ExceptionTestSuite.scala | 2 +- .../pingcap/tispark/index/LineItemSuite.scala | 4 +- .../tispark/overflow/DateOverflowSuite.scala | 5 +- .../overflow/DateTimeOverflowSuite.scala | 5 +- .../overflow/SignedOverflowSuite.scala | 9 +- .../overflow/UnsignedOverflowSuite.scala | 9 +- ...umerateUniqueIndexDataTypeTestAction.scala | 4 +- .../test/generator/ColumnValueGenerator.scala | 3 +- .../spark/sql/test/generator/Index.scala | 5 + .../sql/test/generator/IndexColumn.scala | 8 +- .../spark/sql/test/generator/Schema.scala | 12 +- .../test/generator/TestDataGenerator.scala | 19 +- .../main/java/com/pingcap/tikv/KVClient.java | 180 ++--- .../main/java/com/pingcap/tikv/PDClient.java | 93 +++ .../main/java/com/pingcap/tikv/Snapshot.java | 37 +- .../com/pingcap/tikv/TiConfiguration.java | 11 + .../java/com/pingcap/tikv/TiDBJDBCClient.java | 81 ++- .../main/java/com/pingcap/tikv/TiSession.java | 109 ++- .../com/pingcap/tikv/TwoPhaseCommitter.java | 242 +++++-- .../tikv/allocator/RowIDAllocator.java | 5 + .../com/pingcap/tikv/codec/TableCodecV1.java | 2 +- .../columnar/BatchedTiChunkColumnVector.java | 4 - .../tikv/columnar/TiChunkColumnVector.java | 27 +- .../pingcap/tikv/columnar/TiColumnVector.java | 13 + .../exception/ConvertOverflowException.java | 4 + .../tikv/exception/TiDBConvertException.java | 22 + .../java/com/pingcap/tikv/key/IndexKey.java | 40 +- .../main/java/com/pingcap/tikv/key/Key.java | 2 +- .../java/com/pingcap/tikv/key/RowKey.java | 12 + .../com/pingcap/tikv/meta/TiTimestamp.java | 4 + .../tikv/operation/KVErrorHandler.java | 22 +- .../region/AbstractRegionStoreClient.java | 2 +- .../pingcap/tikv/region/RegionManager.java | 14 +- .../tikv/region/RegionStoreClient.java | 67 +- .../com/pingcap/tikv/region/TiRegion.java | 4 + .../tikv/txn/LockResolverClientV2.java | 28 +- .../tikv/txn/LockResolverClientV3.java | 35 +- .../tikv/txn/LockResolverClientV4.java | 44 +- .../com/pingcap/tikv/txn/type/BatchKeys.java | 13 +- .../tikv/types/AbstractDateTimeType.java | 17 + .../com/pingcap/tikv/types/ArrayType.java | 154 +++++ .../java/com/pingcap/tikv/types/DataType.java | 25 +- .../com/pingcap/tikv/types/DateTimeType.java | 6 + .../java/com/pingcap/tikv/types/DateType.java | 10 + .../com/pingcap/tikv/types/MySQLType.java | 3 +- .../com/pingcap/tikv/types/TimestampType.java | 6 + .../java/com/pingcap/tikv/util/BackOffer.java | 1 - .../pingcap/tikv/util/ConcreteBackOffer.java | 4 - 63 files changed, 2048 insertions(+), 552 deletions(-) create mode 100644 core/src/main/scala/com/pingcap/tispark/write/TiReginSplitPartitioner.scala create mode 100644 tikv-client/src/main/java/com/pingcap/tikv/exception/TiDBConvertException.java create mode 100644 tikv-client/src/main/java/com/pingcap/tikv/types/ArrayType.java diff --git a/core/src/main/java/com/pingcap/tikv/columnar/TiColumnVectorAdapter.java b/core/src/main/java/com/pingcap/tikv/columnar/TiColumnVectorAdapter.java index 27649d26b9..60130ce49e 100644 --- a/core/src/main/java/com/pingcap/tikv/columnar/TiColumnVectorAdapter.java +++ b/core/src/main/java/com/pingcap/tikv/columnar/TiColumnVectorAdapter.java @@ -23,11 +23,17 @@ public class TiColumnVectorAdapter extends ColumnVector { private final TiColumnVector tiColumnVector; + private final ColumnVector offsets; /** Sets up the data type of this column vector. */ public TiColumnVectorAdapter(TiColumnVector tiColumnVector) { super(TypeMapping.toSparkType(tiColumnVector.dataType())); this.tiColumnVector = tiColumnVector; + if (tiColumnVector.getOffset() == null) { + this.offsets = null; + } else { + this.offsets = new TiColumnVectorAdapter(tiColumnVector.getOffset()); + } } /** @@ -122,7 +128,13 @@ public double getDouble(int rowId) { */ @Override public ColumnarArray getArray(int rowId) { - throw new UnsupportedOperationException("TiColumnVectorAdapter is not supported this method"); + if (tiColumnVector.isNullAt(rowId)) { + return null; + } + int index = rowId * 8; + int start = offsets.getInt(index); + int end = offsets.getInt(index + 1); + return new ColumnarArray(this, start, end - start); } /** diff --git a/core/src/main/java/com/pingcap/tikv/datatype/TypeMapping.java b/core/src/main/java/com/pingcap/tikv/datatype/TypeMapping.java index 4bb6dee625..37c114c141 100644 --- a/core/src/main/java/com/pingcap/tikv/datatype/TypeMapping.java +++ b/core/src/main/java/com/pingcap/tikv/datatype/TypeMapping.java @@ -18,6 +18,7 @@ import static com.pingcap.tikv.types.MySQLType.TypeLonglong; import com.pingcap.tikv.types.AbstractDateTimeType; +import com.pingcap.tikv.types.ArrayType; import com.pingcap.tikv.types.BytesType; import com.pingcap.tikv.types.DataType; import com.pingcap.tikv.types.DateType; @@ -96,6 +97,10 @@ public static org.apache.spark.sql.types.DataType toSparkType(DataType type) { return DataTypes.LongType; } + if (type instanceof ArrayType) { + return DataTypes.createArrayType(DataTypes.LongType); + } + throw new UnsupportedOperationException( String.format("found unsupported type %s", type.getClass().getCanonicalName())); } diff --git a/core/src/main/scala/com/pingcap/tispark/TiConfigConst.scala b/core/src/main/scala/com/pingcap/tispark/TiConfigConst.scala index 810937250d..75a40036af 100644 --- a/core/src/main/scala/com/pingcap/tispark/TiConfigConst.scala +++ b/core/src/main/scala/com/pingcap/tispark/TiConfigConst.scala @@ -44,6 +44,7 @@ object TiConfigConst { val TIKV_REGION_SPLIT_SIZE_IN_MB: String = "spark.tispark.tikv.region_split_size_in_mb" val ISOLATION_READ_ENGINES: String = "spark.tispark.isolation_read_engines" val PARTITION_PER_SPLIT: String = "spark.tispark.partition_per_split" + val KV_CLIENT_CONCURRENCY: String = "spark.tispark.kv_client_concurrency" val SNAPSHOT_ISOLATION_LEVEL: String = "SI" val READ_COMMITTED_ISOLATION_LEVEL: String = "RC" diff --git a/core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala b/core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala index 303eb6b599..e0f6d9737f 100644 --- a/core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala +++ b/core/src/main/scala/com/pingcap/tispark/utils/TiUtil.scala @@ -49,7 +49,7 @@ object TiUtil { } def isDataFrameEmpty(df: DataFrame): Boolean = { - df.limit(1).count() == 0 + df.rdd.isEmpty() } def sparkConfToTiConf(conf: SparkConf): TiConfiguration = { @@ -130,6 +130,10 @@ object TiUtil { getIsolationReadEnginesFromString(conf.get(TiConfigConst.ISOLATION_READ_ENGINES)).toList) } + if (conf.contains(TiConfigConst.KV_CLIENT_CONCURRENCY)) { + tiConf.setKvClientConcurrency(conf.get(TiConfigConst.KV_CLIENT_CONCURRENCY).toInt) + } + tiConf } diff --git a/core/src/main/scala/com/pingcap/tispark/write/SerializableKey.scala b/core/src/main/scala/com/pingcap/tispark/write/SerializableKey.scala index 0db348f670..e909efa92e 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/SerializableKey.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/SerializableKey.scala @@ -18,8 +18,12 @@ package com.pingcap.tispark.write import java.util import com.pingcap.tikv.codec.KeyUtils +import com.pingcap.tikv.key.Key +import com.pingcap.tikv.util.FastByteComparisons -class SerializableKey(val bytes: Array[Byte]) extends Serializable { +class SerializableKey(val bytes: Array[Byte]) + extends Comparable[SerializableKey] + with Serializable { override def toString: String = KeyUtils.formatBytes(bytes) override def equals(that: Any): Boolean = @@ -30,4 +34,12 @@ class SerializableKey(val bytes: Array[Byte]) extends Serializable { override def hashCode(): Int = util.Arrays.hashCode(bytes) + + override def compareTo(o: SerializableKey): Int = { + FastByteComparisons.compareTo(bytes, o.bytes) + } + + def getRowKey: Key = { + Key.toRawKey(bytes) + } } diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala index 91f0930cbe..82fac6638d 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala @@ -80,6 +80,7 @@ class TiBatchWrite( private var lockTTLSeconds: Long = _ @transient private var tiDBJDBCClient: TiDBJDBCClient = _ @transient private var tiBatchWriteTables: List[TiBatchWriteTable] = _ + @transient private var startMS: Long = _ private def write(): Unit = { try { @@ -116,6 +117,8 @@ class TiBatchWrite( } private def doWrite(): Unit = { + startMS = System.currentTimeMillis() + // check if write enable if (!tiContext.tiConf.isWriteEnable) { throw new TiBatchWriteException( @@ -128,7 +131,9 @@ class TiBatchWrite( val tikvSupportUpdateTTL = StoreVersion.minTiKVVersion("3.0.5", tiSession.getPDClient) isTTLUpdate = options.isTTLUpdate(tikvSupportUpdateTTL) lockTTLSeconds = options.getLockTTLSeconds(tikvSupportUpdateTTL) - tiDBJDBCClient = new TiDBJDBCClient(TiDBUtils.createConnectionFactory(options.url)()) + tiDBJDBCClient = new TiDBJDBCClient( + TiDBUtils.createConnectionFactory(options.url)(), + options.writeSplitRegionFinish) // init tiBatchWriteTables tiBatchWriteTables = { @@ -188,9 +193,33 @@ class TiBatchWrite( logger.info(s"startTS: $startTs") // pre calculate - val shuffledRDD: RDD[(SerializableKey, Array[Byte])] = { + var shuffledRDD: RDD[(SerializableKey, Array[Byte])] = { val rddList = tiBatchWriteTables.map(_.preCalculate(startTimeStamp)) - tiContext.sparkSession.sparkContext.union(rddList) + if (rddList.lengthCompare(1) == 0) { + rddList.head + } else { + tiContext.sparkSession.sparkContext.union(rddList) + } + } + val shuffledRDDCount = shuffledRDD.count() + logger.info(s"write kv data count=$shuffledRDDCount") + + if (options.enableRegionSplit && "v2".equals(options.regionSplitMethod)) { + // calculate region split points + val orderedSplitPoints = getRegionSplitPoints(shuffledRDD, shuffledRDDCount) + + // split region + try { + tiSession.splitRegionAndScatter( + orderedSplitPoints.map(_.bytes).asJava, + options.splitRegionBackoffMS, + options.scatterWaitMS) + } catch { + case e: Throwable => logger.warn("split region and scatter error!", e) + } + + // shuffle according to split points + shuffledRDD = shuffledRDD.partitionBy(new TiReginSplitPartitioner(orderedSplitPoints)) } // take one row as primary key @@ -212,7 +241,16 @@ class TiBatchWrite( } // driver primary pre-write - val ti2PCClient = new TwoPhaseCommitter(tiConf, startTs, lockTTLSeconds * 1000) + val ti2PCClient = + new TwoPhaseCommitter( + tiConf, + startTs, + lockTTLSeconds * 1000, + options.txnPrewriteBatchSize, + options.txnCommitBatchSize, + options.writeBufferSize, + options.writeThreadPerTask, + options.retryCommitSecondaryKey) val prewritePrimaryBackoff = ConcreteBackOffer.newCustomBackOff(BackOffer.BATCH_PREWRITE_BACKOFF) logger.info("start to prewritePrimaryKey") @@ -235,13 +273,24 @@ class TiBatchWrite( logger.info("start to prewriteSecondaryKeys") secondaryKeysRDD.foreachPartition { iterator => val ti2PCClientOnExecutor = - new TwoPhaseCommitter(tiConf, startTs, lockTTLSeconds * 1000) + new TwoPhaseCommitter( + tiConf, + startTs, + lockTTLSeconds * 1000, + options.txnPrewriteBatchSize, + options.txnCommitBatchSize, + options.writeBufferSize, + options.writeThreadPerTask, + options.retryCommitSecondaryKey) val pairs = iterator.map { keyValue => new BytePairWrapper(keyValue._1.bytes, keyValue._2) }.asJava - ti2PCClientOnExecutor.prewriteSecondaryKeys(primaryKey.bytes, pairs) + ti2PCClientOnExecutor.prewriteSecondaryKeys( + primaryKey.bytes, + pairs, + options.prewriteBackOfferMS) try { ti2PCClientOnExecutor.close() @@ -283,6 +332,11 @@ class TiBatchWrite( logger.info("start to commitPrimaryKey") ti2PCClient.commitPrimaryKey(commitPrimaryBackoff, primaryKey.bytes, commitTs) + try { + ti2PCClient.close() + } catch { + case _: Throwable => + } logger.info("commitPrimaryKey success") // stop primary key ttl update @@ -297,24 +351,83 @@ class TiBatchWrite( if (!options.skipCommitSecondaryKey) { logger.info("start to commitSecondaryKeys") secondaryKeysRDD.foreachPartition { iterator => - val ti2PCClientOnExecutor = new TwoPhaseCommitter(tiConf, startTs) + val ti2PCClientOnExecutor = new TwoPhaseCommitter( + tiConf, + startTs, + lockTTLSeconds * 1000, + options.txnPrewriteBatchSize, + options.txnCommitBatchSize, + options.writeBufferSize, + options.writeThreadPerTask, + options.retryCommitSecondaryKey) val keys = iterator.map { keyValue => new ByteWrapper(keyValue._1.bytes) }.asJava try { - ti2PCClientOnExecutor.commitSecondaryKeys(keys, commitTs) + ti2PCClientOnExecutor.commitSecondaryKeys(keys, commitTs, options.commitBackOfferMS) } catch { case e: TiBatchWriteException => // ignored logger.warn(s"commit secondary key error", e) } + + try { + ti2PCClientOnExecutor.close() + } catch { + case _: Throwable => + } } logger.info("commitSecondaryKeys finish") } else { logger.info("skipping commit secondary key") } + + val endMS = System.currentTimeMillis() + logger.info(s"batch write cost ${(endMS - startMS) / 1000} seconds") + } + + private def getRegionSplitPoints( + rdd: RDD[(SerializableKey, Array[Byte])], + count: Long): List[SerializableKey] = { + if (count < options.regionSplitThreshold) { + return Nil + } + + val regionSplitPointNum = if (options.regionSplitNum > 0) { + options.regionSplitNum + } else { + Math.max( + options.minRegionSplitNum, + Math.ceil(count.toDouble / options.regionSplitKeys).toInt) + } + logger.info(s"regionSplitPointNum=$regionSplitPointNum") + + val sampleSize = (regionSplitPointNum + 1) * options.sampleSplitFrac + logger.info(s"sampleSize=$sampleSize") + + val sampleData = + rdd + .map(_._1) + .sample(withReplacement = false, fraction = sampleSize.toDouble / count) + .collect() + logger.info(s"sampleData size=${sampleData.length}") + + val sortedSampleData = sampleData.sorted(new Ordering[SerializableKey] { + override def compare(x: SerializableKey, y: SerializableKey): Int = { + x.compareTo(y) + } + }) + + val orderedSplitPoints = new Array[SerializableKey](regionSplitPointNum) + val step = Math.floor(sortedSampleData.length.toDouble / (regionSplitPointNum + 1)).toInt + for (i <- 0 until regionSplitPointNum) { + orderedSplitPoints(i) = sortedSampleData((i + 1) * step) + } + + logger.info(s"orderedSplitPoints size=${orderedSplitPoints.length}") + orderedSplitPoints.toList } private def getUseTableLock: Boolean = { diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWriteTable.scala b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWriteTable.scala index 866683967f..f59b47a95a 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWriteTable.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWriteTable.scala @@ -18,19 +18,30 @@ package com.pingcap.tispark.write import java.sql.SQLException import java.util -import com.google.protobuf.ByteString import com.pingcap.tikv.allocator.RowIDAllocator import com.pingcap.tikv.codec.{CodecDataOutput, TableCodec} -import com.pingcap.tikv.exception.TiBatchWriteException +import com.pingcap.tikv.exception.{ + ConvertOverflowException, + TiBatchWriteException, + TiDBConvertException +} import com.pingcap.tikv.key.{IndexKey, RowKey} import com.pingcap.tikv.meta._ import com.pingcap.tikv.region.TiRegion import com.pingcap.tikv.row.ObjectRowImpl import com.pingcap.tikv.types.DataType.EncodeType import com.pingcap.tikv.types.IntegerType -import com.pingcap.tikv.{TiBatchWriteUtils, TiConfiguration, TiDBJDBCClient, TiSession} +import com.pingcap.tikv.util.FastByteComparisons +import com.pingcap.tikv.{ + BytePairWrapper, + TiBatchWriteUtils, + TiConfiguration, + TiDBJDBCClient, + TiSession +} import com.pingcap.tispark.TiTableReference import com.pingcap.tispark.utils.TiUtil +import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.functions._ @@ -41,7 +52,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable class TiBatchWriteTable( - @transient val df: DataFrame, + @transient var df: DataFrame, @transient val tiContext: TiContext, val options: TiDBOptions, val tiConf: TiConfiguration, @@ -60,7 +71,7 @@ class TiBatchWriteTable( private var tableColSize: Int = _ private var colsMapInTiDB: Map[String, TiColumnInfo] = _ private var colsInDf: List[String] = _ - private var uniqueIndices: List[TiIndexInfo] = _ + private var uniqueIndices: Seq[TiIndexInfo] = _ private var handleCol: TiColumnInfo = _ private var tableLocked: Boolean = false @@ -74,16 +85,16 @@ class TiBatchWriteTable( colsMapInTiDB = tiTableInfo.getColumns.asScala.map(col => col.getName -> col).toMap colsInDf = df.columns.toList.map(_.toLowerCase()) - uniqueIndices = tiTableInfo.getIndices.asScala.filter(index => index.isUnique).toList + uniqueIndices = tiTableInfo.getIndices.asScala.filter(index => index.isUnique) handleCol = tiTableInfo.getPKIsHandleColumn tableColSize = tiTableInfo.getColumns.size() def persist(): Unit = { - df.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + df = df.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) } def isDFEmpty: Boolean = { - if (TiUtil.isDataFrameEmpty(df)) { + if (TiUtil.isDataFrameEmpty(df.select(df.columns.head))) { logger.warn(s"the dataframe write to $tiTableRef is empty!") true } else { @@ -100,6 +111,13 @@ class TiBatchWriteTable( } def preCalculate(startTimeStamp: TiTimestamp): RDD[(SerializableKey, Array[Byte])] = { + val sc = tiContext.sparkSession.sparkContext + + df.explain(true) + + val count = df.count + logger.info(s"source data count=$count") + // auto increment val rdd = if (tiTableInfo.hasAutoIncrementColumn) { val isProvidedID = tableColSize == colsInDf.length @@ -113,22 +131,21 @@ class TiBatchWriteTable( "Column size is matched but cannot find auto increment column by name") } - val colOffset = - colsInDf.zipWithIndex.find(col => autoIncrementColName.equals(col._1)).get._2 - val hasNullValue = df - .filter(row => row.get(colOffset) == null) - .count() > 0 + val hasNullValue = !df + .select(autoIncrementColName) + .filter(row => row.get(0) == null) + .rdd + .isEmpty() if (hasNullValue) { throw new TiBatchWriteException( - "cannot allocate id on the condition of having null value " + - "and valid value on auto increment column") + "cannot allocate id on the condition of having null value and valid value on auto increment column") } df.rdd } else { // if auto increment column is not provided, we need allocate id for it. // adding an auto increment column to df val newDf = df.withColumn(autoIncrementColName, lit(null).cast("long")) - val start = getAutoTableIdStart(df.count) + val start = getAutoTableIdStart(count) // update colsInDF since we just add one column in df colsInDf = newDf.columns.toList.map(_.toLowerCase()) @@ -162,13 +179,13 @@ class TiBatchWriteTable( // currently we only support replace and insert. val constraintCheckIsNeeded = handleCol != null || uniqueIndices.nonEmpty - val encodedTiRowRDD = if (constraintCheckIsNeeded) { + val keyValueRDD = if (constraintCheckIsNeeded) { val wrappedRowRdd = if (tiTableInfo.isPkHandle) { tiRowRdd.map { row => WrappedRow(row, extractHandleId(row)) } } else { - val start = getAutoTableIdStart(tiRowRdd.count) + val start = getAutoTableIdStart(count) tiRowRdd.zipWithIndex.map { data => WrappedRow(data._1, data._2 + start) } @@ -176,67 +193,80 @@ class TiBatchWriteTable( val distinctWrappedRowRdd = deduplicate(wrappedRowRdd) - val deletion = if (options.useSnapshotBatchGet) { - generateDataToBeRemovedRddV2(distinctWrappedRowRdd, startTimeStamp) - } else { - generateDataToBeRemovedRddV1(distinctWrappedRowRdd, startTimeStamp) - } + val deletion = (if (options.useSnapshotBatchGet) { + generateDataToBeRemovedRddV2(distinctWrappedRowRdd, startTimeStamp) + } else { + generateDataToBeRemovedRddV1(distinctWrappedRowRdd, startTimeStamp) + }).persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + if (!options.replace && !deletion.isEmpty()) { throw new TiBatchWriteException("data to be inserted has conflicts with TiKV data") } - val wrappedEncodedRdd = generateKV(distinctWrappedRowRdd, remove = false) - splitTableRegion(wrappedEncodedRdd.filter(r => !r.isIndex)) - splitIndexRegion(wrappedEncodedRdd.filter(r => r.isIndex)) + val wrappedEncodedRecordRdd = generateRecordKV(distinctWrappedRowRdd, remove = false) + val wrappedEncodedIndexRdds = generateIndexKVs(distinctWrappedRowRdd, remove = false) + val wrappedEncodedIndexRdd: RDD[WrappedEncodedRow] = { + val list = wrappedEncodedIndexRdds.values.toSeq + if (list.isEmpty) { + sc.emptyRDD[WrappedEncodedRow] + } else if (list.lengthCompare(1) == 0) { + list.head + } else { + sc.union(list) + } + } + + if ("v1".equals(options.regionSplitMethod)) { + splitTableRegion(wrappedEncodedRecordRdd) + splitIndexRegion(wrappedEncodedIndexRdds, count) + } - val mergedRDD = wrappedEncodedRdd ++ generateKV(deletion, remove = true) - mergedRDD + val g1 = (wrappedEncodedRecordRdd ++ generateRecordKV(deletion, remove = true)) .map(wrappedEncodedRow => (wrappedEncodedRow.encodedKey, wrappedEncodedRow)) - .groupByKey() - .map { - case (key, iterable) => - // if rdd contains same key, it means we need first delete the old value and insert the new value associated the - // key. We can merge the two operation into one update operation. - // Note: the deletion operation's value of kv pair is empty. - iterable.find(value => value.encodedValue.nonEmpty) match { - case Some(wrappedEncodedRow) => - WrappedEncodedRow( - wrappedEncodedRow.row, - wrappedEncodedRow.handle, - wrappedEncodedRow.encodedKey, - wrappedEncodedRow.encodedValue, - isIndex = wrappedEncodedRow.isIndex, - wrappedEncodedRow.indexId, - remove = false) - case None => - WrappedEncodedRow( - iterable.head.row, - iterable.head.handle, - key, - new Array[Byte](0), - isIndex = iterable.head.isIndex, - iterable.head.indexId, - remove = true) - } + .reduceByKey { (r1, r2) => + // if rdd contains same key, it means we need first delete the old value and insert the new value associated the + // key. We can merge the two operation into one update operation. + // Note: the deletion operation's value of kv pair is empty. + if (r1.encodedValue.isEmpty) r2 else r1 + } + .map(_._2) + val g2 = (wrappedEncodedIndexRdd ++ generateIndexKV(sc, deletion, remove = true)) + .map(wrappedEncodedRow => (wrappedEncodedRow.encodedKey, wrappedEncodedRow)) + .reduceByKey { (r1, r2) => + if (r1.encodedValue.isEmpty) r2 else r1 } + .map(_._2) + + (g1 ++ g2).map(obj => (obj.encodedKey, obj.encodedValue)) } else { - val start = getAutoTableIdStart(tiRowRdd.count) + val start = getAutoTableIdStart(count) val wrappedRowRdd = tiRowRdd.zipWithIndex.map { row => WrappedRow(row._1, row._2 + start) } - val wrappedEncodedRdd = generateKV(wrappedRowRdd, remove = false) - splitTableRegion(wrappedEncodedRdd.filter(r => !r.isIndex)) - splitIndexRegion(wrappedEncodedRdd.filter(r => r.isIndex)) + val wrappedEncodedRecordRdd = generateRecordKV(wrappedRowRdd, remove = false) + val wrappedEncodedIndexRdds = generateIndexKVs(wrappedRowRdd, remove = false) + val wrappedEncodedIndexRdd = sc.union(wrappedEncodedIndexRdds.values.toSeq) + + if ("v1".equals(options.regionSplitMethod)) { + splitTableRegion(wrappedEncodedRecordRdd) + splitIndexRegion(wrappedEncodedIndexRdds, count) + } - wrappedEncodedRdd + (wrappedEncodedRecordRdd ++ wrappedEncodedIndexRdd).map(obj => + (obj.encodedKey, obj.encodedValue)) } - val encodedKVPairRDD = - encodedTiRowRDD.map(row => EncodedKVPair(row.encodedKey, row.encodedValue)) - // shuffle data in same task which belong to same region - val shuffledRDD = shuffleKeyToSameRegion(encodedKVPairRDD).cache() - shuffledRDD + // shuffle or persist + val shuffledOrPersistedRDD = + if (options.enableRegionSplit && "v2".equals(options.regionSplitMethod)) { + keyValueRDD.persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK) + } else if (options.shuffleKeyToSameRegion) { + shuffleKeyToSameRegion(keyValueRDD) + } else { + shuffleUseRepartition(keyValueRDD) + } + shuffledOrPersistedRDD } def lockTable(): Unit = { @@ -288,30 +318,38 @@ class TiBatchWriteTable( .getStart } + private def isNullUniqueIndexValue(value: Array[Byte]): Boolean = { + value.length == 1 && value(0) == '0' + } + private def generateDataToBeRemovedRddV1( rdd: RDD[WrappedRow], startTs: TiTimestamp): RDD[WrappedRow] = { rdd .mapPartitions { wrappedRows => - val snapshot = TiSession.getInstance(tiConf).createSnapshot(startTs) + val snapshot = TiSession.getInstance(tiConf).createSnapshot(startTs.getPrevious) wrappedRows.map { wrappedRow => val rowBuf = mutable.ListBuffer.empty[WrappedRow] // check handle key if (handleCol != null) { val oldValue = snapshot.get(buildRowKey(wrappedRow.row, wrappedRow.handle).bytes) - if (oldValue.nonEmpty) { + if (oldValue.nonEmpty && !isNullUniqueIndexValue(oldValue)) { val oldRow = TableCodec.decodeRow(oldValue, wrappedRow.handle, tiTableInfo) rowBuf += WrappedRow(oldRow, wrappedRow.handle) } } uniqueIndices.foreach { index => - val oldValue = snapshot.get(buildUniqueIndexKey(wrappedRow.row, index).bytes) - if (oldValue.nonEmpty) { - val oldHandle = TableCodec.decodeHandle(oldValue) - val oldRowValue = snapshot.get(buildRowKey(wrappedRow.row, oldHandle).bytes) - val oldRow = TableCodec.decodeRow(oldRowValue, oldHandle, tiTableInfo) - rowBuf += WrappedRow(oldRow, oldHandle) + val keyInfo = buildUniqueIndexKey(wrappedRow.row, wrappedRow.handle, index) + // if handle is appended, it must not exists in old table + if (!keyInfo._2) { + val oldValue = snapshot.get(keyInfo._1.bytes) + if (oldValue.nonEmpty && !isNullUniqueIndexValue(oldValue)) { + val oldHandle = TableCodec.decodeHandle(oldValue) + val oldRowValue = snapshot.get(buildRowKey(wrappedRow.row, oldHandle).bytes) + val oldRow = TableCodec.decodeRow(oldRowValue, oldHandle, tiTableInfo) + rowBuf += WrappedRow(oldRow, oldHandle) + } } } rowBuf @@ -324,18 +362,18 @@ class TiBatchWriteTable( rdd: RDD[WrappedRow], startTs: TiTimestamp): RDD[WrappedRow] = { rdd.mapPartitions { wrappedRows => - val snapshot = TiSession.getInstance(tiConf).createSnapshot(startTs) + val snapshot = TiSession.getInstance(tiConf).createSnapshot(startTs.getPrevious) var rowBuf = mutable.ListBuffer.empty[WrappedRow] - var rowBufIndex = 0 + var rowBufIterator = rowBuf.iterator new Iterator[WrappedRow] { override def hasNext: Boolean = { while (true) { - if (!wrappedRows.hasNext && !(rowBufIndex < rowBuf.size)) { + if (!wrappedRows.hasNext && !rowBufIterator.hasNext) { return false } - if (rowBufIndex < rowBuf.size) { + if (rowBufIterator.hasNext) { return true } @@ -347,8 +385,7 @@ class TiBatchWriteTable( override def next(): WrappedRow = { if (hasNext) { - rowBufIndex = rowBufIndex + 1 - rowBuf(rowBufIndex - 1) + rowBufIterator.next } else { null } @@ -365,87 +402,100 @@ class TiBatchWriteTable( } def genNextHandleBatch( - batch: List[WrappedRow]): (util.List[Array[Byte]], util.Map[ByteString, Long]) = { + batch: List[WrappedRow]): (util.List[Array[Byte]], util.List[Long]) = { val list = new util.ArrayList[Array[Byte]]() - val map = new util.HashMap[ByteString, Long]() + val handles = new util.ArrayList[Long]() batch.foreach { wrappedRow => val bytes = buildRowKey(wrappedRow.row, wrappedRow.handle).bytes - val key = ByteString.copyFrom(bytes) list.add(bytes) - map.put(key, wrappedRow.handle) + handles.add(wrappedRow.handle) } - (list, map) + (list, handles) } - def genNextIndexBatch( + def genNextUniqueIndexBatch( batch: List[WrappedRow], - index: TiIndexInfo): (util.List[Array[Byte]], util.Map[ByteString, TiRow]) = { - val list = new util.ArrayList[Array[Byte]]() - val map = new util.HashMap[ByteString, TiRow]() + index: TiIndexInfo): (util.List[Array[Byte]], util.List[TiRow]) = { + val keyList = new util.ArrayList[Array[Byte]]() + val rowList = new util.ArrayList[TiRow]() batch.foreach { wrappedRow => - val bytes = buildUniqueIndexKey(wrappedRow.row, index).bytes - val key = ByteString.copyFrom(bytes) - list.add(bytes) - map.put(key, wrappedRow.row) + val encodeResult = buildUniqueIndexKey(wrappedRow.row, wrappedRow.handle, index) + if (!encodeResult._2) { + // only add the key if handle is not appended, since if handle is appened, + // the value must be a new value + val bytes = encodeResult._1.bytes + keyList.add(bytes) + rowList.add(wrappedRow.row) + } + } + (keyList, rowList) + } + + def decodeHandle(row: Array[Byte]): Long = { + RowKey.decode(row).getHandle + } + + def processHandleDelete( + oldValueList: java.util.List[BytePairWrapper], + handleList: java.util.List[Long]): Unit = { + for (i <- 0 until oldValueList.size) { + val oldValuePair = oldValueList.get(i) + val oldValue = oldValuePair.getValue + val handle = handleList.get(i) + + if (oldValue.nonEmpty && !isNullUniqueIndexValue(oldValue)) { + val oldRow = TableCodec.decodeRow(oldValue, handle, tiTableInfo) + rowBuf += WrappedRow(oldRow, handle) + } } - (list, map) } def processNextBatch(): Unit = { rowBuf = mutable.ListBuffer.empty[WrappedRow] - rowBufIndex = 0 val batch = getNextBatch(wrappedRows) if (handleCol != null) { - val (batchHandle, handleMap) = genNextHandleBatch(batch) - val oldValueList = snapshot.batchGet(batchHandle) - (0 until oldValueList.size()).foreach { i => - val oldValuePair = oldValueList.get(i) - val oldValue = oldValuePair.getValue - val key = oldValuePair.getKey - val handle = handleMap.get(key) - - val oldRow = TableCodec.decodeRow(oldValue, handle, tiTableInfo) - rowBuf += WrappedRow(oldRow, handle) - } + val (batchHandle, handleList) = genNextHandleBatch(batch) + val oldValueList = snapshot.batchGet(options.batchGetBackOfferMS, batchHandle) + processHandleDelete(oldValueList, handleList) } val oldIndicesBatch: util.List[Array[Byte]] = new util.ArrayList[Array[Byte]]() - val oldIndicesMap: mutable.HashMap[SerializableKey, Long] = new mutable.HashMap() uniqueIndices.foreach { index => - val (batchIndices, rowMap) = genNextIndexBatch(batch, index) - val oldValueList = snapshot.batchGet(batchIndices) - (0 until oldValueList.size()).foreach { i => + val (batchIndices, rowList) = genNextUniqueIndexBatch(batch, index) + val oldValueList = snapshot.batchGet(options.batchGetBackOfferMS, batchIndices) + for (i <- 0 until oldValueList.size) { val oldValuePair = oldValueList.get(i) val oldValue = oldValuePair.getValue - val key = oldValuePair.getKey - val oldHandle = TableCodec.decodeHandle(oldValue) - val tiRow = rowMap.get(key) - - oldIndicesBatch.add(buildRowKey(tiRow, oldHandle).bytes) - oldIndicesMap.put( - new SerializableKey(buildRowKey(tiRow, oldHandle).bytes), - oldHandle) + if (oldValue.nonEmpty && !isNullUniqueIndexValue(oldValue)) { + val oldHandle = TableCodec.decodeHandle(oldValue) + val tiRow = rowList.get(i) + + oldIndicesBatch.add(buildRowKey(tiRow, oldHandle).bytes) + } } } - val oldIndicesRowPairs = snapshot.batchGet(oldIndicesBatch) - (0 until oldIndicesRowPairs.size()).foreach { i => - val oldIndicesRowPair = oldIndicesRowPairs.get(i) + val oldIndicesRowPairs = snapshot.batchGet(options.batchGetBackOfferMS, oldIndicesBatch) + oldIndicesRowPairs.asScala.foreach { oldIndicesRowPair => val oldRowKey = oldIndicesRowPair.getKey val oldRowValue = oldIndicesRowPair.getValue - val oldHandle = oldIndicesMap(new SerializableKey(oldRowKey)) - val oldRow = TableCodec.decodeRow(oldRowValue, oldHandle, tiTableInfo) - rowBuf += WrappedRow(oldRow, oldHandle) + if (oldRowValue.nonEmpty && !isNullUniqueIndexValue(oldRowValue)) { + val oldHandle = decodeHandle(oldRowKey) + val oldRow = TableCodec.decodeRow(oldRowValue, oldHandle, tiTableInfo) + rowBuf += WrappedRow(oldRow, oldHandle) + } } + + rowBufIterator = rowBuf.iterator } } } } private def checkValueNotNull(rdd: RDD[TiRow]): Unit = { - val nullRowCount = rdd + val nullRows = !rdd .filter { row => colsMapInTiDB.exists { case (_, v) => @@ -456,11 +506,11 @@ class TiBatchWriteTable( } } } - .count() + .isEmpty() - if (nullRowCount > 0) { + if (nullRows) { throw new TiBatchWriteException( - s"Insert null value to not null column! $nullRowCount rows contain illegal null values!") + s"Insert null value to not null column! rows contain illegal null values!") } } @@ -474,18 +524,18 @@ class TiBatchWriteTable( val rowKey = buildRowKey(wrappedRow.row, wrappedRow.handle) (rowKey, wrappedRow) } - .groupByKey() - .map(_._2.head) + .reduceByKey((r1, _) => r1) + .map(_._2) } uniqueIndices.foreach { index => { mutableRdd = mutableRdd .map { wrappedRow => - val indexKey = buildUniqueIndexKey(wrappedRow.row, index) + val indexKey = buildUniqueIndexKey(wrappedRow.row, wrappedRow.handle, index)._1 (indexKey, wrappedRow) } - .groupByKey() - .map(_._2.head) + .reduceByKey((r1, _) => r1) + .map(_._2) } } mutableRdd @@ -493,15 +543,23 @@ class TiBatchWriteTable( @throws(classOf[NoSuchTableException]) private def shuffleKeyToSameRegion( - rdd: RDD[EncodedKVPair]): RDD[(SerializableKey, Array[Byte])] = { + rdd: RDD[(SerializableKey, Array[Byte])]): RDD[(SerializableKey, Array[Byte])] = { val regions = getRegions assert(regions.size() > 0) - val tiRegionPartitioner = new TiRegionPartitioner(regions, options.writeConcurrency) + val tiRegionPartitioner = + new TiRegionPartitioner(regions, options.writeConcurrency, options.taskNumPerRegion) - rdd - .map(obj => (obj.encodedKey, obj.encodedValue)) - // remove duplicate rows if key equals (should not happen, cause already deduplicated) - .reduceByKey(tiRegionPartitioner, (a: Array[Byte], _: Array[Byte]) => a) + rdd.partitionBy(tiRegionPartitioner) + } + + @throws(classOf[NoSuchTableException]) + private def shuffleUseRepartition( + rdd: RDD[(SerializableKey, Array[Byte])]): RDD[(SerializableKey, Array[Byte])] = { + if (options.writeTaskNumber > 0) { + rdd.repartition(options.writeTaskNumber) + } else { + rdd + } } private def getRegions: util.List[TiRegion] = { @@ -529,10 +587,19 @@ class TiBatchWriteTable( val tiRow = ObjectRowImpl.create(fieldCount) for (i <- 0 until fieldCount) { // TODO: add tiDataType back - tiRow.set( - colsMapInTiDB(colsInDf(i)).getOffset, - null, - colsMapInTiDB(colsInDf(i)).getType.convertToTiDBType(sparkRow(i))) + try { + tiRow.set( + colsMapInTiDB(colsInDf(i)).getOffset, + null, + colsMapInTiDB(colsInDf(i)).getType.convertToTiDBType(sparkRow(i))) + } catch { + case e: ConvertOverflowException => + throw new ConvertOverflowException( + e.getMessage, + new TiDBConvertException(colsMapInTiDB(colsInDf(i)).getName, e)) + case e: Throwable => + throw new TiDBConvertException(colsMapInTiDB(colsInDf(i)).getName, e) + } } tiRow } @@ -572,13 +639,20 @@ class TiBatchWriteTable( handle: Long, index: TiIndexInfo, remove: Boolean): (SerializableKey, Array[Byte]) = { - val indexKey = buildUniqueIndexKey(row, index) + val encodeResult = buildUniqueIndexKey(row, handle, index) + val indexKey = encodeResult._1 val value = if (remove) { new Array[Byte](0) } else { - val cdo = new CodecDataOutput() - cdo.writeLong(handle) - cdo.toBytes + if (encodeResult._2) { + val value = new Array[Byte](1) + value(0) = '0' + value + } else { + val cdo = new CodecDataOutput() + cdo.writeLong(handle) + cdo.toBytes + } } (indexKey, value) @@ -589,7 +663,8 @@ class TiBatchWriteTable( handle: Long, index: TiIndexInfo, remove: Boolean): (SerializableKey, Array[Byte]) = { - val keys = IndexKey.encodeIndexDataValues(row, index.getIndexColumns, tiTableInfo) + val keys = + IndexKey.encodeIndexDataValues(row, index.getIndexColumns, handle, false, tiTableInfo).keys val cdo = new CodecDataOutput() cdo.write(IndexKey.toIndexKey(locatePhysicalTable(row), index.getId, keys: _*).getBytes) IntegerType.BIGINT.encode(cdo, EncodeType.KEY, handle) @@ -607,12 +682,21 @@ class TiBatchWriteTable( new SerializableKey(RowKey.toRowKey(locatePhysicalTable(row), handle).getBytes) } - private def buildUniqueIndexKey(row: TiRow, index: TiIndexInfo): SerializableKey = { - val keys = - IndexKey.encodeIndexDataValues(row, index.getIndexColumns, tiTableInfo) + private def buildUniqueIndexKey( + row: TiRow, + handle: Long, + index: TiIndexInfo): (SerializableKey, Boolean) = { + // NULL is only allowed in unique key, primary key does not allow NULL value + val encodeResult = IndexKey.encodeIndexDataValues( + row, + index.getIndexColumns, + handle, + index.isUnique && !index.isPrimary, + tiTableInfo) + val keys = encodeResult.keys val indexKey = IndexKey.toIndexKey(locatePhysicalTable(row), index.getId, keys: _*) - new SerializableKey(indexKey.getBytes) + (new SerializableKey(indexKey.getBytes), encodeResult.appendHandle) } private def generateRowKey( @@ -628,13 +712,72 @@ class TiBatchWriteTable( } } - private def generateKV(rdd: RDD[WrappedRow], remove: Boolean): RDD[WrappedEncodedRow] = { + private def calcSize(rdd: RDD[WrappedEncodedRow]): Long = { + rdd.aggregate(0L)( + (prev, r) => prev + r.encodedKey.bytes.length + r.encodedValue.length, + (r1, r2) => r1 + r2) + } + + private def calcRecordMinMax(rdd: RDD[WrappedEncodedRow]): (Long, Long) = { + rdd.aggregate((Long.MaxValue, Long.MinValue))( + (prev, r) => (Math.min(prev._1, r.handle), Math.max(prev._2, r.handle)), + (r1, r2) => (Math.min(r1._1, r2._1), Math.max(r1._2, r2._2))) + } + + private def calcIndexMinMax(rdd: RDD[WrappedEncodedRow], index: TiIndexInfo): (Any, Any) = { + val colName = index.getIndexColumns.get(0).getName + val tiColumn = tiTableInfo.getColumn(colName) + val colOffset = tiColumn.getOffset + val dataType = tiColumn.getType + + def compare(x: Any, y: Any): Int = { + x match { + case _: java.lang.Integer => + x.asInstanceOf[java.lang.Integer] + .compareTo(y.asInstanceOf[java.lang.Integer]) + case _: java.lang.Long => + x.asInstanceOf[java.lang.Long] + .compareTo(y.asInstanceOf[java.lang.Long]) + case _: java.lang.Double => + x.asInstanceOf[java.lang.Double] + .compareTo(y.asInstanceOf[java.lang.Double]) + case _: java.lang.Float => + x.asInstanceOf[java.lang.Float] + .compareTo(y.asInstanceOf[java.lang.Float]) + case _: Array[Byte] => + FastByteComparisons.compareTo(x.asInstanceOf[Array[Byte]], y.asInstanceOf[Array[Byte]]) + case _ => x.toString.compareTo(y.toString) + } + } + + def min(x: Any, y: Any): Any = { + if (x == null) y + else if (y == null) x + else if (compare(x, y) < 0) x + else y + } + + def max(x: Any, y: Any): Any = { + if (x == null) y + else if (y == null) x + else if (compare(x, y) > 0) x + else y + } + + rdd.aggregate((null: Any, null: Any))( + (prev, r) => { + val v = r.row.get(colOffset, dataType) + (min(prev._1, v), max(prev._2, v)) + }, + (r1, r2) => (min(r1._1, r2._1), max(r1._2, r2._2))) + } + + private def generateRecordKV(rdd: RDD[WrappedRow], remove: Boolean): RDD[WrappedEncodedRow] = { rdd .map { row => { - val kvBuf = mutable.ListBuffer.empty[WrappedEncodedRow] val (encodedKey, encodedValue) = generateRowKey(row.row, row.handle, remove) - kvBuf += WrappedEncodedRow( + WrappedEncodedRow( row.row, row.handle, encodedKey, @@ -642,35 +785,62 @@ class TiBatchWriteTable( isIndex = false, -1, remove) - tiTableInfo.getIndices.asScala.foreach { index => - if (index.isUnique) { - val (encodedKey, encodedValue) = - generateUniqueIndexKey(row.row, row.handle, index, remove) - kvBuf += WrappedEncodedRow( - row.row, - row.handle, - encodedKey, - encodedValue, - isIndex = true, - index.getId, - remove) - } else { - val (encodedKey, encodedValue) = - generateSecondaryIndexKey(row.row, row.handle, index, remove) - kvBuf += WrappedEncodedRow( - row.row, - row.handle, - encodedKey, - encodedValue, - isIndex = true, - index.getId, - remove) - } - } - kvBuf } } - .flatMap(identity) + } + + private def generateIndexRDD( + rdd: RDD[WrappedRow], + index: TiIndexInfo, + remove: Boolean): RDD[WrappedEncodedRow] = { + if (index.isUnique) { + rdd.map { row => + val (encodedKey, encodedValue) = + generateUniqueIndexKey(row.row, row.handle, index, remove) + WrappedEncodedRow( + row.row, + row.handle, + encodedKey, + encodedValue, + isIndex = true, + index.getId, + remove) + } + } else { + rdd.map { row => + val (encodedKey, encodedValue) = + generateSecondaryIndexKey(row.row, row.handle, index, remove) + WrappedEncodedRow( + row.row, + row.handle, + encodedKey, + encodedValue, + isIndex = true, + index.getId, + remove) + } + } + } + + private def generateIndexKVs( + rdd: RDD[WrappedRow], + remove: Boolean): Map[Long, RDD[WrappedEncodedRow]] = { + tiTableInfo.getIndices.asScala + .map(index => (index.getId, generateIndexRDD(rdd, index, remove))) + .toMap + } + + private def unionAll( + sc: SparkContext, + rdds: Map[Long, RDD[WrappedEncodedRow]]): RDD[WrappedEncodedRow] = { + rdds.values.foldLeft(sc.emptyRDD[WrappedEncodedRow])(_ ++ _) + } + + private def generateIndexKV( + sc: SparkContext, + rdd: RDD[WrappedRow], + remove: Boolean): RDD[WrappedEncodedRow] = { + unionAll(sc, generateIndexKVs(rdd, remove)) } // TODO: support physical table later. Need use partition info and row value to @@ -679,12 +849,9 @@ class TiBatchWriteTable( tiTableInfo.getId } - private def estimateRegionSplitNum(wrappedEncodedRdd: RDD[WrappedEncodedRow]): Long = { - val totalSize = - wrappedEncodedRdd.map(r => r.encodedKey.bytes.length + r.encodedValue.length).sum() - + private def estimateRegionSplitNum(totalSize: Long): Long = { //TODO: replace 96 with actual value read from pd https://github.com/pingcap/tispark/issues/890 - Math.ceil(totalSize / (tiContext.tiConf.getTikvRegionSplitSizeInMB * 1024 * 1024)).toLong + Math.ceil(totalSize / (tiContext.tiConf.getTikvRegionSplitSizeInMB * 1024.0 * 1024)).toLong } private def checkTidbRegionSplitContidion( @@ -694,7 +861,16 @@ class TiBatchWriteTable( maxHandle - minHandle > regionSplitNum * 1000 } - private def splitIndexRegion(wrappedEncodedRdd: RDD[WrappedEncodedRow]): Unit = { + private def toString(value: Any): String = { + value match { + case a: Array[Byte] => java.util.Arrays.toString(a) + case _ => value.toString + } + } + + private def splitIndexRegion( + wrappedEncodedRdd: Map[Long, RDD[WrappedEncodedRow]], + count: Long): Unit = { if (options.enableRegionSplit && isEnableSplitRegion) { val indices = tiTableInfo.getIndices.asScala @@ -704,44 +880,97 @@ class TiBatchWriteTable( val colOffset = tiColumn.getOffset val dataType = tiColumn.getType - val ordering = new Ordering[WrappedEncodedRow] { + val ordering: Ordering[WrappedEncodedRow] = new Ordering[WrappedEncodedRow] { override def compare(x: WrappedEncodedRow, y: WrappedEncodedRow): Int = { val xIndex = x.row.get(colOffset, dataType) val yIndex = y.row.get(colOffset, dataType) - xIndex.toString.compare(yIndex.toString) + xIndex match { + case _: java.lang.Integer => + xIndex + .asInstanceOf[java.lang.Integer] + .compareTo(yIndex.asInstanceOf[java.lang.Integer]) + case _: java.lang.Long => + xIndex + .asInstanceOf[java.lang.Long] + .compareTo(yIndex.asInstanceOf[java.lang.Long]) + case _: java.lang.Double => + xIndex + .asInstanceOf[java.lang.Double] + .compareTo(yIndex.asInstanceOf[java.lang.Double]) + case _: java.lang.Float => + xIndex + .asInstanceOf[java.lang.Float] + .compareTo(yIndex.asInstanceOf[java.lang.Float]) + case _: Array[Byte] => + FastByteComparisons.compareTo( + xIndex.asInstanceOf[Array[Byte]], + yIndex.asInstanceOf[Array[Byte]]) + case _ => xIndex.toString.compareTo(yIndex.toString) + } } } - val rdd = wrappedEncodedRdd.filter(_.indexId == index.getId) + val rdd = wrappedEncodedRdd(index.getId) + val regionSplitNum = if (options.regionSplitNum != 0) { options.regionSplitNum } else { - estimateRegionSplitNum(rdd) + val indexSize = calcSize(rdd) + logger.info(s"count=$count indexSize=$indexSize") + val splitNum = estimateRegionSplitNum(indexSize) + logger.info(s"index region split num=$splitNum") + splitNum } // region split if (regionSplitNum > 1) { - val minIndexValue = rdd.min()(ordering).row.get(colOffset, dataType).toString - val maxIndexValue = rdd.max()(ordering).row.get(colOffset, dataType).toString logger.info( s"index region split, regionSplitNum=$regionSplitNum, indexName=${index.getName}") - try { - tiDBJDBCClient - .splitIndexRegion( - options.database, - options.table, - index.getName, - minIndexValue, - maxIndexValue, - regionSplitNum) - } catch { - case e: SQLException => - if (options.isTest) { - throw e + if (count > (regionSplitNum * 1000 + 1) * 10) { + logger.info("split by sample data") + val frac = options.sampleSplitFrac + val sampleData = rdd.takeSample(false, (regionSplitNum * frac + 1).toInt) + val sortedSampleData = sampleData.sorted(ordering) + val buf = new StringBuilder + for (i <- 1 until regionSplitNum.toInt) { + val indexValue = toString(sortedSampleData(i * frac).row.get(colOffset, dataType)) + buf.append(" (") + buf.append("\"") + buf.append(indexValue) + buf.append("\"") + buf.append(")") + if (i != regionSplitNum - 1) { + buf.append(",") } + } + try { + tiDBJDBCClient + .splitIndexRegion(options.database, options.table, index.getName, buf.toString()) + } catch { + case e: SQLException => throw e + } + } else { + logger.info("split by min/max data") + val (minIndexValue, maxIndexValue) = calcIndexMinMax(rdd, index) + logger.info(s"index min=$minIndexValue max=$maxIndexValue") + try { + tiDBJDBCClient + .splitIndexRegion( + options.database, + options.table, + index.getName, + minIndexValue.toString, + maxIndexValue.toString, + regionSplitNum) + } catch { + case e: SQLException => + if (options.isTest) { + throw e + } + } } } else { - logger.warn( + logger.info( s"skip index split index, regionSplitNum=$regionSplitNum, indexName=${index.getName}") } } @@ -769,15 +998,19 @@ class TiBatchWriteTable( } } } else { + val regionSplitNum = if (options.regionSplitNum != 0) { options.regionSplitNum } else { - estimateRegionSplitNum(wrappedRowRdd) + val recordSize = calcSize(wrappedRowRdd) + val splitNum = estimateRegionSplitNum(recordSize) + logger.info(s"record region split num=$splitNum") + splitNum } // region split if (regionSplitNum > 1) { - val minHandle = wrappedRowRdd.min().handle - val maxHandle = wrappedRowRdd.max().handle + val (minHandle, maxHandle) = calcRecordMinMax(wrappedRowRdd) + logger.info(s"record min=$minHandle max=$maxHandle") if (checkTidbRegionSplitContidion(minHandle, maxHandle, regionSplitNum)) { logger.info(s"table region split is enabled, regionSplitNum=$regionSplitNum") try { @@ -795,7 +1028,7 @@ class TiBatchWriteTable( } } } else { - logger.warn("table region split is skipped") + logger.info("table region split is skipped") } } } diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala b/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala index 31ebfa009c..c1d09115a8 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala @@ -30,6 +30,8 @@ class TiDBOptions(@transient val parameters: CaseInsensitiveMap[String]) extends import com.pingcap.tispark.write.TiDBOptions._ + private val optParamPrefix = "spark.tispark." + // ------------------------------------------------------------ // Required parameters // ------------------------------------------------------------ @@ -37,7 +39,7 @@ class TiDBOptions(@transient val parameters: CaseInsensitiveMap[String]) extends val port: String = checkAndGet(TIDB_PORT) val user: String = checkAndGet(TIDB_USER) val password: String = checkAndGet(TIDB_PASSWORD) - val multiTables: Boolean = parameters.getOrElse(TIDB_MULTI_TABLES, "false").toBoolean + val multiTables: Boolean = getOrDefault(TIDB_MULTI_TABLES, "false").toBoolean val database: String = if (!multiTables) { checkAndGet(TIDB_DATABASE) } else { @@ -51,39 +53,71 @@ class TiDBOptions(@transient val parameters: CaseInsensitiveMap[String]) extends // ------------------------------------------------------------ // Optional parameters only for writing // ------------------------------------------------------------ - val replace: Boolean = parameters.getOrElse(TIDB_REPLACE, "false").toBoolean + val replace: Boolean = getOrDefault(TIDB_REPLACE, "false").toBoolean // It is an optimize by the nature of 2pc protocol // We leave other txn, gc or read to resolve locks. val skipCommitSecondaryKey: Boolean = - parameters.getOrElse(TIDB_SKIP_COMMIT_SECONDARY_KEY, "false").toBoolean - val regionSplitNum: Int = parameters.getOrElse(TIDB_REGION_SPLIT_NUM, "0").toInt - val enableRegionSplit: Boolean = - parameters.getOrElse(TIDB_ENABLE_REGION_SPLIT, "true").toBoolean - val writeConcurrency: Int = parameters.getOrElse(TIDB_WRITE_CONCURRENCY, "0").toInt + getOrDefault(TIDB_SKIP_COMMIT_SECONDARY_KEY, "false").toBoolean + val writeConcurrency: Int = getOrDefault(TIDB_WRITE_CONCURRENCY, "0").toInt // ttlMode = { "FIXED", "UPDATE", "DEFAULT" } - val ttlMode: String = parameters.getOrElse(TIDB_TTL_MODE, "DEFAULT").toUpperCase() - val useSnapshotBatchGet: Boolean = - parameters.getOrElse(TIDB_USE_SNAPSHOT_BATCH_GET, "true").toBoolean - val snapshotBatchGetSize: Int = parameters.getOrElse(TIDB_SNAPSHOT_BATCH_GET_SIZE, "2048").toInt + val ttlMode: String = getOrDefault(TIDB_TTL_MODE, "DEFAULT").toUpperCase() + val useSnapshotBatchGet: Boolean = getOrDefault(TIDB_USE_SNAPSHOT_BATCH_GET, "true").toBoolean + //20k + val snapshotBatchGetSize: Int = getOrDefault(TIDB_SNAPSHOT_BATCH_GET_SIZE, "20480").toInt + val batchGetBackOfferMS: Int = getOrDefault(TIDB_BATCH_GET_BACKOFFER_MS, "60000").toInt val sleepAfterPrewritePrimaryKey: Long = - parameters.getOrElse(TIDB_SLEEP_AFTER_PREWRITE_PRIMARY_KEY, "0").toLong + getOrDefault(TIDB_SLEEP_AFTER_PREWRITE_PRIMARY_KEY, "0").toLong val sleepAfterPrewriteSecondaryKey: Long = - parameters.getOrElse(TIDB_SLEEP_AFTER_PREWRITE_SECONDARY_KEY, "0").toLong - val sleepAfterGetCommitTS: Long = - parameters.getOrElse(TIDB_SLEEP_AFTER_GET_COMMIT_TS, "0").toLong - val isTest: Boolean = parameters.getOrElse(TIDB_IS_TEST, "false").toBoolean + getOrDefault(TIDB_SLEEP_AFTER_PREWRITE_SECONDARY_KEY, "0").toLong + val sleepAfterGetCommitTS: Long = getOrDefault(TIDB_SLEEP_AFTER_GET_COMMIT_TS, "0").toLong + val isTest: Boolean = getOrDefault(TIDB_IS_TEST, "false").toBoolean + val taskNumPerRegion: Int = { + val num = getOrDefault(TIDB_TASK_NUM_PER_REGION, "5").toInt + if (num <= 0) { + 5 + } else { + num + } + } + val shuffleKeyToSameRegion: Boolean = + getOrDefault(TIDB_SHUFFLE_KEY_TO_SAME_REGION, "true").toBoolean + val writeTaskNumber: Int = getOrDefault(TIDB_WRITE_TASK_NUMBER, "0").toInt + val prewriteBackOfferMS: Int = getOrDefault(TIDB_PREWRITE_BACKOFFER_MS, "240000").toInt + val commitBackOfferMS: Int = getOrDefault(TIDB_COMMIT_BACKOFFER_MS, "20000").toInt + // 728 * 1024 + val txnPrewriteBatchSize: Long = getOrDefault(TIDB_TXN_PREWITE_BATCH_SIZE, "786432").toLong + // 728 * 1024 + val txnCommitBatchSize: Long = getOrDefault(TIDB_TXN_COMMIT_BATCH_SIZE, "786432").toLong + // 32 * 1024 + val writeBufferSize: Int = getOrDefault(TIDB_WRITE_BUFFER_SIZE, "32768").toInt + val writeThreadPerTask: Int = getOrDefault(TIDB_WRITE_THREAD_PER_TASK, "1").toInt + val retryCommitSecondaryKey: Boolean = + getOrDefault(TIDB_RETRY_COMMIT_SECONDARY_KEY, "true").toBoolean + + // region split + val enableRegionSplit: Boolean = getOrDefault(TIDB_ENABLE_REGION_SPLIT, "true").toBoolean + val regionSplitNum: Int = getOrDefault(TIDB_REGION_SPLIT_NUM, "0").toInt + val sampleSplitFrac: Int = getOrDefault(TIDB_SAMPLE_SPLIT_FRAC, "1000").toInt + val writeSplitRegionFinish: Int = getOrDefault(TIDB_WRITE_SPLIT_REGION_FINISH, "-1").toInt + val regionSplitMethod: String = getOrDefault(TIDB_REGION_SPLIT_METHOD, "v2") + val scatterWaitMS: Int = getOrDefault(TIDB_SCATTER_WAIT_MS, "300000").toInt + val regionSplitKeys: Int = getOrDefault(TIDB_REGION_SPLIT_KEYS, "960000").toInt + val minRegionSplitNum: Int = getOrDefault(TIDB_MIN_REGION_SPLIT_NUM, "4").toInt + val regionSplitThreshold: Int = getOrDefault(TIDB_REGION_SPLIT_THRESHOLD, "100000").toInt + val splitRegionBackoffMS: Int = getOrDefault(TIDB_SPLIT_REGION_BACKOFFER_MS, "120000").toInt + // ------------------------------------------------------------ // Calculated parameters // ------------------------------------------------------------ val url: String = s"jdbc:mysql://address=(protocol=tcp)(host=$address)(port=$port)/?user=$user&password=$password&useSSL=false&rewriteBatchedStatements=true" - private val optParamPrefix = "spark.tispark." + .replaceAll("%", "%25") def useTableLock(isV4: Boolean): Boolean = { if (isV4) { - parameters.getOrElse(TIDB_USE_TABLE_LOCK, "false").toBoolean + getOrDefault(TIDB_USE_TABLE_LOCK, "false").toBoolean } else { - parameters.getOrElse(TIDB_USE_TABLE_LOCK, "true").toBoolean + getOrDefault(TIDB_USE_TABLE_LOCK, "true").toBoolean } } @@ -148,7 +182,6 @@ class TiDBOptions(@transient val parameters: CaseInsensitiveMap[String]) extends } object TiDBOptions { - private final val _tidbOptionNames = collection.mutable.Set[String]() val TIDB_ADDRESS: String = newOption("tidb.addr") val TIDB_PORT: String = newOption("tidb.port") val TIDB_USER: String = newOption("tidb.user") @@ -157,14 +190,36 @@ object TiDBOptions { val TIDB_TABLE: String = newOption("table") val TIDB_REPLACE: String = newOption("replace") val TIDB_SKIP_COMMIT_SECONDARY_KEY: String = newOption("skipCommitSecondaryKey") - val TIDB_ENABLE_REGION_SPLIT: String = newOption("enableRegionSplit") - val TIDB_REGION_SPLIT_NUM: String = newOption("regionSplitNum") val TIDB_WRITE_CONCURRENCY: String = newOption("writeConcurrency") val TIDB_TTL_MODE: String = newOption("ttlMode") val TIDB_USE_SNAPSHOT_BATCH_GET: String = newOption("useSnapshotBatchGet") val TIDB_SNAPSHOT_BATCH_GET_SIZE: String = newOption("snapshotBatchGetSize") + val TIDB_BATCH_GET_BACKOFFER_MS: String = newOption("batchGetBackOfferMS") val TIDB_USE_TABLE_LOCK: String = newOption("useTableLock") val TIDB_MULTI_TABLES: String = newOption("multiTables") + val TIDB_TASK_NUM_PER_REGION: String = newOption("taskNumPerRegion") + val TIDB_SHUFFLE_KEY_TO_SAME_REGION: String = newOption("shuffleKeyToSameRegion") + val TIDB_WRITE_TASK_NUMBER: String = newOption("writeTaskNumber") + val TIDB_PREWRITE_BACKOFFER_MS: String = newOption("prewriteBackOfferMS") + val TIDB_COMMIT_BACKOFFER_MS: String = newOption("commitBackOfferMS") + val TIDB_TXN_PREWITE_BATCH_SIZE: String = newOption("txnPrewriteBatchSize") + val TIDB_TXN_COMMIT_BATCH_SIZE: String = newOption("txnCommitBatchSize") + val TIDB_WRITE_BUFFER_SIZE: String = newOption("writeBufferSize") + val TIDB_WRITE_THREAD_PER_TASK: String = newOption("writeThreadPerTask") + val TIDB_RETRY_COMMIT_SECONDARY_KEY: String = newOption("retryCommitSecondaryKey") + + // region split + val TIDB_ENABLE_REGION_SPLIT: String = newOption("enableRegionSplit") + val TIDB_REGION_SPLIT_NUM: String = newOption("regionSplitNum") + val TIDB_SAMPLE_SPLIT_FRAC: String = newOption("sampleSplitFrac") + val TIDB_WRITE_SPLIT_REGION_FINISH: String = newOption("writeSplitRegionFinish") + val TIDB_REGION_SPLIT_METHOD: String = newOption("regionSplitMethod") + val TIDB_SCATTER_WAIT_MS: String = newOption("scatterWaitSecondes") + val TIDB_REGION_SPLIT_KEYS: String = newOption("regionSplitKeys") + val TIDB_MIN_REGION_SPLIT_NUM: String = newOption("minRegionSplitNum") + val TIDB_REGION_SPLIT_THRESHOLD: String = newOption("regionSplitThreshold") + val TIDB_SPLIT_REGION_BACKOFFER_MS: String = newOption("splitRegionBackoffMS") + // ------------------------------------------------------------ // parameters only for test // ------------------------------------------------------------ @@ -176,8 +231,7 @@ object TiDBOptions { val TIDB_SLEEP_AFTER_GET_COMMIT_TS: String = newOption("sleepAfterGetCommitTS") private def newOption(name: String): String = { - _tidbOptionNames += name.toLowerCase(Locale.ROOT) - name + name.toLowerCase(Locale.ROOT) } private def mergeWithSparkConf(parameters: Map[String, String]): Map[String, String] = { diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiReginSplitPartitioner.scala b/core/src/main/scala/com/pingcap/tispark/write/TiReginSplitPartitioner.scala new file mode 100644 index 0000000000..090cf659ac --- /dev/null +++ b/core/src/main/scala/com/pingcap/tispark/write/TiReginSplitPartitioner.scala @@ -0,0 +1,46 @@ +/* + * Copyright 2020 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.pingcap.tispark.write + +import com.pingcap.tikv.key.Key +import org.apache.spark.Partitioner + +class TiReginSplitPartitioner(orderedSplitPoints: List[SerializableKey]) extends Partitioner { + override def getPartition(key: Any): Int = { + val serializableKey = key.asInstanceOf[SerializableKey] + val rawKey = Key.toRawKey(serializableKey.bytes) + binarySearch(rawKey) % numPartitions + } + + def binarySearch(key: Key): Int = { + var l = 0 + var r = orderedSplitPoints.size + while (l < r) { + val mid = l + (r - l) / 2 + val splitPointKey = orderedSplitPoints(mid).getRowKey + if (splitPointKey.compareTo(key) < 0) { + l = mid + 1 + } else { + r = mid + } + } + l + } + + override def numPartitions: Int = { + orderedSplitPoints.size + 1 + } +} diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiRegionPartitioner.scala b/core/src/main/scala/com/pingcap/tispark/write/TiRegionPartitioner.scala index fe49569b6d..4bcdf5af04 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiRegionPartitioner.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiRegionPartitioner.scala @@ -21,13 +21,22 @@ import com.pingcap.tikv.key.Key import com.pingcap.tikv.region.TiRegion import org.apache.spark.Partitioner -class TiRegionPartitioner(regions: util.List[TiRegion], writeConcurrency: Int) +class TiRegionPartitioner( + regions: util.List[TiRegion], + writeConcurrency: Int, + taskNumPerRegion: Int) extends Partitioner { override def getPartition(key: Any): Int = { val serializableKey = key.asInstanceOf[SerializableKey] val rawKey = Key.toRawKey(serializableKey.bytes) - binarySearch(rawKey) % numPartitions + if (writeConcurrency <= 0) { + val regionNumber = binarySearch(rawKey) + val offset = Math.abs(rawKey.hashCode()) % taskNumPerRegion + (regionNumber * taskNumPerRegion + offset) % numPartitions + } else { + binarySearch(rawKey) % numPartitions + } } def binarySearch(key: Key): Int = { @@ -39,7 +48,7 @@ class TiRegionPartitioner(regions: util.List[TiRegion], writeConcurrency: Int) while (l < r) { val mid = l + (r - l) / 2 val region = regions.get(mid) - if (Key.toRawKey(region.getEndKey).compareTo(key) <= 0) { + if (region.getRowEndKey.compareTo(key) <= 0) { l = mid + 1 } else { r = mid @@ -49,6 +58,7 @@ class TiRegionPartitioner(regions: util.List[TiRegion], writeConcurrency: Int) l } - override def numPartitions: Int = - if (writeConcurrency <= 0) regions.size() else writeConcurrency + override def numPartitions: Int = { + if (writeConcurrency <= 0) regions.size() * taskNumPerRegion else writeConcurrency + } } diff --git a/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala b/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala index d734da6213..c1b0e47bde 100644 --- a/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala +++ b/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala @@ -187,7 +187,7 @@ case class ColumnarRegionTaskExec( val batchSize = tiConf.getIndexScanBatchSize val downgradeThreshold = tiConf.getDowngradeThreshold - iter.flatMap { row => + def computeWithRowIterator(row: InternalRow): Iterator[InternalRow] = { val handles = row.getArray(1).toLongArray() val handleIterator: util.Iterator[Long] = handles.iterator var taskCount = 0 @@ -376,6 +376,20 @@ case class ColumnarRegionTaskExec( } }.asInstanceOf[Iterator[InternalRow]] } + + iter match { + case batch: Iterator[ColumnarBatch] => + batch.asInstanceOf[Iterator[ColumnarBatch]].flatMap { it => + it.rowIterator().flatMap { row => + computeWithRowIterator(row) + } + } + case _: Iterator[InternalRow] => + iter.flatMap { row => + computeWithRowIterator(row) + } + } + } override protected def doExecute(): RDD[InternalRow] = { diff --git a/core/src/main/scala/org/apache/spark/sql/tispark/TiHandleRDD.scala b/core/src/main/scala/org/apache/spark/sql/tispark/TiHandleRDD.scala index da949d375b..d05edc8450 100644 --- a/core/src/main/scala/org/apache/spark/sql/tispark/TiHandleRDD.scala +++ b/core/src/main/scala/org/apache/spark/sql/tispark/TiHandleRDD.scala @@ -15,19 +15,31 @@ package org.apache.spark.sql.tispark +import java.nio.ByteBuffer + +import com.pingcap.tikv.codec.CodecDataOutput +import com.pingcap.tikv.columnar.{ + TiChunk, + TiChunkColumnVector, + TiColumnVector, + TiColumnarBatchHelper +} import com.pingcap.tikv.meta.TiDAGRequest +import com.pingcap.tikv.types.{ArrayType, DataType, IntegerType} import com.pingcap.tikv.util.RangeSplitter import com.pingcap.tikv.{TiConfiguration, TiSession} -import com.pingcap.tispark.utils.TiUtil import com.pingcap.tispark.{TiPartition, TiTableReference} import gnu.trove.list.array.TLongArrayList import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.SparkSession import org.apache.spark.{Partition, TaskContext, TaskKilledException} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer /** * RDD used for retrieving handles from TiKV. Result is arranged as @@ -53,7 +65,7 @@ class TiHandleRDD( outputTypes.map(CatalystTypeConverters.createToCatalystConverter) override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = - new Iterator[InternalRow] { + new Iterator[ColumnarBatch] { checkTimezone() private val tiPartition = split.asInstanceOf[TiPartition] @@ -61,7 +73,7 @@ class TiHandleRDD( private val snapshot = session.createSnapshot(dagRequest.getStartTs) private[this] val tasks = tiPartition.tasks - private val handleIterator = snapshot.indexHandleRead(dagRequest, tasks) + private val handleIterator = snapshot.indexHandleReadRow(dagRequest, tasks) private val regionManager = session.getRegionManager private lazy val handleList = { val lst = new TLongArrayList() @@ -92,14 +104,63 @@ class TiHandleRDD( iterator.hasNext } - override def next(): InternalRow = { - val next = iterator.next - val regionId = next._1 - val handleList = next._2 + override def next(): ColumnarBatch = { + var numRows = 0 + val batchSize = 20480 + val cdi0 = new CodecDataOutput() + val cdi1 = new CodecDataOutput() + var offsets = new mutable.ArrayBuffer[Long] + var curOffset = 0L + while (hasNext && numRows < batchSize) { + val next = iterator.next + val regionId = next._1 + val handleList = next._2 + if (!handleList.isEmpty) { + // Returns RegionId:[handle1, handle2, handle3...] K-V pair +// val sparkRow = Row.apply(regionId, handleList.toArray()) +// TiUtil.rowToInternalRow(sparkRow, outputTypes, converters) + cdi0.writeLong(regionId) + cdi1.writeLong(handleList.size()) + for (i <- 0 until handleList.size()) { + cdi1.writeLong(handleList.get(i)) + } + offsets += curOffset + curOffset += handleList.size().toLong + numRows += 1 + } + } + offsets += curOffset + + val buffer0 = ByteBuffer.wrap(cdi0.toBytes) + val buffer1 = ByteBuffer.wrap(cdi1.toBytes) + + val nullBitMaps = DataType.setAllNotNullBitMapWithNumRows(numRows) + + val regionIdType = IntegerType.BIGINT + val handleListType = ArrayType.ARRAY - // Returns RegionId:[handle1, handle2, handle3...] K-V pair - val sparkRow = Row.apply(regionId, handleList.toArray()) - TiUtil.rowToInternalRow(sparkRow, outputTypes, converters) + val childColumnVectors = new ArrayBuffer[TiColumnVector] + childColumnVectors += + new TiChunkColumnVector( + regionIdType, + regionIdType.getFixLen, + numRows, + 0, + nullBitMaps, + null, + buffer0) + childColumnVectors += + // any type will do? actual type is array[Long] + new TiChunkColumnVector( + handleListType, + 8, + curOffset.toInt, + 0, + nullBitMaps, + offsets.toArray, + buffer1) + val chunk = new TiChunk(childColumnVectors.toArray) + TiColumnarBatchHelper.createColumnarBatch(chunk) } - } + }.asInstanceOf[Iterator[InternalRow]] } diff --git a/core/src/test/scala/com/pingcap/tispark/BatchWriteIssueSuite.scala b/core/src/test/scala/com/pingcap/tispark/BatchWriteIssueSuite.scala index 60b346bf98..2688c78932 100644 --- a/core/src/test/scala/com/pingcap/tispark/BatchWriteIssueSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/BatchWriteIssueSuite.scala @@ -17,7 +17,13 @@ package com.pingcap.tispark import com.pingcap.tispark.datasource.BaseDataSourceTest import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ + IntegerType, + StringType, + StructField, + StructType, + TimestampType +} class BatchWriteIssueSuite extends BaseDataSourceTest("test_batchwrite_issue") { override def beforeAll(): Unit = { @@ -36,6 +42,45 @@ class BatchWriteIssueSuite extends BaseDataSourceTest("test_batchwrite_issue") { doTestNullValues(s"create table $dbtable(a int, b varchar(64), PRIMARY KEY (a))") } + test("Index for timestamp was written multiple times") { + if (!supportBatchWrite) { + cancel + } + + val schema = StructType( + List( + StructField("a", IntegerType), + StructField("b", StringType), + StructField("c", TimestampType))) + val options = Some(Map("replace" -> "true")) + + dropTable() + jdbcUpdate( + s"create table $dbtable(a int, b varchar(64), c datetime, CONSTRAINT xx UNIQUE (b), key `dt_index` (c))") + + for (_ <- 0 to 1) { + val row1 = Row(10, "1", java.sql.Timestamp.valueOf("2001-12-29 22:44:04")) + val row2 = Row(20, "2", java.sql.Timestamp.valueOf("2001-12-29 23:10:31")) + val row3 = Row(30, "3", java.sql.Timestamp.valueOf("2001-12-29 23:27:14")) + val row4 = Row(40, "4", java.sql.Timestamp.valueOf("2001-12-29 23:18:46")) + val row5 = Row(50, "5", java.sql.Timestamp.valueOf("2001-12-29 23:21:45")) + val row6 = Row(50, "5", java.sql.Timestamp.valueOf("2001-12-29 23:21:45")) + tidbWrite(List(row1, row2, row3, row4, row5, row6), schema, options) + + try { + assert(spark.sql(s"select count(c) from $table").collect().head.get(0) === 5) + assert(spark.sql(s"select count(a) from $table").collect().head.get(0) === 5) + } finally { + spark.sql(s"select * from $table").show(false) + spark.sql(s"select count(c) from $table").show(false) + spark.sql(s"select count(c) from $table").explain + spark.sql(s"select count(a) from $table").show(false) + spark.sql(s"select count(a) from $table").explain + } + + } + } + override def afterAll(): Unit = try { dropTable() diff --git a/core/src/test/scala/com/pingcap/tispark/concurrency/WriteReadSuite.scala b/core/src/test/scala/com/pingcap/tispark/concurrency/WriteReadSuite.scala index 1c99957f8f..ec58c730ad 100644 --- a/core/src/test/scala/com/pingcap/tispark/concurrency/WriteReadSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/concurrency/WriteReadSuite.scala @@ -17,27 +17,27 @@ package com.pingcap.tispark.concurrency class WriteReadSuite extends ConcurrencyTest { - test("read conflict using jdbc") { + ignore("read conflict using jdbc") { doTestJDBC(s"create table $dbtable(i int, s varchar(128))") } - test("read conflict using jdbc: primary key") { + ignore("read conflict using jdbc: primary key") { doTestJDBC(s"create table $dbtable(i int, s varchar(128), PRIMARY KEY(i))") } - test("read conflict using jdbc: unique key") { + ignore("read conflict using jdbc: unique key") { doTestJDBC(s"create table $dbtable(i int, s varchar(128), UNIQUE KEY(i))") } - test("read conflict using tispark") { + ignore("read conflict using tispark") { doTestTiSpark(s"create table $dbtable(i int, s varchar(128))") } - test("read conflict using tispark: primary key") { + ignore("read conflict using tispark: primary key") { doTestTiSpark(s"create table $dbtable(i int, s varchar(128), PRIMARY KEY(i))") } - test("read conflict using tispark: unique key") { + ignore("read conflict using tispark: unique key") { doTestTiSpark(s"create table $dbtable(i int, s varchar(128), UNIQUE KEY(i))") } diff --git a/core/src/test/scala/com/pingcap/tispark/datasource/BaseDataSourceTest.scala b/core/src/test/scala/com/pingcap/tispark/datasource/BaseDataSourceTest.scala index 722d855718..c54d8bf175 100644 --- a/core/src/test/scala/com/pingcap/tispark/datasource/BaseDataSourceTest.scala +++ b/core/src/test/scala/com/pingcap/tispark/datasource/BaseDataSourceTest.scala @@ -105,14 +105,16 @@ class BaseDataSourceTest(val table: String, val database: String = "tispark_test caughtTiDB.getCause.getClass.equals(tidbErrorClass), s"${caughtTiDB.getCause.getClass.getName} not equals to ${tidbErrorClass.getName}") - if (!msgStartWith) { - assert( - Objects.equals(caughtTiDB.getCause.getMessage, tidbErrorMsg), - s"${caughtTiDB.getCause.getMessage} not equals to $tidbErrorMsg") - } else { - assert( - startWith(caughtTiDB.getCause.getMessage, tidbErrorMsg), - s"${caughtTiDB.getCause.getMessage} not start with $tidbErrorMsg") + if (tidbErrorMsg != null) { + if (!msgStartWith) { + assert( + Objects.equals(caughtTiDB.getCause.getMessage, tidbErrorMsg), + s"${caughtTiDB.getCause.getMessage} not equals to $tidbErrorMsg") + } else { + assert( + startWith(caughtTiDB.getCause.getMessage, tidbErrorMsg), + s"${caughtTiDB.getCause.getMessage} not start with $tidbErrorMsg") + } } } @@ -222,8 +224,8 @@ class BaseDataSourceTest(val table: String, val database: String = "tispark_test if (!compResult(jdbcResult, tidbResult)) { logger.error(s"""Failed on $tblName\n - |DataSourceAPI result: ${listToString(jdbcResult)}\n - |TiDB via JDBC result: ${listToString(tidbResult)}""".stripMargin) + |TiDB via JDBC result: ${listToString(jdbcResult)}\n + |DataSourceAPI result: ${listToString(tidbResult)}""".stripMargin) fail() } } diff --git a/core/src/test/scala/com/pingcap/tispark/datasource/ExceptionTestSuite.scala b/core/src/test/scala/com/pingcap/tispark/datasource/ExceptionTestSuite.scala index 78fceec1bb..b126020044 100644 --- a/core/src/test/scala/com/pingcap/tispark/datasource/ExceptionTestSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/datasource/ExceptionTestSuite.scala @@ -110,7 +110,7 @@ class ExceptionTestSuite extends BaseDataSourceTest("test_datasource_exception_t } assert( caught.getMessage - .equals("Insert null value to not null column! 1 rows contain illegal null values!")) + .equals("Insert null value to not null column! rows contain illegal null values!")) } } diff --git a/core/src/test/scala/com/pingcap/tispark/index/LineItemSuite.scala b/core/src/test/scala/com/pingcap/tispark/index/LineItemSuite.scala index a6d836a823..24a917e2a2 100644 --- a/core/src/test/scala/com/pingcap/tispark/index/LineItemSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/index/LineItemSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.functions._ class LineItemSuite extends BaseTiSparkTest { private val table = "LINEITEM" - private val where = "where L_PARTKEY < 2100000" + private val where = "where L_PARTKEY < 3100000" private val batchWriteTablePrefix = "BATCH.WRITE" private val isPkHandlePrefix = "isPkHandle" private val replacePKHandlePrefix = "replacePKHandle" @@ -54,7 +54,7 @@ class LineItemSuite extends BaseTiSparkTest { df, ti, new TiDBOptions( - tidbOptions + ("database" -> s"$database", "table" -> tableToWrite, "isTest" -> "true"))) + tidbOptions + ("database" -> s"$database", "table" -> tableToWrite, "isTest" -> "true", "regionSplitMethod" -> "v2"))) // refresh refreshConnections(TestTables(database, tableToWrite)) diff --git a/core/src/test/scala/com/pingcap/tispark/overflow/DateOverflowSuite.scala b/core/src/test/scala/com/pingcap/tispark/overflow/DateOverflowSuite.scala index 82a99c2c56..49031a80c8 100644 --- a/core/src/test/scala/com/pingcap/tispark/overflow/DateOverflowSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/overflow/DateOverflowSuite.scala @@ -15,6 +15,7 @@ package com.pingcap.tispark.overflow +import com.pingcap.tikv.exception.TiDBConvertException import com.pingcap.tispark.datasource.BaseDataSourceTest import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ class DateOverflowSuite extends BaseDataSourceTest("test_data_type_date_overflow val row = Row("10000-01-01") val schema = StructType(List(StructField("c1", StringType))) val jdbcErrorClass = classOf[java.sql.BatchUpdateException] - val tidbErrorClass = classOf[java.lang.IllegalArgumentException] + val tidbErrorClass = classOf[TiDBConvertException] val tidbErrorMsg = null compareTiDBWriteFailureWithJDBC( @@ -93,7 +94,7 @@ class DateOverflowSuite extends BaseDataSourceTest("test_data_type_date_overflow val row = Row("2019-13-01") val schema = StructType(List(StructField("c1", StringType))) val jdbcErrorClass = classOf[java.sql.BatchUpdateException] - val tidbErrorClass = classOf[java.lang.IllegalArgumentException] + val tidbErrorClass = classOf[TiDBConvertException] val tidbErrorMsgStart = null compareTiDBWriteFailureWithJDBC( diff --git a/core/src/test/scala/com/pingcap/tispark/overflow/DateTimeOverflowSuite.scala b/core/src/test/scala/com/pingcap/tispark/overflow/DateTimeOverflowSuite.scala index 6886c64298..47b0aa2f86 100644 --- a/core/src/test/scala/com/pingcap/tispark/overflow/DateTimeOverflowSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/overflow/DateTimeOverflowSuite.scala @@ -15,6 +15,7 @@ package com.pingcap.tispark.overflow +import com.pingcap.tikv.exception.TiDBConvertException import com.pingcap.tispark.datasource.BaseDataSourceTest import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -58,8 +59,8 @@ class DateTimeOverflowSuite extends BaseDataSourceTest("test_data_type_datetime_ val row = Row("10000-11-11 11:11:11") val schema = StructType(List(StructField("c1", StringType))) val jdbcErrorClass = classOf[java.sql.BatchUpdateException] - val tidbErrorClass = classOf[java.lang.IllegalArgumentException] - val tidbErrorMsg = "Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]" + val tidbErrorClass = classOf[TiDBConvertException] + val tidbErrorMsg = null compareTiDBWriteFailureWithJDBC( List(row), diff --git a/core/src/test/scala/com/pingcap/tispark/overflow/SignedOverflowSuite.scala b/core/src/test/scala/com/pingcap/tispark/overflow/SignedOverflowSuite.scala index f1655f6d14..b37f6ef117 100644 --- a/core/src/test/scala/com/pingcap/tispark/overflow/SignedOverflowSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/overflow/SignedOverflowSuite.scala @@ -15,6 +15,7 @@ package com.pingcap.tispark.overflow +import com.pingcap.tikv.exception.TiDBConvertException import com.pingcap.tispark.datasource.BaseDataSourceTest import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -364,8 +365,8 @@ class SignedOverflowSuite extends BaseDataSourceTest("test_data_type_signed_over val row = Row("9223372036854775808") val schema = StructType(List(StructField("c1", StringType))) val jdbcErrorClass = classOf[java.sql.BatchUpdateException] - val tidbErrorClass = classOf[java.lang.NumberFormatException] - val tidbErrorMsg = "For input string: \"9223372036854775808\"" + val tidbErrorClass = classOf[TiDBConvertException] + val tidbErrorMsg = null compareTiDBWriteFailureWithJDBC( List(row), @@ -400,8 +401,8 @@ class SignedOverflowSuite extends BaseDataSourceTest("test_data_type_signed_over val row = Row("-9223372036854775809") val schema = StructType(List(StructField("c1", StringType))) val jdbcErrorClass = classOf[java.sql.BatchUpdateException] - val tidbErrorClass = classOf[java.lang.NumberFormatException] - val tidbErrorMsg = "For input string: \"-9223372036854775809\"" + val tidbErrorClass = classOf[TiDBConvertException] + val tidbErrorMsg = null compareTiDBWriteFailureWithJDBC( List(row), diff --git a/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala b/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala index 15a2364eaf..7384e4bd80 100644 --- a/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/overflow/UnsignedOverflowSuite.scala @@ -15,6 +15,7 @@ package com.pingcap.tispark.overflow +import com.pingcap.tikv.exception.TiDBConvertException import com.pingcap.tispark.datasource.BaseDataSourceTest import org.apache.spark.sql.Row import org.apache.spark.sql.types._ @@ -365,8 +366,8 @@ class UnsignedOverflowSuite extends BaseDataSourceTest("test_data_type_unsigned_ val row = Row("18446744073709551616") val schema = StructType(List(StructField("c1", StringType))) val jdbcErrorClass = classOf[java.sql.BatchUpdateException] - val tidbErrorClass = classOf[java.lang.NumberFormatException] - val tidbErrorMsg = "Too large for unsigned long: 18446744073709551616" + val tidbErrorClass = classOf[TiDBConvertException] + val tidbErrorMsg = null compareTiDBWriteFailureWithJDBC( List(row), @@ -387,8 +388,8 @@ class UnsignedOverflowSuite extends BaseDataSourceTest("test_data_type_unsigned_ val row = Row("-1") val schema = StructType(List(StructField("c1", StringType))) val jdbcErrorClass = classOf[java.sql.BatchUpdateException] - val tidbErrorClass = classOf[java.lang.NumberFormatException] - val tidbErrorMsg = "-1" + val tidbErrorClass = classOf[TiDBConvertException] + val tidbErrorMsg = null compareTiDBWriteFailureWithJDBC( List(row), diff --git a/core/src/test/scala/org/apache/spark/sql/insertion/EnumerateUniqueIndexDataTypeTestAction.scala b/core/src/test/scala/org/apache/spark/sql/insertion/EnumerateUniqueIndexDataTypeTestAction.scala index 138953b145..3a7eff39eb 100644 --- a/core/src/test/scala/org/apache/spark/sql/insertion/EnumerateUniqueIndexDataTypeTestAction.scala +++ b/core/src/test/scala/org/apache/spark/sql/insertion/EnumerateUniqueIndexDataTypeTestAction.scala @@ -26,7 +26,7 @@ trait EnumerateUniqueIndexDataTypeTestAction extends BaseEnumerateDataTypesTestS override def genIndex(dataTypes: List[ReflectedDataType], r: Random): List[List[Index]] = { val size = dataTypes.length // the first step is generate all possible keys - val keyList = scala.collection.mutable.ListBuffer.empty[List[Key]] + val keyList = scala.collection.mutable.ListBuffer.empty[List[UniqueKey]] for (i <- 1 until 3) { val combination = new Combinations(size, i) //(i, size) @@ -44,7 +44,7 @@ trait EnumerateUniqueIndexDataTypeTestAction extends BaseEnumerateDataTypesTestS } } - keyList += Key(indexColumnList.toList) :: Nil + keyList += UniqueKey(indexColumnList.toList) :: Nil } } diff --git a/core/src/test/scala/org/apache/spark/sql/test/generator/ColumnValueGenerator.scala b/core/src/test/scala/org/apache/spark/sql/test/generator/ColumnValueGenerator.scala index 1f720b7e26..abdcd13b38 100644 --- a/core/src/test/scala/org/apache/spark/sql/test/generator/ColumnValueGenerator.scala +++ b/core/src/test/scala/org/apache/spark/sql/test/generator/ColumnValueGenerator.scala @@ -143,8 +143,7 @@ case class ColumnValueGenerator( val nullString = if (!nullable) " not null" else "" val defaultString = if (!noDefault) s" default $default" else "" val unsignedString = if (isUnsigned) " unsigned" else "" - val uniqueString = if (isUnique) " unique" else "" - s"$unsignedString$nullString$uniqueString$defaultString" + s"$unsignedString$nullString$defaultString" } private var generatedRandomValues: List[Any] = List.empty[Any] private var curPos = 0 diff --git a/core/src/test/scala/org/apache/spark/sql/test/generator/Index.scala b/core/src/test/scala/org/apache/spark/sql/test/generator/Index.scala index df98b4ed2d..a88ce77cd2 100644 --- a/core/src/test/scala/org/apache/spark/sql/test/generator/Index.scala +++ b/core/src/test/scala/org/apache/spark/sql/test/generator/Index.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.test.generator trait Index { def indexColumns: List[IndexColumn] val isPrimaryKey: Boolean = false + val isUnique: Boolean = false } case class Key(indexColumns: List[IndexColumn]) extends Index {} @@ -27,3 +28,7 @@ case class Key(indexColumns: List[IndexColumn]) extends Index {} case class PrimaryKey(indexColumns: List[IndexColumn]) extends Index { override val isPrimaryKey: Boolean = true } + +case class UniqueKey(indexColumns: List[IndexColumn]) extends Index { + override val isUnique: Boolean = true +} diff --git a/core/src/test/scala/org/apache/spark/sql/test/generator/IndexColumn.scala b/core/src/test/scala/org/apache/spark/sql/test/generator/IndexColumn.scala index 60a13d2871..9c4fc47e19 100644 --- a/core/src/test/scala/org/apache/spark/sql/test/generator/IndexColumn.scala +++ b/core/src/test/scala/org/apache/spark/sql/test/generator/IndexColumn.scala @@ -101,11 +101,17 @@ case class IndexColumnInfo(column: String, length: Integer) { } } -case class IndexInfo(indexName: String, indexColumns: List[IndexColumnInfo], isPrimary: Boolean) { +case class IndexInfo( + indexName: String, + indexColumns: List[IndexColumnInfo], + isPrimary: Boolean, + isUnique: Boolean) { override def toString: String = { val indexColumnString = indexColumns.mkString("(", ",", ")") if (isPrimary) { s"PRIMARY KEY $indexColumnString" + } else if (isUnique) { + s"UNIQUE KEY $indexColumnString" } else { s"KEY `$indexName`$indexColumnString" } diff --git a/core/src/test/scala/org/apache/spark/sql/test/generator/Schema.scala b/core/src/test/scala/org/apache/spark/sql/test/generator/Schema.scala index a978e3df21..aae02045b3 100644 --- a/core/src/test/scala/org/apache/spark/sql/test/generator/Schema.scala +++ b/core/src/test/scala/org/apache/spark/sql/test/generator/Schema.scala @@ -34,7 +34,7 @@ case class Schema( tableName: String, columnNames: List[String], columnDesc: Map[String, (ReflectedDataType, (Integer, Integer), String)], - indexColumns: Map[String, (List[(String, Integer)], Boolean)]) { + indexColumns: Map[String, (List[(String, Integer)], Boolean, Boolean)]) { // validations assert(columnDesc.size == columnNames.size, "columnDesc size not equal to column name size") @@ -46,7 +46,8 @@ case class Schema( idx._2._1.map { x => IndexColumnInfo(x._1, x._2) }, - idx._2._2) + idx._2._2, + idx._2._3) }.toList assert(indexInfo.count(_.isPrimary) <= 1, "more than one primary key exist in schema") @@ -58,10 +59,17 @@ case class Schema( pkIndexInfo.head.indexColumns.map(_.column).mkString(",") } + val uniqueIndexInfo: List[IndexInfo] = indexInfo.filter(_.isUnique) + val uniqueColumnNames: List[String] = uniqueIndexInfo.map { indexInfo => + indexInfo.indexColumns.map(_.column).mkString(",") + } + val columnInfo: List[ColumnInfo] = columnNames.map { col => val x = columnDesc(col) if (col == pkColumnName) { ColumnInfo(col, x._1, x._2, x._3 + " primary key") + } else if (uniqueColumnNames.contains(col)) { + ColumnInfo(col, x._1, x._2, x._3 + " unique key") } else { ColumnInfo(col, x._1, x._2, x._3) } diff --git a/core/src/test/scala/org/apache/spark/sql/test/generator/TestDataGenerator.scala b/core/src/test/scala/org/apache/spark/sql/test/generator/TestDataGenerator.scala index 1846d23847..c449aaf8d6 100644 --- a/core/src/test/scala/org/apache/spark/sql/test/generator/TestDataGenerator.scala +++ b/core/src/test/scala/org/apache/spark/sql/test/generator/TestDataGenerator.scala @@ -246,17 +246,18 @@ object TestDataGenerator { } // single column primary key defined in schema - val singleColumnPrimaryKey: Map[String, (List[(String, Integer)], Boolean)] = columnDesc.toMap - .filter { colDesc => - colDesc._2._3.contains("primary key") - } - .map { x => - (x._1, (List((x._1, null)), true)) - } + val singleColumnPrimaryKey: Map[String, (List[(String, Integer)], Boolean, Boolean)] = + columnDesc.toMap + .filter { colDesc => + colDesc._2._3.contains("primary key") + } + .map { x => + (x._1, (List((x._1, null)), true, true)) + } assert(singleColumnPrimaryKey.size <= 1, "More than one primary key present in schema") - val idxColumns: Map[String, (List[(String, Integer)], Boolean)] = + val idxColumns: Map[String, (List[(String, Integer)], Boolean, Boolean)] = singleColumnPrimaryKey ++ indices.map { idx => val columns = idx.indexColumns.map(x => (columnNames(x.getId), x.getLength)) @@ -269,7 +270,7 @@ object TestDataGenerator { generateIndexName(columns.map { _._1 }), - (columns, idx.isPrimaryKey)) + (columns, idx.isPrimaryKey, idx.isUnique)) }.toMap Schema(database, table, columnNames, columnDesc.toMap, idxColumns) diff --git a/tikv-client/src/main/java/com/pingcap/tikv/KVClient.java b/tikv-client/src/main/java/com/pingcap/tikv/KVClient.java index f209c4246f..9fe5be5adf 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/KVClient.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/KVClient.java @@ -17,6 +17,7 @@ package com.pingcap.tikv; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.protobuf.ByteString; import com.pingcap.tikv.exception.GrpcException; import com.pingcap.tikv.exception.TiKVException; @@ -28,18 +29,14 @@ import com.pingcap.tikv.util.BackOffer; import com.pingcap.tikv.util.ConcreteBackOffer; import java.util.ArrayList; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -58,9 +55,10 @@ public KVClient(TiConfiguration conf, RegionStoreClientBuilder clientBuilder) { Objects.requireNonNull(clientBuilder, "clientBuilder is null"); this.conf = conf; this.clientBuilder = clientBuilder; - // TODO: ExecutorService executors = - // Executors.newFixedThreadPool(conf.getKVClientConcurrency()); - executorService = Executors.newFixedThreadPool(20); + executorService = + Executors.newFixedThreadPool( + conf.getKvClientConcurrency(), + new ThreadFactoryBuilder().setNameFormat("kvclient-pool-%d").setDaemon(true).build()); } @Override @@ -91,25 +89,15 @@ public ByteString get(ByteString key, long version) throws GrpcException { /** * Get a set of key-value pair by keys from TiKV * - * @param keys keys + * @param backOffer + * @param keys + * @param version + * @return + * @throws GrpcException */ - public List batchGet(List keys, long version) throws GrpcException { - return batchGet(ConcreteBackOffer.newBatchGetMaxBackOff(), keys, version); - } - - private List batchGet(BackOffer backOffer, List keys, long version) { - Set set = new HashSet<>(keys); - return batchGet(backOffer, set, version); - } - - private List batchGet(BackOffer backOffer, Set keys, long version) { - Map> groupKeys = groupKeysByRegion(keys); - List batches = new ArrayList<>(); - - for (Map.Entry> entry : groupKeys.entrySet()) { - appendBatches(batches, entry.getKey(), entry.getValue(), BATCH_GET_SIZE); - } - return sendBatchGet(backOffer, batches, version); + public List batchGet(BackOffer backOffer, List keys, long version) + throws GrpcException { + return doSendBatchGet(backOffer, keys, version); } /** @@ -147,26 +135,101 @@ public List scan(ByteString startKey, long version) throws GrpcE return scan(startKey, version, Integer.MAX_VALUE); } + private List doSendBatchGet(BackOffer backOffer, List keys, long version) { + ExecutorCompletionService> completionService = + new ExecutorCompletionService<>(executorService); + + Map> groupKeys = groupKeysByRegion(keys); + List batches = new ArrayList<>(); + + for (Map.Entry> entry : groupKeys.entrySet()) { + appendBatches(batches, entry.getKey(), entry.getValue(), BATCH_GET_SIZE); + } + + for (Batch batch : batches) { + BackOffer singleBatchBackOffer = ConcreteBackOffer.create(backOffer); + completionService.submit( + () -> doSendBatchGetInBatchesWithRetry(singleBatchBackOffer, batch, version)); + } + + try { + List result = new ArrayList<>(); + for (int i = 0; i < batches.size(); i++) { + result.addAll(completionService.take().get()); + } + return result; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new TiKVException("Current thread interrupted.", e); + } catch (ExecutionException e) { + throw new TiKVException("Execution exception met.", e); + } + } + + private List doSendBatchGetInBatchesWithRetry( + BackOffer backOffer, Batch batch, long version) { + TiRegion oldRegion = batch.region; + TiRegion currentRegion = + clientBuilder.getRegionManager().getRegionByKey(oldRegion.getStartKey()); + + if (oldRegion.equals(currentRegion)) { + RegionStoreClient client = clientBuilder.build(batch.region); + try { + return client.batchGet(backOffer, batch.keys, version); + } catch (final TiKVException e) { + backOffer.doBackOff(BackOffFunction.BackOffFuncType.BoRegionMiss, e); + clientBuilder.getRegionManager().invalidateRegion(batch.region.getId()); + logger.warn("ReSplitting ranges for BatchGetRequest", e); + + // retry + return doSendBatchGetWithRefetchRegion(backOffer, batch, version); + } + } else { + return doSendBatchGetWithRefetchRegion(backOffer, batch, version); + } + } + + private List doSendBatchGetWithRefetchRegion( + BackOffer backOffer, Batch batch, long version) { + Map> groupKeys = groupKeysByRegion(batch.keys); + List retryBatches = new ArrayList<>(); + + for (Map.Entry> entry : groupKeys.entrySet()) { + appendBatches(retryBatches, entry.getKey(), entry.getValue(), BATCH_GET_SIZE); + } + + ArrayList results = new ArrayList<>(); + for (Batch retryBatch : retryBatches) { + // recursive calls + List batchResult = doSendBatchGetInBatchesWithRetry(backOffer, retryBatch, version); + results.addAll(batchResult); + } + return results; + } + /** * Append batch to list and split them according to batch limit * * @param batches a grouped batch * @param region region * @param keys keys - * @param limit batch max limit + * @param batchGetMaxSizeInByte batch max limit */ private void appendBatches( - List batches, TiRegion region, List keys, int limit) { - List tmpKeys = new ArrayList<>(); - for (int i = 0; i < keys.size(); i++) { - if (i >= limit) { - batches.add(new Batch(region, tmpKeys)); - tmpKeys.clear(); - } - tmpKeys.add(keys.get(i)); + List batches, TiRegion region, List keys, int batchGetMaxSizeInByte) { + int start; + int end; + if (keys == null) { + return; } - if (!tmpKeys.isEmpty()) { - batches.add(new Batch(region, tmpKeys)); + int len = keys.size(); + for (start = 0; start < len; start = end) { + int size = 0; + for (end = start; end < len && size < batchGetMaxSizeInByte; end++) { + size += keys.get(end).size(); + } + Batch batch = new Batch(region, keys.subList(start, end)); + batches.add(batch); } } @@ -176,54 +239,11 @@ private void appendBatches( * @param keys keys * @return a mapping of keys and their region */ - private Map> groupKeysByRegion(Set keys) { + private Map> groupKeysByRegion(List keys) { return keys.stream() .collect(Collectors.groupingBy(clientBuilder.getRegionManager()::getRegionByKey)); } - /** - * Send batchPut request concurrently - * - * @param backOffer current backOffer - * @param batches list of batch to send - */ - private List sendBatchGet(BackOffer backOffer, List batches, long version) { - ExecutorCompletionService> completionService = - new ExecutorCompletionService<>(executorService); - for (Batch batch : batches) { - completionService.submit( - () -> { - RegionStoreClient client = clientBuilder.build(batch.region); - BackOffer singleBatchBackOffer = ConcreteBackOffer.create(backOffer); - List keys = batch.keys; - try { - return client.batchGet(singleBatchBackOffer, keys, version); - } catch (final TiKVException e) { - // TODO: any elegant way to re-split the ranges if fails? - singleBatchBackOffer.doBackOff(BackOffFunction.BackOffFuncType.BoRegionMiss, e); - logger.warn("ReSplitting ranges for BatchGetRequest"); - // recursive calls - return batchGet(singleBatchBackOffer, batch.keys, version); - } - }); - } - try { - List result = new ArrayList<>(); - for (int i = 0; i < batches.size(); i++) { - result.addAll( - completionService.take().get(BackOffer.BATCH_GET_MAX_BACKOFF, TimeUnit.SECONDS)); - } - return result; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new TiKVException("Current thread interrupted.", e); - } catch (TimeoutException e) { - throw new TiKVException("TimeOut Exceeded for current operation. ", e); - } catch (ExecutionException e) { - throw new TiKVException("Execution exception met.", e); - } - } - private Iterator scanIterator( TiConfiguration conf, RegionStoreClientBuilder builder, diff --git a/tikv-client/src/main/java/com/pingcap/tikv/PDClient.java b/tikv-client/src/main/java/com/pingcap/tikv/PDClient.java index d52467d7dd..2cb0b39151 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/PDClient.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/PDClient.java @@ -24,14 +24,18 @@ import com.google.protobuf.ByteString; import com.pingcap.tikv.codec.Codec.BytesCodec; import com.pingcap.tikv.codec.CodecDataOutput; +import com.pingcap.tikv.codec.KeyUtils; import com.pingcap.tikv.exception.GrpcException; import com.pingcap.tikv.exception.TiClientInternalException; import com.pingcap.tikv.meta.TiTimestamp; +import com.pingcap.tikv.operation.NoopHandler; import com.pingcap.tikv.operation.PDErrorHandler; import com.pingcap.tikv.pd.PDUtils; import com.pingcap.tikv.region.TiRegion; +import com.pingcap.tikv.util.BackOffFunction.BackOffFuncType; import com.pingcap.tikv.util.BackOffer; import com.pingcap.tikv.util.ChannelFactory; +import com.pingcap.tikv.util.ConcreteBackOffer; import com.pingcap.tikv.util.FutureObserver; import io.etcd.jetcd.ByteSequence; import io.etcd.jetcd.Client; @@ -57,15 +61,23 @@ import org.tikv.kvproto.PDGrpc; import org.tikv.kvproto.PDGrpc.PDBlockingStub; import org.tikv.kvproto.PDGrpc.PDStub; +import org.tikv.kvproto.Pdpb.Error; +import org.tikv.kvproto.Pdpb.ErrorType; import org.tikv.kvproto.Pdpb.GetAllStoresRequest; import org.tikv.kvproto.Pdpb.GetMembersRequest; import org.tikv.kvproto.Pdpb.GetMembersResponse; +import org.tikv.kvproto.Pdpb.GetOperatorRequest; +import org.tikv.kvproto.Pdpb.GetOperatorResponse; import org.tikv.kvproto.Pdpb.GetRegionByIDRequest; import org.tikv.kvproto.Pdpb.GetRegionRequest; import org.tikv.kvproto.Pdpb.GetRegionResponse; import org.tikv.kvproto.Pdpb.GetStoreRequest; import org.tikv.kvproto.Pdpb.GetStoreResponse; +import org.tikv.kvproto.Pdpb.OperatorStatus; import org.tikv.kvproto.Pdpb.RequestHeader; +import org.tikv.kvproto.Pdpb.ResponseHeader; +import org.tikv.kvproto.Pdpb.ScatterRegionRequest; +import org.tikv.kvproto.Pdpb.ScatterRegionResponse; import org.tikv.kvproto.Pdpb.Timestamp; import org.tikv.kvproto.Pdpb.TsoRequest; import org.tikv.kvproto.Pdpb.TsoResponse; @@ -112,6 +124,87 @@ public TiTimestamp getTimestamp(BackOffer backOffer) { return new TiTimestamp(timestamp.getPhysical(), timestamp.getLogical()); } + /** + * Sends request to pd to scatter region. + * + * @param region represents a region info + */ + void scatterRegion(TiRegion region, BackOffer backOffer) { + Supplier request = + () -> + ScatterRegionRequest.newBuilder().setHeader(header).setRegionId(region.getId()).build(); + + PDErrorHandler handler = + new PDErrorHandler<>( + r -> r.getHeader().hasError() ? buildFromPdpbError(r.getHeader().getError()) : null, + this); + + ScatterRegionResponse resp = + callWithRetry(backOffer, PDGrpc.getScatterRegionMethod(), request, handler); + // TODO: maybe we should retry here, need dig into pd's codebase. + if (resp.hasHeader() && resp.getHeader().hasError()) { + throw new TiClientInternalException( + String.format("failed to scatter region because %s", resp.getHeader().getError())); + } + } + + /** + * wait scatter region until finish + * + * @param region + */ + void waitScatterRegionFinish(TiRegion region, BackOffer backOffer) { + for (; ; ) { + GetOperatorResponse resp = getOperator(region.getId()); + if (resp != null) { + if (isScatterRegionFinish(resp)) { + logger.info(String.format("wait scatter region on %d is finished", region.getId())); + return; + } else { + backOffer.doBackOff( + BackOffFuncType.BoRegionMiss, new GrpcException("waiting scatter region")); + logger.info( + String.format( + "wait scatter region %d at key %s is %s", + region.getId(), + KeyUtils.formatBytes(resp.getDesc().toByteArray()), + resp.getStatus().toString())); + } + } + } + } + + private GetOperatorResponse getOperator(long regionId) { + Supplier request = + () -> GetOperatorRequest.newBuilder().setHeader(header).setRegionId(regionId).build(); + // get operator no need to handle error and no need back offer. + return callWithRetry( + ConcreteBackOffer.newCustomBackOff(0), + PDGrpc.getGetOperatorMethod(), + request, + new NoopHandler<>()); + } + + private boolean isScatterRegionFinish(GetOperatorResponse resp) { + // If the current operator of region is not `scatter-region`, we could assume + // that `scatter-operator` has finished or timeout. + boolean finished = + !resp.getDesc().equals(ByteString.copyFromUtf8("scatter-region")) + || resp.getStatus() != OperatorStatus.RUNNING; + + if (resp.hasHeader()) { + ResponseHeader header = resp.getHeader(); + if (header.hasError()) { + Error error = header.getError(); + // heartbeat may not send to PD + if (error.getType() == ErrorType.REGION_NOT_FOUND) { + finished = true; + } + } + } + return finished; + } + @Override public TiRegion getRegionByKey(BackOffer backOffer, ByteString key) { CodecDataOutput cdo = new CodecDataOutput(); diff --git a/tikv-client/src/main/java/com/pingcap/tikv/Snapshot.java b/tikv-client/src/main/java/com/pingcap/tikv/Snapshot.java index 976c857ecd..62fcfe40f0 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/Snapshot.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/Snapshot.java @@ -27,6 +27,7 @@ import com.pingcap.tikv.operation.iterator.ConcreteScanIterator; import com.pingcap.tikv.operation.iterator.IndexScanIterator; import com.pingcap.tikv.row.Row; +import com.pingcap.tikv.util.ConcreteBackOffer; import com.pingcap.tikv.util.RangeSplitter; import com.pingcap.tikv.util.RangeSplitter.RegionTask; import java.util.ArrayList; @@ -66,25 +67,28 @@ public byte[] get(byte[] key) { } public ByteString get(ByteString key) { - return new KVClient(session.getConf(), session.getRegionStoreClientBuilder()) - .get(key, timestamp.getVersion()); + try (KVClient client = new KVClient(session.getConf(), session.getRegionStoreClientBuilder())) { + return client.get(key, timestamp.getVersion()); + } } - public List batchGet(List keys) { + public List batchGet(int backOffer, List keys) { List list = new ArrayList<>(); for (byte[] key : keys) { list.add(ByteString.copyFrom(key)); } - - List kvPairList = - new KVClient(session.getConf(), session.getRegionStoreClientBuilder()) - .batchGet(list, timestamp.getVersion()); - return kvPairList - .stream() - .map( - kvPair -> - new BytePairWrapper(kvPair.getKey().toByteArray(), kvPair.getValue().toByteArray())) - .collect(Collectors.toList()); + try (KVClient client = new KVClient(session.getConf(), session.getRegionStoreClientBuilder())) { + List kvPairList = + client.batchGet( + ConcreteBackOffer.newCustomBackOff(backOffer), list, timestamp.getVersion()); + return kvPairList + .stream() + .map( + kvPair -> + new BytePairWrapper( + kvPair.getKey().toByteArray(), kvPair.getValue().toByteArray())) + .collect(Collectors.toList()); + } } public Iterator tableReadChunk( @@ -127,6 +131,11 @@ private Iterator tableReadRow(TiDAGRequest dagRequest, List tas } } + public Iterator indexHandleReadChunk( + TiDAGRequest dagRequest, List tasks, int numOfRows) { + return getTiChunkIterator(dagRequest, tasks, getSession(), numOfRows); + } + /** * Below is lower level API for env like Spark which already did key range split Perform handle * scan @@ -135,7 +144,7 @@ private Iterator tableReadRow(TiDAGRequest dagRequest, List tas * @param tasks RegionTask of the coprocessor request to send * @return Row iterator to iterate over resulting rows */ - public Iterator indexHandleRead(TiDAGRequest dagRequest, List tasks) { + public Iterator indexHandleReadRow(TiDAGRequest dagRequest, List tasks) { return getHandleIterator(dagRequest, tasks, session); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/TiConfiguration.java b/tikv-client/src/main/java/com/pingcap/tikv/TiConfiguration.java index 3212d1ad81..500df5ea1d 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/TiConfiguration.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/TiConfiguration.java @@ -54,6 +54,7 @@ public class TiConfiguration implements Serializable { private static final boolean DEF_WRITE_WITHOUT_LOCK_TABLE = false; private static final int DEF_TIKV_REGION_SPLIT_SIZE_IN_MB = 96; private static final int DEF_PARTITION_PER_SPLIT = 1; + private static final int DEF_KV_CLIENT_CONCURRENCY = 10; private static final List DEF_ISOLATION_READ_ENGINES = ImmutableList.of(TiStoreType.TiKV, TiStoreType.TiFlash); @@ -79,6 +80,8 @@ public class TiConfiguration implements Serializable { private int tikvRegionSplitSizeInMB = DEF_TIKV_REGION_SPLIT_SIZE_IN_MB; private int partitionPerSplit = DEF_PARTITION_PER_SPLIT; + private int kvClientConcurrency = DEF_KV_CLIENT_CONCURRENCY; + private List isolationReadEngines = DEF_ISOLATION_READ_ENGINES; public static TiConfiguration createDefault(String pdAddrsStr) { @@ -291,4 +294,12 @@ public List getIsolationReadEngines() { public void setIsolationReadEngines(List isolationReadEngines) { this.isolationReadEngines = isolationReadEngines; } + + public int getKvClientConcurrency() { + return kvClientConcurrency; + } + + public void setKvClientConcurrency(int kvClientConcurrency) { + this.kvClientConcurrency = kvClientConcurrency; + } } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/TiDBJDBCClient.java b/tikv-client/src/main/java/com/pingcap/tikv/TiDBJDBCClient.java index bc9881753a..4d1ac592d9 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/TiDBJDBCClient.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/TiDBJDBCClient.java @@ -40,12 +40,21 @@ public class TiDBJDBCClient implements AutoCloseable { private static final String ENABLE_SPLIT_TABLE_KEY = "split-table"; private static final Boolean ENABLE_SPLIT_TABLE_DEFAULT = false; private static final String TIDB_ROW_FORMAT_VERSION_SQL = "select @@tidb_row_format_version"; + private static final String TIDB_SET_WAIT_SPLIT_REGION_FINISH = + "set @@tidb_wait_split_region_finish=%d;"; private static final int TIDB_ROW_FORMAT_VERSION_DEFAULT = 1; private final Logger logger = LoggerFactory.getLogger(getClass().getName()); private final Connection connection; + private final int waitSplitRegionFinish; public TiDBJDBCClient(Connection connection) { this.connection = connection; + this.waitSplitRegionFinish = 1; + } + + public TiDBJDBCClient(Connection connection, int waitSplitRegionFinish) { + this.connection = connection; + this.waitSplitRegionFinish = waitSplitRegionFinish; } public boolean isEnableTableLock() throws IOException, SQLException { @@ -96,6 +105,12 @@ public int getRowFormatVersion() { } } + private void setTiDBWriteSplitRegionFinish() throws SQLException { + if (waitSplitRegionFinish == 0 || waitSplitRegionFinish == 1) { + executeUpdate(String.format(TIDB_SET_WAIT_SPLIT_REGION_FINISH, waitSplitRegionFinish)); + } + } + public boolean lockTableWriteLocal(String databaseName, String tableName) throws SQLException { try (Statement tidbStmt = connection.createStatement()) { String sql = "lock tables `" + databaseName + "`.`" + tableName + "` write local"; @@ -144,19 +159,42 @@ public boolean isEnableSplitRegion() throws IOException, SQLException { */ public void splitTableRegion( String dbName, String tblName, long minVal, long maxVal, long regionNum) throws SQLException { + + setTiDBWriteSplitRegionFinish(); + + if (minVal < maxVal) { + try (Statement tidbStmt = connection.createStatement()) { + String sql = + String.format( + "split table `%s`.`%s` between (%d) and (%d) regions %d", + dbName, tblName, minVal, maxVal, regionNum); + logger.warn("split table region: " + sql); + tidbStmt.execute(sql); + } catch (SQLException e) { + logger.warn("failed to split table region", e); + throw e; + } + } else { + logger.warn("try to split table region with minVal >= maxVal, skipped"); + } + } + + public void splitIndexRegion(String dbName, String tblName, String idxName, String valueList) + throws SQLException { + + setTiDBWriteSplitRegionFinish(); + try (Statement tidbStmt = connection.createStatement()) { String sql = String.format( - "split table `%s`.`%s` between (%d) and (%d) regions %d", - dbName, tblName, minVal, maxVal, regionNum); - logger.info("split table region: " + sql); + "split table `%s`.`%s` index `%s` by %s", dbName, tblName, idxName, valueList); + logger.warn("split index region: " + sql); tidbStmt.execute(sql); } catch (SQLException e) { - logger.warn("failed to split table region", e); + logger.warn("failed to split index region", e); throw e; } } - /** * split index region by calling tidb jdbc command `SPLIT TABLE`, e.g. SPLIT TABLE t INDEX idx * BETWEEN ("2010-01-01 00:00:00") AND ("2020-01-01 00:00:00") REGIONS 16; @@ -176,16 +214,23 @@ public void splitIndexRegion( String maxIndexVal, long regionNum) throws SQLException { - try (Statement tidbStmt = connection.createStatement()) { - String sql = - String.format( - "split table `%s`.`%s` index %s between (\"%s\") and (\"%s\") regions %d", - dbName, tblName, idxName, minIndexVal, maxIndexVal, regionNum); - logger.info("split index region: " + sql); - tidbStmt.execute(sql); - } catch (SQLException e) { - logger.warn("failed to split index region", e); - throw e; + + setTiDBWriteSplitRegionFinish(); + + if (!minIndexVal.equals(maxIndexVal)) { + try (Statement tidbStmt = connection.createStatement()) { + String sql = + String.format( + "split table `%s`.`%s` index `%s` between (\"%s\") and (\"%s\") regions %d", + dbName, tblName, idxName, minIndexVal, maxIndexVal, regionNum); + logger.warn("split index region: " + sql); + tidbStmt.execute(sql); + } catch (SQLException e) { + logger.warn("failed to split index region", e); + throw e; + } + } else { + logger.warn("try to split index region with minVal = maxVal, skipped"); } } @@ -198,6 +243,12 @@ public void close() throws Exception { connection.close(); } + private int executeUpdate(String sql) throws SQLException { + try (Statement tidbStmt = connection.createStatement()) { + return tidbStmt.executeUpdate(sql); + } + } + private List> queryTiDBViaJDBC(String query) throws SQLException { ArrayList> result = new ArrayList<>(); diff --git a/tikv-client/src/main/java/com/pingcap/tikv/TiSession.java b/tikv-client/src/main/java/com/pingcap/tikv/TiSession.java index beba4206a4..28d5734138 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/TiSession.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/TiSession.java @@ -16,21 +16,33 @@ package com.pingcap.tikv; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import com.google.protobuf.ByteString; import com.pingcap.tikv.catalog.Catalog; import com.pingcap.tikv.event.CacheInvalidateEvent; +import com.pingcap.tikv.exception.TiClientInternalException; +import com.pingcap.tikv.exception.TiKVException; +import com.pingcap.tikv.key.Key; import com.pingcap.tikv.meta.TiTimestamp; import com.pingcap.tikv.region.RegionManager; import com.pingcap.tikv.region.RegionStoreClient; +import com.pingcap.tikv.region.TiRegion; import com.pingcap.tikv.txn.TxnKVClient; +import com.pingcap.tikv.util.BackOffFunction; +import com.pingcap.tikv.util.BackOffer; import com.pingcap.tikv.util.ChannelFactory; import com.pingcap.tikv.util.ConcreteBackOffer; +import com.pingcap.tikv.util.Pair; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.Function; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.tikv.kvproto.Metapb; public class TiSession implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(TiSession.class); @@ -149,7 +161,10 @@ public ExecutorService getThreadPoolForIndexScan() { indexScanThreadPool = Executors.newFixedThreadPool( conf.getIndexScanConcurrency(), - new ThreadFactoryBuilder().setDaemon(true).build()); + new ThreadFactoryBuilder() + .setNameFormat("index-scan-pool-%d") + .setDaemon(true) + .build()); } res = indexScanThreadPool; } @@ -182,6 +197,98 @@ public void injectCallBackFunc(Function callBackFunc this.cacheInvalidateCallback = callBackFunc; } + /** + * split region and scatter + * + * @param splitKeys + */ + public void splitRegionAndScatter( + List splitKeys, int splitRegionBackoffMS, int scatterWaitMS) { + logger.info(String.format("split key's size is %d", splitKeys.size())); + long startMS = System.currentTimeMillis(); + + BackOffer splitRegionBackoff = ConcreteBackOffer.newCustomBackOff(splitRegionBackoffMS); + // split region + List newRegions = + splitRegion( + splitKeys + .stream() + .map(k -> Key.toRawKey(k).next().toByteString()) + .collect(Collectors.toList()), + splitRegionBackoff); + + // scatter region + for (TiRegion newRegion : newRegions) { + getPDClient().scatterRegion(newRegion, splitRegionBackoff); + } + + // wait scatter region finish + if (scatterWaitMS > 0) { + logger.info("start to wait scatter region finish"); + long scatterRegionStartMS = System.currentTimeMillis(); + for (TiRegion newRegion : newRegions) { + long remainMS = (scatterRegionStartMS + scatterWaitMS) - System.currentTimeMillis(); + if (remainMS <= 0) { + logger.warn("wait scatter region timeout"); + return; + } + getPDClient() + .waitScatterRegionFinish(newRegion, ConcreteBackOffer.newCustomBackOff((int) remainMS)); + } + } else { + logger.info("skip to wait scatter region finish"); + } + + long endMS = System.currentTimeMillis(); + logger.info("splitRegionAndScatter cost {} seconds", (endMS - startMS) / 1000); + } + + private List splitRegion(List splitKeys, BackOffer backOffer) { + List regions = new ArrayList<>(); + + Map> groupKeys = groupKeysByRegion(splitKeys); + for (Map.Entry> entry : groupKeys.entrySet()) { + + Pair pair = + getRegionManager().getRegionStorePairByKey(entry.getKey().getStartKey()); + TiRegion region = pair.first; + Metapb.Store store = pair.second; + List splits = + entry + .getValue() + .stream() + .filter(k -> !k.equals(region.getStartKey()) && !k.equals(region.getEndKey())) + .collect(Collectors.toList()); + + if (splits.isEmpty()) { + logger.warn( + "split key equal to region start key or end key. Region splitting is not needed."); + } else { + logger.info("start to split region id={}, split size={}", region.getId(), splits.size()); + List newRegions; + try { + newRegions = getRegionStoreClientBuilder().build(region, store).splitRegion(splits); + } catch (final TiKVException | TiClientInternalException e) { + // retry + logger.warn("ReSplitting ranges for splitRegion", e); + clientBuilder.getRegionManager().invalidateRegion(region.getId()); + backOffer.doBackOff(BackOffFunction.BackOffFuncType.BoRegionMiss, e); + newRegions = splitRegion(splits, backOffer); + } + logger.info("region id={}, new region size={}", region.getId(), newRegions.size()); + regions.addAll(newRegions); + } + } + + logger.info("splitRegion: return region size={}", regions.size()); + return regions; + } + + private Map> groupKeysByRegion(List keys) { + return keys.stream() + .collect(Collectors.groupingBy(clientBuilder.getRegionManager()::getRegionByKey)); + } + @Override public synchronized void close() throws Exception { if (isClosed) { diff --git a/tikv-client/src/main/java/com/pingcap/tikv/TwoPhaseCommitter.java b/tikv-client/src/main/java/com/pingcap/tikv/TwoPhaseCommitter.java index 2e136b3e45..4e0da2b567 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/TwoPhaseCommitter.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/TwoPhaseCommitter.java @@ -15,6 +15,7 @@ package com.pingcap.tikv; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.protobuf.ByteString; import com.pingcap.tikv.codec.KeyUtils; import com.pingcap.tikv.exception.GrpcException; @@ -37,6 +38,10 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.tikv.kvproto.Kvrpcpb; @@ -65,24 +70,62 @@ public class TwoPhaseCommitter { /** unit is millisecond */ private final long lockTTL; + private final boolean retryCommitSecondaryKeys; + private final TxnKVClient kvClient; private final RegionManager regionManager; + private final long txnPrewriteBatchSize; + private final long txnCommitBatchSize; + private final int writeBufferSize; + private final int writeThreadPerTask; + private final ExecutorService executorService; + public TwoPhaseCommitter(TiConfiguration conf, long startTime) { this.kvClient = TiSession.getInstance(conf).createTxnClient(); this.regionManager = kvClient.getRegionManager(); this.startTs = startTime; this.lockTTL = DEFAULT_BATCH_WRITE_LOCK_TTL; + this.retryCommitSecondaryKeys = true; + this.txnPrewriteBatchSize = TXN_COMMIT_BATCH_SIZE; + this.txnCommitBatchSize = TXN_COMMIT_BATCH_SIZE; + this.writeBufferSize = WRITE_BUFFER_SIZE; + this.writeThreadPerTask = 1; + this.executorService = createExecutorService(); } - public TwoPhaseCommitter(TiConfiguration conf, long startTime, long lockTTL) { + public TwoPhaseCommitter( + TiConfiguration conf, + long startTime, + long lockTTL, + long txnPrewriteBatchSize, + long txnCommitBatchSize, + int writeBufferSize, + int writeThreadPerTask, + boolean retryCommitSecondaryKeys) { this.kvClient = TiSession.getInstance(conf).createTxnClient(); this.regionManager = kvClient.getRegionManager(); this.startTs = startTime; this.lockTTL = lockTTL; + this.retryCommitSecondaryKeys = retryCommitSecondaryKeys; + this.txnPrewriteBatchSize = txnPrewriteBatchSize; + this.txnCommitBatchSize = txnCommitBatchSize; + this.writeBufferSize = writeBufferSize; + this.writeThreadPerTask = writeThreadPerTask; + this.executorService = createExecutorService(); } - public void close() throws Exception {} + private ExecutorService createExecutorService() { + return Executors.newFixedThreadPool( + writeThreadPerTask, + new ThreadFactoryBuilder().setNameFormat("2pc-pool-%d").setDaemon(true).build()); + } + + public void close() throws Exception { + if (executorService != null) { + executorService.shutdownNow(); + } + } /** * 2pc - prewrite primary key @@ -138,7 +181,7 @@ private void doPrewritePrimaryKeyWithRetry(BackOffer backOffer, ByteString key, } } - LOG.debug("prewrite primary key {} successfully", KeyUtils.formatBytes(key)); + LOG.info("prewrite primary key {} successfully", KeyUtils.formatBytes(key)); } /** @@ -178,7 +221,7 @@ private void doCommitPrimaryKeyWithRetry(BackOffer backOffer, ByteString key, lo } } - LOG.debug("commit primary key {} successfully", KeyUtils.formatBytes(key)); + LOG.info("commit primary key {} successfully", KeyUtils.formatBytes(key)); } /** @@ -188,7 +231,8 @@ private void doCommitPrimaryKeyWithRetry(BackOffer backOffer, ByteString key, lo * @param pairs * @return */ - public void prewriteSecondaryKeys(byte[] primaryKey, Iterator pairs) + public void prewriteSecondaryKeys( + byte[] primaryKey, Iterator pairs, int maxBackOfferMS) throws TiBatchWriteException { Iterator> byteStringKeys = new Iterator>() { @@ -206,28 +250,54 @@ public Pair next() { } }; - doPrewriteSecondaryKeys(ByteString.copyFrom(primaryKey), byteStringKeys); + doPrewriteSecondaryKeys(ByteString.copyFrom(primaryKey), byteStringKeys, maxBackOfferMS); } private void doPrewriteSecondaryKeys( - ByteString primaryKey, Iterator> pairs) + ByteString primaryKey, Iterator> pairs, int maxBackOfferMS) throws TiBatchWriteException { - int totalSize = 0; - while (pairs.hasNext()) { - ByteString[] keyBytes = new ByteString[WRITE_BUFFER_SIZE]; - ByteString[] valueBytes = new ByteString[WRITE_BUFFER_SIZE]; - int size = 0; - while (size < WRITE_BUFFER_SIZE && pairs.hasNext()) { - Pair pair = pairs.next(); - keyBytes[size] = pair.first; - valueBytes[size] = pair.second; - size++; + try { + int taskBufferSize = writeThreadPerTask * 2; + int totalSize = 0, cnt = 0; + Pair pair; + ExecutorCompletionService completionService = + new ExecutorCompletionService<>(executorService); + while (pairs.hasNext()) { + int size = 0; + ByteString[] keyBytes = new ByteString[writeBufferSize]; + ByteString[] valueBytes = new ByteString[writeBufferSize]; + while (size < writeBufferSize && pairs.hasNext()) { + pair = pairs.next(); + keyBytes[size] = pair.first; + valueBytes[size] = pair.second; + size++; + } + int curSize = size; + cnt++; + if (cnt > taskBufferSize) { + // consume one task if reaches task limit + completionService.take().get(); + } + BackOffer backOffer = ConcreteBackOffer.newCustomBackOff(maxBackOfferMS); + completionService.submit( + () -> { + doPrewriteSecondaryKeysInBatchesWithRetry( + backOffer, primaryKey, keyBytes, valueBytes, curSize, 0); + return null; + }); + + totalSize = totalSize + size; } - BackOffer backOffer = ConcreteBackOffer.newCustomBackOff(BackOffer.BATCH_PREWRITE_BACKOFF); - doPrewriteSecondaryKeysInBatchesWithRetry( - backOffer, primaryKey, keyBytes, valueBytes, size, 0); - totalSize = totalSize + size; + for (int i = 0; i < Math.min(taskBufferSize, cnt); i++) { + completionService.take().get(); + } + + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new TiBatchWriteException("Current thread interrupted.", e); + } catch (ExecutionException e) { + throw new TiBatchWriteException("Execution exception met.", e); } } @@ -264,16 +334,17 @@ private void doPrewriteSecondaryKeysInBatchesWithRetry( GroupKeyResult groupResult = this.groupKeysByRegion(keys, size); List batchKeyList = new LinkedList<>(); Map, List> groupKeyMap = groupResult.getGroupsResult(); - for (Pair pair : groupKeyMap.keySet()) { - TiRegion tiRegion = pair.first; - Metapb.Store store = pair.second; - this.appendBatchBySize(batchKeyList, tiRegion, store, groupKeyMap.get(pair), true, mutations); + + for (Map.Entry, List> entry : groupKeyMap.entrySet()) { + TiRegion tiRegion = entry.getKey().first; + Metapb.Store store = entry.getKey().second; + this.appendBatchBySize(batchKeyList, tiRegion, store, entry.getValue(), true, mutations); } // For prewrite, stop sending other requests after receiving first error. for (BatchKeys batchKeys : batchKeyList) { TiRegion oldRegion = batchKeys.getRegion(); - TiRegion currentRegion = this.regionManager.getRegionById(oldRegion.getId()); + TiRegion currentRegion = this.regionManager.getRegionByKey(oldRegion.getStartKey()); if (oldRegion.equals(currentRegion)) { doPrewriteSecondaryKeySingleBatchWithRetry(backOffer, primaryKey, batchKeys, mutations); } else { @@ -283,7 +354,7 @@ private void doPrewriteSecondaryKeysInBatchesWithRetry( "> max retry number %s, oldRegion=%s, currentRegion=%s", MAX_RETRY_TIMES, oldRegion, currentRegion)); } - LOG.debug( + LOG.info( String.format( "oldRegion=%s != currentRegion=%s, will re-fetch region info and retry", oldRegion, currentRegion)); @@ -318,7 +389,11 @@ private void doPrewriteSecondaryKeySingleBatchWithRetry( BatchKeys batchKeys, Map mutations) throws TiBatchWriteException { - LOG.debug("start prewrite secondary key, size={}", batchKeys.getKeys().size()); + LOG.info( + "start prewrite secondary key, row={}, size={}KB, regionId={}", + batchKeys.getKeys().size(), + batchKeys.getSizeInKB(), + batchKeys.getRegion().getId()); List keyList = batchKeys.getKeys(); int batchSize = keyList.size(); @@ -343,7 +418,7 @@ private void doPrewriteSecondaryKeySingleBatchWithRetry( "prewrite secondary key error", prewriteResult.getException()); } if (prewriteResult.isRetry()) { - LOG.debug("prewrite secondary key fail, will backoff and retry"); + LOG.info("prewrite secondary key fail, will backoff and retry"); try { backOffer.doBackOff( BackOffFunction.BackOffFuncType.BoRegionMiss, @@ -362,7 +437,11 @@ private void doPrewriteSecondaryKeySingleBatchWithRetry( throw new TiBatchWriteException(errorMsg, e); } } - LOG.debug("prewrite secondary key successfully, size={}", batchKeys.getKeys().size()); + LOG.info( + "prewrite secondary key successfully, row={}, size={}KB, regionId={}", + batchKeys.getKeys().size(), + batchKeys.getSizeInKB(), + batchKeys.getRegion().getId()); } private void appendBatchBySize( @@ -372,19 +451,24 @@ private void appendBatchBySize( List keys, boolean sizeIncludeValue, Map mutations) { + long commitBatchSize = sizeIncludeValue ? txnPrewriteBatchSize : txnCommitBatchSize; + int start; int end; + if (keys == null) { + return; + } int len = keys.size(); for (start = 0; start < len; start = end) { - int size = 0; - for (end = start; end < len && size < TXN_COMMIT_BATCH_SIZE; end++) { + int sizeInBytes = 0; + for (end = start; end < len && sizeInBytes < commitBatchSize; end++) { if (sizeIncludeValue) { - size += this.keyValueSize(keys.get(end), mutations); + sizeInBytes += this.keyValueSize(keys.get(end), mutations); } else { - size += this.keySize(keys.get(end)); + sizeInBytes += this.keySize(keys.get(end)); } } - BatchKeys batchKeys = new BatchKeys(tiRegion, store, keys.subList(start, end)); + BatchKeys batchKeys = new BatchKeys(tiRegion, store, keys.subList(start, end), sizeInBytes); batchKeyList.add(batchKeys); } } @@ -410,7 +494,7 @@ private long keySize(ByteString key) { * @param commitTs * @return */ - public void commitSecondaryKeys(Iterator keys, long commitTs) + public void commitSecondaryKeys(Iterator keys, long commitTs, int commitBackOfferMS) throws TiBatchWriteException { Iterator byteStringKeys = @@ -427,35 +511,53 @@ public ByteString next() { } }; - doCommitSecondaryKeys(byteStringKeys, commitTs); + doCommitSecondaryKeys(byteStringKeys, commitTs, commitBackOfferMS); } - private void doCommitSecondaryKeys(Iterator keys, long commitTs) + private void doCommitSecondaryKeys( + Iterator keys, long commitTs, int commitBackOfferMS) throws TiBatchWriteException { - LOG.debug("start commit secondary key"); - - int totalSize = 0; - while (keys.hasNext()) { - ByteString[] keyBytes = new ByteString[WRITE_BUFFER_SIZE]; - int size = 0; - for (int i = 0; i < WRITE_BUFFER_SIZE; i++) { - if (keys.hasNext()) { + try { + int taskBufferSize = writeThreadPerTask * 2; + int totalSize = 0, cnt = 0; + ExecutorCompletionService completionService = + new ExecutorCompletionService<>(executorService); + while (keys.hasNext()) { + int size = 0; + ByteString[] keyBytes = new ByteString[writeBufferSize]; + while (size < writeBufferSize && keys.hasNext()) { keyBytes[size] = keys.next(); size++; - } else { - break; } + int curSize = size; + cnt++; + if (cnt > taskBufferSize) { + // consume one task if reaches task limit + completionService.take().get(); + } + BackOffer backOffer = ConcreteBackOffer.newCustomBackOff(commitBackOfferMS); + completionService.submit( + () -> { + doCommitSecondaryKeysWithRetry(backOffer, keyBytes, curSize, commitTs); + return null; + }); + + totalSize = totalSize + size; } - totalSize = totalSize + size; - BackOffer backOffer = ConcreteBackOffer.newCustomBackOff(BackOffer.BATCH_COMMIT_BACKOFF); - doCommitSecondaryKeys(backOffer, keyBytes, size, commitTs); - } + for (int i = 0; i < Math.min(taskBufferSize, cnt); i++) { + completionService.take().get(); + } - LOG.debug("commit secondary key successfully, total size={}", totalSize); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new TiBatchWriteException("Current thread interrupted.", e); + } catch (ExecutionException e) { + throw new TiBatchWriteException("Execution exception met.", e); + } } - private void doCommitSecondaryKeys( + private void doCommitSecondaryKeysWithRetry( BackOffer backOffer, ByteString[] keys, int size, long commitTs) throws TiBatchWriteException { if (keys == null || keys.length == 0 || size <= 0) { @@ -464,23 +566,27 @@ private void doCommitSecondaryKeys( // groups keys by region GroupKeyResult groupResult = this.groupKeysByRegion(keys, size); - List batchKeyList = new LinkedList<>(); + List batchKeyList = new ArrayList<>(); Map, List> groupKeyMap = groupResult.getGroupsResult(); - for (Pair pair : groupKeyMap.keySet()) { - TiRegion tiRegion = pair.first; - Metapb.Store store = pair.second; - this.appendBatchBySize(batchKeyList, tiRegion, store, groupKeyMap.get(pair), false, null); + for (Map.Entry, List> entry : groupKeyMap.entrySet()) { + TiRegion tiRegion = entry.getKey().first; + Metapb.Store store = entry.getKey().second; + this.appendBatchBySize(batchKeyList, tiRegion, store, entry.getValue(), false, null); } - // For prewrite, stop sending other requests after receiving first error. for (BatchKeys batchKeys : batchKeyList) { - doCommitSecondaryKeySingleBatch(backOffer, batchKeys, commitTs); + doCommitSecondaryKeySingleBatchWithRetry(backOffer, batchKeys, commitTs); } } - private void doCommitSecondaryKeySingleBatch( + private void doCommitSecondaryKeySingleBatchWithRetry( BackOffer backOffer, BatchKeys batchKeys, long commitTs) throws TiBatchWriteException { + LOG.info( + "start commit secondary key, row={}, size={}KB, regionId={}", + batchKeys.getKeys().size(), + batchKeys.getSizeInKB(), + batchKeys.getRegion().getId()); List keysCommit = batchKeys.getKeys(); ByteString[] keys = new ByteString[keysCommit.size()]; keysCommit.toArray(keys); @@ -488,13 +594,19 @@ private void doCommitSecondaryKeySingleBatch( ClientRPCResult commitResult = this.kvClient.commit( backOffer, keys, this.startTs, commitTs, batchKeys.getRegion(), batchKeys.getStore()); - if (!commitResult.isSuccess()) { + if (retryCommitSecondaryKeys && commitResult.isRetry()) { + doCommitSecondaryKeysWithRetry(backOffer, keys, keysCommit.size(), commitTs); + } else if (!commitResult.isSuccess()) { String error = String.format("Txn commit secondary key error, regionId=%s", batchKeys.getRegion()); LOG.warn(error); throw new TiBatchWriteException("commit secondary key error", commitResult.getException()); } - LOG.debug("commit {} rows successfully", batchKeys.getKeys().size()); + LOG.info( + "commit {} rows successfully, size={}KB, regionId={}", + batchKeys.getKeys().size(), + batchKeys.getSizeInKB(), + batchKeys.getRegion().getId()); } private GroupKeyResult groupKeysByRegion(ByteString[] keys, int size) @@ -506,7 +618,7 @@ private GroupKeyResult groupKeysByRegion(ByteString[] keys, int size) ByteString key = keys[index]; Pair pair = this.regionManager.getRegionStorePairByKey(key); if (pair != null) { - groups.computeIfAbsent(pair, e -> new LinkedList<>()).add(key); + groups.computeIfAbsent(pair, e -> new ArrayList<>()).add(key); } } } catch (Exception e) { diff --git a/tikv-client/src/main/java/com/pingcap/tikv/allocator/RowIDAllocator.java b/tikv-client/src/main/java/com/pingcap/tikv/allocator/RowIDAllocator.java index 8bf38a8cc8..d44b3fae91 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/allocator/RowIDAllocator.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/allocator/RowIDAllocator.java @@ -80,6 +80,11 @@ private void set(ByteString key, byte[] value) { ConcreteBackOffer.newCustomBackOff(BackOffer.BATCH_COMMIT_BACKOFF), key.toByteArray(), session.getTimestamp().getVersion()); + + try { + twoPhaseCommitter.close(); + } catch (Throwable e) { + } } private void updateMeta(ByteString key, byte[] oldVal, Snapshot snapshot) { diff --git a/tikv-client/src/main/java/com/pingcap/tikv/codec/TableCodecV1.java b/tikv-client/src/main/java/com/pingcap/tikv/codec/TableCodecV1.java index 61e0fd6079..1522d8e395 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/codec/TableCodecV1.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/codec/TableCodecV1.java @@ -66,7 +66,7 @@ protected static Row decodeRow(byte[] value, Long handle, TiTableInfo tableInfo) Object[] res = new Object[colSize]; while (!cdi.eof()) { long colID = (long) IntegerType.BIGINT.decode(cdi); - Object colValue = idToColumn.get(colID).getType().decode(cdi); + Object colValue = idToColumn.get(colID).getType().decodeForBatchWrite(cdi); decodedDataMap.put(colID, colValue); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/columnar/BatchedTiChunkColumnVector.java b/tikv-client/src/main/java/com/pingcap/tikv/columnar/BatchedTiChunkColumnVector.java index e8437f2cfb..583bdbd187 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/columnar/BatchedTiChunkColumnVector.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/columnar/BatchedTiChunkColumnVector.java @@ -44,10 +44,6 @@ public BatchedTiChunkColumnVector(List child, int numOfRows } } - public final String typeName() { - return dataType().getType().name(); - } - // TODO: once we switch off_heap mode, we need control memory access pattern. public void free() {} diff --git a/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiChunkColumnVector.java b/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiChunkColumnVector.java index b81e9b2203..6a19ce1bcd 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiChunkColumnVector.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiChunkColumnVector.java @@ -17,8 +17,10 @@ import com.google.common.primitives.UnsignedLong; import com.pingcap.tikv.codec.CodecDataInput; +import com.pingcap.tikv.codec.CodecDataOutput; import com.pingcap.tikv.codec.MyDecimal; import com.pingcap.tikv.types.AbstractDateTimeType; +import com.pingcap.tikv.types.ArrayType; import com.pingcap.tikv.types.BitType; import com.pingcap.tikv.types.DataType; import com.pingcap.tikv.types.DateTimeType; @@ -55,7 +57,7 @@ public TiChunkColumnVector( byte[] nullBitMaps, long[] offsets, ByteBuffer data) { - super(dataType, numOfRows); + super(dataType, numOfRows, buildColumnVectorFromOffsets(numOfRows, offsets)); this.fixLength = fixLength; this.numOfNulls = numOfNulls; this.nullBitMaps = nullBitMaps; @@ -63,8 +65,25 @@ public TiChunkColumnVector( this.offsets = offsets; } - public final String typeName() { - return dataType().getType().name(); + private static TiChunkColumnVector buildColumnVectorFromOffsets(int numOfRows, long[] offsets) { + if (offsets == null) { + return null; + } else { + DataType type = IntegerType.BIGINT; + int fixLength = type.getFixLen(); + CodecDataOutput cdo = new CodecDataOutput(); + for (long offset : offsets) { + cdo.writeLong(offset); + } + return new TiChunkColumnVector( + type, + fixLength, + numOfRows, + 0, + DataType.setAllNotNullBitMapWithNumRows(numOfRows), + null, + ByteBuffer.wrap(cdo.toBytes())); + } } // TODO: once we switch off_heap mode, we need control memory access pattern. @@ -176,6 +195,8 @@ public long getLong(int rowId) { return getTime(rowId); } else if (type instanceof TimeType) { return data.getLong(rowId * fixLength); + } else if (type instanceof ArrayType) { + return data.getLong(rowId * fixLength + fixLength); } throw new UnsupportedOperationException("only IntegerType and Time related are supported."); diff --git a/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiColumnVector.java b/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiColumnVector.java index c22b13aa41..c83283212b 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiColumnVector.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiColumnVector.java @@ -37,6 +37,7 @@ public abstract class TiColumnVector implements AutoCloseable { private final int numOfRows; + private final TiColumnVector offsets; /** Data type for this column. */ protected DataType type; @@ -44,6 +45,14 @@ public abstract class TiColumnVector implements AutoCloseable { protected TiColumnVector(DataType type, int numOfRows) { this.type = type; this.numOfRows = numOfRows; + this.offsets = null; + } + + /** Sets up the data type of this column vector. */ + protected TiColumnVector(DataType type, int numOfRows, TiColumnVector offsets) { + this.type = type; + this.numOfRows = numOfRows; + this.offsets = offsets; } /** Returns the data type of this column vector. */ @@ -218,4 +227,8 @@ public double[] getDoubles(int rowId, int count) { public int numOfRows() { return numOfRows; } + + public TiColumnVector getOffset() { + return offsets; + } } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/exception/ConvertOverflowException.java b/tikv-client/src/main/java/com/pingcap/tikv/exception/ConvertOverflowException.java index dd4c7027c0..7db8664fda 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/exception/ConvertOverflowException.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/exception/ConvertOverflowException.java @@ -20,6 +20,10 @@ private ConvertOverflowException(String msg) { super(msg); } + public ConvertOverflowException(String msg, Throwable e) { + super(msg, e); + } + public static ConvertOverflowException newMaxLengthException(String value, long maxLength) { return new ConvertOverflowException("value " + value + " length > max length " + maxLength); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/exception/TiDBConvertException.java b/tikv-client/src/main/java/com/pingcap/tikv/exception/TiDBConvertException.java new file mode 100644 index 0000000000..5e47f47653 --- /dev/null +++ b/tikv-client/src/main/java/com/pingcap/tikv/exception/TiDBConvertException.java @@ -0,0 +1,22 @@ +/* + * Copyright 2020 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.pingcap.tikv.exception; + +public class TiDBConvertException extends RuntimeException { + public TiDBConvertException(String columnName, Throwable e) { + super("convert to tidb data error for column '" + columnName + "'", e); + } +} diff --git a/tikv-client/src/main/java/com/pingcap/tikv/key/IndexKey.java b/tikv-client/src/main/java/com/pingcap/tikv/key/IndexKey.java index e8aa3278e6..d3dbdaae83 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/key/IndexKey.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/key/IndexKey.java @@ -23,6 +23,7 @@ import com.pingcap.tikv.meta.TiTableInfo; import com.pingcap.tikv.row.Row; import com.pingcap.tikv.types.DataType; +import com.pingcap.tikv.types.IntegerType; import java.util.List; public class IndexKey extends Key { @@ -39,13 +40,39 @@ private IndexKey(long tableId, long indexId, Key[] dataKeys) { this.dataKeys = dataKeys; } + public static class EncodeIndexDataResult { + public EncodeIndexDataResult(Key[] keys, boolean appendHandle) { + this.keys = keys; + this.appendHandle = appendHandle; + } + + public Key[] keys; + public boolean appendHandle; + } + public static IndexKey toIndexKey(long tableId, long indexId, Key... dataKeys) { return new IndexKey(tableId, indexId, dataKeys); } - public static Key[] encodeIndexDataValues( - Row row, List indexColumns, TiTableInfo tableInfo) { - Key[] keys = new Key[indexColumns.size()]; + public static EncodeIndexDataResult encodeIndexDataValues( + Row row, + List indexColumns, + long handle, + boolean appendHandleIfContainsNull, + TiTableInfo tableInfo) { + // when appendHandleIfContainsNull is true, append handle column if any of the index column is + // NULL + boolean appendHandle = false; + if (appendHandleIfContainsNull) { + for (TiIndexColumn col : indexColumns) { + DataType colTp = tableInfo.getColumn(col.getOffset()).getType(); + if (row.get(col.getOffset(), colTp) == null) { + appendHandle = true; + break; + } + } + } + Key[] keys = new Key[indexColumns.size() + (appendHandle ? 1 : 0)]; for (int i = 0; i < indexColumns.size(); i++) { TiIndexColumn col = indexColumns.get(i); DataType colTp = tableInfo.getColumn(col.getOffset()).getType(); @@ -53,7 +80,12 @@ public static Key[] encodeIndexDataValues( Key key = TypedKey.toTypedKey(row.get(col.getOffset(), colTp), colTp, (int) col.getLength()); keys[i] = key; } - return keys; + if (appendHandle) { + Key key = TypedKey.toTypedKey(handle, IntegerType.BIGINT); + keys[keys.length - 1] = key; + } + + return new EncodeIndexDataResult(keys, appendHandle); } private static byte[] encode(long tableId, long indexId, Key[] dataKeys) { diff --git a/tikv-client/src/main/java/com/pingcap/tikv/key/Key.java b/tikv-client/src/main/java/com/pingcap/tikv/key/Key.java index a10d99126b..1d1fd271bd 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/key/Key.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/key/Key.java @@ -179,7 +179,7 @@ public Key append(Key other) { @Override public int hashCode() { - return Arrays.hashCode(value) * infFlag; + return Arrays.hashCode(value); } public byte[] getBytes() { diff --git a/tikv-client/src/main/java/com/pingcap/tikv/key/RowKey.java b/tikv-client/src/main/java/com/pingcap/tikv/key/RowKey.java index 809208939c..7480093225 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/key/RowKey.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/key/RowKey.java @@ -17,6 +17,8 @@ import static com.pingcap.tikv.codec.Codec.IntegerCodec.writeLong; +import com.pingcap.tikv.codec.Codec.IntegerCodec; +import com.pingcap.tikv.codec.CodecDataInput; import com.pingcap.tikv.codec.CodecDataOutput; import com.pingcap.tikv.exception.TiClientInternalException; import com.pingcap.tikv.exception.TiExpressionException; @@ -68,6 +70,16 @@ public static RowKey createBeyondMax(long tableId) { return new RowKey(tableId); } + public static RowKey decode(byte[] value) { + CodecDataInput cdi = new CodecDataInput(value); + cdi.readByte(); + long tableId = IntegerCodec.readLong(cdi); // tableId + cdi.readByte(); + cdi.readByte(); + long handle = IntegerCodec.readLong(cdi); // handle + return toRowKey(tableId, handle); + } + private static byte[] encode(long tableId, long handle) { CodecDataOutput cdo = new CodecDataOutput(); encodePrefix(cdo, tableId); diff --git a/tikv-client/src/main/java/com/pingcap/tikv/meta/TiTimestamp.java b/tikv-client/src/main/java/com/pingcap/tikv/meta/TiTimestamp.java index 1a587af866..b10b803a83 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/meta/TiTimestamp.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/meta/TiTimestamp.java @@ -46,6 +46,10 @@ public long getLogical() { return this.logical; } + public TiTimestamp getPrevious() { + return new TiTimestamp(physical, logical - 1); + } + @Override public boolean equals(Object other) { if (other == this) { diff --git a/tikv-client/src/main/java/com/pingcap/tikv/operation/KVErrorHandler.java b/tikv-client/src/main/java/com/pingcap/tikv/operation/KVErrorHandler.java index 802735960d..0f06d5d49b 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/operation/KVErrorHandler.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/operation/KVErrorHandler.java @@ -143,17 +143,19 @@ private void notifyStoreCacheInvalidate(long storeId) { } private void resolveLock(BackOffer backOffer, Lock lock) { - logger.warn("resolving lock"); + if (lockResolverClient != null) { + logger.warn("resolving lock"); - ResolveLockResult resolveLockResult = - lockResolverClient.resolveLocks( - backOffer, callerStartTS, Collections.singletonList(lock), forWrite); - resolveLockResultCallback.apply(resolveLockResult); - long msBeforeExpired = resolveLockResult.getMsBeforeTxnExpired(); - if (msBeforeExpired > 0) { - // if not resolve all locks, we wait and retry - backOffer.doBackOffWithMaxSleep( - BoTxnLockFast, msBeforeExpired, new KeyException(lock.toString())); + ResolveLockResult resolveLockResult = + lockResolverClient.resolveLocks( + backOffer, callerStartTS, Collections.singletonList(lock), forWrite); + resolveLockResultCallback.apply(resolveLockResult); + long msBeforeExpired = resolveLockResult.getMsBeforeTxnExpired(); + if (msBeforeExpired > 0) { + // if not resolve all locks, we wait and retry + backOffer.doBackOffWithMaxSleep( + BoTxnLockFast, msBeforeExpired, new KeyException(lock.toString())); + } } } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/region/AbstractRegionStoreClient.java b/tikv-client/src/main/java/com/pingcap/tikv/region/AbstractRegionStoreClient.java index 54347083a8..5e6c460f85 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/region/AbstractRegionStoreClient.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/region/AbstractRegionStoreClient.java @@ -78,7 +78,7 @@ public boolean onNotLeader(Metapb.Store newStore) { if (logger.isDebugEnabled()) { logger.debug(region + ", new leader = " + newStore.getId()); } - TiRegion cachedRegion = regionManager.getRegionById(region.getId()); + TiRegion cachedRegion = regionManager.getRegionByKey(region.getStartKey()); // When switch leader fails or the region changed its key range, // it would be necessary to re-split task's key range for new region. if (!region.getStartKey().equals(cachedRegion.getStartKey()) diff --git a/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java b/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java index b2706eb31d..77def7f48f 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java @@ -28,6 +28,7 @@ import com.pingcap.tikv.exception.GrpcException; import com.pingcap.tikv.exception.TiClientInternalException; import com.pingcap.tikv.key.Key; +import com.pingcap.tikv.util.BackOffer; import com.pingcap.tikv.util.ConcreteBackOffer; import com.pingcap.tikv.util.Pair; import java.util.ArrayList; @@ -67,8 +68,15 @@ public TiRegion getRegionByKey(ByteString key) { return cache.getRegionByKey(key); } + @Deprecated + // Do not use GetRegionByID when retrying request. + // + // A,B |_______|_____| + // A |_____________| + // Consider region A, B. After merge of (A, B) -> A, region ID B does not exist. + // This request is unrecoverable. public TiRegion getRegionById(long regionId) { - return cache.getRegionById(regionId); + return cache.getRegionById(ConcreteBackOffer.newGetBackOff(), regionId); } public Pair getRegionStorePairByKey(ByteString key) { @@ -206,13 +214,13 @@ private synchronized boolean putRegion(TiRegion region) { return true; } - private synchronized TiRegion getRegionById(long regionId) { + private synchronized TiRegion getRegionById(BackOffer backOffer, long regionId) { TiRegion region = regionCache.get(regionId); if (logger.isDebugEnabled()) { logger.debug(String.format("getRegionByKey ID[%s] -> Region[%s]", regionId, region)); } if (region == null) { - region = pdClient.getRegionByID(ConcreteBackOffer.newGetBackOff(), regionId); + region = pdClient.getRegionByID(backOffer, regionId); if (!putRegion(region)) { throw new TiClientInternalException("Invalid Region: " + region.toString()); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/region/RegionStoreClient.java b/tikv-client/src/main/java/com/pingcap/tikv/region/RegionStoreClient.java index e32f54fe4d..5981e931af 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/region/RegionStoreClient.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/region/RegionStoreClient.java @@ -59,6 +59,7 @@ import java.util.Queue; import java.util.Set; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.tikv.kvproto.Coprocessor; @@ -79,6 +80,8 @@ import org.tikv.kvproto.Kvrpcpb.PrewriteResponse; import org.tikv.kvproto.Kvrpcpb.ScanRequest; import org.tikv.kvproto.Kvrpcpb.ScanResponse; +import org.tikv.kvproto.Kvrpcpb.SplitRegionRequest; +import org.tikv.kvproto.Kvrpcpb.SplitRegionResponse; import org.tikv.kvproto.Kvrpcpb.TxnHeartBeatRequest; import org.tikv.kvproto.Kvrpcpb.TxnHeartBeatResponse; import org.tikv.kvproto.Metapb.Store; @@ -278,15 +281,11 @@ private List handleBatchGetResponse( ResolveLockResult resolveLockResult = lockResolverClient.resolveLocks(backOffer, version, locks, forWrite); addResolvedLocks(version, resolveLockResult.getResolvedLocks()); - long msBeforeExpired = resolveLockResult.getMsBeforeTxnExpired(); - if (msBeforeExpired > 0) { - // resolveLocks already retried, just throw error to upper logic. - throw new TiKVException("locks not resolved, retry"); - } - - // FIXME: we should retry + // resolveLocks already retried, just throw error to upper logic. + throw new TiKVException("locks not resolved, retry"); + } else { + return resp.getPairsList(); } - return resp.getPairsList(); } public List scan( @@ -408,6 +407,7 @@ public void prewrite( .setLockTtl(ttl) .setSkipConstraintCheck(skipConstraintCheck) .setMinCommitTs(startTs) + .setTxnSize(16) .build(); KVErrorHandler handler = new KVErrorHandler<>( @@ -754,6 +754,57 @@ public Iterator coprocessStreaming( return doCoprocessor(responseIterator); } + /** + * Send SplitRegion request to tikv split a region at splitKey. splitKey must between current + * region's start key and end key. + * + * @param splitKeys is the split points for a specific region. + * @return a split region info. + */ + public List splitRegion(Iterable splitKeys) { + Supplier request = + () -> + SplitRegionRequest.newBuilder() + .setContext(region.getContext()) + .addAllSplitKeys(splitKeys) + .build(); + + KVErrorHandler handler = + new KVErrorHandler<>( + regionManager, + this, + null, + region, + resp -> resp.hasRegionError() ? resp.getRegionError() : null, + resp -> null, + resolveLockResult -> null, + 0L, + false); + + SplitRegionResponse resp = + callWithRetry( + ConcreteBackOffer.newGetBackOff(), TikvGrpc.getSplitRegionMethod(), request, handler); + + if (resp == null) { + this.regionManager.onRequestFail(region); + throw new TiClientInternalException("SplitRegion Response failed without a cause"); + } + + if (resp.hasRegionError()) { + throw new TiClientInternalException( + String.format( + "failed to split region %d because %s", + region.getId(), resp.getRegionError().toString())); + } + + return resp.getRegionsList() + .stream() + .map( + region -> + new TiRegion(region, null, conf.getIsolationLevel(), conf.getCommandPriority())) + .collect(Collectors.toList()); + } + public enum RequestTypes { REQ_TYPE_SELECT(101), REQ_TYPE_INDEX(102), diff --git a/tikv-client/src/main/java/com/pingcap/tikv/region/TiRegion.java b/tikv-client/src/main/java/com/pingcap/tikv/region/TiRegion.java index be2b8dcc27..41908431c7 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/region/TiRegion.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/region/TiRegion.java @@ -113,6 +113,10 @@ public ByteString getEndKey() { return meta.getEndKey(); } + public Key getRowEndKey() { + return Key.toRawKey(getEndKey()); + } + public Kvrpcpb.Context getContext() { return getContext(java.util.Collections.emptySet()); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/txn/LockResolverClientV2.java b/tikv-client/src/main/java/com/pingcap/tikv/txn/LockResolverClientV2.java index ca0889d71f..6cb5823bed 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/txn/LockResolverClientV2.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/txn/LockResolverClientV2.java @@ -23,6 +23,7 @@ import com.pingcap.tikv.TiConfiguration; import com.pingcap.tikv.exception.KeyException; import com.pingcap.tikv.exception.RegionException; +import com.pingcap.tikv.exception.TiClientInternalException; import com.pingcap.tikv.operation.KVErrorHandler; import com.pingcap.tikv.region.AbstractRegionStoreClient; import com.pingcap.tikv.region.RegionManager; @@ -142,6 +143,16 @@ private Long getTxnStatus(BackOffer bo, Long txnID, ByteString primary) { CleanupResponse resp = callWithRetry(bo, TikvGrpc.getKvCleanupMethod(), factory, handler); status = 0L; + + if (resp == null) { + logger.error("getKvCleanupMethod failed without a cause"); + regionManager.onRequestFail(region); + bo.doBackOff( + BoRegionMiss, + new TiClientInternalException("getKvCleanupMethod failed without a cause")); + continue; + } + if (resp.hasRegionError()) { bo.doBackOff(BoRegionMiss, new RegionException(resp.getRegionError())); continue; @@ -249,10 +260,13 @@ private void resolveLock(BackOffer bo, Lock lock, long txnStatus, Set handler = new KVErrorHandler<>( regionManager, - this, - this, + primaryKeyRegionStoreClient, + primaryKeyRegionStoreClient.lockResolverClient, primaryKeyRegion, resp -> resp.hasRegionError() ? resp.getRegionError() : null, resp -> resp.hasError() ? resp.getError() : null, @@ -243,12 +255,19 @@ private TxnStatus getTxnStatus(BackOffer bo, Long txnID, ByteString primary, Lon 0L, false); - // new RegionStoreClient for PrimaryKey - RegionStoreClient primaryKeyRegionStoreClient = clientBuilder.build(primary); CleanupResponse resp = primaryKeyRegionStoreClient.callWithRetry( bo, TikvGrpc.getKvCleanupMethod(), factory, handler); + if (resp == null) { + logger.error("getKvCleanupMethod failed without a cause"); + regionManager.onRequestFail(primaryKeyRegion); + bo.doBackOff( + BoRegionMiss, + new TiClientInternalException("getKvCleanupMethod failed without a cause")); + continue; + } + if (resp.hasRegionError()) { bo.doBackOff(BoRegionMiss, new RegionException(resp.getRegionError())); continue; diff --git a/tikv-client/src/main/java/com/pingcap/tikv/txn/LockResolverClientV4.java b/tikv-client/src/main/java/com/pingcap/tikv/txn/LockResolverClientV4.java index 1fb631e11a..739fb25eac 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/txn/LockResolverClientV4.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/txn/LockResolverClientV4.java @@ -25,6 +25,7 @@ import com.pingcap.tikv.TiConfiguration; import com.pingcap.tikv.exception.KeyException; import com.pingcap.tikv.exception.RegionException; +import com.pingcap.tikv.exception.TiClientInternalException; import com.pingcap.tikv.exception.TxnNotFoundException; import com.pingcap.tikv.exception.WriteConflictException; import com.pingcap.tikv.operation.KVErrorHandler; @@ -189,6 +190,15 @@ private void resolvePessimisticLock(BackOffer bo, Lock lock, Set cl Kvrpcpb.PessimisticRollbackResponse resp = callWithRetry(bo, TikvGrpc.getKVPessimisticRollbackMethod(), factory, handler); + if (resp == null) { + logger.error("getKVPessimisticRollbackMethod failed without a cause"); + regionManager.onRequestFail(region); + bo.doBackOff( + BoRegionMiss, + new TiClientInternalException("getKVPessimisticRollbackMethod failed without a cause")); + continue; + } + if (resp.hasRegionError()) { bo.doBackOff(BoRegionMiss, new RegionException(resp.getRegionError())); continue; @@ -290,11 +300,13 @@ private TxnStatus getTxnStatus( while (true) { TiRegion primaryKeyRegion = regionManager.getRegionByKey(primary); + // new RegionStoreClient for PrimaryKey + RegionStoreClient primaryKeyRegionStoreClient = clientBuilder.build(primary); KVErrorHandler handler = new KVErrorHandler<>( regionManager, - this, - this, + primaryKeyRegionStoreClient, + primaryKeyRegionStoreClient.lockResolverClient, primaryKeyRegion, resp -> resp.hasRegionError() ? resp.getRegionError() : null, resp -> resp.hasError() ? resp.getError() : null, @@ -302,12 +314,19 @@ private TxnStatus getTxnStatus( callerStartTS, false); - // new RegionStoreClient for PrimaryKey - RegionStoreClient primaryKeyRegionStoreClient = clientBuilder.build(primary); Kvrpcpb.CheckTxnStatusResponse resp = primaryKeyRegionStoreClient.callWithRetry( bo, TikvGrpc.getKvCheckTxnStatusMethod(), factory, handler); + if (resp == null) { + logger.error("getKvCheckTxnStatusMethod failed without a cause"); + regionManager.onRequestFail(primaryKeyRegion); + bo.doBackOff( + BoRegionMiss, + new TiClientInternalException("getKvCheckTxnStatusMethod failed without a cause")); + continue; + } + if (resp.hasRegionError()) { bo.doBackOff(BoRegionMiss, new RegionException(resp.getRegionError())); continue; @@ -377,10 +396,13 @@ private void resolveLock( Kvrpcpb.ResolveLockResponse resp = callWithRetry(bo, TikvGrpc.getKvResolveLockMethod(), factory, handler); - if (resp.hasError()) { - logger.error( - String.format("unexpected resolveLock err: %s, lock: %s", resp.getError(), lock)); - throw new KeyException(resp.getError()); + if (resp == null) { + logger.error("getKvResolveLockMethod failed without a cause"); + regionManager.onRequestFail(region); + bo.doBackOff( + BoRegionMiss, + new TiClientInternalException("getKvResolveLockMethod failed without a cause")); + continue; } if (resp.hasRegionError()) { @@ -388,6 +410,12 @@ private void resolveLock( continue; } + if (resp.hasError()) { + logger.error( + String.format("unexpected resolveLock err: %s, lock: %s", resp.getError(), lock)); + throw new KeyException(resp.getError()); + } + if (cleanWholeRegion) { cleanRegion.add(region.getVerID()); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/txn/type/BatchKeys.java b/tikv-client/src/main/java/com/pingcap/tikv/txn/type/BatchKeys.java index a35aae0eba..1b06af528f 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/txn/type/BatchKeys.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/txn/type/BatchKeys.java @@ -25,12 +25,15 @@ public class BatchKeys { private final TiRegion region; private final Metapb.Store store; private List keys; + private final int sizeInBytes; - public BatchKeys(TiRegion region, Metapb.Store store, List keysInput) { + public BatchKeys( + TiRegion region, Metapb.Store store, List keysInput, int sizeInBytes) { this.region = region; this.store = store; this.keys = new ArrayList<>(); this.keys.addAll(keysInput); + this.sizeInBytes = sizeInBytes; } public List getKeys() { @@ -48,4 +51,12 @@ public TiRegion getRegion() { public Metapb.Store getStore() { return store; } + + public int getSizeInBytes() { + return sizeInBytes; + } + + public float getSizeInKB() { + return ((float) sizeInBytes) / 1024; + } } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/AbstractDateTimeType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/AbstractDateTimeType.java index a3f460907e..97d8a0f6fc 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/AbstractDateTimeType.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/AbstractDateTimeType.java @@ -63,12 +63,29 @@ long decodeDateTime(int flag, CodecDataInput cdi) { if (extendedDateTime == null) { Timestamp ts = DateTimeCodec.createExtendedDateTime(getTimezone(), 1, 1, 1, 0, 0, 0, 0).toTimeStamp(); + // by dividing 1000 on milliseconds, we have eliminated fraction part of ts return ts.getTime() / 1000 * 100000 + ts.getNanos() / 1000; } Timestamp ts = extendedDateTime.toTimeStamp(); return ts.getTime() / 1000 * 1000000 + ts.getNanos() / 1000; } + Timestamp decodeDateTimeForBatchWrite(int flag, CodecDataInput cdi) { + ExtendedDateTime extendedDateTime; + if (flag == Codec.UVARINT_FLAG) { + extendedDateTime = DateTimeCodec.readFromUVarInt(cdi, getTimezone()); + } else if (flag == Codec.UINT_FLAG) { + extendedDateTime = DateTimeCodec.readFromUInt(cdi, getTimezone()); + } else { + throw new InvalidCodecFormatException( + "Invalid Flag type for " + getClass().getSimpleName() + ": " + flag); + } + if (extendedDateTime == null) { + return DateTimeCodec.createExtendedDateTime(getTimezone(), 1, 1, 1, 0, 0, 0, 0).toTimeStamp(); + } + return extendedDateTime.toTimeStamp(); + } + /** Decode Date from packed long value */ LocalDate decodeDate(int flag, CodecDataInput cdi) { LocalDate date; diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/ArrayType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/ArrayType.java new file mode 100644 index 0000000000..f75cf1ba5d --- /dev/null +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/ArrayType.java @@ -0,0 +1,154 @@ +/* + * + * Copyright 2020 PingCAP, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.pingcap.tikv.types; + +import com.pingcap.tidb.tipb.ExprType; +import com.pingcap.tikv.codec.Codec; +import com.pingcap.tikv.codec.Codec.IntegerCodec; +import com.pingcap.tikv.codec.CodecDataInput; +import com.pingcap.tikv.codec.CodecDataOutput; +import com.pingcap.tikv.exception.ConvertNotSupportException; +import com.pingcap.tikv.exception.ConvertOverflowException; +import com.pingcap.tikv.exception.TypeException; +import com.pingcap.tikv.meta.Collation; +import com.pingcap.tikv.meta.TiColumnInfo; + +public class ArrayType extends DataType { + public static final ArrayType ARRAY = new ArrayType(MySQLType.TypeArray); + + public static final MySQLType[] subTypes = new MySQLType[] {MySQLType.TypeArray}; + + protected ArrayType(MySQLType type, int flag, int len, int decimal) { + super(type, flag, len, decimal, "", Collation.DEF_COLLATION_CODE); + } + + protected ArrayType(MySQLType tp) { + super(tp); + } + + protected ArrayType(TiColumnInfo.InternalTypeHolder holder) { + super(holder); + } + + @Override + protected Object doConvertToTiDBType(Object value) + throws ConvertNotSupportException, ConvertOverflowException { + // TODO: support write to YEAR + if (this.getType() == MySQLType.TypeYear) { + throw new ConvertNotSupportException(value.getClass().getName(), this.getClass().getName()); + } + + Long result; + if (this.isUnsigned()) { + result = Converter.safeConvertToUnsigned(value, this.unsignedUpperBound()); + } else { + result = + Converter.safeConvertToSigned(value, this.signedLowerBound(), this.signedUpperBound()); + } + + return result; + } + + public boolean isSameCatalog(DataType other) { + return other instanceof IntegerType; + } + + /** {@inheritDoc} */ + @Override + protected Object decodeNotNull(int flag, CodecDataInput cdi) { + int count = cdi.readInt(); + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + long ret; + switch (flag) { + case Codec.UVARINT_FLAG: + ret = IntegerCodec.readUVarLong(cdi); + break; + case Codec.UINT_FLAG: + ret = IntegerCodec.readULong(cdi); + break; + case Codec.VARINT_FLAG: + ret = IntegerCodec.readVarLong(cdi); + break; + case Codec.INT_FLAG: + ret = IntegerCodec.readLong(cdi); + break; + default: + throw new TypeException("Invalid IntegerType flag: " + flag); + } + res[i] = ret; + } + return res; + } + + /** {@inheritDoc} */ + @Override + protected void encodeKey(CodecDataOutput cdo, Object value) { + long longVal = Converter.convertToLong(value); + if (isUnsigned()) { + IntegerCodec.writeULongFully(cdo, longVal, true); + } else { + IntegerCodec.writeLongFully(cdo, longVal, true); + } + } + + /** {@inheritDoc} */ + @Override + protected void encodeValue(CodecDataOutput cdo, Object value) { + long longVal = Converter.convertToLong(value); + if (isUnsigned()) { + IntegerCodec.writeULongFully(cdo, longVal, false); + } else { + IntegerCodec.writeLongFully(cdo, longVal, false); + } + } + + /** {@inheritDoc} */ + @Override + protected void encodeProto(CodecDataOutput cdo, Object value) { + long longVal = Converter.convertToLong(value); + if (isUnsigned()) { + IntegerCodec.writeULong(cdo, longVal); + } else { + IntegerCodec.writeLong(cdo, longVal); + } + } + + @Override + public String getName() { + if (isUnsigned()) { + return "UNSIGNED LONG"; + } + return "LONG"; + } + + @Override + public ExprType getProtoExprType() { + return isUnsigned() ? ExprType.Uint64 : ExprType.Int64; + } + + public boolean isUnsignedLong() { + return tp == MySQLType.TypeLonglong && isUnsigned(); + } + + /** {@inheritDoc} */ + @Override + public Object getOriginDefaultValueNonNull(String value, long version) { + return Long.parseLong(value); + } +} diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/DataType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/DataType.java index a2c6cb2c91..853c02f79a 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/DataType.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/DataType.java @@ -83,7 +83,7 @@ public abstract class DataType implements Serializable { protected final long length; private final String charset; private final List elems; - private final byte[] allNotNullBitMap = initAllNotNullBitMap(); + private static final byte[] allNotNullBitMap = initAllNotNullBitMap(); private final byte[] readBuffer = new byte[8]; public DataType(MySQLType tp, int prec, int scale) { @@ -199,7 +199,11 @@ public Long unsignedUpperBound() throws TypeException { protected abstract Object decodeNotNull(int flag, CodecDataInput cdi); - private int getFixLen() { + protected Object decodeNotNullForBatchWrite(int flag, CodecDataInput cdi) { + return decodeNotNull(flag, cdi); + } + + public int getFixLen() { switch (this.getType()) { case TypeFloat: return 4; @@ -225,7 +229,12 @@ private int getFixLen() { } } - private byte[] setAllNotNull(int numNullBitMapBytes) { + public static byte[] setAllNotNullBitMapWithNumRows(int numRows) { + int numNullBitmapBytes = (numRows + 7) / 8; + return setAllNotNull(numNullBitmapBytes); + } + + private static byte[] setAllNotNull(int numNullBitMapBytes) { byte[] nullBitMaps = new byte[numNullBitMapBytes]; for (int i = 0; i < numNullBitMapBytes; ) { // allNotNullBitNMap's actual length @@ -237,7 +246,7 @@ private byte[] setAllNotNull(int numNullBitMapBytes) { return nullBitMaps; } - private byte[] initAllNotNullBitMap() { + private static byte[] initAllNotNullBitMap() { byte[] allNotNullBitMap = new byte[128]; Arrays.fill(allNotNullBitMap, (byte) 0xFF); return allNotNullBitMap; @@ -319,6 +328,14 @@ public Object decode(CodecDataInput cdi) { return decodeNotNull(flag, cdi); } + public Object decodeForBatchWrite(CodecDataInput cdi) { + int flag = cdi.readUnsignedByte(); + if (isNullFlag(flag)) { + return null; + } + return decodeNotNullForBatchWrite(flag, cdi); + } + public boolean isNextNull(CodecDataInput cdi) { return isNullFlag(cdi.peekByte()); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/DateTimeType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/DateTimeType.java index e0df9b7f64..d059f2c50e 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/DateTimeType.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/DateTimeType.java @@ -21,6 +21,7 @@ import com.pingcap.tikv.exception.ConvertNotSupportException; import com.pingcap.tikv.exception.ConvertOverflowException; import com.pingcap.tikv.meta.TiColumnInfo; +import java.sql.Timestamp; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -65,6 +66,11 @@ protected Long decodeNotNull(int flag, CodecDataInput cdi) { return decodeDateTime(flag, cdi); } + @Override + protected Timestamp decodeNotNullForBatchWrite(int flag, CodecDataInput cdi) { + return decodeDateTimeForBatchWrite(flag, cdi); + } + @Override public DateTime getOriginDefaultValueNonNull(String value, long version) { return Converter.convertToDateTime(value).getDateTime(); diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/DateType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/DateType.java index a938ad53ce..b21cc56f8c 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/DateType.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/DateType.java @@ -101,4 +101,14 @@ protected Long decodeNotNull(int flag, CodecDataInput cdi) { // return how many days from EPOCH return Math.floorDiv(date.toDate().getTime(), AbstractDateTimeType.MILLS_PER_DAY); } + + @Override + protected Date decodeNotNullForBatchWrite(int flag, CodecDataInput cdi) { + LocalDate date = decodeDate(flag, cdi); + + if (date == null) { + return null; + } + return new Date(date.toDate().getTime()); + } } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/MySQLType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/MySQLType.java index 7fb59adf37..8e1b733fd0 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/MySQLType.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/MySQLType.java @@ -52,7 +52,8 @@ public enum MySQLType { TypeBlob(0xfc, 65535, 2, -1), TypeVarString(0xfd, 255, 1, -1), TypeString(0xfe, 255, 1, 1), - TypeGeometry(0xff, 1024, 1, -1); + TypeGeometry(0xff, 1024, 1, -1), + TypeArray(100, 65536, 2, -1); private static final Map typeMap = new HashMap<>(); private static final Map sizeMap = new HashMap<>(); diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/TimestampType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/TimestampType.java index c90bee98fc..ae33d3d5e7 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/TimestampType.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/TimestampType.java @@ -26,6 +26,7 @@ import com.pingcap.tikv.exception.ConvertNotSupportException; import com.pingcap.tikv.exception.ConvertOverflowException; import com.pingcap.tikv.meta.TiColumnInfo; +import java.sql.Timestamp; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.joda.time.LocalDateTime; @@ -85,6 +86,11 @@ protected Long decodeNotNull(int flag, CodecDataInput cdi) { return decodeDateTime(flag, cdi); } + @Override + protected Timestamp decodeNotNullForBatchWrite(int flag, CodecDataInput cdi) { + return decodeDateTimeForBatchWrite(flag, cdi); + } + @Override public DateTime getOriginDefaultValueNonNull(String value, long version) { if (version >= DataType.COLUMN_VERSION_FLAG) { diff --git a/tikv-client/src/main/java/com/pingcap/tikv/util/BackOffer.java b/tikv-client/src/main/java/com/pingcap/tikv/util/BackOffer.java index 9f06ec4b09..46ac03fe8e 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/util/BackOffer.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/util/BackOffer.java @@ -37,7 +37,6 @@ public interface BackOffer { int SPLIT_REGION_BACKOFF = 20 * seconds; int BATCH_PREWRITE_BACKOFF = TTLManager.MANAGED_LOCK_TTL; int BATCH_COMMIT_BACKOFF = 10 * seconds; - int WAIT_SCATTER_REGION_FINISH = 120 * seconds; int PD_INFO_BACKOFF = 5 * seconds; /** diff --git a/tikv-client/src/main/java/com/pingcap/tikv/util/ConcreteBackOffer.java b/tikv-client/src/main/java/com/pingcap/tikv/util/ConcreteBackOffer.java index 38b8eb0704..111872bb18 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/util/ConcreteBackOffer.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/util/ConcreteBackOffer.java @@ -67,10 +67,6 @@ public static ConcreteBackOffer newGetBackOff() { return new ConcreteBackOffer(GET_MAX_BACKOFF); } - public static ConcreteBackOffer newWaitScatterRegionBackOff() { - return new ConcreteBackOffer(WAIT_SCATTER_REGION_FINISH); - } - public static ConcreteBackOffer newRawKVBackOff() { return new ConcreteBackOffer(RAWKV_MAX_BACKOFF); } From a68934a074da77083390132c876c6347f6752f5f Mon Sep 17 00:00:00 2001 From: Liangliang Gu Date: Wed, 19 Aug 2020 19:53:12 +0800 Subject: [PATCH 2/7] scatterWaitSecondes -> scatterWaitMS (#1563) --- .../main/scala/com/pingcap/tispark/write/TiDBOptions.scala | 2 +- .../com/pingcap/tispark/datasource/RegionSplitSuite.scala | 6 ++++-- .../scala/org/apache/spark/sql/SequenceTestSuite.scala | 7 ++++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala b/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala index c1d09115a8..1e28a90104 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala @@ -214,7 +214,7 @@ object TiDBOptions { val TIDB_SAMPLE_SPLIT_FRAC: String = newOption("sampleSplitFrac") val TIDB_WRITE_SPLIT_REGION_FINISH: String = newOption("writeSplitRegionFinish") val TIDB_REGION_SPLIT_METHOD: String = newOption("regionSplitMethod") - val TIDB_SCATTER_WAIT_MS: String = newOption("scatterWaitSecondes") + val TIDB_SCATTER_WAIT_MS: String = newOption("scatterWaitMS") val TIDB_REGION_SPLIT_KEYS: String = newOption("regionSplitKeys") val TIDB_MIN_REGION_SPLIT_NUM: String = newOption("minRegionSplitNum") val TIDB_REGION_SPLIT_THRESHOLD: String = newOption("regionSplitThreshold") diff --git a/core/src/test/scala/com/pingcap/tispark/datasource/RegionSplitSuite.scala b/core/src/test/scala/com/pingcap/tispark/datasource/RegionSplitSuite.scala index 3c9fb729d1..b2a417cf59 100644 --- a/core/src/test/scala/com/pingcap/tispark/datasource/RegionSplitSuite.scala +++ b/core/src/test/scala/com/pingcap/tispark/datasource/RegionSplitSuite.scala @@ -39,7 +39,8 @@ class RegionSplitSuite extends BaseDataSourceTest("region_split_test") { jdbcUpdate( s"CREATE TABLE $dbtable ( `a` int(11), unique index(a)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin") - val options = Some(Map("enableRegionSplit" -> "true", "regionSplitNum" -> "3")) + val options = Some( + Map("enableRegionSplit" -> "true", "regionSplitMethod" -> "v1", "regionSplitNum" -> "3")) tidbWrite(List(row1, row2, row3), schema, options) @@ -65,7 +66,8 @@ class RegionSplitSuite extends BaseDataSourceTest("region_split_test") { jdbcUpdate( s"CREATE TABLE $dbtable ( `a` int(11) DEFAULT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin") - val options = Some(Map("enableRegionSplit" -> "true", "regionSplitNum" -> "3")) + val options = Some( + Map("enableRegionSplit" -> "true", "regionSplitMethod" -> "v1", "regionSplitNum" -> "3")) tidbWrite(List(row1, row2, row3), schema, options) diff --git a/core/src/test/scala/org/apache/spark/sql/SequenceTestSuite.scala b/core/src/test/scala/org/apache/spark/sql/SequenceTestSuite.scala index 8502694c8f..c53d8797d1 100644 --- a/core/src/test/scala/org/apache/spark/sql/SequenceTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/SequenceTestSuite.scala @@ -38,7 +38,12 @@ class SequenceTestSuite extends BaseTiSparkTest { } override def afterAll(): Unit = { - dropTbl() + tidbStmt.execute(s"drop table if exists $table") + try { + tidbStmt.execute("drop sequence if exists sq_test") + } catch { + case _: Exception => + } } private def dropTbl() = { From 4facb56029b07a5624558004baf210c07ba9a771 Mon Sep 17 00:00:00 2001 From: Liangliang Gu Date: Thu, 20 Aug 2020 10:32:56 +0800 Subject: [PATCH 3/7] fix columnar batch v3 (#1566) --- .../tikv/columnar/TiColumnVectorAdapter.java | 14 +- .../pingcap/tikv/datatype/TypeMapping.java | 5 - .../pingcap/tispark/write/TiBatchWrite.scala | 5 +- .../spark/sql/execution/CoprocessorRDD.scala | 18 +- .../spark/sql/tispark/TiHandleRDD.scala | 85 ++-------- .../datasource/BaseDataSourceTest.scala | 4 +- .../main/java/com/pingcap/tikv/Snapshot.java | 7 +- .../columnar/BatchedTiChunkColumnVector.java | 4 + .../tikv/columnar/TiChunkColumnVector.java | 27 +-- .../pingcap/tikv/columnar/TiColumnVector.java | 13 -- .../com/pingcap/tikv/types/ArrayType.java | 154 ------------------ .../java/com/pingcap/tikv/types/DataType.java | 13 +- .../com/pingcap/tikv/types/MySQLType.java | 3 +- 13 files changed, 32 insertions(+), 320 deletions(-) delete mode 100644 tikv-client/src/main/java/com/pingcap/tikv/types/ArrayType.java diff --git a/core/src/main/java/com/pingcap/tikv/columnar/TiColumnVectorAdapter.java b/core/src/main/java/com/pingcap/tikv/columnar/TiColumnVectorAdapter.java index 60130ce49e..27649d26b9 100644 --- a/core/src/main/java/com/pingcap/tikv/columnar/TiColumnVectorAdapter.java +++ b/core/src/main/java/com/pingcap/tikv/columnar/TiColumnVectorAdapter.java @@ -23,17 +23,11 @@ public class TiColumnVectorAdapter extends ColumnVector { private final TiColumnVector tiColumnVector; - private final ColumnVector offsets; /** Sets up the data type of this column vector. */ public TiColumnVectorAdapter(TiColumnVector tiColumnVector) { super(TypeMapping.toSparkType(tiColumnVector.dataType())); this.tiColumnVector = tiColumnVector; - if (tiColumnVector.getOffset() == null) { - this.offsets = null; - } else { - this.offsets = new TiColumnVectorAdapter(tiColumnVector.getOffset()); - } } /** @@ -128,13 +122,7 @@ public double getDouble(int rowId) { */ @Override public ColumnarArray getArray(int rowId) { - if (tiColumnVector.isNullAt(rowId)) { - return null; - } - int index = rowId * 8; - int start = offsets.getInt(index); - int end = offsets.getInt(index + 1); - return new ColumnarArray(this, start, end - start); + throw new UnsupportedOperationException("TiColumnVectorAdapter is not supported this method"); } /** diff --git a/core/src/main/java/com/pingcap/tikv/datatype/TypeMapping.java b/core/src/main/java/com/pingcap/tikv/datatype/TypeMapping.java index 37c114c141..4bb6dee625 100644 --- a/core/src/main/java/com/pingcap/tikv/datatype/TypeMapping.java +++ b/core/src/main/java/com/pingcap/tikv/datatype/TypeMapping.java @@ -18,7 +18,6 @@ import static com.pingcap.tikv.types.MySQLType.TypeLonglong; import com.pingcap.tikv.types.AbstractDateTimeType; -import com.pingcap.tikv.types.ArrayType; import com.pingcap.tikv.types.BytesType; import com.pingcap.tikv.types.DataType; import com.pingcap.tikv.types.DateType; @@ -97,10 +96,6 @@ public static org.apache.spark.sql.types.DataType toSparkType(DataType type) { return DataTypes.LongType; } - if (type instanceof ArrayType) { - return DataTypes.createArrayType(DataTypes.LongType); - } - throw new UnsupportedOperationException( String.format("found unsupported type %s", type.getClass().getCanonicalName())); } diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala index 82fac6638d..02b61ddbaa 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala @@ -408,10 +408,7 @@ class TiBatchWrite( logger.info(s"sampleSize=$sampleSize") val sampleData = - rdd - .map(_._1) - .sample(withReplacement = false, fraction = sampleSize.toDouble / count) - .collect() + rdd.map(_._1).sample(withReplacement = false, sampleSize.toDouble / count).collect() logger.info(s"sampleData size=${sampleData.length}") val sortedSampleData = sampleData.sorted(new Ordering[SerializableKey] { diff --git a/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala b/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala index c1b0e47bde..6abbef4a11 100644 --- a/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala +++ b/core/src/main/scala/org/apache/spark/sql/execution/CoprocessorRDD.scala @@ -84,6 +84,8 @@ case class ColumnarCoprocessorRDD( override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(sparkContext.union(internalRDDs)) + override protected def supportsBatch: Boolean = !fetchHandle + override protected def doExecute(): RDD[InternalRow] = { if (!fetchHandle) { WholeStageCodegenExec(this)(codegenStageId = 0).execute() @@ -187,7 +189,7 @@ case class ColumnarRegionTaskExec( val batchSize = tiConf.getIndexScanBatchSize val downgradeThreshold = tiConf.getDowngradeThreshold - def computeWithRowIterator(row: InternalRow): Iterator[InternalRow] = { + iter.flatMap { row => val handles = row.getArray(1).toLongArray() val handleIterator: util.Iterator[Long] = handles.iterator var taskCount = 0 @@ -376,20 +378,6 @@ case class ColumnarRegionTaskExec( } }.asInstanceOf[Iterator[InternalRow]] } - - iter match { - case batch: Iterator[ColumnarBatch] => - batch.asInstanceOf[Iterator[ColumnarBatch]].flatMap { it => - it.rowIterator().flatMap { row => - computeWithRowIterator(row) - } - } - case _: Iterator[InternalRow] => - iter.flatMap { row => - computeWithRowIterator(row) - } - } - } override protected def doExecute(): RDD[InternalRow] = { diff --git a/core/src/main/scala/org/apache/spark/sql/tispark/TiHandleRDD.scala b/core/src/main/scala/org/apache/spark/sql/tispark/TiHandleRDD.scala index d05edc8450..da949d375b 100644 --- a/core/src/main/scala/org/apache/spark/sql/tispark/TiHandleRDD.scala +++ b/core/src/main/scala/org/apache/spark/sql/tispark/TiHandleRDD.scala @@ -15,31 +15,19 @@ package org.apache.spark.sql.tispark -import java.nio.ByteBuffer - -import com.pingcap.tikv.codec.CodecDataOutput -import com.pingcap.tikv.columnar.{ - TiChunk, - TiChunkColumnVector, - TiColumnVector, - TiColumnarBatchHelper -} import com.pingcap.tikv.meta.TiDAGRequest -import com.pingcap.tikv.types.{ArrayType, DataType, IntegerType} import com.pingcap.tikv.util.RangeSplitter import com.pingcap.tikv.{TiConfiguration, TiSession} +import com.pingcap.tispark.utils.TiUtil import com.pingcap.tispark.{TiPartition, TiTableReference} import gnu.trove.list.array.TLongArrayList import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.{Partition, TaskContext, TaskKilledException} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer /** * RDD used for retrieving handles from TiKV. Result is arranged as @@ -65,7 +53,7 @@ class TiHandleRDD( outputTypes.map(CatalystTypeConverters.createToCatalystConverter) override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = - new Iterator[ColumnarBatch] { + new Iterator[InternalRow] { checkTimezone() private val tiPartition = split.asInstanceOf[TiPartition] @@ -73,7 +61,7 @@ class TiHandleRDD( private val snapshot = session.createSnapshot(dagRequest.getStartTs) private[this] val tasks = tiPartition.tasks - private val handleIterator = snapshot.indexHandleReadRow(dagRequest, tasks) + private val handleIterator = snapshot.indexHandleRead(dagRequest, tasks) private val regionManager = session.getRegionManager private lazy val handleList = { val lst = new TLongArrayList() @@ -104,63 +92,14 @@ class TiHandleRDD( iterator.hasNext } - override def next(): ColumnarBatch = { - var numRows = 0 - val batchSize = 20480 - val cdi0 = new CodecDataOutput() - val cdi1 = new CodecDataOutput() - var offsets = new mutable.ArrayBuffer[Long] - var curOffset = 0L - while (hasNext && numRows < batchSize) { - val next = iterator.next - val regionId = next._1 - val handleList = next._2 - if (!handleList.isEmpty) { - // Returns RegionId:[handle1, handle2, handle3...] K-V pair -// val sparkRow = Row.apply(regionId, handleList.toArray()) -// TiUtil.rowToInternalRow(sparkRow, outputTypes, converters) - cdi0.writeLong(regionId) - cdi1.writeLong(handleList.size()) - for (i <- 0 until handleList.size()) { - cdi1.writeLong(handleList.get(i)) - } - offsets += curOffset - curOffset += handleList.size().toLong - numRows += 1 - } - } - offsets += curOffset - - val buffer0 = ByteBuffer.wrap(cdi0.toBytes) - val buffer1 = ByteBuffer.wrap(cdi1.toBytes) - - val nullBitMaps = DataType.setAllNotNullBitMapWithNumRows(numRows) - - val regionIdType = IntegerType.BIGINT - val handleListType = ArrayType.ARRAY + override def next(): InternalRow = { + val next = iterator.next + val regionId = next._1 + val handleList = next._2 - val childColumnVectors = new ArrayBuffer[TiColumnVector] - childColumnVectors += - new TiChunkColumnVector( - regionIdType, - regionIdType.getFixLen, - numRows, - 0, - nullBitMaps, - null, - buffer0) - childColumnVectors += - // any type will do? actual type is array[Long] - new TiChunkColumnVector( - handleListType, - 8, - curOffset.toInt, - 0, - nullBitMaps, - offsets.toArray, - buffer1) - val chunk = new TiChunk(childColumnVectors.toArray) - TiColumnarBatchHelper.createColumnarBatch(chunk) + // Returns RegionId:[handle1, handle2, handle3...] K-V pair + val sparkRow = Row.apply(regionId, handleList.toArray()) + TiUtil.rowToInternalRow(sparkRow, outputTypes, converters) } - }.asInstanceOf[Iterator[InternalRow]] + } } diff --git a/core/src/test/scala/com/pingcap/tispark/datasource/BaseDataSourceTest.scala b/core/src/test/scala/com/pingcap/tispark/datasource/BaseDataSourceTest.scala index c54d8bf175..d48c5bde75 100644 --- a/core/src/test/scala/com/pingcap/tispark/datasource/BaseDataSourceTest.scala +++ b/core/src/test/scala/com/pingcap/tispark/datasource/BaseDataSourceTest.scala @@ -224,8 +224,8 @@ class BaseDataSourceTest(val table: String, val database: String = "tispark_test if (!compResult(jdbcResult, tidbResult)) { logger.error(s"""Failed on $tblName\n - |TiDB via JDBC result: ${listToString(jdbcResult)}\n - |DataSourceAPI result: ${listToString(tidbResult)}""".stripMargin) + |DataSourceAPI result: ${listToString(jdbcResult)}\n + |TiDB via JDBC result: ${listToString(tidbResult)}""".stripMargin) fail() } } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/Snapshot.java b/tikv-client/src/main/java/com/pingcap/tikv/Snapshot.java index 62fcfe40f0..e63d3dc557 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/Snapshot.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/Snapshot.java @@ -131,11 +131,6 @@ private Iterator tableReadRow(TiDAGRequest dagRequest, List tas } } - public Iterator indexHandleReadChunk( - TiDAGRequest dagRequest, List tasks, int numOfRows) { - return getTiChunkIterator(dagRequest, tasks, getSession(), numOfRows); - } - /** * Below is lower level API for env like Spark which already did key range split Perform handle * scan @@ -144,7 +139,7 @@ public Iterator indexHandleReadChunk( * @param tasks RegionTask of the coprocessor request to send * @return Row iterator to iterate over resulting rows */ - public Iterator indexHandleReadRow(TiDAGRequest dagRequest, List tasks) { + public Iterator indexHandleRead(TiDAGRequest dagRequest, List tasks) { return getHandleIterator(dagRequest, tasks, session); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/columnar/BatchedTiChunkColumnVector.java b/tikv-client/src/main/java/com/pingcap/tikv/columnar/BatchedTiChunkColumnVector.java index 583bdbd187..e8437f2cfb 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/columnar/BatchedTiChunkColumnVector.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/columnar/BatchedTiChunkColumnVector.java @@ -44,6 +44,10 @@ public BatchedTiChunkColumnVector(List child, int numOfRows } } + public final String typeName() { + return dataType().getType().name(); + } + // TODO: once we switch off_heap mode, we need control memory access pattern. public void free() {} diff --git a/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiChunkColumnVector.java b/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiChunkColumnVector.java index 6a19ce1bcd..b81e9b2203 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiChunkColumnVector.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiChunkColumnVector.java @@ -17,10 +17,8 @@ import com.google.common.primitives.UnsignedLong; import com.pingcap.tikv.codec.CodecDataInput; -import com.pingcap.tikv.codec.CodecDataOutput; import com.pingcap.tikv.codec.MyDecimal; import com.pingcap.tikv.types.AbstractDateTimeType; -import com.pingcap.tikv.types.ArrayType; import com.pingcap.tikv.types.BitType; import com.pingcap.tikv.types.DataType; import com.pingcap.tikv.types.DateTimeType; @@ -57,7 +55,7 @@ public TiChunkColumnVector( byte[] nullBitMaps, long[] offsets, ByteBuffer data) { - super(dataType, numOfRows, buildColumnVectorFromOffsets(numOfRows, offsets)); + super(dataType, numOfRows); this.fixLength = fixLength; this.numOfNulls = numOfNulls; this.nullBitMaps = nullBitMaps; @@ -65,25 +63,8 @@ public TiChunkColumnVector( this.offsets = offsets; } - private static TiChunkColumnVector buildColumnVectorFromOffsets(int numOfRows, long[] offsets) { - if (offsets == null) { - return null; - } else { - DataType type = IntegerType.BIGINT; - int fixLength = type.getFixLen(); - CodecDataOutput cdo = new CodecDataOutput(); - for (long offset : offsets) { - cdo.writeLong(offset); - } - return new TiChunkColumnVector( - type, - fixLength, - numOfRows, - 0, - DataType.setAllNotNullBitMapWithNumRows(numOfRows), - null, - ByteBuffer.wrap(cdo.toBytes())); - } + public final String typeName() { + return dataType().getType().name(); } // TODO: once we switch off_heap mode, we need control memory access pattern. @@ -195,8 +176,6 @@ public long getLong(int rowId) { return getTime(rowId); } else if (type instanceof TimeType) { return data.getLong(rowId * fixLength); - } else if (type instanceof ArrayType) { - return data.getLong(rowId * fixLength + fixLength); } throw new UnsupportedOperationException("only IntegerType and Time related are supported."); diff --git a/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiColumnVector.java b/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiColumnVector.java index c83283212b..c22b13aa41 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiColumnVector.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/columnar/TiColumnVector.java @@ -37,7 +37,6 @@ public abstract class TiColumnVector implements AutoCloseable { private final int numOfRows; - private final TiColumnVector offsets; /** Data type for this column. */ protected DataType type; @@ -45,14 +44,6 @@ public abstract class TiColumnVector implements AutoCloseable { protected TiColumnVector(DataType type, int numOfRows) { this.type = type; this.numOfRows = numOfRows; - this.offsets = null; - } - - /** Sets up the data type of this column vector. */ - protected TiColumnVector(DataType type, int numOfRows, TiColumnVector offsets) { - this.type = type; - this.numOfRows = numOfRows; - this.offsets = offsets; } /** Returns the data type of this column vector. */ @@ -227,8 +218,4 @@ public double[] getDoubles(int rowId, int count) { public int numOfRows() { return numOfRows; } - - public TiColumnVector getOffset() { - return offsets; - } } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/ArrayType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/ArrayType.java deleted file mode 100644 index f75cf1ba5d..0000000000 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/ArrayType.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * - * Copyright 2020 PingCAP, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package com.pingcap.tikv.types; - -import com.pingcap.tidb.tipb.ExprType; -import com.pingcap.tikv.codec.Codec; -import com.pingcap.tikv.codec.Codec.IntegerCodec; -import com.pingcap.tikv.codec.CodecDataInput; -import com.pingcap.tikv.codec.CodecDataOutput; -import com.pingcap.tikv.exception.ConvertNotSupportException; -import com.pingcap.tikv.exception.ConvertOverflowException; -import com.pingcap.tikv.exception.TypeException; -import com.pingcap.tikv.meta.Collation; -import com.pingcap.tikv.meta.TiColumnInfo; - -public class ArrayType extends DataType { - public static final ArrayType ARRAY = new ArrayType(MySQLType.TypeArray); - - public static final MySQLType[] subTypes = new MySQLType[] {MySQLType.TypeArray}; - - protected ArrayType(MySQLType type, int flag, int len, int decimal) { - super(type, flag, len, decimal, "", Collation.DEF_COLLATION_CODE); - } - - protected ArrayType(MySQLType tp) { - super(tp); - } - - protected ArrayType(TiColumnInfo.InternalTypeHolder holder) { - super(holder); - } - - @Override - protected Object doConvertToTiDBType(Object value) - throws ConvertNotSupportException, ConvertOverflowException { - // TODO: support write to YEAR - if (this.getType() == MySQLType.TypeYear) { - throw new ConvertNotSupportException(value.getClass().getName(), this.getClass().getName()); - } - - Long result; - if (this.isUnsigned()) { - result = Converter.safeConvertToUnsigned(value, this.unsignedUpperBound()); - } else { - result = - Converter.safeConvertToSigned(value, this.signedLowerBound(), this.signedUpperBound()); - } - - return result; - } - - public boolean isSameCatalog(DataType other) { - return other instanceof IntegerType; - } - - /** {@inheritDoc} */ - @Override - protected Object decodeNotNull(int flag, CodecDataInput cdi) { - int count = cdi.readInt(); - long[] res = new long[count]; - for (int i = 0; i < count; i++) { - long ret; - switch (flag) { - case Codec.UVARINT_FLAG: - ret = IntegerCodec.readUVarLong(cdi); - break; - case Codec.UINT_FLAG: - ret = IntegerCodec.readULong(cdi); - break; - case Codec.VARINT_FLAG: - ret = IntegerCodec.readVarLong(cdi); - break; - case Codec.INT_FLAG: - ret = IntegerCodec.readLong(cdi); - break; - default: - throw new TypeException("Invalid IntegerType flag: " + flag); - } - res[i] = ret; - } - return res; - } - - /** {@inheritDoc} */ - @Override - protected void encodeKey(CodecDataOutput cdo, Object value) { - long longVal = Converter.convertToLong(value); - if (isUnsigned()) { - IntegerCodec.writeULongFully(cdo, longVal, true); - } else { - IntegerCodec.writeLongFully(cdo, longVal, true); - } - } - - /** {@inheritDoc} */ - @Override - protected void encodeValue(CodecDataOutput cdo, Object value) { - long longVal = Converter.convertToLong(value); - if (isUnsigned()) { - IntegerCodec.writeULongFully(cdo, longVal, false); - } else { - IntegerCodec.writeLongFully(cdo, longVal, false); - } - } - - /** {@inheritDoc} */ - @Override - protected void encodeProto(CodecDataOutput cdo, Object value) { - long longVal = Converter.convertToLong(value); - if (isUnsigned()) { - IntegerCodec.writeULong(cdo, longVal); - } else { - IntegerCodec.writeLong(cdo, longVal); - } - } - - @Override - public String getName() { - if (isUnsigned()) { - return "UNSIGNED LONG"; - } - return "LONG"; - } - - @Override - public ExprType getProtoExprType() { - return isUnsigned() ? ExprType.Uint64 : ExprType.Int64; - } - - public boolean isUnsignedLong() { - return tp == MySQLType.TypeLonglong && isUnsigned(); - } - - /** {@inheritDoc} */ - @Override - public Object getOriginDefaultValueNonNull(String value, long version) { - return Long.parseLong(value); - } -} diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/DataType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/DataType.java index 853c02f79a..c7cee3d2c1 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/DataType.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/DataType.java @@ -83,7 +83,7 @@ public abstract class DataType implements Serializable { protected final long length; private final String charset; private final List elems; - private static final byte[] allNotNullBitMap = initAllNotNullBitMap(); + private final byte[] allNotNullBitMap = initAllNotNullBitMap(); private final byte[] readBuffer = new byte[8]; public DataType(MySQLType tp, int prec, int scale) { @@ -203,7 +203,7 @@ protected Object decodeNotNullForBatchWrite(int flag, CodecDataInput cdi) { return decodeNotNull(flag, cdi); } - public int getFixLen() { + private int getFixLen() { switch (this.getType()) { case TypeFloat: return 4; @@ -229,12 +229,7 @@ public int getFixLen() { } } - public static byte[] setAllNotNullBitMapWithNumRows(int numRows) { - int numNullBitmapBytes = (numRows + 7) / 8; - return setAllNotNull(numNullBitmapBytes); - } - - private static byte[] setAllNotNull(int numNullBitMapBytes) { + private byte[] setAllNotNull(int numNullBitMapBytes) { byte[] nullBitMaps = new byte[numNullBitMapBytes]; for (int i = 0; i < numNullBitMapBytes; ) { // allNotNullBitNMap's actual length @@ -246,7 +241,7 @@ private static byte[] setAllNotNull(int numNullBitMapBytes) { return nullBitMaps; } - private static byte[] initAllNotNullBitMap() { + private byte[] initAllNotNullBitMap() { byte[] allNotNullBitMap = new byte[128]; Arrays.fill(allNotNullBitMap, (byte) 0xFF); return allNotNullBitMap; diff --git a/tikv-client/src/main/java/com/pingcap/tikv/types/MySQLType.java b/tikv-client/src/main/java/com/pingcap/tikv/types/MySQLType.java index 8e1b733fd0..7fb59adf37 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/types/MySQLType.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/types/MySQLType.java @@ -52,8 +52,7 @@ public enum MySQLType { TypeBlob(0xfc, 65535, 2, -1), TypeVarString(0xfd, 255, 1, -1), TypeString(0xfe, 255, 1, 1), - TypeGeometry(0xff, 1024, 1, -1), - TypeArray(100, 65536, 2, -1); + TypeGeometry(0xff, 1024, 1, -1); private static final Map typeMap = new HashMap<>(); private static final Map sizeMap = new HashMap<>(); From 327a6b01bec5fdcf471394c36df9b3423697ff9a Mon Sep 17 00:00:00 2001 From: Liangliang Gu Date: Thu, 20 Aug 2020 14:19:29 +0800 Subject: [PATCH 4/7] BatchWrite: add parameter regionSplitUsingSize (#1568) --- .../pingcap/tispark/write/TiBatchWrite.scala | 44 +++++++++++++++---- .../pingcap/tispark/write/TiDBOptions.scala | 6 +++ 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala index 02b61ddbaa..0a0a9a18ea 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala @@ -407,19 +407,34 @@ class TiBatchWrite( val sampleSize = (regionSplitPointNum + 1) * options.sampleSplitFrac logger.info(s"sampleSize=$sampleSize") - val sampleData = - rdd.map(_._1).sample(withReplacement = false, sampleSize.toDouble / count).collect() + val sampleData = rdd.sample(false, sampleSize.toDouble / count).collect() logger.info(s"sampleData size=${sampleData.length}") - val sortedSampleData = sampleData.sorted(new Ordering[SerializableKey] { - override def compare(x: SerializableKey, y: SerializableKey): Int = { - x.compareTo(y) + val finalRegionSplitPointNum = if (options.regionSplitUsingSize) { + val avgSize = getAverageSizeInBytes(sampleData) + logger.info(s"avgSize=$avgSize Bytes") + if (avgSize <= options.bytesPerRegion / options.regionSplitKeys) { + regionSplitPointNum + } else { + Math.min( + Math.floor((count.toDouble / options.bytesPerRegion) * avgSize).toInt, + sampleData.length / 10) } - }) + } else { + regionSplitPointNum + } + logger.info(s"finalRegionSplitPointNum=$finalRegionSplitPointNum") - val orderedSplitPoints = new Array[SerializableKey](regionSplitPointNum) - val step = Math.floor(sortedSampleData.length.toDouble / (regionSplitPointNum + 1)).toInt - for (i <- 0 until regionSplitPointNum) { + val sortedSampleData = sampleData + .map(_._1) + .sorted(new Ordering[SerializableKey] { + override def compare(x: SerializableKey, y: SerializableKey): Int = { + x.compareTo(y) + } + }) + val orderedSplitPoints = new Array[SerializableKey](finalRegionSplitPointNum) + val step = Math.floor(sortedSampleData.length.toDouble / (finalRegionSplitPointNum + 1)).toInt + for (i <- 0 until finalRegionSplitPointNum) { orderedSplitPoints(i) = sortedSampleData((i + 1) * step) } @@ -427,6 +442,17 @@ class TiBatchWrite( orderedSplitPoints.toList } + private def getAverageSizeInBytes(keyValues: Array[(SerializableKey, Array[Byte])]): Int = { + var avg: Double = 0 + var t: Int = 1 + keyValues.foreach { keyValue => + val keySize: Double = keyValue._1.bytes.length + keyValue._2.length + avg = avg + (keySize - avg) / t + t = t + 1 + } + Math.ceil(avg).toInt + } + private def getUseTableLock: Boolean = { if (!options.useTableLock(StoreVersion.minTiKVVersion("4.0.0", tiSession.getPDClient))) { false diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala b/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala index 1e28a90104..f20fcace7d 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiDBOptions.scala @@ -105,6 +105,10 @@ class TiDBOptions(@transient val parameters: CaseInsensitiveMap[String]) extends val minRegionSplitNum: Int = getOrDefault(TIDB_MIN_REGION_SPLIT_NUM, "4").toInt val regionSplitThreshold: Int = getOrDefault(TIDB_REGION_SPLIT_THRESHOLD, "100000").toInt val splitRegionBackoffMS: Int = getOrDefault(TIDB_SPLIT_REGION_BACKOFFER_MS, "120000").toInt + val regionSplitUsingSize: Boolean = + getOrDefault(TIDB_REGION_SPLIT_USING_SIZE, "false").toBoolean + //96M + val bytesPerRegion: Int = getOrDefault(TIDB_BYTES_PER_REGION, "100663296").toInt // ------------------------------------------------------------ // Calculated parameters @@ -219,6 +223,8 @@ object TiDBOptions { val TIDB_MIN_REGION_SPLIT_NUM: String = newOption("minRegionSplitNum") val TIDB_REGION_SPLIT_THRESHOLD: String = newOption("regionSplitThreshold") val TIDB_SPLIT_REGION_BACKOFFER_MS: String = newOption("splitRegionBackoffMS") + val TIDB_REGION_SPLIT_USING_SIZE: String = newOption("regionSplitUsingSize") + val TIDB_BYTES_PER_REGION: String = newOption("bytesPerRegion") // ------------------------------------------------------------ // parameters only for test From c3730e9321ea2310cbf4cf806e3c33abd0dbd2ab Mon Sep 17 00:00:00 2001 From: Liangliang Gu Date: Fri, 21 Aug 2020 12:04:48 +0800 Subject: [PATCH 5/7] BatchWrite: add backoffer for getRegion operation (#1569) --- .../com/pingcap/tikv/TwoPhaseCommitter.java | 16 +++++---- .../pingcap/tikv/region/RegionManager.java | 35 ++++++++++++++----- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/tikv-client/src/main/java/com/pingcap/tikv/TwoPhaseCommitter.java b/tikv-client/src/main/java/com/pingcap/tikv/TwoPhaseCommitter.java index 4e0da2b567..5c771d0a00 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/TwoPhaseCommitter.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/TwoPhaseCommitter.java @@ -143,7 +143,7 @@ public void prewritePrimaryKey(BackOffer backOffer, byte[] primaryKey, byte[] va private void doPrewritePrimaryKeyWithRetry(BackOffer backOffer, ByteString key, ByteString value) throws TiBatchWriteException { - Pair pair = this.regionManager.getRegionStorePairByKey(key); + Pair pair = this.regionManager.getRegionStorePairByKey(key, backOffer); TiRegion tiRegion = pair.first; Metapb.Store store = pair.second; @@ -198,7 +198,7 @@ public void commitPrimaryKey(BackOffer backOffer, byte[] key, long commitTs) private void doCommitPrimaryKeyWithRetry(BackOffer backOffer, ByteString key, long commitTs) throws TiBatchWriteException { - Pair pair = this.regionManager.getRegionStorePairByKey(key); + Pair pair = this.regionManager.getRegionStorePairByKey(key, backOffer); TiRegion tiRegion = pair.first; Metapb.Store store = pair.second; ByteString[] keys = new ByteString[] {key}; @@ -331,7 +331,7 @@ private void doPrewriteSecondaryKeysInBatchesWithRetry( } // groups keys by region - GroupKeyResult groupResult = this.groupKeysByRegion(keys, size); + GroupKeyResult groupResult = this.groupKeysByRegion(keys, size, backOffer); List batchKeyList = new LinkedList<>(); Map, List> groupKeyMap = groupResult.getGroupsResult(); @@ -344,7 +344,8 @@ private void doPrewriteSecondaryKeysInBatchesWithRetry( // For prewrite, stop sending other requests after receiving first error. for (BatchKeys batchKeys : batchKeyList) { TiRegion oldRegion = batchKeys.getRegion(); - TiRegion currentRegion = this.regionManager.getRegionByKey(oldRegion.getStartKey()); + TiRegion currentRegion = + this.regionManager.getRegionByKey(oldRegion.getStartKey(), backOffer); if (oldRegion.equals(currentRegion)) { doPrewriteSecondaryKeySingleBatchWithRetry(backOffer, primaryKey, batchKeys, mutations); } else { @@ -565,7 +566,7 @@ private void doCommitSecondaryKeysWithRetry( } // groups keys by region - GroupKeyResult groupResult = this.groupKeysByRegion(keys, size); + GroupKeyResult groupResult = this.groupKeysByRegion(keys, size, backOffer); List batchKeyList = new ArrayList<>(); Map, List> groupKeyMap = groupResult.getGroupsResult(); @@ -609,14 +610,15 @@ private void doCommitSecondaryKeySingleBatchWithRetry( batchKeys.getRegion().getId()); } - private GroupKeyResult groupKeysByRegion(ByteString[] keys, int size) + private GroupKeyResult groupKeysByRegion(ByteString[] keys, int size, BackOffer backOffer) throws TiBatchWriteException { Map, List> groups = new HashMap<>(); int index = 0; try { for (; index < size; index++) { ByteString key = keys[index]; - Pair pair = this.regionManager.getRegionStorePairByKey(key); + Pair pair = + this.regionManager.getRegionStorePairByKey(key, backOffer); if (pair != null) { groups.computeIfAbsent(pair, e -> new ArrayList<>()).add(key); } diff --git a/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java b/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java index 77def7f48f..2d95df94c0 100644 --- a/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java +++ b/tikv-client/src/main/java/com/pingcap/tikv/region/RegionManager.java @@ -65,7 +65,11 @@ public Function getCacheInvalidateCallback() { } public TiRegion getRegionByKey(ByteString key) { - return cache.getRegionByKey(key); + return getRegionByKey(key, ConcreteBackOffer.newGetBackOff()); + } + + public TiRegion getRegionByKey(ByteString key, BackOffer backOffer) { + return cache.getRegionByKey(key, backOffer); } @Deprecated @@ -79,12 +83,21 @@ public TiRegion getRegionById(long regionId) { return cache.getRegionById(ConcreteBackOffer.newGetBackOff(), regionId); } + public Pair getRegionStorePairByKey(ByteString key, BackOffer backOffer) { + return getRegionStorePairByKey(key, TiStoreType.TiKV, backOffer); + } + public Pair getRegionStorePairByKey(ByteString key) { return getRegionStorePairByKey(key, TiStoreType.TiKV); } public Pair getRegionStorePairByKey(ByteString key, TiStoreType storeType) { - TiRegion region = cache.getRegionByKey(key); + return getRegionStorePairByKey(key, storeType, ConcreteBackOffer.newGetBackOff()); + } + + public Pair getRegionStorePairByKey( + ByteString key, TiStoreType storeType, BackOffer backOffer) { + TiRegion region = cache.getRegionByKey(key, backOffer); if (region == null) { throw new TiClientInternalException("Region not exist for key:" + formatBytesUTF8(key)); } @@ -95,11 +108,11 @@ public Pair getRegionStorePairByKey(ByteString key, TiStoreType Store store = null; if (storeType == TiStoreType.TiKV) { Peer leader = region.getLeader(); - store = cache.getStoreById(leader.getStoreId()); + store = cache.getStoreById(leader.getStoreId(), backOffer); } else { outerLoop: for (Peer peer : region.getLearnerList()) { - Store s = getStoreById(peer.getStoreId()); + Store s = getStoreById(peer.getStoreId(), backOffer); for (Metapb.StoreLabel label : s.getLabelsList()) { if (label.getKey().equals(storeType.getLabelKey()) && label.getValue().equals(storeType.getLabelValue())) { @@ -123,7 +136,11 @@ public Pair getRegionStorePairByKey(ByteString key, TiStoreType } public Store getStoreById(long id) { - return cache.getStoreById(id); + return getStoreById(id, ConcreteBackOffer.newGetBackOff()); + } + + public Store getStoreById(long id, BackOffer backOffer) { + return cache.getStoreById(id, backOffer); } public void onRegionStale(long regionId) { @@ -181,7 +198,7 @@ public RegionCache(ReadOnlyPDClient pdClient) { this.pdClient = pdClient; } - public synchronized TiRegion getRegionByKey(ByteString key) { + public synchronized TiRegion getRegionByKey(ByteString key, BackOffer backOffer) { Long regionId; regionId = keyToRegionIdCache.get(Key.toRawKey(key)); if (logger.isDebugEnabled()) { @@ -191,7 +208,7 @@ public synchronized TiRegion getRegionByKey(ByteString key) { if (regionId == null) { logger.debug("Key not find in keyToRegionIdCache:" + formatBytesUTF8(key)); - TiRegion region = pdClient.getRegionByKey(ConcreteBackOffer.newGetBackOff(), key); + TiRegion region = pdClient.getRegionByKey(backOffer, key); if (!putRegion(region)) { throw new TiClientInternalException("Invalid Region: " + region.toString()); } @@ -264,11 +281,11 @@ public synchronized void invalidateStore(long storeId) { storeCache.remove(storeId); } - public synchronized Store getStoreById(long id) { + public synchronized Store getStoreById(long id, BackOffer backOffer) { try { Store store = storeCache.get(id); if (store == null) { - store = pdClient.getStore(ConcreteBackOffer.newGetBackOff(), id); + store = pdClient.getStore(backOffer, id); } if (store.getState().equals(StoreState.Tombstone)) { return null; From f3f1d103e6229eb375cde989d76681191ef3e45b Mon Sep 17 00:00:00 2001 From: birdstorm Date: Fri, 21 Aug 2020 12:05:27 +0800 Subject: [PATCH 6/7] Fix covering index generates incorrect plan when first column is not included in index (#1573) --- .../sql/catalyst/expressions/ExprUtils.scala | 8 ++++++-- .../org/apache/spark/sql/IssueTestSuite.scala | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 67cc962087..9bd2c91dc5 100644 --- a/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/core/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -80,8 +80,12 @@ object ExprUtils { val col = meta.getColumns.asScala.filter(col => col.isPrimaryKey).head ColumnRef.create(col.getName, meta) } else { - val firstCol = meta.getColumns.get(0) - ColumnRef.create(firstCol.getName, meta) + if (dagRequest.getFields.isEmpty) { + val firstCol = meta.getColumns.get(0) + ColumnRef.create(firstCol.getName, meta) + } else { + dagRequest.getFields.head + } } dagRequest.addRequiredColumn(firstColRef) diff --git a/core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala b/core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala index 1dc4df3f87..6ad17796a3 100644 --- a/core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/IssueTestSuite.scala @@ -19,6 +19,21 @@ import com.pingcap.tispark.TiConfigConst import org.apache.spark.sql.functions.{col, sum} class IssueTestSuite extends BaseTiSparkTest { + // https://github.com/pingcap/tispark/issues/1570 + test("Fix covering index generates incorrect plan when first column is not included in index") { + tidbStmt.execute("DROP TABLE IF EXISTS tt") + tidbStmt.execute("create table tt(c varchar(12), dt datetime, key dt_index(dt))") + tidbStmt.execute(""" + |INSERT INTO tt VALUES + | ('aa', '2007-09-01 00:00:00'), + | ('bb', '2007-09-02 00:00:00'), + | ('cc', '2007-09-03 00:00:00'), + | ('dd', '2007-09-04 00:00:00')""".stripMargin) + + runTest( + "select count(*) from tt where dt >= timestamp '2007-09-01 00:00:01' and dt <= timestamp '2007-09-04 00:00:00'") + } + test("fix partition table pruning when partdef contains bigint") { tidbStmt.execute("DROP TABLE IF EXISTS t") tidbStmt.execute( From 130488e49dd828069a346a2981dfdf04f4a7d909 Mon Sep 17 00:00:00 2001 From: Liangliang Gu Date: Fri, 21 Aug 2020 12:05:50 +0800 Subject: [PATCH 7/7] BatchWrite: split region before prewrite (#1572) --- .../pingcap/tispark/write/TiBatchWrite.scala | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala index 0a0a9a18ea..4f80963a46 100644 --- a/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala +++ b/core/src/main/scala/com/pingcap/tispark/write/TiBatchWrite.scala @@ -204,24 +204,6 @@ class TiBatchWrite( val shuffledRDDCount = shuffledRDD.count() logger.info(s"write kv data count=$shuffledRDDCount") - if (options.enableRegionSplit && "v2".equals(options.regionSplitMethod)) { - // calculate region split points - val orderedSplitPoints = getRegionSplitPoints(shuffledRDD, shuffledRDDCount) - - // split region - try { - tiSession.splitRegionAndScatter( - orderedSplitPoints.map(_.bytes).asJava, - options.splitRegionBackoffMS, - options.scatterWaitMS) - } catch { - case e: Throwable => logger.warn("split region and scatter error!", e) - } - - // shuffle according to split points - shuffledRDD = shuffledRDD.partitionBy(new TiReginSplitPartitioner(orderedSplitPoints)) - } - // take one row as primary key val (primaryKey: SerializableKey, primaryRow: Array[Byte]) = { val takeOne = shuffledRDD.take(1) @@ -240,6 +222,23 @@ class TiBatchWrite( !keyValue._1.equals(primaryKey) } + // split region + if (options.enableRegionSplit && "v2".equals(options.regionSplitMethod)) { + val orderedSplitPoints = getRegionSplitPoints(shuffledRDD, shuffledRDDCount) + + try { + tiSession.splitRegionAndScatter( + orderedSplitPoints.map(_.bytes).asJava, + options.splitRegionBackoffMS, + options.scatterWaitMS) + } catch { + case e: Throwable => logger.warn("split region and scatter error!", e) + } + + // shuffle according to split points + shuffledRDD = shuffledRDD.partitionBy(new TiReginSplitPartitioner(orderedSplitPoints)) + } + // driver primary pre-write val ti2PCClient = new TwoPhaseCommitter(