Skip to content

Commit

Permalink
Merge branch 'feature/improveMptPerformance' into phase/daedalus
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Tallar committed Sep 25, 2017
2 parents 514f94b + 104247e commit c1301d9
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 58 deletions.
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,20 @@ val dep = {

val Integration = config("it") extend Test

val Benchmark = config("benchmark") extend Test

val Evm = config("evm") extend Test

val Ets = config("ets") extend Test

val Snappy = config("snappy") extend Test

val root = project.in(file("."))
.configs(Integration, Evm, Ets, Snappy)
.configs(Integration, Benchmark, Evm, Ets, Snappy)
.settings(commonSettings: _*)
.settings(libraryDependencies ++= dep)
.settings(inConfig(Integration)(Defaults.testSettings) : _*)
.settings(inConfig(Benchmark)(Defaults.testSettings) : _*)
.settings(inConfig(Evm)(Defaults.testSettings) : _*)
.settings(inConfig(Ets)(Defaults.testSettings) : _*)
.settings(inConfig(Snappy)(Defaults.testSettings) : _*)
Expand Down
2 changes: 2 additions & 0 deletions circle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ test:
fi
fi

# As the benchmark tests don't test functionality and should be manually ran, having them compile is enough
- sbt benchmark:compile
# snappy test is not run on Circle - this is just to prevent compilation regression
- sbt snappy:compile

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package io.iohk.ethereum.mpt

import java.io.File
import java.nio.file.Files

import io.iohk.ethereum.db.dataSource.{EphemDataSource, LevelDBDataSource, LevelDbConfig}
import io.iohk.ethereum.db.storage.{ArchiveNodeStorage, NodeStorage}
import io.iohk.ethereum.mpt.MerklePatriciaTrie.defaultByteArraySerializable
import io.iohk.ethereum.utils.Logger
import io.iohk.ethereum.{ObjectGenerators, crypto}
import org.scalatest.FunSuite
import org.scalatest.prop.PropertyChecks
import org.spongycastle.util.encoders.Hex

class MerklePatriciaTreeSpeedSpec extends FunSuite
with PropertyChecks
with ObjectGenerators
with Logger
with PersistentStorage {

test("Performance test (From: https://github.com/ethereum/wiki/wiki/Benchmarks)") {
val Rounds = 1000
val Symmetric = true

val start: Long = System.currentTimeMillis
val emptyTrie = MerklePatriciaTrie[Array[Byte], Array[Byte]](new ArchiveNodeStorage(new NodeStorage(EphemDataSource())))
var seed: Array[Byte] = Array.fill(32)(0.toByte)

val trieResult = (0 until Rounds).foldLeft(emptyTrie) { case (recTrie, i) =>
seed = Node.hashFn(seed)
if (!Symmetric) recTrie.put(seed, seed)
else {
val mykey = seed
seed = Node.hashFn(seed)
val myval = if ((seed(0) & 0xFF) % 2 == 1) Array[Byte](seed.last) else seed
recTrie.put(mykey, myval)
}
}
val rootHash = Hex.toHexString(trieResult.getRootHash)

log.debug("Time taken(ms): " + (System.currentTimeMillis - start))
log.debug("Root hash obtained: " + rootHash)

if (Symmetric) assert(rootHash.take(4) == "36f6" && rootHash.drop(rootHash.length - 4) == "93a3")
else assert(rootHash.take(4) == "da8a" && rootHash.drop(rootHash.length - 4) == "0ca4")
}

test("MPT benchmark") {
withNodeStorage { ns =>
val hashFn = crypto.kec256(_: Array[Byte])

val defaultByteArraySer = MerklePatriciaTrie.defaultByteArraySerializable
val EmptyTrie = MerklePatriciaTrie[Array[Byte], Array[Byte]](ns)(defaultByteArraySer, defaultByteArraySer)

var t = System.currentTimeMillis()
(1 to 20000000).foldLeft(EmptyTrie){case (trie, i) =>
val k = hashFn(("hello" + i).getBytes)
val v = hashFn(("world" + i).getBytes)

if (i % 100000 == 0) {
val newT = System.currentTimeMillis()
val delta = (newT - t) / 1000.0
t = newT
log.debug(s"=== $i elements put, time for batch is: $delta sec")
}
trie.put(k, v)
}
}
}

}

trait PersistentStorage {
def withNodeStorage(testCode: NodesKeyValueStorage => Unit): Unit = {
val dbPath = Files.createTempDirectory("testdb").toAbsolutePath.toString
val dataSource = LevelDBDataSource(new LevelDbConfig {
override val verifyChecksums: Boolean = true
override val paranoidChecks: Boolean = true
override val createIfMissing: Boolean = true
override val path: String = dbPath
})

try {
testCode(new ArchiveNodeStorage(new NodeStorage(dataSource)))
} finally {
val dir = new File(dbPath)
!dir.exists() || dir.delete()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ import java.nio.file.Files
import java.security.MessageDigest

import io.iohk.ethereum.ObjectGenerators
import io.iohk.ethereum.crypto.kec256
import io.iohk.ethereum.db.dataSource.{EphemDataSource, LevelDBDataSource, LevelDbConfig}
import io.iohk.ethereum.db.storage.{ArchiveNodeStorage, NodeStorage, ReferenceCountNodeStorage}
import io.iohk.ethereum.db.dataSource.{LevelDBDataSource, LevelDbConfig}
import io.iohk.ethereum.db.storage.{ArchiveNodeStorage, NodeStorage}
import io.iohk.ethereum.mpt.MerklePatriciaTrie.defaultByteArraySerializable
import io.iohk.ethereum.utils.Logger
import org.scalatest.FunSuite
Expand Down Expand Up @@ -117,36 +116,6 @@ class MerklePatriciaTreeIntegrationSuite extends FunSuite
}
}

/* Performance test */
test("Performance test (From: https://github.com/ethereum/wiki/wiki/Benchmarks)") {
withNodeStorage { ns =>
val EmptyTrie = MerklePatriciaTrie[Array[Byte], Array[Byte]](ns)
val Rounds = 1000
val Symmetric = true

val start: Long = System.currentTimeMillis
val emptyTrie = MerklePatriciaTrie[Array[Byte], Array[Byte]](new ArchiveNodeStorage(new NodeStorage(EphemDataSource())))
var seed: Array[Byte] = Array.fill(32)(0.toByte)

val trieResult = (0 until Rounds).foldLeft(emptyTrie) { case (recTrie, i) =>
seed = Node.hashFn(seed)
if (!Symmetric) recTrie.put(seed, seed)
else {
val mykey = seed
seed = Node.hashFn(seed)
val myval = if ((seed(0) & 0xFF) % 2 == 1) Array[Byte](seed.last) else seed
recTrie.put(mykey, myval)
}
}
val rootHash = Hex.toHexString(trieResult.getRootHash)

log.debug("Time taken(ms): " + (System.currentTimeMillis - start))
log.debug("Root hash obtained: " + rootHash)

if (Symmetric) assert(rootHash.take(4) == "36f6" && rootHash.drop(rootHash.length - 4) == "93a3")
else assert(rootHash.take(4) == "da8a" && rootHash.drop(rootHash.length - 4) == "0ca4")
}
}
}

trait PersistentStorage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class DumpChainActor(peerManager: ActorRef, peerMessageBus: ActorRef, startBlock

val children = nodes.flatMap {
case n: BranchNode => n.children.collect { case Some(Left(h)) => h }
case ExtensionNode(_, Left(h)) => Seq(h)
case ExtensionNode(_, Left(h), _, _) => Seq(h)
case n: LeafNode => Seq.empty
case _ => Seq.empty
}
Expand Down Expand Up @@ -125,7 +125,7 @@ class DumpChainActor(peerManager: ActorRef, peerMessageBus: ActorRef, startBlock
val cNodes = NodeData(contractNodes).values.indices.map(i => NodeData(contractNodes).getMptNode(i))
contractChildren = contractChildren ++ cNodes.flatMap {
case n: BranchNode => n.children.collect { case Some(Left(h)) => h }
case ExtensionNode(_, Left(h)) => Seq(h)
case ExtensionNode(_, Left(h), _, _) => Seq(h)
case _ => Seq.empty
}

Expand Down
32 changes: 16 additions & 16 deletions src/main/scala/io/iohk/ethereum/mpt/MerklePatriciaTrie.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object MerklePatriciaTrie {
val nodeEncoded =
if (nodeId.length < 32) nodeId
else source.get(ByteString(nodeId)).getOrElse(throw MPTException(s"Node not found ${Hex.toHexString(nodeId)}, trie is inconsistent"))
decodeRLP[MptNode](nodeEncoded)
decodeRLP[MptNode](nodeEncoded).withCachedHash(nodeId).withCachedRlpEncoded(nodeEncoded)
}

private def matchingLength(a: Array[Byte], b: Array[Byte]): Int = a.zip(b).takeWhile(t => t._1 == t._2).length
Expand Down Expand Up @@ -236,16 +236,16 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]

@tailrec
private def get(node: MptNode, searchKey: Array[Byte]): Option[Array[Byte]] = node match {
case LeafNode(key, value) =>
case LeafNode(key, value, _, _) =>
if (key.toArray[Byte] sameElements searchKey) Some(value.toArray[Byte]) else None
case extNode@ExtensionNode(sharedKey, _) =>
case extNode@ExtensionNode(sharedKey, _, _, _) =>
val (commonKey, remainingKey) = searchKey.splitAt(sharedKey.length)
if (searchKey.length >= sharedKey.length && (sharedKey sameElements commonKey)) {
val nextNode = getNextNode(extNode, nodeStorage)
get(nextNode, remainingKey)
}
else None
case branch@BranchNode(_, terminator) =>
case branch@BranchNode(_, terminator, _, _) =>
if (searchKey.isEmpty) terminator.map(_.toArray[Byte])
else getChild(branch, searchKey(0), nodeStorage) match {
case Some(child) => get(child, searchKey.slice(1, searchKey.length))
Expand All @@ -260,7 +260,7 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]
}

private def putInLeafNode(node: LeafNode, searchKey: Array[Byte], value: Array[Byte]): NodeInsertResult = {
val LeafNode(existingKey, storedValue) = node
val LeafNode(existingKey, storedValue, _, _) = node
matchingLength(existingKey.toArray[Byte], searchKey) match {
case ml if ml == existingKey.length && ml == searchKey.length =>
// We are trying to insert a leaf node that has the same key as this one but different value so we need to
Expand Down Expand Up @@ -304,7 +304,7 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]
}

private def putInExtensionNode(extensionNode: ExtensionNode, searchKey: Array[Byte], value: Array[Byte]): NodeInsertResult = {
val ExtensionNode(sharedKey, next) = extensionNode
val ExtensionNode(sharedKey, next, _, _) = extensionNode
matchingLength(sharedKey.toArray[Byte], searchKey) match {
case 0 =>
// There is no common prefix with the node which means we have to replace it for a branch node
Expand Down Expand Up @@ -349,7 +349,7 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]
}

private def putInBranchNode(branchNode: BranchNode, searchKey: Array[Byte], value: Array[Byte]): NodeInsertResult = {
val BranchNode(children, _) = branchNode
val BranchNode(children, _, _, _) = branchNode
if (searchKey.isEmpty) {
// The key is empty, the branch node should now be a terminator node with the new value asociated with it
val newBranchNode = BranchNode(children, Some(ByteString(value)))
Expand Down Expand Up @@ -395,13 +395,13 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]

private def removeFromBranchNode(node: BranchNode, searchKey: Array[Byte]): NodeRemoveResult = (node, searchKey.isEmpty) match {
// They key matches a branch node but it's value doesn't match the key
case (BranchNode(_, None), true) => NodeRemoveResult(hasChanged = false, newNode = None)
case (BranchNode(_, None, _, _), true) => NodeRemoveResult(hasChanged = false, newNode = None)
// We want to delete Branch node value
case (BranchNode(children, _), true) =>
case (BranchNode(children, _, _, _), true) =>
// We need to remove old node and fix it because we removed the value
val fixedNode = fix(BranchNode(children, None), nodeStorage, Nil)
NodeRemoveResult(hasChanged = true, newNode = Some(fixedNode), toDeleteFromStorage = Seq(node), toUpdateInStorage = Seq(fixedNode))
case (branchNode@BranchNode(children, optStoredValue), false) =>
case (branchNode@BranchNode(children, optStoredValue, _, _), false) =>
// We might be trying to remove a node that's inside one of the 16 mapped nibbles
val searchKeyHead = searchKey(0)
getChild(branchNode, searchKeyHead, nodeStorage) map { child =>
Expand Down Expand Up @@ -435,7 +435,7 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]
}

private def removeFromLeafNode(leafNode: LeafNode, searchKey: Array[Byte]): NodeRemoveResult = {
val LeafNode(existingKey, _) = leafNode
val LeafNode(existingKey, _, _, _) = leafNode
if (existingKey sameElements searchKey) {
// We found the node to delete
NodeRemoveResult(hasChanged = true, newNode = None, toDeleteFromStorage = Seq(leafNode))
Expand All @@ -444,7 +444,7 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]
}

private def removeFromExtensionNode(extensionNode: ExtensionNode, searchKey: Array[Byte]): NodeRemoveResult = {
val ExtensionNode(sharedKey, _) = extensionNode
val ExtensionNode(sharedKey, _, _, _) = extensionNode
val cp = matchingLength(sharedKey.toArray[Byte], searchKey)
if (cp == sharedKey.length) {
// A child node of this extension is removed, so move forward
Expand Down Expand Up @@ -487,7 +487,7 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]
*/
@tailrec
private def fix(node: MptNode, nodeStorage: NodesKeyValueStorage, notStoredYet: Seq[MptNode]): MptNode = node match {
case BranchNode(children, optStoredValue) =>
case BranchNode(children, optStoredValue, _, _) =>
val usedIndexes = children.indices.foldLeft[Seq[Int]](Nil) {
(acc, i) =>
if (children(i).isDefined) i +: acc else acc
Expand All @@ -500,7 +500,7 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]
case (Nil, Some(value)) => LeafNode(ByteString.empty, value)
case _ => node
}
case extensionNode@ExtensionNode(sharedKey, _) =>
case extensionNode@ExtensionNode(sharedKey, _, _, _) =>
val nextNode = extensionNode.next match {
case Left(nextHash) =>
// If the node is not in the extension node then it might be a node to be inserted at the end of this remove
Expand All @@ -512,9 +512,9 @@ class MerklePatriciaTrie[K, V] private (private val rootHash: Option[Array[Byte]
}
val newNode = nextNode match {
// Compact Two extensions into one
case ExtensionNode(subSharedKey, subNext) => ExtensionNode(sharedKey ++ subSharedKey, subNext)
case ExtensionNode(subSharedKey, subNext, _, _) => ExtensionNode(sharedKey ++ subSharedKey, subNext)
// Compact the extension and the leaf into the same leaf node
case LeafNode(subRemainingKey, subValue) => LeafNode(sharedKey ++ subRemainingKey, subValue)
case LeafNode(subRemainingKey, subValue, _, _) => LeafNode(sharedKey ++ subRemainingKey, subValue)
// It's ok
case _: BranchNode => node
}
Expand Down
33 changes: 27 additions & 6 deletions src/main/scala/io/iohk/ethereum/mpt/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ import io.iohk.ethereum.rlp.{encode => encodeRLP}
/**
* Trie elements
*/
sealed trait MptNode {
sealed abstract class MptNode {
val cachedHash: Option[Array[Byte]]
val cachedRlpEncoded: Option[Array[Byte]]

import MerklePatriciaTrie._

lazy val encode: Array[Byte] = encodeRLP[MptNode](this)
def withCachedHash(cachedHash: Array[Byte]): MptNode

lazy val hash: Array[Byte] = Node.hashFn(encode)
def withCachedRlpEncoded(cachedEncode: Array[Byte]): MptNode

lazy val encode: Array[Byte] = cachedRlpEncoded.getOrElse(encodeRLP[MptNode](this))

lazy val hash: Array[Byte] = cachedHash.getOrElse(Node.hashFn(encode))

def capped: ByteString = {
val encoded = encode
Expand All @@ -25,11 +31,26 @@ object Node {
val hashFn: (Array[Byte]) => Array[Byte] = (input: Array[Byte]) => crypto.kec256(input)
}

case class LeafNode(key: ByteString, value: ByteString) extends MptNode
case class LeafNode(key: ByteString, value: ByteString,
cachedHash: Option[Array[Byte]] = None, cachedRlpEncoded: Option[Array[Byte]] = None) extends MptNode {
def withCachedHash(cachedHash: Array[Byte]): MptNode = copy(cachedHash = Some(cachedHash))

def withCachedRlpEncoded(cachedEncode: Array[Byte]): MptNode = copy(cachedRlpEncoded = Some(cachedEncode))
}

case class ExtensionNode(sharedKey: ByteString, next: Either[ByteString, MptNode],
cachedHash: Option[Array[Byte]] = None, cachedRlpEncoded: Option[Array[Byte]] = None) extends MptNode {
def withCachedHash(cachedHash: Array[Byte]): MptNode = copy(cachedHash = Some(cachedHash))

def withCachedRlpEncoded(cachedEncode: Array[Byte]): MptNode = copy(cachedRlpEncoded = Some(cachedEncode))
}

case class BranchNode(children: Seq[Option[Either[ByteString, MptNode]]], terminator: Option[ByteString],
cachedHash: Option[Array[Byte]] = None, cachedRlpEncoded: Option[Array[Byte]] = None) extends MptNode {
def withCachedHash(cachedHash: Array[Byte]): MptNode = copy(cachedHash = Some(cachedHash))

case class ExtensionNode(sharedKey: ByteString, next: Either[ByteString, MptNode]) extends MptNode
def withCachedRlpEncoded(cachedEncode: Array[Byte]): MptNode = copy(cachedRlpEncoded = Some(cachedEncode))

case class BranchNode(children: Seq[Option[Either[ByteString, MptNode]]], terminator: Option[ByteString]) extends MptNode {
require(children.length == 16, "MptBranch childHashes length have to be 16")

/**
Expand Down

0 comments on commit c1301d9

Please sign in to comment.