Skip to content

Commit

Permalink
Add Codec interface
Browse files Browse the repository at this point in the history
  • Loading branch information
alexedwards committed Sep 3, 2019
1 parent ddba8b6 commit 4cd374b
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 75 deletions.
49 changes: 49 additions & 0 deletions codec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package scs

import (
"bytes"
"encoding/gob"
"time"
)

// Codec is the interface for encoding/decoding session data to and from a byte
// slice for use by the session store.
type Codec interface {
Encode(deadline time.Time, values map[string]interface{}) ([]byte, error)
Decode([]byte) (deadline time.Time, values map[string]interface{}, err error)
}

type gobCodec struct{}

func (gobCodec) Encode(deadline time.Time, values map[string]interface{}) ([]byte, error) {
aux := &struct {
Deadline time.Time
Values map[string]interface{}
}{
Deadline: deadline,
Values: values,
}

var b bytes.Buffer
err := gob.NewEncoder(&b).Encode(&aux)
if err != nil {
return nil, err
}

return b.Bytes(), nil
}

func (gobCodec) Decode(b []byte) (time.Time, map[string]interface{}, error) {
aux := &struct {
Deadline time.Time
Values map[string]interface{}
}{}

r := bytes.NewReader(b)
err := gob.NewDecoder(r).Decode(&aux)
if err != nil {
return time.Time{}, nil, err
}

return aux.Deadline, aux.Values, nil
}
63 changes: 23 additions & 40 deletions data.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package scs

import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/gob"
"fmt"
"sort"
"sync"
Expand All @@ -30,18 +28,18 @@ const (
)

type sessionData struct {
Deadline time.Time // Exported for gob encoding.
deadline time.Time
status Status
token string
Values map[string]interface{} // Exported for gob encoding.
values map[string]interface{}
mu sync.Mutex
}

func newSessionData(lifetime time.Duration) *sessionData {
return &sessionData{
Deadline: time.Now().Add(lifetime).UTC(),
deadline: time.Now().Add(lifetime).UTC(),
status: Unmodified,
Values: make(map[string]interface{}),
values: make(map[string]interface{}),
}
}

Expand Down Expand Up @@ -71,7 +69,7 @@ func (s *SessionManager) Load(ctx context.Context, token string) (context.Contex
status: Unmodified,
token: token,
}
err = sd.decode(b)
sd.deadline, sd.values, err = s.Codec.Decode(b)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -104,12 +102,12 @@ func (s *SessionManager) Commit(ctx context.Context) (string, time.Time, error)
}
}

b, err := sd.encode()
b, err := s.Codec.Encode(sd.deadline, sd.values)
if err != nil {
return "", time.Time{}, err
}

expiry := sd.Deadline
expiry := sd.deadline
if s.IdleTimeout > 0 {
ie := time.Now().Add(s.IdleTimeout)
if ie.Before(expiry) {
Expand Down Expand Up @@ -143,9 +141,9 @@ func (s *SessionManager) Destroy(ctx context.Context) error {

// Reset everything else to defaults.
sd.token = ""
sd.Deadline = time.Now().Add(s.Lifetime).UTC()
for key := range sd.Values {
delete(sd.Values, key)
sd.deadline = time.Now().Add(s.Lifetime).UTC()
for key := range sd.values {
delete(sd.values, key)
}

return nil
Expand All @@ -158,7 +156,7 @@ func (s *SessionManager) Put(ctx context.Context, key string, val interface{}) {
sd := s.getSessionDataFromContext(ctx)

sd.mu.Lock()
sd.Values[key] = val
sd.values[key] = val
sd.status = Modified
sd.mu.Unlock()
}
Expand All @@ -180,7 +178,7 @@ func (s *SessionManager) Get(ctx context.Context, key string) interface{} {
sd.mu.Lock()
defer sd.mu.Unlock()

return sd.Values[key]
return sd.values[key]
}

// Pop acts like a one-time Get. It returns the value for a given key from the
Expand All @@ -193,11 +191,11 @@ func (s *SessionManager) Pop(ctx context.Context, key string) interface{} {
sd.mu.Lock()
defer sd.mu.Unlock()

val, exists := sd.Values[key]
val, exists := sd.values[key]
if !exists {
return nil
}
delete(sd.Values, key)
delete(sd.values, key)
sd.status = Modified

return val
Expand All @@ -212,12 +210,12 @@ func (s *SessionManager) Remove(ctx context.Context, key string) {
sd.mu.Lock()
defer sd.mu.Unlock()

_, exists := sd.Values[key]
_, exists := sd.values[key]
if !exists {
return
}

delete(sd.Values, key)
delete(sd.values, key)
sd.status = Modified
}

Expand All @@ -230,12 +228,12 @@ func (s *SessionManager) Clear(ctx context.Context) error {
sd.mu.Lock()
defer sd.mu.Unlock()

if len(sd.Values) == 0 {
if len(sd.values) == 0 {
return nil
}

for key := range sd.Values {
delete(sd.Values, key)
for key := range sd.values {
delete(sd.values, key)
}
sd.status = Modified
return nil
Expand All @@ -246,7 +244,7 @@ func (s *SessionManager) Exists(ctx context.Context, key string) bool {
sd := s.getSessionDataFromContext(ctx)

sd.mu.Lock()
_, exists := sd.Values[key]
_, exists := sd.values[key]
sd.mu.Unlock()

return exists
Expand All @@ -259,9 +257,9 @@ func (s *SessionManager) Keys(ctx context.Context) []string {
sd := s.getSessionDataFromContext(ctx)

sd.mu.Lock()
keys := make([]string, len(sd.Values))
keys := make([]string, len(sd.values))
i := 0
for key := range sd.Values {
for key := range sd.values {
keys[i] = key
i++
}
Expand Down Expand Up @@ -298,7 +296,7 @@ func (s *SessionManager) RenewToken(ctx context.Context) error {
}

sd.token = newToken
sd.Deadline = time.Now().Add(s.Lifetime).UTC()
sd.deadline = time.Now().Add(s.Lifetime).UTC()
sd.status = Modified

return nil
Expand Down Expand Up @@ -477,21 +475,6 @@ func (s *SessionManager) getSessionDataFromContext(ctx context.Context) *session
return c
}

func (sd *sessionData) encode() ([]byte, error) {
var b bytes.Buffer
err := gob.NewEncoder(&b).Encode(sd)
if err != nil {
return nil, err
}

return b.Bytes(), nil
}

func (sd *sessionData) decode(b []byte) error {
r := bytes.NewReader(b)
return gob.NewDecoder(r).Decode(sd)
}

func generateToken() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
Expand Down
Loading

0 comments on commit 4cd374b

Please sign in to comment.