Skip to content

Commit

Permalink
Refactoring (#11)
Browse files Browse the repository at this point in the history
* Refactoring

* Refactoring
  • Loading branch information
Roma7-7-7 authored Feb 14, 2024
1 parent 0bce956 commit 044ef2e
Show file tree
Hide file tree
Showing 18 changed files with 570 additions and 658 deletions.
13 changes: 7 additions & 6 deletions cmd/app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (

"github.com/Roma7-7-7/shared-clipboard/internal/app"
"github.com/Roma7-7-7/shared-clipboard/internal/config"
"github.com/Roma7-7-7/shared-clipboard/internal/domain"
"github.com/Roma7-7-7/shared-clipboard/tools/log"
ac "github.com/Roma7-7-7/shared-clipboard/internal/context"
"github.com/Roma7-7-7/shared-clipboard/internal/log"
)

var configPath = flag.String("config", "", "path to config file")
Expand Down Expand Up @@ -42,13 +42,14 @@ func main() {
sLog := l.Sugar()
traced := log.NewZapTracedLogger(sLog)

if a, err = app.NewApp(conf, traced); err != nil {
traced.Errorw(domain.RuntimeTraceID, "Create app", err)
bootstrapCtx := ac.WithTraceID(context.Background(), "bootstrap")
if a, err = app.NewApp(bootstrapCtx, conf, traced); err != nil {
traced.Errorw(bootstrapCtx, "Create app", err)
os.Exit(1)
}

if err = a.Run(context.Background()); err != nil {
traced.Errorw(domain.RuntimeTraceID, "Run", err)
if err = a.Run(ac.WithTraceID(context.Background(), "runtime")); err != nil {
traced.Errorw(bootstrapCtx, "Run", err)
os.Exit(1)
}
}
28 changes: 14 additions & 14 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/Roma7-7-7/shared-clipboard/internal/handle"
"github.com/Roma7-7-7/shared-clipboard/internal/handle/cookie"
"github.com/Roma7-7-7/shared-clipboard/internal/handle/jwt"
"github.com/Roma7-7-7/shared-clipboard/tools/log"
"github.com/Roma7-7-7/shared-clipboard/internal/log"
)

type (
Expand All @@ -29,20 +29,20 @@ type (
}
)

func NewApp(conf config.App, traced log.TracedLogger) (*App, error) {
traced.Infow(domain.RuntimeTraceID, "Initializing SQL DB")
func NewApp(ctx context.Context, conf config.App, traced log.TracedLogger) (*App, error) {
traced.Infow(ctx, "Initializing SQL DB")
sqlDB, err := sql.Open(conf.DB.SQL.Driver, conf.DB.SQL.DataSource)
if err != nil {
return nil, fmt.Errorf("open sql db: %w", err)
}

traced.Infow(domain.RuntimeTraceID, "Initializing Bolt DB")
traced.Infow(ctx, "Initializing Bolt DB")
boltDB, err := bolt.Open(conf.DB.Bolt.Path, 0600, nil)
if err != nil {
return nil, fmt.Errorf("open bolt db: %w", err)
}

traced.Infow(domain.RuntimeTraceID, "Initializing repositories")
traced.Infow(ctx, "Initializing repositories")
userRpo, err := postgre.NewUserRepository(sqlDB)
if err != nil {
return nil, fmt.Errorf("create user repository: %w", err)
Expand All @@ -60,17 +60,17 @@ func NewApp(conf config.App, traced log.TracedLogger) (*App, error) {
return nil, fmt.Errorf("create jwt repository: %w", err)
}

traced.Infow(domain.RuntimeTraceID, "Initializing services")
traced.Infow(ctx, "Initializing services")
userService := domain.NewUserService(userRpo, traced)

traced.Infow(domain.RuntimeTraceID, "Initializing components")
traced.Infow(ctx, "Initializing components")
jwtProcessor := jwt.NewProcessor(conf.JWT)
cookieProcessor := cookie.NewProcessor(jwtProcessor, conf.Cookie)

sessionService := domain.NewSessionService(sessionRepo, traced)

traced.Infow(domain.RuntimeTraceID, "Creating router")
h, err := handle.NewRouter(handle.Dependencies{
traced.Infow(ctx, "Creating router")
h, err := handle.NewRouter(ctx, handle.Dependencies{
Config: conf,
CookieProcessor: cookieProcessor,
UserService: userService,
Expand Down Expand Up @@ -107,20 +107,20 @@ func (a *App) Run(ctx context.Context) error {
case <-ctx.Done():
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
a.log.Infow(domain.RuntimeTraceID, "Shutting down server")
a.log.Infow(ctx, "Shutting down server")
if err := s.Shutdown(ctx); err != nil {
a.log.Errorw(domain.RuntimeTraceID, "Shutdown server", err)
a.log.Errorw(ctx, "Shutdown server", err)
}
a.log.Infow(domain.RuntimeTraceID, "Server stopped")
a.log.Infow(ctx, "Server stopped")
return
}
}()

a.log.Infow(domain.RuntimeTraceID, "Starting server", "address", addr)
a.log.Infow(ctx, "Starting server", "address", addr)
if err := s.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server listen: %w", err)
}
a.log.Infow(domain.RuntimeTraceID, "Server stopped")
a.log.Infow(ctx, "Server stopped")

return nil
}
16 changes: 8 additions & 8 deletions internal/domain/context.go → internal/context/context.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package domain
package context

import (
"context"
)

const RuntimeTraceID = "runtime"
"github.com/Roma7-7-7/shared-clipboard/tools"
)

type (
Authority struct {
Expand All @@ -16,22 +16,22 @@ type (
traceIDCtxKey struct{}
)

func ContextWithTraceID(ctx context.Context, traceID string) context.Context {
func WithTraceID(ctx context.Context, traceID string) context.Context {
return context.WithValue(ctx, &traceIDCtxKey{}, traceID)
}

func TraceIDFromContext(ctx context.Context) string {
func TraceIDFrom(ctx context.Context) string {
if traceID, ok := ctx.Value(&traceIDCtxKey{}).(string); ok {
return traceID
}
return "undefined"
return "undefined#" + tools.RandomAlphanumericKey(8)
}

func AuthorityFromContext(ctx context.Context) (*Authority, bool) {
func AuthorityFrom(ctx context.Context) (*Authority, bool) {
token, ok := ctx.Value(authorityContextKey{}).(*Authority)
return token, ok
}

func ContextWithAuthority(ctx context.Context, authority *Authority) context.Context {
func WithAuthority(ctx context.Context, authority *Authority) context.Context {
return context.WithValue(ctx, authorityContextKey{}, authority)
}
2 changes: 1 addition & 1 deletion internal/dal/postgre/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (r *SessionRepository) GetByID(id uint64) (*dal.Session, error) {
func (r *SessionRepository) GetAllByUserID(userID uint64) ([]*dal.Session, error) {
res := make([]*dal.Session, 0, 10)

rows, err := r.db.Query("SELECT session_id, user_id, name, created_at, updated_at FROM sessions WHERE user_id = $1", userID)
rows, err := r.db.Query("SELECT session_id, user_id, name, created_at, updated_at FROM sessions WHERE user_id = $1 ORDER BY updated_at DESC", userID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
Expand Down
44 changes: 21 additions & 23 deletions internal/domain/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"time"

"github.com/Roma7-7-7/shared-clipboard/internal/dal"
"github.com/Roma7-7-7/shared-clipboard/tools/log"
"github.com/Roma7-7-7/shared-clipboard/internal/log"
)

var (
Expand Down Expand Up @@ -47,44 +47,44 @@ func NewSessionService(sessionRepo SessionRepository, log log.TracedLogger) *Ses
}
}

func (s *SessionService) GetByID(ctx context.Context, id uint64) (*Session, error) {
tid := TraceIDFromContext(ctx)
s.log.Debugw(tid, "get session by id", "sessionID", id)
func (s *SessionService) GetByID(ctx context.Context, userID, id uint64) (*Session, error) {
s.log.Debugw(ctx, "get session by id", "sessionID", id)

session, err := s.sessionRepo.GetByID(id)
if err != nil {
if errors.Is(err, dal.ErrNotFound) {
s.log.Debugw(tid, "session not found")
s.log.Debugw(ctx, "session not found")
return nil, ErrSessionNotFound
}

return nil, fmt.Errorf("get session by id=%d: %w", id, err)
}
if session.UserID != userID {
return nil, ErrSessionPermissionDenied
}

s.log.Debugw(tid, "session found", "session", session)
s.log.Debugw(ctx, "session found", "session", session)
return toSession(session), nil
}

func (s *SessionService) GetByUserID(ctx context.Context, userID uint64) ([]*Session, error) {
tid := TraceIDFromContext(ctx)
s.log.Debugw(tid, "get sessions by userID", "userID", userID)
s.log.Debugw(ctx, "get sessions by userID", "userID", userID)

sessions, err := s.sessionRepo.GetAllByUserID(userID)
if err != nil {
return nil, fmt.Errorf("get sessions by userID=%d: %w", userID, err)
}

s.log.Debugw(tid, "sessions found", "count", len(sessions))
s.log.Debugw(ctx, "sessions found", "count", len(sessions))
res := make([]*Session, 0, len(sessions))
for _, session := range sessions {
res = append(res, toSession(session))
}
return res, nil
}

func (s *SessionService) Create(ctx context.Context, name string, userID uint64) (*Session, error) {
tid := TraceIDFromContext(ctx)
s.log.Debugw(tid, "create session", "name", name, "userID", userID)
func (s *SessionService) Create(ctx context.Context, userID uint64, name string) (*Session, error) {
s.log.Debugw(ctx, "create session", "name", name, "userID", userID)

if strings.TrimSpace(name) == "" {
return nil, fmt.Errorf("name is empty")
Expand All @@ -95,13 +95,12 @@ func (s *SessionService) Create(ctx context.Context, name string, userID uint64)
return nil, fmt.Errorf("create session: %w", err)
}

s.log.Debugw(tid, "session created", "session", session)
s.log.Debugw(ctx, "session created", "session", session)
return toSession(session), nil
}

func (s *SessionService) Update(ctx context.Context, sessionID, userID uint64, name string) (*Session, error) {
tid := TraceIDFromContext(ctx)
s.log.Debugw(tid, "update session", "sessionID", sessionID, "name", name)
func (s *SessionService) Update(ctx context.Context, userID, sessionID uint64, name string) (*Session, error) {
s.log.Debugw(ctx, "update session", "sessionID", sessionID, "name", name)

if strings.TrimSpace(name) == "" {
return nil, fmt.Errorf("name is empty")
Expand All @@ -117,21 +116,20 @@ func (s *SessionService) Update(ctx context.Context, sessionID, userID uint64, n
}

if session.UserID != userID {
return nil, fmt.Errorf("user with ID %q is not allowed to modify session with ID %q: %w", userID, sessionID, ErrSessionPermissionDenied)
return nil, ErrSessionPermissionDenied
}

updated, err := s.sessionRepo.Update(sessionID, name)
if err != nil {
return nil, fmt.Errorf("update session by id=%q: %w", sessionID, err)
}

s.log.Debugw(tid, "session updated", "session", updated)
s.log.Debugw(ctx, "session updated", "session", updated)
return toSession(updated), nil
}

func (s *SessionService) Delete(ctx context.Context, sessionID, userID uint64) error {
tid := TraceIDFromContext(ctx)
s.log.Debugw(tid, "delete session", "sessionID", sessionID)
func (s *SessionService) Delete(ctx context.Context, userID, sessionID uint64) error {
s.log.Debugw(ctx, "delete session", "sessionID", sessionID)

session, err := s.sessionRepo.GetByID(sessionID)
if err != nil {
Expand All @@ -143,14 +141,14 @@ func (s *SessionService) Delete(ctx context.Context, sessionID, userID uint64) e
}

if session.UserID != userID {
return fmt.Errorf("user with ID %d is not allowed to delete session with ID %d: %w", userID, sessionID, ErrSessionPermissionDenied)
return ErrSessionPermissionDenied
}

if err = s.sessionRepo.Delete(sessionID); err != nil {
return fmt.Errorf("delete session by id=%d: %w", sessionID, err)
}

s.log.Debugw(tid, "session deleted", "sessionID", sessionID)
s.log.Debugw(ctx, "session deleted", "sessionID", sessionID)
return nil
}

Expand Down
27 changes: 10 additions & 17 deletions internal/domain/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"golang.org/x/crypto/bcrypt"

"github.com/Roma7-7-7/shared-clipboard/internal/dal"
"github.com/Roma7-7-7/shared-clipboard/internal/log"
"github.com/Roma7-7-7/shared-clipboard/tools"
"github.com/Roma7-7-7/shared-clipboard/tools/log"
)

const (
Expand All @@ -36,10 +36,7 @@ func NewUserService(repo UserRepository, log log.TracedLogger) *UserService {
}

func (s *UserService) Create(ctx context.Context, name, password string) (*dal.User, error) {
var (
tid = TraceIDFromContext(ctx)
)
s.log.Debugw(tid, "creating user", "name", name)
s.log.Debugw(ctx, "creating user", "name", name)

if err := validateSignup(name, password); err != nil {
return nil, fmt.Errorf("validate signup: %w", err)
Expand All @@ -55,7 +52,7 @@ func (s *UserService) Create(ctx context.Context, name, password string) (*dal.U
user, err := s.repo.Create(name, string(hashed), passwordSalt)
if err != nil {
if errors.Is(err, dal.ErrConflictUnique) {
s.log.Debugw(tid, "user with this name already exists")
s.log.Debugw(ctx, "user with this name already exists")
return nil, &RenderableError{
Code: ErrorCodeSignupConflict,
Message: "User with specified name already exists",
Expand All @@ -65,21 +62,17 @@ func (s *UserService) Create(ctx context.Context, name, password string) (*dal.U
return nil, fmt.Errorf("create user: %w", err)
}

s.log.Debugw(tid, "user created", "id", user.ID)
s.log.Debugw(ctx, "user created", "id", user.ID)
return user, nil
}

func (s *UserService) VerifyPassword(ctx context.Context, name, password string) (*dal.User, error) {
var (
tid = TraceIDFromContext(ctx)
user *dal.User
err error
)
s.log.Debugw(tid, "verifying password", "name", name)
s.log.Debugw(ctx, "verifying password", "name", name)

if user, err = s.repo.GetByName(name); err != nil {
user, err := s.repo.GetByName(name)
if err != nil {
if errors.Is(err, dal.ErrNotFound) {
s.log.Debugw(tid, "user not found")
s.log.Debugw(ctx, "user not found")
return nil, &RenderableError{
Code: ErrorCodeUserNotFound,
Message: "User not found",
Expand All @@ -91,14 +84,14 @@ func (s *UserService) VerifyPassword(ctx context.Context, name, password string)

salted := saltedPassword(password, user.PasswordSalt)
if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(salted)); err != nil {
s.log.Debugw(tid, "wrong password")
s.log.Debugw(ctx, "wrong password")
return nil, &RenderableError{
Code: ErrorCodeSiginWrongPassword,
Message: "Wrong password",
}
}

s.log.Debugw(tid, "password verified", "id", user.ID)
s.log.Debugw(ctx, "password verified", "id", user.ID)
return user, nil
}

Expand Down
Loading

0 comments on commit 044ef2e

Please sign in to comment.