Skip to content

Commit

Permalink
feat: Server side sessions (teamhanko#1673)
Browse files Browse the repository at this point in the history
* feat: add server side sessions

* feat: add lastUsed & admin endpoint

* feat: add session list to elements

* fix: fix public session endpoint

* chore: only store session info when enabled

* build: update go mod

* feat: add translations

* test: fix tests

* feat: change path

* feat: return userID on session validation endpoint

* feat: move all session endpoints to public router

* fix: add missing translation

* fix: add missing structs

* chore: align session persister with other persisters

* fix: use correct translation label

* chore: add db validator to session model

* feat: create server side session from cmd

* fix: fix review findings
  • Loading branch information
FreddyDevelop authored and adilkadivala committed Oct 27, 2024
1 parent 6c85c11 commit b401e57
Show file tree
Hide file tree
Showing 52 changed files with 946 additions and 39 deletions.
25 changes: 24 additions & 1 deletion backend/cmd/jwt/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/teamhanko/hanko/backend/crypto/jwk"
"github.com/teamhanko/hanko/backend/dto"
"github.com/teamhanko/hanko/backend/persistence"
"github.com/teamhanko/hanko/backend/persistence/models"
"github.com/teamhanko/hanko/backend/session"
"log"
)
Expand Down Expand Up @@ -66,12 +67,34 @@ func NewCreateCommand() *cobra.Command {
emailJwt = dto.JwtFromEmailModel(e)
}

token, err := sessionManager.GenerateJWT(userId, emailJwt)
token, rawToken, err := sessionManager.GenerateJWT(userId, emailJwt)
if err != nil {
fmt.Printf("failed to generate token: %s", err)
return
}

if cfg.Session.ServerSide.Enabled {
sessionID, _ := rawToken.Get("session_id")

expirationTime := rawToken.Expiration()
sessionModel := models.Session{
ID: uuid.FromStringOrNil(sessionID.(string)),
UserID: userId,
UserAgent: "",
IpAddress: "",
CreatedAt: rawToken.IssuedAt(),
UpdatedAt: rawToken.IssuedAt(),
ExpiresAt: &expirationTime,
LastUsed: rawToken.IssuedAt(),
}

err = persister.GetSessionPersister().Create(sessionModel)
if err != nil {
fmt.Printf("failed to store session: %s", err)
return
}
}

