diff --git a/arc.go b/arc.go index bce30ed..24e357b 100644 --- a/arc.go +++ b/arc.go @@ -33,6 +33,9 @@ func (c *ArcCache) Init(clock Clock, capacity int) { } func (c *ArcCache) Set(ctx context.Context, key string, val interface{}, ttl time.Duration) error { + c.Lock() + defer c.Unlock() + value := deref(val) item, ok := c.items[key] if ttl > 0 { @@ -50,7 +53,7 @@ func (c *ArcCache) Set(ctx context.Context, key string, val interface{}, ttl tim } defer func() { - c.Evict(ctx, 1) + c.evict(ctx, 1) if c.t1.Has(key) || c.t2.Has(key) { return } @@ -65,6 +68,9 @@ func (c *ArcCache) Set(ctx context.Context, key string, val interface{}, ttl tim } func (c *ArcCache) Get(ctx context.Context, key string) (interface{}, error) { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if !ok { return nil, KeyNotFoundError @@ -72,7 +78,7 @@ func (c *ArcCache) Get(ctx context.Context, key string) (interface{}, error) { c.update(ctx, key) if item.IsExpired(c.clock) { - c.Remove(ctx, key) + c.remove(ctx, key) return nil, KeyExpiredError } @@ -80,6 +86,9 @@ func (c *ArcCache) Get(ctx context.Context, key string) (interface{}, error) { } func (c *ArcCache) Exists(ctx context.Context, key string) bool { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if !ok { return false @@ -87,13 +96,27 @@ func (c *ArcCache) Exists(ctx context.Context, key string) bool { c.update(ctx, key) if item.IsExpired(c.clock) { - c.Remove(ctx, key) + c.remove(ctx, key) return false } return true } func (c *ArcCache) Remove(ctx context.Context, key string) bool { + c.Lock() + defer c.Unlock() + + return c.remove(ctx, key) +} + +func (c *ArcCache) Evict(ctx context.Context, count int) { + c.Lock() + defer c.Unlock() + + c.evict(ctx, count) +} + +func (c *ArcCache) remove(ctx context.Context, key string) bool { delete(c.items, key) if elt := c.b1.Lookup(key); elt != nil { c.b1.Remove(key, elt) @@ -115,7 +138,7 @@ func (c *ArcCache) Remove(ctx context.Context, key string) bool { return false } -func (c *ArcCache) Evict(ctx context.Context, count int) { +func (c *ArcCache) evict(ctx context.Context, count int) { if !c.isCacheFull() && c.t1.Len()+c.t2.Len() < c.cap { return } diff --git a/cache.go b/cache.go index 5987db3..405afcf 100644 --- a/cache.go +++ b/cache.go @@ -27,9 +27,8 @@ type Cache interface { //only for debug DebugShardIndex(key string) uint64 - debugFromLocal2(ctx context.Context, key string, onLoad bool) (interface{}, error) - debugFromLocal(ctx context.Context, key string, onLoad bool) (interface{}, error) - debugRemove(ctx context.Context, key string) bool + debugLocalGet(ctx context.Context, key string) (interface{}, error) + debugLocalRemove(ctx context.Context, key string) bool serializer } diff --git a/lfu.go b/lfu.go index 0a5c17d..6db34e0 100644 --- a/lfu.go +++ b/lfu.go @@ -32,6 +32,9 @@ func (c *LfuCache) Init(clock Clock, capacity int) { } func (c *LfuCache) Set(ctx context.Context, key string, val interface{}, ttl time.Duration) error { + c.Lock() + defer c.Unlock() + value := deref(val) item, ok := c.items[key] if ttl > 0 { @@ -43,7 +46,7 @@ func (c *LfuCache) Set(ctx context.Context, key string, val interface{}, ttl tim if ok { item.value = value } else { - c.Evict(ctx, 1) + c.evict(ctx, 1) item.key = key item.value = value item.freqElement = nil @@ -60,6 +63,9 @@ func (c *LfuCache) Set(ctx context.Context, key string, val interface{}, ttl tim } func (c *LfuCache) Get(ctx context.Context, key string) (interface{}, error) { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if ok { if !item.IsExpired(c.clock) { @@ -72,29 +78,10 @@ func (c *LfuCache) Get(ctx context.Context, key string) (interface{}, error) { return nil, KeyNotFoundError } -func (c *LfuCache) Evict(ctx context.Context, count int) { - if len(c.items) < c.cap { - return - } - - entry := c.freqList.Front() - for i := 0; i < count; { - if entry == nil { - return - } else { - for _, item := range entry.Value.(*freqEntry).items { - if i >= count { - return - } - c.removeItem(item) - i++ - } - entry = entry.Next() - } - } -} - func (c *LfuCache) Exists(ctx context.Context, key string) bool { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if !ok { return false @@ -109,6 +96,9 @@ func (c *LfuCache) Exists(ctx context.Context, key string) bool { } func (c *LfuCache) Remove(ctx context.Context, key string) bool { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if ok { c.removeItem(&item) @@ -117,6 +107,35 @@ func (c *LfuCache) Remove(ctx context.Context, key string) bool { return false } +func (c *LfuCache) Evict(ctx context.Context, count int) { + c.Lock() + defer c.Unlock() + + c.evict(ctx, count) +} + +func (c *LfuCache) evict(ctx context.Context, count int) { + if len(c.items) < c.cap { + return + } + + entry := c.freqList.Front() + for i := 0; i < count; { + if entry == nil { + return + } else { + for _, item := range entry.Value.(*freqEntry).items { + if i >= count { + return + } + c.removeItem(item) + i++ + } + entry = entry.Next() + } + } +} + func (c *LfuCache) removeItem(item *lfuItem) { entry := item.freqElement.Value.(*freqEntry) delete(c.items, item.key) diff --git a/lru.go b/lru.go index faa0c84..e820d1e 100644 --- a/lru.go +++ b/lru.go @@ -23,6 +23,9 @@ func (c *LruCache) Init(clock Clock, capacity int) { } func (c *LruCache) Set(ctx context.Context, key string, val interface{}, ttl time.Duration) error { + c.Lock() + defer c.Unlock() + value := deref(val) it, ok := c.items[key] if ok { @@ -35,7 +38,7 @@ func (c *LruCache) Set(ctx context.Context, key string, val interface{}, ttl tim } c.evictList.MoveToFront(it) } else { - c.Evict(ctx, 1) + c.evict(ctx, 1) item := lruItem{ key: key, value: value, @@ -52,6 +55,9 @@ func (c *LruCache) Set(ctx context.Context, key string, val interface{}, ttl tim } func (c *LruCache) Get(ctx context.Context, key string) (interface{}, error) { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if ok { it := item.Value.(*lruItem) @@ -64,22 +70,10 @@ func (c *LruCache) Get(ctx context.Context, key string) (interface{}, error) { return nil, KeyNotFoundError } -func (c *LruCache) Evict(ctx context.Context, count int) { - if c.evictList.Len() < c.cap { - return - } - - for i := 0; i < count; i++ { - ent := c.evictList.Back() - if ent == nil { - return - } else { - c.removeElement(ent) - } - } -} - func (c *LruCache) Exists(ctx context.Context, key string) bool { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if !ok { return false @@ -93,6 +87,9 @@ func (c *LruCache) Exists(ctx context.Context, key string) bool { } func (c *LruCache) Remove(ctx context.Context, key string) bool { + c.Lock() + defer c.Unlock() + if ent, ok := c.items[key]; ok { c.removeElement(ent) return true @@ -100,6 +97,28 @@ func (c *LruCache) Remove(ctx context.Context, key string) bool { return false } +func (c *LruCache) Evict(ctx context.Context, count int) { + c.Lock() + defer c.Unlock() + + c.evict(ctx, count) +} + +func (c *LruCache) evict(ctx context.Context, count int) { + if c.evictList.Len() < c.cap { + return + } + + for i := 0; i < count; i++ { + ent := c.evictList.Back() + if ent == nil { + return + } else { + c.removeElement(ent) + } + } +} + func (c *LruCache) removeElement(e *list.Element) { c.evictList.Remove(e) entry := e.Value.(*lruItem) diff --git a/shard.go b/shard.go index 8e4e054..16fb200 100644 --- a/shard.go +++ b/shard.go @@ -50,11 +50,7 @@ func (c cacheHandler[T, P]) Set(ctx context.Context, key string, value interface } } - s := c.getShard(key) - s.Lock() - defer s.Unlock() - - return s.Set(ctx, key, value, o.TTL) + return c.getShard(key).Set(ctx, key, value, o.TTL) } func (c cacheHandler[T, P]) MSet(ctx context.Context, keys []string, values []interface{}, opts ...Option) error { @@ -73,13 +69,9 @@ func (c cacheHandler[T, P]) MSet(ctx context.Context, keys []string, values []in } for i, k := range keys { - s := c.getShard(k) - s.Lock() - if err := s.Set(ctx, k, values[i], o.TTL); err != nil { - s.Unlock() + if err := c.getShard(k).Set(ctx, k, values[i], o.TTL); err != nil { return err } - s.Unlock() } return nil } @@ -89,8 +81,6 @@ func (c cacheHandler[T, P]) Get(ctx context.Context, key string, opts ...Option) defer c.putOpt(o) s := c.getShard(key) - s.Lock() - defer s.Unlock() val, err := s.Get(ctx, key) if err == KeyNotFoundError { @@ -129,14 +119,12 @@ func (c cacheHandler[T, P]) MGet(ctx context.Context, keys []string, opts ...Opt miss := make(map[string]P, len(keys)) for _, key := range keys { s := c.getShard(key) - s.Lock() val, err := s.Get(ctx, key) if err == nil { res[key] = val } else if err == KeyNotFoundError { miss[key] = s } - s.Unlock() } if len(miss) > 0 { @@ -155,9 +143,7 @@ func (c cacheHandler[T, P]) MGet(ctx context.Context, keys []string, opts ...Opt for key, val := range kvs { s := miss[key] - s.Lock() err := s.Set(ctx, key, val, o.TTL) - s.Unlock() if err != nil { goto END } @@ -178,11 +164,7 @@ func (c cacheHandler[T, P]) Remove(ctx context.Context, key string) bool { } } - s := c.getShard(key) - s.Lock() - defer s.Unlock() - - return s.Remove(ctx, key) + return c.getShard(key).Remove(ctx, key) } func (c cacheHandler[T, P]) MRemove(ctx context.Context, keys []string) bool { @@ -194,24 +176,16 @@ func (c cacheHandler[T, P]) MRemove(ctx context.Context, keys []string) bool { } for _, key := range keys { - s := c.getShard(key) - s.Lock() - if !s.Remove(ctx, key) { - s.Unlock() + if !c.getShard(key).Remove(ctx, key) { return false } - s.Unlock() } return true } func (c cacheHandler[T, P]) Exists(ctx context.Context, key string) bool { - s := c.getShard(key) - s.Lock() - defer s.Unlock() - - return s.Exists(ctx, key) + return c.getShard(key).Exists(ctx, key) } func (c cacheHandler[T, P]) getShard(key string) P { @@ -222,39 +196,17 @@ func (c cacheHandler[T, P]) DebugShardIndex(key string) uint64 { return MemHashString(key) & uint64(c.shardCount-1) } -func (c cacheHandler[T, P]) debugFromLocal2(ctx context.Context, key string, onLoad bool) (interface{}, error) { - s := c.getShard(key) - val, err := s.Get(ctx, key) +func (c cacheHandler[T, P]) debugLocalGet(ctx context.Context, key string) (interface{}, error) { + val, err := c.getShard(key).Get(ctx, key) if err != nil { - if !onLoad { - } return nil, err } return val, nil } -func (c cacheHandler[T, P]) debugFromLocal(ctx context.Context, key string, onLoad bool) (interface{}, error) { - s := c.getShard(key) - s.Lock() - defer s.Unlock() - - val, err := s.Get(ctx, key) - if err != nil { - if !onLoad { - } - return nil, err - } - - return val, nil -} - -func (c cacheHandler[T, P]) debugRemove(ctx context.Context, key string) bool { - s := c.getShard(key) - s.Lock() - defer s.Unlock() - - return s.Remove(ctx, key) +func (c cacheHandler[T, P]) debugLocalRemove(ctx context.Context, key string) bool { + return c.getShard(key).Remove(ctx, key) } func (c cacheHandler[T, P]) serialize(ctx context.Context, val interface{}, opts ...Option) ([]byte, error) { diff --git a/simple.go b/simple.go index b5c8c23..024d392 100644 --- a/simple.go +++ b/simple.go @@ -23,6 +23,9 @@ func (c *SimpleCache) Init(clock Clock, capacity int) { } func (c *SimpleCache) Set(ctx context.Context, key string, val interface{}, ttl time.Duration) error { + c.Lock() + defer c.Unlock() + value := deref(val) var entry = simpleEntry{ key: key, @@ -40,7 +43,7 @@ func (c *SimpleCache) Set(ctx context.Context, key string, val interface{}, ttl entry.priority = item.expireAt.UnixNano() c.pq.update(entry.index) } else { - c.Evict(ctx, 1) + c.evict(ctx, 1) item.value = value item.index = c.pq.Len() c.items[key] = item @@ -53,41 +56,16 @@ func (c *SimpleCache) Set(ctx context.Context, key string, val interface{}, ttl return nil } -func (c *SimpleCache) Evict(ctx context.Context, count int) { - if len(c.items) < c.cap { - return - } - - cnt := 0 - now := c.clock.Now() - if n := c.pq.Len(); n > 0 { - entry := c.pq[0] - item := c.items[entry.key] - if now.After(item.expireAt) { - heap.Pop(&c.pq) - delete(c.items, entry.key) - cnt++ - } - } - - for k, v := range c.items { - if cnt >= count { - return - } - - heap.Remove(&c.pq, v.index) - delete(c.items, k) - cnt++ - } -} - func (c *SimpleCache) Get(ctx context.Context, key string) (interface{}, error) { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if ok { if !item.IsExpired(c.clock) { return item.value, nil } - c.Remove(ctx, key) + c.remove(ctx, key) return nil, KeyExpiredError } @@ -95,19 +73,36 @@ func (c *SimpleCache) Get(ctx context.Context, key string) (interface{}, error) } func (c *SimpleCache) Exists(ctx context.Context, key string) bool { + c.Lock() + defer c.Unlock() + item, ok := c.items[key] if !ok { return false } if item.IsExpired(c.clock) { - c.Remove(ctx, key) + c.remove(ctx, key) return false } return true } func (c *SimpleCache) Remove(ctx context.Context, key string) bool { + c.Lock() + defer c.Unlock() + + return c.remove(ctx, key) +} + +func (c *SimpleCache) Evict(ctx context.Context, count int) { + c.Lock() + defer c.Unlock() + + c.evict(ctx, count) +} + +func (c *SimpleCache) remove(ctx context.Context, key string) bool { item, ok := c.items[key] if ok { heap.Remove(&c.pq, item.index) @@ -117,6 +112,34 @@ func (c *SimpleCache) Remove(ctx context.Context, key string) bool { return false } +func (c *SimpleCache) evict(ctx context.Context, count int) { + if len(c.items) < c.cap { + return + } + + cnt := 0 + now := c.clock.Now() + if n := c.pq.Len(); n > 0 { + entry := c.pq[0] + item := c.items[entry.key] + if now.After(item.expireAt) { + heap.Pop(&c.pq) + delete(c.items, entry.key) + cnt++ + } + } + + for k, v := range c.items { + if cnt >= count { + return + } + + heap.Remove(&c.pq, v.index) + delete(c.items, k) + cnt++ + } +} + type simpleItem struct { value interface{} expireAt time.Time