Skip to content

Commit

Permalink
memdb: use atomic u64 addr to reduce allocation (#1453)
Browse files Browse the repository at this point in the history
 

Signed-off-by: you06 <[email protected]>
  • Loading branch information
you06 authored Aug 28, 2024
1 parent 41d133b commit 6b1453c
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 23 deletions.
17 changes: 13 additions & 4 deletions internal/unionstore/arena/arena.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,26 @@ const (
)

var (
Tombstone = []byte{}
NullAddr = MemdbArenaAddr{math.MaxUint32, math.MaxUint32}
BadAddr = MemdbArenaAddr{math.MaxUint32 - 1, math.MaxUint32}
endian = binary.LittleEndian
Tombstone = []byte{}
NullAddr = MemdbArenaAddr{math.MaxUint32, math.MaxUint32}
NullU64Addr uint64 = math.MaxUint64
BadAddr = MemdbArenaAddr{math.MaxUint32 - 1, math.MaxUint32}
endian = binary.LittleEndian
)

type MemdbArenaAddr struct {
idx uint32
off uint32
}

func U64ToAddr(u64 uint64) MemdbArenaAddr {
return MemdbArenaAddr{uint32(u64 >> 32), uint32(u64)}
}

func (addr MemdbArenaAddr) AsU64() uint64 {
return uint64(addr.idx)<<32 | uint64(addr.off)
}

func (addr MemdbArenaAddr) IsNull() bool {
// Combine all checks into a single condition
return addr == NullAddr || addr.idx == math.MaxUint32 || addr.off == math.MaxUint32
Expand Down
2 changes: 1 addition & 1 deletion internal/unionstore/memdb_art.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (db *artDBWithContext) SnapshotIter(lower, upper []byte) Iterator {

// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer.
func (db *artDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator {
return db.ART.SnapshotIter(upper, lower)
return db.ART.SnapshotIterReverse(upper, lower)
}

// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
Expand Down
2 changes: 1 addition & 1 deletion internal/unionstore/memdb_rbt.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (db *rbtDBWithContext) SnapshotIter(lower, upper []byte) Iterator {

// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer.
func (db *rbtDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator {
return db.RBT.SnapshotIter(upper, lower)
return db.RBT.SnapshotIterReverse(upper, lower)
}

// SnapshotGetter returns a Getter for a snapshot of MemBuffer.
Expand Down
30 changes: 23 additions & 7 deletions internal/unionstore/memdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,26 +844,38 @@ func TestUnsetTemporaryFlag(t *testing.T) {

func TestSnapshotGetIter(t *testing.T) {
assert := assert.New(t)
buffer := NewMemDB()
db := NewMemDB()
var getters []Getter
var iters []Iterator
var reverseIters []Iterator
for i := 0; i < 100; i++ {
assert.Nil(buffer.Set([]byte{byte(0)}, []byte{byte(i)}))
assert.Nil(db.Set([]byte{byte(0)}, []byte{byte(i)}))
assert.Nil(db.Set([]byte{byte(1)}, []byte{byte(i)}))

// getter
getter := buffer.SnapshotGetter()
getter := db.SnapshotGetter()
val, err := getter.Get(context.Background(), []byte{byte(0)})
assert.Nil(err)
assert.Equal(val, []byte{byte(min(i, 50))})
getters = append(getters, getter)

// iter
iter := buffer.SnapshotIter(nil, nil)
assert.Nil(err)
iter := db.SnapshotIter(nil, nil)
assert.Equal(iter.Key(), []byte{byte(0)})
assert.Equal(iter.Value(), []byte{byte(min(i, 50))})
iter.Close()
iters = append(iters, buffer.SnapshotIter(nil, nil))
iters = append(iters, db.SnapshotIter(nil, nil))

// reverse iter
reverseIter := db.SnapshotIterReverse(nil, nil)
assert.Equal(reverseIter.Key(), []byte{byte(1)})
assert.Equal(reverseIter.Value(), []byte{byte(min(i, 50))})
reverseIter.Close()
reverseIters = append(reverseIters, db.SnapshotIterReverse(nil, nil))

// writes after staging should be bypassed in snapshot read.
if i == 50 {
_ = buffer.Staging()
_ = db.Staging()
}
}
for _, getter := range getters {
Expand All @@ -875,4 +887,8 @@ func TestSnapshotGetIter(t *testing.T) {
assert.Equal(iter.Key(), []byte{byte(0)})
assert.Equal(iter.Value(), []byte{byte(50)})
}
for _, reverseIter := range reverseIters {
assert.Equal(reverseIter.Key(), []byte{byte(1)})
assert.Equal(reverseIter.Value(), []byte{byte(50)})
}
}
24 changes: 14 additions & 10 deletions internal/unionstore/rbt/rbt.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ type RBT struct {
dirty bool
stages []arena.MemDBCheckpoint

// The lastTraversedNode must exist
lastTraversedNode atomic.Pointer[MemdbNodeAddr]
// The lastTraversedNode stores addr in uint64 of the last traversed node.
// Compare to atomic.Pointer, atomic.Uint64 can avoid allocation so it's more efficient.
lastTraversedNode atomic.Uint64
hitCount atomic.Uint64
missCount atomic.Uint64
}
Expand All @@ -85,24 +86,26 @@ func New() *RBT {
db.stages = make([]arena.MemDBCheckpoint, 0, 2)
db.entrySizeLimit = unlimitedSize
db.bufferSizeLimit = unlimitedSize
db.lastTraversedNode.Store(&nullNodeAddr)
db.lastTraversedNode.Store(arena.NullU64Addr)
return db
}

// updateLastTraversed updates the last traversed node atomically
func (db *RBT) updateLastTraversed(node MemdbNodeAddr) {
db.lastTraversedNode.Store(&node)
db.lastTraversedNode.Store(node.addr.AsU64())
}

// checkKeyInCache retrieves the last traversed node if the key matches
func (db *RBT) checkKeyInCache(key []byte) (MemdbNodeAddr, bool) {
nodePtr := db.lastTraversedNode.Load()
if nodePtr == nil || nodePtr.isNull() {
addrU64 := db.lastTraversedNode.Load()
if addrU64 == arena.NullU64Addr {
return nullNodeAddr, false
}
addr := arena.U64ToAddr(addrU64)
node := db.getNode(addr)

if bytes.Equal(key, nodePtr.memdbNode.getKey()) {
return *nodePtr, true
if bytes.Equal(key, node.memdbNode.getKey()) {
return node, true
}

return nullNodeAddr, false
Expand Down Expand Up @@ -209,6 +212,7 @@ func (db *RBT) Reset() {
db.count = 0
db.vlog.Reset()
db.allocator.reset()
db.lastTraversedNode.Store(arena.NullU64Addr)
}

// DiscardValues releases the memory used by all values.
Expand Down Expand Up @@ -592,8 +596,8 @@ func (db *RBT) rightRotate(y MemdbNodeAddr) {

func (db *RBT) deleteNode(z MemdbNodeAddr) {
var x, y MemdbNodeAddr
if db.lastTraversedNode.Load().addr == z.addr {
db.lastTraversedNode.Store(&nullNodeAddr)
if db.lastTraversedNode.Load() == z.addr.AsU64() {
db.lastTraversedNode.Store(arena.NullU64Addr)
}

db.count--
Expand Down

0 comments on commit 6b1453c

Please sign in to comment.