From c07626fcc2f5c89479ba68111ea0dd7fd80d4984 Mon Sep 17 00:00:00 2001 From: exqlnet Date: Wed, 13 Mar 2024 14:53:57 +0800 Subject: [PATCH] Support sortedPairs/sortedLeaves/fillDefault options --- merkletree.go | 51 +++++++++++++++++++++++++++++++++++++++------------ parameters.go | 38 ++++++++++++++++++++++++++++++-------- 2 files changed, 69 insertions(+), 20 deletions(-) diff --git a/merkletree.go b/merkletree.go index 48a3c7d..cbdb410 100644 --- a/merkletree.go +++ b/merkletree.go @@ -41,10 +41,9 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "github.com/pkg/errors" "math" "sort" - - "github.com/pkg/errors" ) // MerkleTree is the structure for the Merkle tree. @@ -212,11 +211,14 @@ func NewTree(params ...Parameter) (*MerkleTree, error) { nodes[branchesLen:branchesLen+len(parameters.data)], parameters.hash, parameters.salt, - parameters.sorted, + parameters.sorted || parameters.sortedLeaves, + parameters.fillDefault, ) // Pad the space left after the leaves. - for i := len(parameters.data) + branchesLen; i < len(nodes); i++ { - nodes[i] = make([]byte, parameters.hash.HashLength()) + if parameters.fillDefault { + for i := len(parameters.data) + branchesLen; i < len(nodes); i++ { + nodes[i] = make([]byte, parameters.hash.HashLength()) + } } // Branches. @@ -224,7 +226,7 @@ func NewTree(params ...Parameter) (*MerkleTree, error) { nodes, parameters.hash, branchesLen, - parameters.sorted, + parameters.sorted || parameters.sortedPairs, ) tree := &MerkleTree{ @@ -248,7 +250,7 @@ func New(data [][]byte) (*MerkleTree, error) { // Hashes the data slice, placing the result hashes into dest. // salt adds a salt to the hash using the index. // sorted sorts the leaves and data by the value of the leaf hash. -func createLeaves(data [][]byte, dest [][]byte, hash HashType, salt, sorted bool) { +func createLeaves(data [][]byte, dest [][]byte, hash HashType, salt, sortedLeaves, fillDefault bool) { indexSalt := make([]byte, 4) for i := range data { if salt { @@ -259,7 +261,7 @@ func createLeaves(data [][]byte, dest [][]byte, hash HashType, salt, sorted bool } } - if sorted { + if sortedLeaves { sorter := hashSorter{ data: data, hashes: dest, @@ -268,16 +270,41 @@ func createLeaves(data [][]byte, dest [][]byte, hash HashType, salt, sorted bool } } +func isZeroBytes(s []byte) bool { + for _, v := range s { + if v != 0 { + return false + } + } + return true +} + // Create the branch nodes from the existing leaf data. -func createBranches(nodes [][]byte, hash HashType, leafOffset int, sorted bool) { +func createBranches(nodes [][]byte, hash HashType, leafOffset int, sortedPairs bool) { for leafIndex := leafOffset - 1; leafIndex > 0; leafIndex-- { left := nodes[leafIndex*2] right := nodes[leafIndex*2+1] - if sorted && bytes.Compare(left, right) == 1 { - nodes[leafIndex] = hash.Hash(right, left) + var pairs [][]byte + if len(left) != 0 { + pairs = append(pairs, left) + } + if len(right) != 0 { + pairs = append(pairs, right) + } + + if sortedPairs { + sort.Slice(pairs, func(i, j int) bool { + return bytes.Compare(pairs[i], pairs[j]) < 0 + }) + } + + if len(pairs) == 1 { + nodes[leafIndex] = pairs[0] + } else if len(pairs) > 1 { + nodes[leafIndex] = hash.Hash(pairs...) } else { - nodes[leafIndex] = hash.Hash(left, right) + nodes[leafIndex] = []byte{} } } } diff --git a/parameters.go b/parameters.go index ceb18c9..c731203 100644 --- a/parameters.go +++ b/parameters.go @@ -20,13 +20,16 @@ import ( ) type parameters struct { - data [][]byte - values uint64 - hashes map[uint64][]byte - indices []uint64 - salt bool - sorted bool - hash HashType + data [][]byte + values uint64 + hashes map[uint64][]byte + indices []uint64 + salt bool + sorted bool + sortedPairs bool + sortedLeaves bool + fillDefault bool + hash HashType } // Parameter is the interface for service parameters. @@ -94,10 +97,29 @@ func WithHashType(hash HashType) Parameter { }) } +func WithSortedPairs(sortedPairs bool) Parameter { + return parameterFunc(func(p *parameters) { + p.sortedPairs = sortedPairs + }) +} + +func WithSortedLeaves(sortedLeaves bool) Parameter { + return parameterFunc(func(p *parameters) { + p.sortedLeaves = sortedLeaves + }) +} + +func WithFillDefault(fillDefault bool) Parameter { + return parameterFunc(func(p *parameters) { + p.fillDefault = fillDefault + }) +} + // parseAndCheckTreeParameters parses and checks parameters to ensure that mandatory parameters are present and correct. func parseAndCheckTreeParameters(params ...Parameter) (*parameters, error) { parameters := parameters{ - hash: blake2b.New(), + hash: blake2b.New(), + fillDefault: true, } for _, p := range params { if params != nil {