Skip to content

Commit

Permalink
Adding context based parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshaw committed Nov 6, 2023
1 parent 7431c54 commit 35af001
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/_example/_example
/_example/cache/
/.cache/
*.txt
29 changes: 26 additions & 3 deletions diskcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package diskcache
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"errors"
"fmt"
Expand Down Expand Up @@ -176,7 +177,7 @@ func (c *Cache) EvictKey(key string) error {
// or if the cached response is stale the request will be executed and cached.
func (c *Cache) Fetch(key string, p Policy, req *http.Request, force bool) (bool, time.Time, *http.Response, error) {
// check stale
stale, mod, err := c.Stale(key, p.TTL)
stale, mod, err := c.Stale(req.Context(), key, p.TTL)
if err != nil {
return false, time.Time{}, nil, err
}
Expand Down Expand Up @@ -213,14 +214,17 @@ func (c *Cache) Mod(key string) (time.Time, error) {
}

// Stale returns whether or not the key is stale, based on ttl.
func (c *Cache) Stale(key string, ttl time.Duration) (bool, time.Time, error) {
func (c *Cache) Stale(ctx context.Context, key string, ttl time.Duration) (bool, time.Time, error) {
mod, err := c.Mod(key)
switch {
case err != nil && errors.Is(err, fs.ErrNotExist):
return true, mod, nil
case err != nil:
return false, time.Time{}, err
}
if d, ok := TTL(ctx); ok {
ttl = d
}
return ttl != 0 && time.Now().After(mod.Add(ttl)), mod, nil
}

Expand All @@ -230,7 +234,7 @@ func (c *Cache) Cached(req *http.Request) (bool, error) {
if err != nil {
return false, err
}
stale, _, err := c.Stale(key, p.TTL)
stale, _, err := c.Stale(req.Context(), key, p.TTL)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -347,3 +351,22 @@ func UserCacheDir(paths ...string) (string, error) {
}
return filepath.Join(append([]string{dir}, paths...)...), nil
}

// contextKey is a context key.
type contextKey string

// context keys.
const (
ttlKey contextKey = "ttl"
)

// WithContextTTL adds the ttl to the context.
func WithContextTTL(parent context.Context, ttl time.Duration) context.Context {
return context.WithValue(parent, ttlKey, ttl)
}

// TTL returns the ttl from the context.
func TTL(ctx context.Context) (time.Duration, bool) {
ttl, ok := ctx.Value(ttlKey).(time.Duration)
return ttl, ok
}
97 changes: 97 additions & 0 deletions diskcache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package diskcache

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"sync/atomic"
"testing"
"time"
)

func TestWithContextTTL(t *testing.T) {
// set up simple test server for demonstration
var count uint64
s := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
fmt.Fprintf(res, "%d\n", atomic.AddUint64(&count, 1))
}))
defer s.Close()
baseDir := setupDir(t, "test-with-context-ttl")
// create disk cache
c, err := New(
WithBasePathFs(baseDir),
WithErrorTruncator(),
WithTTL(365*24*time.Hour),
)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
cl := &http.Client{
Transport: c,
}
ctx := context.Background()
for i := 0; i < 3; i++ {
v, err := doReq(ctx, cl, s.URL)
switch {
case err != nil:
t.Fatalf("expected no error, got: %v", err)
case v != 1:
t.Errorf("expected %d, got: %d", 1, v)
}
}
if count != 1 {
t.Fatalf("expected count == %d, got: %d", 1, count)
}
for i := 1; i < 5; i++ {
v, err := doReq(WithContextTTL(ctx, 1*time.Millisecond), cl, s.URL)
switch {
case err != nil:
t.Fatalf("expected no error, got: %v", err)
case v != i+1:
t.Errorf("expected %d, got: %d", i+1, v)
}
<-time.After(2 * time.Millisecond)
}
}

func doReq(ctx context.Context, cl *http.Client, urlstr string) (int, error) {
req, err := http.NewRequestWithContext(ctx, "GET", urlstr, nil)
if err != nil {
return -1, err
}
res, err := cl.Do(req)
if err != nil {
return -1, err
}
defer res.Body.Close()
buf, err := io.ReadAll(res.Body)
if err != nil {
return -1, err
}
return strconv.Atoi(string(bytes.TrimSpace(buf)))
}

func setupDir(t *testing.T, name string) string {
t.Helper()
wd, err := os.Getwd()
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
dir := filepath.Join(wd, ".cache", name)
switch err := os.RemoveAll(dir); {
case errors.Is(err, os.ErrNotExist):
case err != nil:
t.Fatalf("expected no error, got: %v", err)
}
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
return dir
}

0 comments on commit 35af001

Please sign in to comment.