fmt.Printf("token: %s", token)
},
}
Expand Down
4 changes: 4 additions & 0 deletions backend/config/config_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ func DefaultConfig() *Config {
SameSite: "strict",
Secure: true,
},
ServerSide: ServerSide{
Enabled: false,
Limit: 100,
},
},
AuditLog: AuditLog{
ConsoleOutput: AuditLogConsole{
Expand Down
14 changes: 13 additions & 1 deletion backend/config/config_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ type Session struct {
// `issuer` is a string that identifies the principal (human user, an organization, or a service)
// that issued the JWT. Its value is set in the `iss` claim of a JWT.
Issuer string `yaml:"issuer" json:"issuer,omitempty" koanf:"issuer"`
// `lifespan` determines how long a session token (JWT) is valid. It must be a (possibly signed) sequence of decimal
// `lifespan` determines the maximum duration for which a session token (JWT) is valid. It must be a (possibly signed) sequence of decimal
// numbers, each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
Lifespan string `yaml:"lifespan" json:"lifespan,omitempty" koanf:"lifespan" jsonschema:"default=12h"`
// `server_side` contains configuration for server-side sessions.
ServerSide ServerSide `yaml:"server_side" json:"server_side" koanf:"server_side"`
}

func (s *Session) Validate() error {
Expand Down Expand Up @@ -61,3 +63,13 @@ func (c *Cookie) GetName() string {

return "hanko"
}

type ServerSide struct {
// `enabled` determines whether server-side sessions are enabled.
//
// NOTE: When enabled the session endpoint must be used in order to check if a session is still valid.
Enabled bool `yaml:"enabled" json:"enabled,omitempty" koanf:"enabled" jsonschema:"default=false"`
// `limit` determines the maximum number of server-side sessions a user can have. When the limit is exceeded,
// older sessions are invalidated.
Limit int `yaml:"limit" json:"limit,omitempty" koanf:"limit" jsonschema:"default=100"`
}
44 changes: 44 additions & 0 deletions backend/dto/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package dto

import (
"fmt"
"github.com/gofrs/uuid"
"github.com/mileusna/useragent"
"github.com/teamhanko/hanko/backend/persistence/models"
"time"
)

type SessionData struct {
ID uuid.UUID `json:"id"`
UserAgentRaw string `json:"user_agent_raw"`
UserAgent string `json:"user_agent"`
IpAddress string `json:"ip_address"`
Current bool `json:"current"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
LastUsed time.Time `json:"last_used"`
}

func FromSessionModel(model models.Session, current bool) SessionData {
ua := useragent.Parse(model.UserAgent)
return SessionData{
ID: model.ID,
UserAgentRaw: model.UserAgent,
UserAgent: fmt.Sprintf("%s (%s)", ua.OS, ua.Name),
IpAddress: model.IpAddress,
Current: current,
CreatedAt: model.CreatedAt,
ExpiresAt: model.ExpiresAt,
LastUsed: model.LastUsed,
}
}

type ValidateSessionResponse struct {
IsValid bool `json:"is_valid"`
ExpirationTime *time.Time `json:"expiration_time,omitempty"`
UserID *uuid.UUID `json:"user_id,omitempty"`
}

type ValidateSessionRequest struct {
SessionToken string `json:"session_token" validate:"required"`
}
3 changes: 2 additions & 1 deletion backend/flow_api/flow/flows.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ func NewProfileFlow(debug bool) flowpilot.Flow {
profile.WebauthnCredentialRename{},
profile.WebauthnCredentialCreate{},
profile.WebauthnCredentialDelete{},
profile.SessionDelete{},
).
State(shared.StateProfileWebauthnCredentialVerification,
profile.WebauthnVerifyAttestationResponse{},
Expand All @@ -155,7 +156,7 @@ func NewProfileFlow(debug bool) flowpilot.Flow {
InitialState(shared.StatePreflight, shared.StateProfileInit).
ErrorState(shared.StateError).
BeforeEachAction(profile.RefreshSessionUser{}).
BeforeState(shared.StateProfileInit, profile.GetProfileData{}).
BeforeState(shared.StateProfileInit, profile.GetProfileData{}, profile.GetSessions{}).
AfterState(shared.StateProfileWebauthnCredentialVerification, shared.WebauthnCredentialSave{}).
AfterState(shared.StatePasscodeConfirmation, shared.EmailPersistVerifiedStatus{}).
SubFlows(
Expand Down
67 changes: 67 additions & 0 deletions backend/flow_api/flow/profile/action_session_delete.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package profile

import (
"fmt"
"github.com/gofrs/uuid"
"github.com/teamhanko/hanko/backend/flow_api/flow/shared"
"github.com/teamhanko/hanko/backend/flowpilot"
"github.com/teamhanko/hanko/backend/persistence/models"
)

type SessionDelete struct {
shared.Action
}

func (a SessionDelete) GetName() flowpilot.ActionName {
return shared.ActionSessionDelete
}

func (a SessionDelete) GetDescription() string {
return "Delete a session."
}

func (a SessionDelete) Initialize(c flowpilot.InitializationContext) {
deps := a.GetDeps(c)
if !deps.Cfg.Session.ServerSide.Enabled {
c.SuspendAction()
}
userModel, ok := c.Get("session_user").(*models.User)
if !ok {
c.SuspendAction()
return
}

input := flowpilot.StringInput("session_id").Required(true).Hidden(true)

currentSessionID := uuid.FromStringOrNil(c.Get("session_id").(string))
sessions, err := deps.Persister.GetSessionPersisterWithConnection(deps.Tx).ListActive(userModel.ID)
if err != nil {
c.SuspendAction()
return
}

for _, session := range sessions {
if session.ID != currentSessionID {
input.AllowedValue(session.ID.String(), session.ID.String())
}
}

c.AddInputs(input)
}

func (a SessionDelete) Execute(c flowpilot.ExecutionContext) error {
deps := a.GetDeps(c)

sessionToBeDeleted := uuid.FromStringOrNil(c.Input().Get("session_id").String())

session, err := deps.Persister.GetSessionPersisterWithConnection(deps.Tx).Get(sessionToBeDeleted)
if err != nil {
return fmt.Errorf("failed to get session from db: %w", err)
}

if session != nil {
err = deps.Persister.GetSessionPersisterWithConnection(deps.Tx).Delete(*session)
}

return c.Continue(shared.StateProfileInit)
}
47 changes: 47 additions & 0 deletions backend/flow_api/flow/profile/hook_get_sessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package profile

import (
"errors"
"fmt"
"github.com/gofrs/uuid"
"github.com/teamhanko/hanko/backend/dto"
"github.com/teamhanko/hanko/backend/flow_api/flow/shared"
"github.com/teamhanko/hanko/backend/flowpilot"
"github.com/teamhanko/hanko/backend/persistence/models"
)

type GetSessions struct {
shared.Action
}

func (h GetSessions) Execute(c flowpilot.HookExecutionContext) error {
deps := h.GetDeps(c)

if !deps.Cfg.Session.ServerSide.Enabled {
return nil
}

userModel, ok := c.Get("session_user").(*models.User)
if !ok {
return errors.New("no valid session")
}

activeSessions, err := deps.Persister.GetSessionPersisterWithConnection(deps.Tx).ListActive(userModel.ID)
if err != nil {
return fmt.Errorf("failed to get sessions from db: %w", err)
}

currentSessionID := uuid.FromStringOrNil(c.Get("session_id").(string))

sessionsDto := make([]dto.SessionData, len(activeSessions))
for i := range activeSessions {
sessionsDto[i] = dto.FromSessionModel(activeSessions[i], activeSessions[i].ID == currentSessionID)
}

err = c.Payload().Set("sessions", sessionsDto)
if err != nil {
return fmt.Errorf("failed to set sessions payload: %w", err)
}

return nil
}
5 changes: 5 additions & 0 deletions backend/flow_api/flow/profile/hook_refresh_session_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,10 @@ func (h RefreshSessionUser) Execute(c flowpilot.HookExecutionContext) error {
c.Set("session_user", userModel)
}

sessionId, found := sessionToken.Get("session_id")
if found {
c.Set("session_id", sessionId)
}

return nil
}
1 change: 1 addition & 0 deletions backend/flow_api/flow/shared/const_action_names.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ const (
ActionWebauthnGenerateRequestOptions flowpilot.ActionName = "webauthn_generate_request_options"
ActionWebauthnVerifyAssertionResponse flowpilot.ActionName = "webauthn_verify_assertion_response"
ActionWebauthnVerifyAttestationResponse flowpilot.ActionName = "webauthn_verify_attestation_response"
ActionSessionDelete flowpilot.ActionName = "session_delete"
)
42 changes: 39 additions & 3 deletions backend/flow_api/flow/shared/hook_issue_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,56 @@ func (h IssueSession) Execute(c flowpilot.HookExecutionContext) error {
emailDTO = dto.JwtFromEmailModel(email)
}

sessionToken, err := deps.SessionManager.GenerateJWT(userId, emailDTO)
signedSessionToken, rawToken, err := deps.SessionManager.GenerateJWT(userId, emailDTO)
if err != nil {
return fmt.Errorf("failed to generate JWT: %w", err)
}

cookie, err := deps.SessionManager.GenerateCookie(sessionToken)
activeSessions, err := deps.Persister.GetSessionPersisterWithConnection(deps.Tx).ListActive(userId)
if err != nil {
return fmt.Errorf("failed to list active sessions: %w", err)
}

if deps.Cfg.Session.ServerSide.Enabled {
// remove all server side sessions that exceed the limit
if len(activeSessions) >= deps.Cfg.Session.ServerSide.Limit {
for i := deps.Cfg.Session.ServerSide.Limit - 1; i < len(activeSessions); i++ {
err = deps.Persister.GetSessionPersisterWithConnection(deps.Tx).Delete(activeSessions[i])
if err != nil {
return fmt.Errorf("failed to remove latest session: %w", err)
}
}
}

sessionID, _ := rawToken.Get("session_id")

expirationTime := rawToken.Expiration()
sessionModel := models.Session{
ID: uuid.FromStringOrNil(sessionID.(string)),
UserID: userId,
UserAgent: deps.HttpContext.Request().UserAgent(),
IpAddress: deps.HttpContext.RealIP(),
CreatedAt: rawToken.IssuedAt(),
UpdatedAt: rawToken.IssuedAt(),
ExpiresAt: &expirationTime,
LastUsed: rawToken.IssuedAt(),
}

err = deps.Persister.GetSessionPersisterWithConnection(deps.Tx).Create(sessionModel)
if err != nil {
return fmt.Errorf("failed to store session: %w", err)
}
}

cookie, err := deps.SessionManager.GenerateCookie(signedSessionToken)
if err != nil {
return fmt.Errorf("failed to generate auth cookie, %w", err)
}

deps.HttpContext.Response().Header().Set("X-Session-Lifetime", fmt.Sprintf("%d", cookie.MaxAge))

if deps.Cfg.Session.EnableAuthTokenHeader {
deps.HttpContext.Response().Header().Set("X-Auth-Token", sessionToken)
deps.HttpContext.Response().Header().Set("X-Auth-Token", signedSessionToken)
} else {
deps.HttpContext.SetCookie(cookie)
}
Expand Down
33 changes: 33 additions & 0 deletions backend/flow_api/handler.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package flow_api

import (
"errors"
"fmt"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
echojwt "github.com/labstack/echo-jwt/v4"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog"
zeroLogger "github.com/rs/zerolog/log"
"github.com/sethvargo/go-limiter"
auditlog "github.com/teamhanko/hanko/backend/audit_log"
"github.com/teamhanko/hanko/backend/config"
"github.com/teamhanko/hanko/backend/dto"
"github.com/teamhanko/hanko/backend/ee/saml"
"github.com/teamhanko/hanko/backend/flow_api/flow"
"github.com/teamhanko/hanko/backend/flow_api/flow/shared"
Expand Down Expand Up @@ -81,6 +84,36 @@ func (h *FlowPilotHandler) validateSession(c echo.Context) error {
continue
}

if h.Cfg.Session.ServerSide.Enabled {
// check that the session id is stored in the database
sessionId, ok := token.Get("session_id")
if !ok {
lastTokenErr = errors.New("no session id found in token")
continue
}
sessionID, err := uuid.FromString(sessionId.(string))
if err != nil {
lastTokenErr = errors.New("session id has wrong format")
continue
}

sessionModel, err := h.Persister.GetSessionPersister().Get(sessionID)
if err != nil {
return fmt.Errorf("failed to get session from database: %w", err)
}
if sessionModel == nil {
lastTokenErr = fmt.Errorf("session id not found in database")
continue
}

// Update lastUsed field
sessionModel.LastUsed = time.Now().UTC()
err = h.Persister.GetSessionPersister().Update(*sessionModel)
if err != nil {
return dto.ToHttpError(err)
}
}

c.Set("session", token)

return nil
Expand Down
Loading

0 comments on commit b401e57

Please sign in to comment.