diff --git a/cmd/app/main.go b/cmd/app/main.go index 3041b3f..e2afca9 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -10,8 +10,8 @@ import ( "github.com/Roma7-7-7/shared-clipboard/internal/app" "github.com/Roma7-7-7/shared-clipboard/internal/config" - "github.com/Roma7-7-7/shared-clipboard/internal/domain" - "github.com/Roma7-7-7/shared-clipboard/tools/log" + ac "github.com/Roma7-7-7/shared-clipboard/internal/context" + "github.com/Roma7-7-7/shared-clipboard/internal/log" ) var configPath = flag.String("config", "", "path to config file") @@ -42,13 +42,14 @@ func main() { sLog := l.Sugar() traced := log.NewZapTracedLogger(sLog) - if a, err = app.NewApp(conf, traced); err != nil { - traced.Errorw(domain.RuntimeTraceID, "Create app", err) + bootstrapCtx := ac.WithTraceID(context.Background(), "bootstrap") + if a, err = app.NewApp(bootstrapCtx, conf, traced); err != nil { + traced.Errorw(bootstrapCtx, "Create app", err) os.Exit(1) } - if err = a.Run(context.Background()); err != nil { - traced.Errorw(domain.RuntimeTraceID, "Run", err) + if err = a.Run(ac.WithTraceID(context.Background(), "runtime")); err != nil { + traced.Errorw(bootstrapCtx, "Run", err) os.Exit(1) } } diff --git a/internal/app/app.go b/internal/app/app.go index aad7d11..eb2a7ac 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -18,7 +18,7 @@ import ( "github.com/Roma7-7-7/shared-clipboard/internal/handle" "github.com/Roma7-7-7/shared-clipboard/internal/handle/cookie" "github.com/Roma7-7-7/shared-clipboard/internal/handle/jwt" - "github.com/Roma7-7-7/shared-clipboard/tools/log" + "github.com/Roma7-7-7/shared-clipboard/internal/log" ) type ( @@ -29,20 +29,20 @@ type ( } ) -func NewApp(conf config.App, traced log.TracedLogger) (*App, error) { - traced.Infow(domain.RuntimeTraceID, "Initializing SQL DB") +func NewApp(ctx context.Context, conf config.App, traced log.TracedLogger) (*App, error) { + traced.Infow(ctx, "Initializing SQL DB") sqlDB, err := sql.Open(conf.DB.SQL.Driver, conf.DB.SQL.DataSource) if err != nil { return nil, fmt.Errorf("open sql db: %w", err) } - traced.Infow(domain.RuntimeTraceID, "Initializing Bolt DB") + traced.Infow(ctx, "Initializing Bolt DB") boltDB, err := bolt.Open(conf.DB.Bolt.Path, 0600, nil) if err != nil { return nil, fmt.Errorf("open bolt db: %w", err) } - traced.Infow(domain.RuntimeTraceID, "Initializing repositories") + traced.Infow(ctx, "Initializing repositories") userRpo, err := postgre.NewUserRepository(sqlDB) if err != nil { return nil, fmt.Errorf("create user repository: %w", err) @@ -60,17 +60,17 @@ func NewApp(conf config.App, traced log.TracedLogger) (*App, error) { return nil, fmt.Errorf("create jwt repository: %w", err) } - traced.Infow(domain.RuntimeTraceID, "Initializing services") + traced.Infow(ctx, "Initializing services") userService := domain.NewUserService(userRpo, traced) - traced.Infow(domain.RuntimeTraceID, "Initializing components") + traced.Infow(ctx, "Initializing components") jwtProcessor := jwt.NewProcessor(conf.JWT) cookieProcessor := cookie.NewProcessor(jwtProcessor, conf.Cookie) sessionService := domain.NewSessionService(sessionRepo, traced) - traced.Infow(domain.RuntimeTraceID, "Creating router") - h, err := handle.NewRouter(handle.Dependencies{ + traced.Infow(ctx, "Creating router") + h, err := handle.NewRouter(ctx, handle.Dependencies{ Config: conf, CookieProcessor: cookieProcessor, UserService: userService, @@ -107,20 +107,20 @@ func (a *App) Run(ctx context.Context) error { case <-ctx.Done(): ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - a.log.Infow(domain.RuntimeTraceID, "Shutting down server") + a.log.Infow(ctx, "Shutting down server") if err := s.Shutdown(ctx); err != nil { - a.log.Errorw(domain.RuntimeTraceID, "Shutdown server", err) + a.log.Errorw(ctx, "Shutdown server", err) } - a.log.Infow(domain.RuntimeTraceID, "Server stopped") + a.log.Infow(ctx, "Server stopped") return } }() - a.log.Infow(domain.RuntimeTraceID, "Starting server", "address", addr) + a.log.Infow(ctx, "Starting server", "address", addr) if err := s.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("server listen: %w", err) } - a.log.Infow(domain.RuntimeTraceID, "Server stopped") + a.log.Infow(ctx, "Server stopped") return nil } diff --git a/internal/domain/context.go b/internal/context/context.go similarity index 54% rename from internal/domain/context.go rename to internal/context/context.go index e0c201a..fede623 100644 --- a/internal/domain/context.go +++ b/internal/context/context.go @@ -1,10 +1,10 @@ -package domain +package context import ( "context" -) -const RuntimeTraceID = "runtime" + "github.com/Roma7-7-7/shared-clipboard/tools" +) type ( Authority struct { @@ -16,22 +16,22 @@ type ( traceIDCtxKey struct{} ) -func ContextWithTraceID(ctx context.Context, traceID string) context.Context { +func WithTraceID(ctx context.Context, traceID string) context.Context { return context.WithValue(ctx, &traceIDCtxKey{}, traceID) } -func TraceIDFromContext(ctx context.Context) string { +func TraceIDFrom(ctx context.Context) string { if traceID, ok := ctx.Value(&traceIDCtxKey{}).(string); ok { return traceID } - return "undefined" + return "undefined#" + tools.RandomAlphanumericKey(8) } -func AuthorityFromContext(ctx context.Context) (*Authority, bool) { +func AuthorityFrom(ctx context.Context) (*Authority, bool) { token, ok := ctx.Value(authorityContextKey{}).(*Authority) return token, ok } -func ContextWithAuthority(ctx context.Context, authority *Authority) context.Context { +func WithAuthority(ctx context.Context, authority *Authority) context.Context { return context.WithValue(ctx, authorityContextKey{}, authority) } diff --git a/internal/dal/postgre/session.go b/internal/dal/postgre/session.go index c5a440a..a6c6f1d 100644 --- a/internal/dal/postgre/session.go +++ b/internal/dal/postgre/session.go @@ -42,7 +42,7 @@ func (r *SessionRepository) GetByID(id uint64) (*dal.Session, error) { func (r *SessionRepository) GetAllByUserID(userID uint64) ([]*dal.Session, error) { res := make([]*dal.Session, 0, 10) - rows, err := r.db.Query("SELECT session_id, user_id, name, created_at, updated_at FROM sessions WHERE user_id = $1", userID) + rows, err := r.db.Query("SELECT session_id, user_id, name, created_at, updated_at FROM sessions WHERE user_id = $1 ORDER BY updated_at DESC", userID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil diff --git a/internal/domain/session.go b/internal/domain/session.go index 9e9cb82..3b26994 100644 --- a/internal/domain/session.go +++ b/internal/domain/session.go @@ -8,7 +8,7 @@ import ( "time" "github.com/Roma7-7-7/shared-clipboard/internal/dal" - "github.com/Roma7-7-7/shared-clipboard/tools/log" + "github.com/Roma7-7-7/shared-clipboard/internal/log" ) var ( @@ -47,34 +47,35 @@ func NewSessionService(sessionRepo SessionRepository, log log.TracedLogger) *Ses } } -func (s *SessionService) GetByID(ctx context.Context, id uint64) (*Session, error) { - tid := TraceIDFromContext(ctx) - s.log.Debugw(tid, "get session by id", "sessionID", id) +func (s *SessionService) GetByID(ctx context.Context, userID, id uint64) (*Session, error) { + s.log.Debugw(ctx, "get session by id", "sessionID", id) session, err := s.sessionRepo.GetByID(id) if err != nil { if errors.Is(err, dal.ErrNotFound) { - s.log.Debugw(tid, "session not found") + s.log.Debugw(ctx, "session not found") return nil, ErrSessionNotFound } return nil, fmt.Errorf("get session by id=%d: %w", id, err) } + if session.UserID != userID { + return nil, ErrSessionPermissionDenied + } - s.log.Debugw(tid, "session found", "session", session) + s.log.Debugw(ctx, "session found", "session", session) return toSession(session), nil } func (s *SessionService) GetByUserID(ctx context.Context, userID uint64) ([]*Session, error) { - tid := TraceIDFromContext(ctx) - s.log.Debugw(tid, "get sessions by userID", "userID", userID) + s.log.Debugw(ctx, "get sessions by userID", "userID", userID) sessions, err := s.sessionRepo.GetAllByUserID(userID) if err != nil { return nil, fmt.Errorf("get sessions by userID=%d: %w", userID, err) } - s.log.Debugw(tid, "sessions found", "count", len(sessions)) + s.log.Debugw(ctx, "sessions found", "count", len(sessions)) res := make([]*Session, 0, len(sessions)) for _, session := range sessions { res = append(res, toSession(session)) @@ -82,9 +83,8 @@ func (s *SessionService) GetByUserID(ctx context.Context, userID uint64) ([]*Ses return res, nil } -func (s *SessionService) Create(ctx context.Context, name string, userID uint64) (*Session, error) { - tid := TraceIDFromContext(ctx) - s.log.Debugw(tid, "create session", "name", name, "userID", userID) +func (s *SessionService) Create(ctx context.Context, userID uint64, name string) (*Session, error) { + s.log.Debugw(ctx, "create session", "name", name, "userID", userID) if strings.TrimSpace(name) == "" { return nil, fmt.Errorf("name is empty") @@ -95,13 +95,12 @@ func (s *SessionService) Create(ctx context.Context, name string, userID uint64) return nil, fmt.Errorf("create session: %w", err) } - s.log.Debugw(tid, "session created", "session", session) + s.log.Debugw(ctx, "session created", "session", session) return toSession(session), nil } -func (s *SessionService) Update(ctx context.Context, sessionID, userID uint64, name string) (*Session, error) { - tid := TraceIDFromContext(ctx) - s.log.Debugw(tid, "update session", "sessionID", sessionID, "name", name) +func (s *SessionService) Update(ctx context.Context, userID, sessionID uint64, name string) (*Session, error) { + s.log.Debugw(ctx, "update session", "sessionID", sessionID, "name", name) if strings.TrimSpace(name) == "" { return nil, fmt.Errorf("name is empty") @@ -117,7 +116,7 @@ func (s *SessionService) Update(ctx context.Context, sessionID, userID uint64, n } if session.UserID != userID { - return nil, fmt.Errorf("user with ID %q is not allowed to modify session with ID %q: %w", userID, sessionID, ErrSessionPermissionDenied) + return nil, ErrSessionPermissionDenied } updated, err := s.sessionRepo.Update(sessionID, name) @@ -125,13 +124,12 @@ func (s *SessionService) Update(ctx context.Context, sessionID, userID uint64, n return nil, fmt.Errorf("update session by id=%q: %w", sessionID, err) } - s.log.Debugw(tid, "session updated", "session", updated) + s.log.Debugw(ctx, "session updated", "session", updated) return toSession(updated), nil } -func (s *SessionService) Delete(ctx context.Context, sessionID, userID uint64) error { - tid := TraceIDFromContext(ctx) - s.log.Debugw(tid, "delete session", "sessionID", sessionID) +func (s *SessionService) Delete(ctx context.Context, userID, sessionID uint64) error { + s.log.Debugw(ctx, "delete session", "sessionID", sessionID) session, err := s.sessionRepo.GetByID(sessionID) if err != nil { @@ -143,14 +141,14 @@ func (s *SessionService) Delete(ctx context.Context, sessionID, userID uint64) e } if session.UserID != userID { - return fmt.Errorf("user with ID %d is not allowed to delete session with ID %d: %w", userID, sessionID, ErrSessionPermissionDenied) + return ErrSessionPermissionDenied } if err = s.sessionRepo.Delete(sessionID); err != nil { return fmt.Errorf("delete session by id=%d: %w", sessionID, err) } - s.log.Debugw(tid, "session deleted", "sessionID", sessionID) + s.log.Debugw(ctx, "session deleted", "sessionID", sessionID) return nil } diff --git a/internal/domain/user.go b/internal/domain/user.go index bb0fb97..ed33962 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -8,8 +8,8 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/Roma7-7-7/shared-clipboard/internal/dal" + "github.com/Roma7-7-7/shared-clipboard/internal/log" "github.com/Roma7-7-7/shared-clipboard/tools" - "github.com/Roma7-7-7/shared-clipboard/tools/log" ) const ( @@ -36,10 +36,7 @@ func NewUserService(repo UserRepository, log log.TracedLogger) *UserService { } func (s *UserService) Create(ctx context.Context, name, password string) (*dal.User, error) { - var ( - tid = TraceIDFromContext(ctx) - ) - s.log.Debugw(tid, "creating user", "name", name) + s.log.Debugw(ctx, "creating user", "name", name) if err := validateSignup(name, password); err != nil { return nil, fmt.Errorf("validate signup: %w", err) @@ -55,7 +52,7 @@ func (s *UserService) Create(ctx context.Context, name, password string) (*dal.U user, err := s.repo.Create(name, string(hashed), passwordSalt) if err != nil { if errors.Is(err, dal.ErrConflictUnique) { - s.log.Debugw(tid, "user with this name already exists") + s.log.Debugw(ctx, "user with this name already exists") return nil, &RenderableError{ Code: ErrorCodeSignupConflict, Message: "User with specified name already exists", @@ -65,21 +62,17 @@ func (s *UserService) Create(ctx context.Context, name, password string) (*dal.U return nil, fmt.Errorf("create user: %w", err) } - s.log.Debugw(tid, "user created", "id", user.ID) + s.log.Debugw(ctx, "user created", "id", user.ID) return user, nil } func (s *UserService) VerifyPassword(ctx context.Context, name, password string) (*dal.User, error) { - var ( - tid = TraceIDFromContext(ctx) - user *dal.User - err error - ) - s.log.Debugw(tid, "verifying password", "name", name) + s.log.Debugw(ctx, "verifying password", "name", name) - if user, err = s.repo.GetByName(name); err != nil { + user, err := s.repo.GetByName(name) + if err != nil { if errors.Is(err, dal.ErrNotFound) { - s.log.Debugw(tid, "user not found") + s.log.Debugw(ctx, "user not found") return nil, &RenderableError{ Code: ErrorCodeUserNotFound, Message: "User not found", @@ -91,14 +84,14 @@ func (s *UserService) VerifyPassword(ctx context.Context, name, password string) salted := saltedPassword(password, user.PasswordSalt) if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(salted)); err != nil { - s.log.Debugw(tid, "wrong password") + s.log.Debugw(ctx, "wrong password") return nil, &RenderableError{ Code: ErrorCodeSiginWrongPassword, Message: "Wrong password", } } - s.log.Debugw(tid, "password verified", "id", user.ID) + s.log.Debugw(ctx, "password verified", "id", user.ID) return user, nil } diff --git a/internal/handle/auth.go b/internal/handle/auth.go index 26937a7..1776ae3 100644 --- a/internal/handle/auth.go +++ b/internal/handle/auth.go @@ -7,14 +7,12 @@ import ( "net/http" "time" - "github.com/go-chi/chi/v5" "github.com/golang-jwt/jwt/v5" "github.com/Roma7-7-7/shared-clipboard/internal/dal" "github.com/Roma7-7-7/shared-clipboard/internal/domain" "github.com/Roma7-7-7/shared-clipboard/internal/handle/cookie" - "github.com/Roma7-7-7/shared-clipboard/tools/log" - "github.com/Roma7-7-7/shared-clipboard/tools/rest" + "github.com/Roma7-7-7/shared-clipboard/internal/log" ) type ( @@ -42,6 +40,8 @@ type ( } AuthHandler struct { + resp *responder + userService UserService cookieProcessor CookieProcessor jwtRepository JWTRepository @@ -56,9 +56,11 @@ type ( ) func NewAuthHandler( - userService UserService, cookieProcessor CookieProcessor, jwtRepository JWTRepository, log log.TracedLogger, + userService UserService, cookieProcessor CookieProcessor, jwtRepository JWTRepository, resp *responder, log log.TracedLogger, ) *AuthHandler { return &AuthHandler{ + resp: resp, + userService: userService, cookieProcessor: cookieProcessor, jwtRepository: jwtRepository, @@ -67,137 +69,112 @@ func NewAuthHandler( } } -func (h *AuthHandler) RegisterRoutes(r chi.Router) { - r.Post("/signup", h.SignUp) - r.Post("/signin", h.SignIn) - r.Post("/signout", h.SignOut) -} - func (h *AuthHandler) SignUp(rw http.ResponseWriter, r *http.Request) { var ( - ctx = r.Context() - tid = domain.TraceIDFromContext(ctx) - req namePasswordRequest - user *dal.User - userCookie *http.Cookie - marshaled []byte - err error + ctx = r.Context() + req namePasswordRequest + err error ) if err = json.NewDecoder(r.Body).Decode(&req); err != nil { - h.log.Debugw(tid, "failed to decode request", err) - sendBadRequest(ctx, rw, "failed to parse request", h.log) + h.log.Debugw(ctx, "failed to decode request", err) + h.resp.SendBadRequest(ctx, rw, "failed to parse request") return } - if user, err = h.userService.Create(ctx, req.Name, req.Password); err != nil { + user, err := h.userService.Create(ctx, req.Name, req.Password) + if err != nil { var re *domain.RenderableError if errors.As(err, &re) { - h.log.Debugw(tid, "failed to create user", err) - sendRenderableError(ctx, re, rw, h.log) + h.log.Infow(ctx, "failed to create user", err) + h.resp.SendError(ctx, rw, http.StatusConflict, re.Code.Value, re.Message, re.Details) return } - h.log.Errorw(tid, "failed to create user", err) - sendInternalServerError(ctx, rw, h.log) + h.log.Errorw(ctx, "failed to create user", err) + h.resp.SendInternalServerError(ctx, rw) return } - if userCookie, err = h.cookieProcessor.ToAccessToken(user.ID, user.Name); err != nil { - h.log.Errorw(tid, "failed to create cookie", err) - sendInternalServerError(ctx, rw, h.log) + userCookie, err := h.cookieProcessor.ToAccessToken(user.ID, user.Name) + if err != nil { + h.log.Errorw(ctx, "failed to create cookie", err) + h.resp.SendInternalServerError(ctx, rw) return } http.SetCookie(rw, userCookie) - if marshaled, err = json.Marshal(userToDTO(user)); err != nil { - h.log.Errorw(tid, "failed to marshal response", err) - sendErrorMarshalBody(ctx, rw, h.log) - return - } - - rest.Send(ctx, rw, http.StatusCreated, rest.ContentTypeJSON, marshaled, h.log) + h.resp.Send(ctx, rw, http.StatusCreated, nil, userToDTO(user)) } func (h *AuthHandler) SignIn(rw http.ResponseWriter, r *http.Request) { var ( - ctx = r.Context() - tid = domain.TraceIDFromContext(ctx) - req namePasswordRequest - user *dal.User - userCookie *http.Cookie - marshaled []byte - err error + ctx = r.Context() + req namePasswordRequest + err error ) if err = json.NewDecoder(r.Body).Decode(&req); err != nil { - h.log.Debugw(tid, "failed to decode request", err) - sendBadRequest(ctx, rw, "failed to parse request", h.log) + h.log.Debugw(ctx, "failed to decode request", err) + h.resp.SendBadRequest(ctx, rw, "failed to parse request") return } - if user, err = h.userService.VerifyPassword(ctx, req.Name, req.Password); err != nil { + user, err := h.userService.VerifyPassword(ctx, req.Name, req.Password) + if err != nil { var re *domain.RenderableError if errors.As(err, &re) { - h.log.Debugw(tid, "failed to verify password", err) - sendRenderableError(ctx, re, rw, h.log) + h.log.Debugw(ctx, "failed to verify password", err) + h.resp.SendError(ctx, rw, http.StatusUnauthorized, re.Code.Value, re.Message, re.Details) return } - h.log.Errorw(tid, "failed to verify password", err) - sendInternalServerError(ctx, rw, h.log) + h.log.Errorw(ctx, "failed to verify password", err) + h.resp.SendInternalServerError(ctx, rw) return } - if userCookie, err = h.cookieProcessor.ToAccessToken(user.ID, user.Name); err != nil { - h.log.Errorw(tid, "failed to create cookie", err) - sendInternalServerError(ctx, rw, h.log) + userCookie, err := h.cookieProcessor.ToAccessToken(user.ID, user.Name) + if err != nil { + h.log.Errorw(ctx, "failed to create cookie", err) + h.resp.SendInternalServerError(ctx, rw) return } http.SetCookie(rw, userCookie) - if marshaled, err = json.Marshal(userToDTO(user)); err != nil { - h.log.Errorw(tid, "failed to marshal response", err) - sendErrorMarshalBody(ctx, rw, h.log) - return - } - - rest.Send(ctx, rw, http.StatusOK, rest.ContentTypeJSON, marshaled, h.log) + h.resp.Send(ctx, rw, http.StatusOK, nil, userToDTO(user)) } func (h *AuthHandler) SignOut(rw http.ResponseWriter, r *http.Request) { var ( - ctx = r.Context() - tid = domain.TraceIDFromContext(ctx) - token *jwt.Token - claims jwt.MapClaims - ok bool - err error + ctx = r.Context() ) - h.log.Debugw(tid, "signing out") + h.log.Debugw(ctx, "signing out") - if token, err = h.cookieProcessor.AccessTokenFromRequest(r); err != nil { + token, err := h.cookieProcessor.AccessTokenFromRequest(r) + if err != nil { if errors.Is(err, cookie.ErrAccessTokenNotFound) { - h.log.Debugw(tid, "access token cookie not found") - rest.SendNoContent(ctx, rw, h.log) + h.log.Debugw(ctx, "access token cookie not found") + rw.WriteHeader(http.StatusNoContent) return } if errors.Is(err, cookie.ErrParseAccessToken) { - h.log.Debugw(tid, "failed to parse access token cookie") + h.log.Debugw(ctx, "failed to parse access token cookie") http.SetCookie(rw, h.cookieProcessor.ExpireAccessToken()) - rest.SendNoContent(ctx, rw, h.log) + rw.WriteHeader(http.StatusNoContent) return } - h.log.Errorw(tid, "failed to get access token cookie from request", err) - sendInternalServerError(ctx, rw, h.log) + h.log.Errorw(ctx, "failed to get access token cookie from request", err) + h.resp.SendInternalServerError(ctx, rw) return } - if claims, ok = token.Claims.(jwt.MapClaims); !ok || !token.Valid { - h.log.Debugw(tid, "failed to parse access token cookie") + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + h.log.Debugw(ctx, "failed to parse access token cookie") http.SetCookie(rw, h.cookieProcessor.ExpireAccessToken()) - rest.SendNoContent(ctx, rw, h.log) + rw.WriteHeader(http.StatusNoContent) return } @@ -206,12 +183,12 @@ func (h *AuthHandler) SignOut(rw http.ResponseWriter, r *http.Request) { if jti != "" && exp > 0 { expAt := time.Unix(int64(exp), 0) if err = h.jwtRepository.CreateBlockedJTI(jti, expAt); err != nil { - h.log.Errorw(tid, "failed to create blocked jti", err) + h.log.Errorw(ctx, "failed to create blocked jti", err) } } http.SetCookie(rw, h.cookieProcessor.ExpireAccessToken()) - rest.SendNoContent(ctx, rw, h.log) + rw.WriteHeader(http.StatusNoContent) } func userToDTO(user *dal.User) *User { diff --git a/internal/handle/error.go b/internal/handle/error.go deleted file mode 100644 index 1d85ba9..0000000 --- a/internal/handle/error.go +++ /dev/null @@ -1,93 +0,0 @@ -package handle - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - - "github.com/Roma7-7-7/shared-clipboard/internal/domain" - "github.com/Roma7-7-7/shared-clipboard/tools/log" - "github.com/Roma7-7-7/shared-clipboard/tools/rest" -) - -const errorResponseTmpl = `{"error": true, "code": "%s", "message": "%s"}` - -type genericErrorResponse struct { - Error bool `json:"error"` - Code string `json:"code"` - Message string `json:"message"` - Details any `json:"details,omitempty"` -} - -func badRequestErrorBody(message string) []byte { - return []byte(fmt.Sprintf(errorResponseTmpl, domain.ErrorBadRequest.Value, message)) -} - -func notFoundErrorBody(message string) []byte { - return []byte(fmt.Sprintf(errorResponseTmpl, domain.ErrorCodeNotFound.Value, message)) -} - -func unauthorizedErrorBody(message string) []byte { - return []byte(fmt.Sprintf(errorResponseTmpl, domain.ErrorCodeUnauthorized.Value, message)) -} - -func forbiddenErrorBody(message string) []byte { - return []byte(fmt.Sprintf(errorResponseTmpl, domain.ErrorCodeForbidden.Value, message)) -} - -func methodNotAllowedErrorBody(method string) []byte { - return []byte(fmt.Sprintf(errorResponseTmpl, domain.ErrorCodeMethodNotAllowed.Value, fmt.Sprintf("Method %s is not allowed", method))) -} - -func internalServerErrorBody() []byte { - return []byte(fmt.Sprintf(errorResponseTmpl, domain.ErrorCodeInternalServerError.Value, "Internal server error")) -} - -func marshalErrorBody() []byte { - return []byte(fmt.Sprintf(errorResponseTmpl, domain.ErrorCodeMarshalResponse.Value, "Failed to marshal response")) -} - -func sendBadRequest(ctx context.Context, rw http.ResponseWriter, message string, log log.TracedLogger) { - rest.Send(ctx, rw, http.StatusBadRequest, rest.ContentTypeJSON, badRequestErrorBody(message), log) -} - -func sendNotFound(ctx context.Context, rw http.ResponseWriter, message string, log log.TracedLogger) { - rest.Send(ctx, rw, http.StatusNotFound, rest.ContentTypeJSON, notFoundErrorBody(message), log) -} - -func sendUnauthorized(ctx context.Context, rw http.ResponseWriter, log log.TracedLogger) { - rest.Send(ctx, rw, http.StatusUnauthorized, rest.ContentTypeJSON, unauthorizedErrorBody("Request is not authorized"), log) -} - -func sendForbidden(ctx context.Context, rw http.ResponseWriter, message string, log log.TracedLogger) { - rest.Send(ctx, rw, http.StatusUnauthorized, rest.ContentTypeJSON, forbiddenErrorBody(message), log) -} - -func sendErrorMarshalBody(ctx context.Context, rw http.ResponseWriter, log log.TracedLogger) { - rest.Send(ctx, rw, http.StatusInternalServerError, rest.ContentTypeJSON, marshalErrorBody(), log) -} - -func sendErrorMethodNotAllowed(ctx context.Context, method string, rw http.ResponseWriter, log log.TracedLogger) { - rest.Send(ctx, rw, http.StatusMethodNotAllowed, rest.ContentTypeJSON, methodNotAllowedErrorBody(method), log) -} - -func sendInternalServerError(ctx context.Context, rw http.ResponseWriter, log log.TracedLogger) { - rest.Send(ctx, rw, http.StatusInternalServerError, rest.ContentTypeJSON, internalServerErrorBody(), log) -} - -func sendRenderableError(ctx context.Context, err *domain.RenderableError, rw http.ResponseWriter, log log.TracedLogger) { - bytes, mErr := json.Marshal(genericErrorResponse{ - Error: true, - Code: err.Code.Value, - Message: err.Message, - Details: err.Details, - }) - if mErr != nil { - log.Errorw(domain.TraceIDFromContext(ctx), "failed to marshal renderable error", mErr) - sendErrorMarshalBody(ctx, rw, log) - return - } - - rest.Send(ctx, rw, err.Code.StatusCode, rest.ContentTypeJSON, bytes, log) -} diff --git a/internal/handle/middleware.go b/internal/handle/middleware.go index 16df966..a8e2ec6 100644 --- a/internal/handle/middleware.go +++ b/internal/handle/middleware.go @@ -1,6 +1,7 @@ package handle import ( + "context" "errors" "fmt" "math/rand" @@ -11,12 +12,14 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/golang-jwt/jwt/v5" + ac "github.com/Roma7-7-7/shared-clipboard/internal/context" "github.com/Roma7-7-7/shared-clipboard/internal/domain" "github.com/Roma7-7-7/shared-clipboard/internal/handle/cookie" - "github.com/Roma7-7-7/shared-clipboard/tools/log" + "github.com/Roma7-7-7/shared-clipboard/internal/log" ) type AuthorizedMiddleware struct { + resp *responder cookieProcessor CookieProcessor jwtRepository JWTRepository log log.TracedLogger @@ -29,7 +32,7 @@ func TraceID(next http.Handler) http.Handler { tid = randomAlphanumericTraceID() } w.Header().Set(middleware.RequestIDHeader, tid) - next.ServeHTTP(w, r.WithContext(domain.ContextWithTraceID(r.Context(), tid))) + next.ServeHTTP(w, r.WithContext(ac.WithTraceID(r.Context(), tid))) }) } @@ -40,7 +43,7 @@ func Logger(l log.TracedLogger) func(next http.Handler) http.Handler { started := time.Now() defer func() { - l.Infow(domain.TraceIDFromContext(r.Context()), "request", + l.Infow(r.Context(), "request", "method", r.Method, "url", r.URL.String(), "proto", r.Proto, @@ -56,9 +59,10 @@ func Logger(l log.TracedLogger) func(next http.Handler) http.Handler { } func NewAuthorizedMiddleware( - cookieProcessor CookieProcessor, jwtRepository JWTRepository, log log.TracedLogger, + cookieProcessor CookieProcessor, jwtRepository JWTRepository, resp *responder, log log.TracedLogger, ) *AuthorizedMiddleware { return &AuthorizedMiddleware{ + resp: resp, cookieProcessor: cookieProcessor, jwtRepository: jwtRepository, log: log, @@ -68,42 +72,40 @@ func NewAuthorizedMiddleware( func (m *AuthorizedMiddleware) Handle(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { var ( - ctx = r.Context() - tid = domain.TraceIDFromContext(ctx) - token *jwt.Token - claims jwt.MapClaims - ok bool - authority *domain.Authority - err error + ctx = r.Context() + token *jwt.Token + err error ) - m.log.Debugw(tid, "authorized middleware") + m.log.Debugw(ctx, "authorized middleware") if token, err = m.cookieProcessor.AccessTokenFromRequest(r); err != nil { if errors.Is(err, cookie.ErrAccessTokenNotFound) { - m.log.Debugw(tid, "access token cookie not found") - sendUnauthorized(ctx, rw, m.log) + m.log.Debugw(ctx, "access token cookie not found") + m.resp.SendError(ctx, rw, http.StatusUnauthorized, domain.ErrorCodeUnauthorized.Value, "Request is not authorized", nil) return } if errors.Is(err, cookie.ErrParseAccessToken) { - m.log.Debugw(tid, "failed to parse access token cookie") - sendForbidden(ctx, rw, "JWT token is not valid or expired", m.log) + m.log.Debugw(ctx, "failed to parse access token cookie") + m.sendForbidden(ctx, rw, "JWT token is not valid or expired") return } - m.log.Errorw(tid, "failed to get access token cookie from request", err) - sendInternalServerError(ctx, rw, m.log) + m.log.Errorw(ctx, "failed to get access token cookie from request", err) + m.resp.SendInternalServerError(ctx, rw) return } - if claims, ok = token.Claims.(jwt.MapClaims); !ok || !token.Valid { - m.log.Debugw(tid, "failed to parse access token cookie") - sendForbidden(ctx, rw, "JWT token is not valid or expired", m.log) + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + m.log.Debugw(ctx, "failed to parse access token cookie") + m.sendForbidden(ctx, rw, "JWT token is not valid or expired") return } - if authority, err = toAuthority(claims); err != nil { - m.log.Errorw(tid, "failed to parse authority", err) - sendInternalServerError(ctx, rw, m.log) + authority, err := toAuthority(claims) + if err != nil { + m.log.Errorw(ctx, "failed to parse authority", err) + m.resp.SendInternalServerError(ctx, rw) return } @@ -111,21 +113,25 @@ func (m *AuthorizedMiddleware) Handle(next http.Handler) http.Handler { if ok && jti != "" { ok, err = m.jwtRepository.IsBlockedJTIExists(jti) if err != nil { - m.log.Errorw(tid, "failed to check blocked jti", err) - sendInternalServerError(ctx, rw, m.log) + m.log.Errorw(ctx, "failed to check blocked jti", err) + m.resp.SendInternalServerError(ctx, rw) return } if ok { - m.log.Debugw(tid, "blocked jti") - sendForbidden(ctx, rw, "JWT token is not valid or expired", m.log) + m.log.Debugw(ctx, "blocked jti") + m.sendForbidden(ctx, rw, "JWT token is not valid or expired") return } } - next.ServeHTTP(rw, r.WithContext(domain.ContextWithAuthority(ctx, authority))) + next.ServeHTTP(rw, r.WithContext(ac.WithAuthority(ctx, authority))) }) } +func (m *AuthorizedMiddleware) sendForbidden(ctx context.Context, rw http.ResponseWriter, message string) { + m.resp.SendError(ctx, rw, http.StatusForbidden, domain.ErrorCodeForbidden.Value, message, nil) +} + var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") func randomAlphanumericTraceID() string { @@ -136,7 +142,7 @@ func randomAlphanumericTraceID() string { return string(b) } -func toAuthority(claims jwt.MapClaims) (*domain.Authority, error) { +func toAuthority(claims jwt.MapClaims) (*ac.Authority, error) { var ( ids string id uint64 @@ -154,7 +160,7 @@ func toAuthority(claims jwt.MapClaims) (*domain.Authority, error) { return nil, errors.New("name is not a string") } - return &domain.Authority{ + return &ac.Authority{ UserID: id, UserName: name, }, nil diff --git a/internal/handle/respond.go b/internal/handle/respond.go new file mode 100644 index 0000000..282a830 --- /dev/null +++ b/internal/handle/respond.go @@ -0,0 +1,98 @@ +package handle + +import ( + "context" + "encoding/json" + "net/http" + "strings" + + "github.com/Roma7-7-7/shared-clipboard/internal/domain" + "github.com/Roma7-7-7/shared-clipboard/internal/log" +) + +const ( + ContentTypeHeader = "Content-Type" + ContentTypeJSON = "application/json" + LastModifiedHeader = "Last-Modified" + IfModifiedSinceHeader = "If-Modified-Since" +) + +type genericErrorResponse struct { + Error bool `json:"error"` + Code string `json:"code"` + Message string `json:"message"` + Details any `json:"details,omitempty"` +} + +type responder struct { + log log.TracedLogger +} + +func (r *responder) Send(ctx context.Context, rw http.ResponseWriter, status int, headers map[string][]string, value interface{}) { + body, err := json.Marshal(value) + if err != nil { + r.log.Errorw(ctx, "Failed to marshal response", err) + r.SendInternalServerError(ctx, rw) + return + } + + contentTypeSet := false + for k, v := range headers { + if strings.EqualFold(k, ContentTypeHeader) { + contentTypeSet = true + } + for _, v := range v { + rw.Header().Add(k, v) + } + } + if !contentTypeSet { + rw.Header().Set(ContentTypeJSON, ContentTypeJSON) + } + rw.WriteHeader(status) + + if n, err := rw.Write(body); err != nil { + // no reason to return error if we already wrote some bytes + r.log.Errorw(ctx, "Failed to write response", "bytesWritten", n, err) + } +} + +func (r *responder) SendError(ctx context.Context, rw http.ResponseWriter, status int, code, message string, details any) { + r.Send(ctx, rw, status, nil, genericErrorResponse{ + Error: true, + Code: code, + Message: message, + Details: details, + }) +} + +func (r *responder) SendUnauthorized(ctx context.Context, rw http.ResponseWriter) { + r.Send(ctx, rw, http.StatusUnauthorized, nil, genericErrorResponse{ + Error: true, + Code: domain.ErrorCodeUnauthorized.Value, + Message: "Request is not authorized", + }) +} + +func (r *responder) SendBadRequest(ctx context.Context, rw http.ResponseWriter, message string) { + r.Send(ctx, rw, http.StatusBadRequest, nil, genericErrorResponse{ + Error: true, + Code: domain.ErrorBadRequest.Value, + Message: message, + }) +} + +func (r *responder) SendNotFound(ctx context.Context, rw http.ResponseWriter, message string) { + r.Send(ctx, rw, http.StatusNotFound, nil, genericErrorResponse{ + Error: true, + Code: domain.ErrorCodeNotFound.Value, + Message: message, + }) +} + +func (r *responder) SendInternalServerError(ctx context.Context, rw http.ResponseWriter) { + r.Send(ctx, rw, http.StatusInternalServerError, nil, genericErrorResponse{ + Error: true, + Code: domain.ErrorCodeInternalServerError.Value, + Message: "Internal server error", + }) +} diff --git a/internal/handle/route.go b/internal/handle/route.go deleted file mode 100644 index dc17aea..0000000 --- a/internal/handle/route.go +++ /dev/null @@ -1,86 +0,0 @@ -package handle - -import ( - "fmt" - "net/http" - "time" - - "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" - "github.com/go-chi/cors" - "github.com/go-chi/httprate" - - "github.com/Roma7-7-7/shared-clipboard/internal/config" - "github.com/Roma7-7-7/shared-clipboard/internal/domain" - "github.com/Roma7-7-7/shared-clipboard/tools/log" -) - -type Dependencies struct { - Config config.App - CookieProcessor - UserService - JWTRepository - SessionService - ClipboardRepository -} - -func NewRouter(deps Dependencies, log log.TracedLogger) (*chi.Mux, error) { - log.Infow(domain.RuntimeTraceID, "Initializing router") - - r := chi.NewRouter() - conf := deps.Config - - r.Use(TraceID) - r.Use(Logger(log)) - r.Use(httprate.LimitByIP(10, 1*time.Second)) - r.Use(middleware.RedirectSlashes) - r.Use(middleware.Recoverer) - r.Use(middleware.Compress(5, "text/html", "text/css", "text/javascript")) - r.Use(cors.Handler(cors.Options{ - AllowedOrigins: conf.CORS.AllowOrigins, - AllowedMethods: conf.CORS.AllowMethods, - AllowedHeaders: conf.CORS.AllowHeaders, - ExposedHeaders: conf.CORS.ExposeHeaders, - MaxAge: conf.CORS.MaxAge, - AllowCredentials: conf.CORS.AllowCredentials, - })) - - r.Route("/", NewAuthHandler(deps.UserService, deps.CookieProcessor, deps.JWTRepository, log).RegisterRoutes) - - r.With(NewAuthorizedMiddleware(deps.CookieProcessor, deps.JWTRepository, log).Handle). - Route("/v1", func(r chi.Router) { - r.Route("/sessions", NewSessionHandler(deps.SessionService, deps.ClipboardRepository, log).RegisterRoutes) - r.Route("/user", NewUserHandler(log).RegisterRoutes) - }) - - r.NotFound(handleNotFound(log)) - r.MethodNotAllowed(handleMethodNotAllowed(log)) - - printRoutes(r, log) - - log.Infow(domain.RuntimeTraceID, "Router initialized") - return r, nil -} - -func handleNotFound(log log.TracedLogger) func(rw http.ResponseWriter, r *http.Request) { - return func(rw http.ResponseWriter, r *http.Request) { - sendNotFound(r.Context(), rw, "Not Found", log) - } -} - -func handleMethodNotAllowed(log log.TracedLogger) func(rw http.ResponseWriter, r *http.Request) { - return func(rw http.ResponseWriter, r *http.Request) { - sendErrorMethodNotAllowed(r.Context(), r.Method, rw, log) - } -} - -func printRoutes(r *chi.Mux, logger log.TracedLogger) { - logger.Infow(domain.RuntimeTraceID, "Routes:") - err := chi.Walk(r, func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { - logger.Infow(domain.RuntimeTraceID, fmt.Sprintf("[%s]: '%s' has %d middlewares", method, route, len(middlewares))) - return nil - }) - if err != nil { - logger.Errorw(domain.RuntimeTraceID, "Failed to walk routes", err) - } -} diff --git a/internal/handle/routes.go b/internal/handle/routes.go new file mode 100644 index 0000000..3475c01 --- /dev/null +++ b/internal/handle/routes.go @@ -0,0 +1,100 @@ +package handle + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/cors" + "github.com/go-chi/httprate" + + "github.com/Roma7-7-7/shared-clipboard/internal/config" + "github.com/Roma7-7-7/shared-clipboard/internal/domain" + "github.com/Roma7-7-7/shared-clipboard/internal/log" +) + +type Dependencies struct { + Config config.App + CookieProcessor + UserService + JWTRepository + SessionService + ClipboardRepository +} + +func NewRouter(ctx context.Context, deps Dependencies, log log.TracedLogger) (*chi.Mux, error) { + log.Infow(ctx, "Initializing router") + + r := chi.NewRouter() + conf := deps.Config + + r.Use(TraceID) + r.Use(Logger(log)) + r.Use(httprate.LimitByIP(10, 1*time.Second)) + r.Use(middleware.RedirectSlashes) + r.Use(middleware.Recoverer) + r.Use(middleware.Compress(5, "text/html", "text/css", "text/javascript")) + r.Use(cors.Handler(cors.Options{ + AllowedOrigins: conf.CORS.AllowOrigins, + AllowedMethods: conf.CORS.AllowMethods, + AllowedHeaders: conf.CORS.AllowHeaders, + ExposedHeaders: conf.CORS.ExposeHeaders, + MaxAge: conf.CORS.MaxAge, + AllowCredentials: conf.CORS.AllowCredentials, + })) + + resp := &responder{log: log} + + authHandler := NewAuthHandler(deps.UserService, deps.CookieProcessor, deps.JWTRepository, resp, log) + r.Post("/signup", authHandler.SignUp) + r.Post("/signin", authHandler.SignIn) + r.Post("/signout", authHandler.SignOut) + + authorizedRouter := r.With(NewAuthorizedMiddleware(deps.CookieProcessor, deps.JWTRepository, resp, log).Handle) + + sessionHandler := NewSessionHandler(deps.SessionService, deps.ClipboardRepository, resp, log) + authorizedRouter.Post("/v1/sessions", sessionHandler.Create) + authorizedRouter.Get("/v1/sessions", sessionHandler.GetAllByUserID) + authorizedRouter.Get("/v1/sessions/{sessionID}", sessionHandler.GetByID) + authorizedRouter.Put("/v1/sessions/{sessionID}", sessionHandler.Update) + authorizedRouter.Delete("/v1/sessions/{sessionID}", sessionHandler.Delete) + authorizedRouter.Get("/v1/sessions/{sessionID}/clipboard", sessionHandler.GetClipboard) + authorizedRouter.Put("/v1/sessions/{sessionID}/clipboard", sessionHandler.SetClipboard) + + userHandler := NewUserHandler(resp, log) + authorizedRouter.Get("/v1/user/info", userHandler.GetUserInfo) + + r.NotFound(handleNotFound(resp)) + r.MethodNotAllowed(handleMethodNotAllowed(resp)) + + printRoutes(ctx, r, log) + + log.Infow(ctx, "Router initialized") + return r, nil +} + +func handleNotFound(resp *responder) func(rw http.ResponseWriter, r *http.Request) { + return func(rw http.ResponseWriter, r *http.Request) { + resp.SendNotFound(r.Context(), rw, "Not Found") + } +} + +func handleMethodNotAllowed(resp *responder) func(rw http.ResponseWriter, r *http.Request) { + return func(rw http.ResponseWriter, r *http.Request) { + resp.SendError(r.Context(), rw, http.StatusMethodNotAllowed, domain.ErrorCodeMethodNotAllowed.Value, "Method Not Allowed", nil) + } +} + +func printRoutes(ctx context.Context, r *chi.Mux, logger log.TracedLogger) { + logger.Infow(ctx, "Routes:") + err := chi.Walk(r, func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { + logger.Infow(ctx, fmt.Sprintf("[%s]: '%s' has %d middlewares", method, route, len(middlewares))) + return nil + }) + if err != nil { + logger.Errorw(ctx, "Failed to walk routes", err) + } +} diff --git a/internal/handle/session.go b/internal/handle/session.go index ea2c0a5..df00f7b 100644 --- a/internal/handle/session.go +++ b/internal/handle/session.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net/http" "strconv" @@ -12,10 +11,10 @@ import ( "github.com/go-chi/chi/v5" + ac "github.com/Roma7-7-7/shared-clipboard/internal/context" "github.com/Roma7-7-7/shared-clipboard/internal/dal" "github.com/Roma7-7-7/shared-clipboard/internal/domain" - "github.com/Roma7-7-7/shared-clipboard/tools/log" - "github.com/Roma7-7-7/shared-clipboard/tools/rest" + "github.com/Roma7-7-7/shared-clipboard/internal/log" ) type ( @@ -31,11 +30,11 @@ type ( } SessionService interface { - GetByID(ctx context.Context, id uint64) (*domain.Session, error) + GetByID(ctx context.Context, userID, id uint64) (*domain.Session, error) GetByUserID(ctx context.Context, userID uint64) ([]*domain.Session, error) - Create(ctx context.Context, name string, userID uint64) (*domain.Session, error) - Update(ctx context.Context, sessionID, userID uint64, name string) (*domain.Session, error) - Delete(ctx context.Context, sessionID, userID uint64) error + Create(ctx context.Context, userID uint64, name string) (*domain.Session, error) + Update(ctx context.Context, userID, sessionID uint64, name string) (*domain.Session, error) + Delete(ctx context.Context, userID, sessionID uint64) error } ClipboardRepository interface { @@ -44,369 +43,337 @@ type ( } SessionHandler struct { + resp *responder service SessionService clipboardRepo ClipboardRepository log log.TracedLogger } ) -func NewSessionHandler(sessionService SessionService, clipboardRepo ClipboardRepository, log log.TracedLogger) *SessionHandler { +func NewSessionHandler(sessionService SessionService, clipboardRepo ClipboardRepository, resp *responder, log log.TracedLogger) *SessionHandler { return &SessionHandler{ + resp: resp, service: sessionService, clipboardRepo: clipboardRepo, log: log, } } -func (s *SessionHandler) RegisterRoutes(r chi.Router) { - r.Post("/", s.Create) - r.Get("/", s.GetAllByUserID) - r.Get("/{sessionID}", s.GetByID) - r.Put("/{sessionID}", s.Update) - r.Delete("/{sessionID}", s.Delete) - r.Get("/{sessionID}/clipboard", s.GetClipboard) - r.Put("/{sessionID}/clipboard", s.SetClipboard) -} - -func (s *SessionHandler) GetByID(rw http.ResponseWriter, r *http.Request) { +func (h *SessionHandler) GetByID(rw http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - tid = domain.TraceIDFromContext(ctx) sessionID = chi.URLParam(r, "sessionID") ) - s.log.Debugw(tid, "Get session by ID", "sessionID", sessionID) + h.log.Debugw(ctx, "Get session by ID", "sessionID", sessionID) + + auth, ok := ac.AuthorityFrom(ctx) + if !ok { + h.log.Debugw(ctx, "user not found in context") + h.resp.SendUnauthorized(ctx, rw) + return + } if sessionID == "" { - s.log.Debugw(tid, "sessionID is empty") - sendBadRequest(ctx, rw, "sessionID param is required", s.log) + h.log.Debugw(ctx, "sessionID is empty") + h.resp.SendBadRequest(ctx, rw, "sessionID param is required") return } sid, err := strconv.ParseUint(sessionID, 10, 64) if err != nil { - s.log.Errorw(tid, "failed to parse sessionID", err) - sendBadRequest(ctx, rw, "sessionID param must be a valid uint64 value", s.log) + h.log.Errorw(ctx, "failed to parse sessionID", err) + h.resp.SendBadRequest(ctx, rw, "sessionID param must be a valid uint64 value") return } - session, err := s.service.GetByID(ctx, sid) + session, err := h.service.GetByID(ctx, auth.UserID, sid) if err != nil { if errors.Is(err, domain.ErrSessionNotFound) { - s.log.Debugw(tid, "session not found", "sessionID", sessionID) - sendNotFound(ctx, rw, "Session with provided ID not found", s.log) + h.log.Debugw(ctx, "session not found", "sessionID", sessionID) + h.resp.SendNotFound(ctx, rw, "Session with provided ID not found") return } - s.log.Errorw(tid, "failed to get session", "sessionID", sessionID, err) - sendInternalServerError(ctx, rw, s.log) + h.log.Errorw(ctx, "failed to get session", "sessionID", sessionID, err) + h.resp.SendInternalServerError(ctx, rw) return } - s.log.Debugw(tid, "Got session", "sessionID", session.ID) - body, err := rest.ToJSON(toDTO(session)) - if err != nil { - s.log.Errorw(tid, "failed to marshal session", err) - sendErrorMarshalBody(ctx, rw, s.log) - return - } - - rw.Header().Set(rest.LastModifiedHeader, session.UpdatedAt.Format(http.TimeFormat)) - rest.Send(ctx, rw, http.StatusOK, rest.ContentTypeJSON, body, s.log) + h.log.Debugw(ctx, "Got session", "sessionID", session.ID) + h.resp.Send(ctx, rw, http.StatusOK, map[string][]string{ + LastModifiedHeader: {session.UpdatedAt.Format(http.TimeFormat)}, + }, toDTO(session)) } -func (s *SessionHandler) GetAllByUserID(rw http.ResponseWriter, r *http.Request) { +func (h *SessionHandler) GetAllByUserID(rw http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - tid = domain.TraceIDFromContext(ctx) ) - auth, ok := domain.AuthorityFromContext(ctx) + auth, ok := ac.AuthorityFrom(ctx) if !ok { - s.log.Debugw(tid, "user not found in context") - sendUnauthorized(ctx, rw, s.log) + h.log.Debugw(ctx, "user not found in context") + h.resp.SendUnauthorized(ctx, rw) return } - s.log.Debugw(tid, "Get all sessions by user", "userID", auth.UserID) + h.log.Debugw(ctx, "Get all sessions by user", "userID", auth.UserID) - sessions, err := s.service.GetByUserID(ctx, auth.UserID) + sessions, err := h.service.GetByUserID(ctx, auth.UserID) if err != nil { - s.log.Errorw(tid, "failed to get sessions", "userID", auth.UserID, err) - sendInternalServerError(ctx, rw, s.log) + h.log.Errorw(ctx, "failed to get sessions", "userID", auth.UserID, err) + h.resp.SendInternalServerError(ctx, rw) return } - s.log.Debugw(tid, "Got sessions", "count", len(sessions)) + h.log.Debugw(ctx, "Got sessions", "count", len(sessions)) res := make([]*Session, 0, len(sessions)) for _, session := range sessions { res = append(res, toDTO(session)) } - body, err := rest.ToJSON(res) - if err != nil { - s.log.Errorw(tid, "failed to marshal sessions", err) - sendErrorMarshalBody(ctx, rw, s.log) - return - } - - rest.Send(ctx, rw, http.StatusOK, rest.ContentTypeJSON, body, s.log) + h.resp.Send(ctx, rw, http.StatusOK, nil, res) } -func (s *SessionHandler) Create(rw http.ResponseWriter, r *http.Request) { +func (h *SessionHandler) Create(rw http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - tid = domain.TraceIDFromContext(ctx) ) - user, ok := domain.AuthorityFromContext(ctx) + user, ok := ac.AuthorityFrom(ctx) if !ok { - s.log.Debugw(tid, "user not found in context") - sendUnauthorized(ctx, rw, s.log) + h.log.Debugw(ctx, "user not found in context") + h.resp.SendUnauthorized(ctx, rw) return } var req sessionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.log.Debugw(tid, "failed to decode body", err) - sendBadRequest(ctx, rw, "failed to parse request", s.log) + h.log.Debugw(ctx, "failed to decode body", err) + h.resp.SendBadRequest(ctx, rw, "failed to parse request") return } if strings.TrimSpace(req.Name) == "" { - s.log.Debugw(tid, "name is empty") - sendBadRequest(ctx, rw, "name param is required", s.log) + h.log.Debugw(ctx, "name is empty") + h.resp.SendBadRequest(ctx, rw, "name param is required") return } - session, err := s.service.Create(ctx, req.Name, user.UserID) + session, err := h.service.Create(ctx, user.UserID, req.Name) if err != nil { - s.log.Errorw(tid, "failed to create session", err) - sendInternalServerError(ctx, rw, s.log) + h.log.Errorw(ctx, "failed to create session", err) + h.resp.SendInternalServerError(ctx, rw) return } - s.log.Debugw(tid, "Created session", "id", session.ID) + h.log.Debugw(ctx, "Created session", "id", session.ID) - body, err := rest.ToJSON(toDTO(session)) - if err != nil { - s.log.Errorw(tid, "failed to marshal session", err) - sendErrorMarshalBody(ctx, rw, s.log) - return - } - - rw.Header().Set(rest.LastModifiedHeader, session.UpdatedAt.Format(http.TimeFormat)) - rest.Send(ctx, rw, http.StatusCreated, rest.ContentTypeJSON, body, s.log) + h.resp.Send(ctx, rw, http.StatusCreated, map[string][]string{ + LastModifiedHeader: {session.UpdatedAt.Format(http.TimeFormat)}, + }, toDTO(session)) } -func (s *SessionHandler) Update(rw http.ResponseWriter, r *http.Request) { +func (h *SessionHandler) Update(rw http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() sessionID = chi.URLParam(r, "sessionID") - tid = domain.TraceIDFromContext(ctx) ) if sessionID == "" { - s.log.Debugw(tid, "sessionID is empty") - sendBadRequest(ctx, rw, "sessionID param is required", s.log) + h.log.Debugw(ctx, "sessionID is empty") + h.resp.SendBadRequest(ctx, rw, "sessionID param is required") return } - auth, ok := domain.AuthorityFromContext(ctx) + auth, ok := ac.AuthorityFrom(ctx) if !ok { - s.log.Debugw(tid, "user not found in context") - sendUnauthorized(ctx, rw, s.log) + h.log.Debugw(ctx, "user not found in context") + h.resp.SendUnauthorized(ctx, rw) return } sid, err := strconv.ParseUint(sessionID, 10, 64) if err != nil { - s.log.Errorw(tid, "failed to parse sessionID", err) - sendBadRequest(ctx, rw, "sessionID param must be a valid uint64 value", s.log) + h.log.Errorw(ctx, "failed to parse sessionID", err) + h.resp.SendBadRequest(ctx, rw, "sessionID param must be a valid uint64 value") return } var req sessionRequest if err = json.NewDecoder(r.Body).Decode(&req); err != nil { - s.log.Debugw(tid, "failed to decode body", err) - sendBadRequest(ctx, rw, "failed to parse request", s.log) + h.log.Debugw(ctx, "failed to decode body", err) + h.resp.SendBadRequest(ctx, rw, "failed to parse request") return } - session, err := s.service.Update(ctx, sid, auth.UserID, req.Name) + session, err := h.service.Update(ctx, auth.UserID, sid, req.Name) if err != nil { if errors.Is(err, domain.ErrSessionNotFound) { - s.log.Debugw(tid, "session not found", "sessionID", sessionID) - sendNotFound(ctx, rw, "Session with provided ID not found", s.log) + h.log.Debugw(ctx, "session not found", "sessionID", sessionID) + h.resp.SendNotFound(ctx, rw, "Session with provided ID not found") return } if errors.Is(err, domain.ErrSessionPermissionDenied) { - s.log.Debugw(tid, "permission denied", "sessionID", sessionID) - sendForbidden(ctx, rw, "Permission denied", s.log) + h.log.Debugw(ctx, "permission denied", "sessionID", sessionID) + h.resp.SendNotFound(ctx, rw, "session not found") return } - s.log.Errorw(tid, "failed to update session", err) - sendInternalServerError(ctx, rw, s.log) - return - } - - body, err := rest.ToJSON(toDTO(session)) - if err != nil { - s.log.Errorw(tid, "failed to marshal session", err) - sendErrorMarshalBody(ctx, rw, s.log) + h.log.Errorw(ctx, "failed to update session", err) + h.resp.SendInternalServerError(ctx, rw) return } - rw.Header().Set(rest.LastModifiedHeader, session.UpdatedAt.Format(http.TimeFormat)) - rest.Send(ctx, rw, http.StatusOK, rest.ContentTypeJSON, body, s.log) + h.log.Debugw(ctx, "Updated session", "id", session.ID) + h.resp.Send(ctx, rw, http.StatusOK, map[string][]string{ + LastModifiedHeader: {session.UpdatedAt.Format(http.TimeFormat)}, + }, toDTO(session)) } -func (s *SessionHandler) Delete(rw http.ResponseWriter, r *http.Request) { +func (h *SessionHandler) Delete(rw http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() sessionID = chi.URLParam(r, "sessionID") - tid = domain.TraceIDFromContext(ctx) ) if sessionID == "" { - s.log.Debugw(tid, "sessionID is empty") - sendBadRequest(ctx, rw, "sessionID param is required", s.log) + h.log.Debugw(ctx, "sessionID is empty") + h.resp.SendBadRequest(ctx, rw, "sessionID param is required") return } - auth, ok := domain.AuthorityFromContext(ctx) + auth, ok := ac.AuthorityFrom(ctx) if !ok { - s.log.Debugw(tid, "user not found in context") - sendUnauthorized(ctx, rw, s.log) + h.log.Debugw(ctx, "user not found in context") + h.resp.SendUnauthorized(ctx, rw) return } sid, err := strconv.ParseUint(sessionID, 10, 64) if err != nil { - s.log.Errorw(tid, "failed to parse sessionID", err) - sendBadRequest(ctx, rw, "sessionID param must be a valid uint64 value", s.log) + h.log.Errorw(ctx, "failed to parse sessionID", err) + h.resp.SendBadRequest(ctx, rw, "sessionID param must be a valid uint64 value") return } - if err = s.service.Delete(ctx, sid, auth.UserID); err != nil { + if err = h.service.Delete(ctx, auth.UserID, sid); err != nil { if errors.Is(err, domain.ErrSessionNotFound) { - s.log.Debugw(tid, "session not found", "sessionID", sessionID) - sendNotFound(ctx, rw, "Session with provided ID not found", s.log) + h.log.Debugw(ctx, "session not found", "sessionID", sessionID) + h.resp.SendNotFound(ctx, rw, "Session with provided ID not found") return } if errors.Is(err, domain.ErrSessionPermissionDenied) { - s.log.Debugw(tid, "permission denied", "sessionID", sessionID) - sendForbidden(ctx, rw, "Permission denied", s.log) + h.log.Debugw(ctx, "permission denied", "sessionID", sessionID) + h.resp.SendNotFound(ctx, rw, "session not found") return } - s.log.Errorw(tid, "failed to delete session", err) - sendInternalServerError(ctx, rw, s.log) + h.log.Errorw(ctx, "failed to delete session", err) + h.resp.SendInternalServerError(ctx, rw) return } - rest.SendNoContent(ctx, rw, s.log) + rw.WriteHeader(http.StatusNoContent) } -func (s *SessionHandler) GetClipboard(rw http.ResponseWriter, r *http.Request) { +func (h *SessionHandler) GetClipboard(rw http.ResponseWriter, r *http.Request) { var ( - ifLastModified = r.Header.Get(rest.IfModifiedSinceHeader) - sessionID string - sid uint64 - clipboard *dal.Clipboard - err error + ctx = r.Context() + ifLastModified = r.Header.Get(IfModifiedSinceHeader) + sessionID = chi.URLParam(r, "sessionID") ) - sessionID = chi.URLParam(r, "sessionID") if sessionID == "" { - s.log.Debugw(domain.TraceIDFromContext(r.Context()), "sessionID is empty") - sendBadRequest(r.Context(), rw, "sessionID param is required", s.log) + h.log.Debugw(ctx, "sessionID is empty") + h.resp.SendBadRequest(ctx, rw, "sessionID param is required") return } - if sid, err = strconv.ParseUint(sessionID, 10, 64); err != nil { - s.log.Errorw(domain.TraceIDFromContext(r.Context()), "failed to parse sessionID", err) - sendBadRequest(r.Context(), rw, "sessionID param must be a valid uint64 value", s.log) + sid, err := strconv.ParseUint(sessionID, 10, 64) + if err != nil { + h.log.Errorw(ctx, "failed to parse sessionID", err) + h.resp.SendBadRequest(ctx, rw, "sessionID param must be a valid uint64 value") return } - if clipboard, err = s.clipboardRepo.GetBySessionID(sid); err != nil { + clipboard, err := h.clipboardRepo.GetBySessionID(sid) + if err != nil { if errors.Is(err, dal.ErrNotFound) { - s.log.Debugw(domain.TraceIDFromContext(r.Context()), "clipboard not found", "id", sessionID) - rest.SendNoContent(r.Context(), rw, s.log) + h.log.Debugw(ctx, "clipboard not found", "id", sessionID) + rw.WriteHeader(http.StatusNoContent) return } - s.log.Errorw(domain.TraceIDFromContext(r.Context()), "failed to get clipboard", err) - sendInternalServerError(r.Context(), rw, s.log) + h.log.Errorw(ctx, "failed to get clipboard", err) + h.resp.SendInternalServerError(ctx, rw) return } lastModified := clipboard.UpdatedAt.UTC().Format(http.TimeFormat) if ifLastModified != "" && lastModified == ifLastModified { - s.log.Debugw(domain.TraceIDFromContext(r.Context()), "Not modified", "id", sid) + h.log.Debugw(ctx, "Not modified", "id", sid) rw.WriteHeader(http.StatusNotModified) return } - s.log.Debugw(domain.TraceIDFromContext(r.Context()), "Got session", "id", sid) - rw.Header().Set(rest.LastModifiedHeader, lastModified) - rw.Header().Set(rest.ContentTypeHeader, clipboard.ContentType) + h.log.Debugw(ctx, "Got session", "id", sid) + rw.Header().Set(LastModifiedHeader, lastModified) + rw.Header().Set(ContentTypeHeader, clipboard.ContentType) if _, err = rw.Write(clipboard.Content); err != nil { - s.log.Errorw(domain.TraceIDFromContext(r.Context()), "failed to write content", err) + h.log.Errorw(ctx, "failed to write content", err) } } -func (s *SessionHandler) SetClipboard(rw http.ResponseWriter, r *http.Request) { +func (h *SessionHandler) SetClipboard(rw http.ResponseWriter, r *http.Request) { var ( - contentType = r.Header.Get(rest.ContentTypeHeader) - sessionID string - sid uint64 - clipboard *dal.Clipboard - body []byte - err error + ctx = r.Context() + contentType = r.Header.Get(ContentTypeHeader) + sessionID = chi.URLParam(r, "sessionID") ) - if contentType != "text/plain" { - s.log.Debugw(domain.TraceIDFromContext(r.Context()), "Content-Type is not text/plain") - sendBadRequest(r.Context(), rw, fmt.Sprintf("Content-Type %s is not supported", contentType), s.log) + if strings.ToLower(contentType) != "text/plain" { + h.log.Debugw(ctx, "Content-Type is not text/plain") + h.resp.SendBadRequest(ctx, rw, "Content-Type text/plain is required") return } - sessionID = chi.URLParam(r, "sessionID") if sessionID == "" { - s.log.Debugw(domain.TraceIDFromContext(r.Context()), "sessionID is empty") - sendBadRequest(r.Context(), rw, "sessionID param is required", s.log) + h.log.Debugw(ctx, "sessionID is empty") + h.resp.SendBadRequest(ctx, rw, "sessionID param is required") return } - if body, err = io.ReadAll(r.Body); err != nil { - s.log.Errorw(domain.TraceIDFromContext(r.Context()), "failed to read body", err) - sendInternalServerError(r.Context(), rw, s.log) + body, err := io.ReadAll(r.Body) + if err != nil { + h.log.Errorw(ctx, "failed to read body", err) + h.resp.SendInternalServerError(ctx, rw) return } - if sid, err = strconv.ParseUint(sessionID, 10, 64); err != nil { - s.log.Errorw(domain.TraceIDFromContext(r.Context()), "failed to parse sessionID", err) - sendBadRequest(r.Context(), rw, "sessionID param must be a valid uint64 value", s.log) + sid, err := strconv.ParseUint(sessionID, 10, 64) + if err != nil { + h.log.Errorw(ctx, "failed to parse sessionID", err) + h.resp.SendBadRequest(ctx, rw, "sessionID param must be a valid uint64 value") return } - if clipboard, err = s.clipboardRepo.SetBySessionID(sid, contentType, body); err != nil { + clipboard, err := h.clipboardRepo.SetBySessionID(sid, contentType, body) + if err != nil { if errors.Is(err, dal.ErrNotFound) { - s.log.Debugw(domain.TraceIDFromContext(r.Context()), "session not found", "id", sessionID) - sendNotFound(r.Context(), rw, "Session with provided ID not found", s.log) + h.log.Debugw(ctx, "session not found", "id", sessionID) + h.resp.SendNotFound(ctx, rw, "Session with provided ID not found") return } - s.log.Errorw(domain.TraceIDFromContext(r.Context()), "failed to set content", err) - sendInternalServerError(r.Context(), rw, s.log) + h.log.Errorw(ctx, "failed to set content", err) + h.resp.SendInternalServerError(ctx, rw) return } - s.log.Debugw(domain.TraceIDFromContext(r.Context()), "Set content", "id", sessionID) - rw.Header().Set(rest.LastModifiedHeader, clipboard.UpdatedAt.UTC().Format(http.TimeFormat)) - rest.SendNoContent(r.Context(), rw, s.log) + h.log.Debugw(ctx, "Set content", "id", sessionID) + rw.Header().Set(LastModifiedHeader, clipboard.UpdatedAt.UTC().Format(http.TimeFormat)) + rw.WriteHeader(http.StatusNoContent) } func toDTO(session *domain.Session) *Session { diff --git a/internal/handle/userinfo.go b/internal/handle/userinfo.go index b102536..7b35432 100644 --- a/internal/handle/userinfo.go +++ b/internal/handle/userinfo.go @@ -3,11 +3,8 @@ package handle import ( "net/http" - "github.com/go-chi/chi/v5" - - "github.com/Roma7-7-7/shared-clipboard/internal/domain" - "github.com/Roma7-7-7/shared-clipboard/tools/log" - "github.com/Roma7-7-7/shared-clipboard/tools/rest" + "github.com/Roma7-7-7/shared-clipboard/internal/context" + "github.com/Roma7-7-7/shared-clipboard/internal/log" ) type ( @@ -17,40 +14,31 @@ type ( } UserHandler struct { - log log.TracedLogger + resp *responder + log log.TracedLogger } ) -func NewUserHandler(log log.TracedLogger) *UserHandler { +func NewUserHandler(resp *responder, log log.TracedLogger) *UserHandler { return &UserHandler{ - log: log, + resp: resp, + log: log, } } -func (h *UserHandler) RegisterRoutes(r chi.Router) { - r.Get("/info", h.GetUserInfo) -} - func (h *UserHandler) GetUserInfo(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - h.log.Debugw(domain.TraceIDFromContext(ctx), "get user info") + h.log.Debugw(ctx, "get user info") - auth, ok := domain.AuthorityFromContext(ctx) + auth, ok := context.AuthorityFrom(ctx) if !ok { - h.log.Errorw(domain.TraceIDFromContext(ctx), "authority not found in context") - sendInternalServerError(ctx, w, h.log) + h.log.Errorw(ctx, "authority not found in context") + h.resp.SendInternalServerError(ctx, w) return } - body, err := rest.ToJSON(UserInfo{ + h.resp.Send(ctx, w, http.StatusOK, nil, UserInfo{ ID: auth.UserID, Name: auth.UserName, }) - if err != nil { - h.log.Errorw(domain.TraceIDFromContext(ctx), "marshal body", err) - sendInternalServerError(ctx, w, h.log) - return - } - - rest.Send(ctx, w, http.StatusOK, rest.ContentTypeJSON, body, h.log) } diff --git a/internal/log/log.go b/internal/log/log.go new file mode 100644 index 0000000..2d3ad70 --- /dev/null +++ b/internal/log/log.go @@ -0,0 +1,68 @@ +package log + +import ( + "context" + + "go.uber.org/zap" + + ac "github.com/Roma7-7-7/shared-clipboard/internal/context" +) + +const ( + traceIDLogKey = "traceID" +) + +type ( + TracedLogger interface { + Debugw(ctx context.Context, msg string, keysAndValues ...interface{}) + Infow(ctx context.Context, msg string, keysAndValues ...interface{}) + Warnw(ctx context.Context, msg string, keysAndValues ...interface{}) + Errorw(ctx context.Context, msg string, keysAndValues ...interface{}) + } + + ZapTracedLogger struct { + log *zap.SugaredLogger + } +) + +func NewZapTracedLogger(logger *zap.SugaredLogger) *ZapTracedLogger { + return &ZapTracedLogger{ + log: logger.WithOptions(zap.AddCallerSkip(1)), + } +} + +func (l *ZapTracedLogger) Debugw(ctx context.Context, msg string, keysAndValues ...interface{}) { + log := l.log.With(traceIDLogKey, ac.TraceIDFrom(ctx)) + if len(keysAndValues) > 0 { + log.Debugw(msg, keysAndValues...) + } else { + log.Debug(msg) + } +} + +func (l *ZapTracedLogger) Infow(ctx context.Context, msg string, keysAndValues ...interface{}) { + log := l.log.With(traceIDLogKey, ac.TraceIDFrom(ctx)) + if len(keysAndValues) > 0 { + log.Infow(msg, keysAndValues...) + } else { + log.Info(msg) + } +} + +func (l *ZapTracedLogger) Warnw(ctx context.Context, msg string, keysAndValues ...interface{}) { + log := l.log.With(traceIDLogKey, ac.TraceIDFrom(ctx)) + if len(keysAndValues) > 0 { + log.Warnw(msg, keysAndValues...) + } else { + log.Warn(msg) + } +} + +func (l *ZapTracedLogger) Errorw(ctx context.Context, msg string, keysAndValues ...interface{}) { + log := l.log.With(traceIDLogKey, ac.TraceIDFrom(ctx)) + if len(keysAndValues) > 0 { + log.Errorw(msg, keysAndValues...) + } else { + log.Error(msg) + } +} diff --git a/tools/log/traced.go b/tools/log/traced.go deleted file mode 100644 index 5f1783f..0000000 --- a/tools/log/traced.go +++ /dev/null @@ -1,64 +0,0 @@ -package log - -import ( - "go.uber.org/zap" -) - -const ( - traceIDLogKey = "traceID" -) - -type ( - TracedLogger interface { - Debugw(tid string, msg string, keysAndValues ...interface{}) - Infow(tid string, msg string, keysAndValues ...interface{}) - Warnw(tid string, msg string, keysAndValues ...interface{}) - Errorw(tid string, msg string, keysAndValues ...interface{}) - } - - ZapTracedLogger struct { - log *zap.SugaredLogger - } -) - -func NewZapTracedLogger(logger *zap.SugaredLogger) *ZapTracedLogger { - return &ZapTracedLogger{ - log: logger.WithOptions(zap.AddCallerSkip(1)), - } -} - -func (l *ZapTracedLogger) Debugw(tid string, msg string, keysAndValues ...interface{}) { - log := l.log.With(traceIDLogKey, tid) - if len(keysAndValues) > 0 { - log.Debugw(msg, keysAndValues...) - } else { - log.Debug(msg) - } -} - -func (l *ZapTracedLogger) Infow(tid string, msg string, keysAndValues ...interface{}) { - log := l.log.With(traceIDLogKey, tid) - if len(keysAndValues) > 0 { - log.Infow(msg, keysAndValues...) - } else { - log.Info(msg) - } -} - -func (l *ZapTracedLogger) Warnw(tid string, msg string, keysAndValues ...interface{}) { - log := l.log.With(traceIDLogKey, tid) - if len(keysAndValues) > 0 { - log.Warnw(msg, keysAndValues...) - } else { - log.Warn(msg) - } -} - -func (l *ZapTracedLogger) Errorw(tid string, msg string, keysAndValues ...interface{}) { - log := l.log.With(traceIDLogKey, tid) - if len(keysAndValues) > 0 { - log.Errorw(msg, keysAndValues...) - } else { - log.Error(msg) - } -} diff --git a/tools/rest/response.go b/tools/rest/response.go deleted file mode 100644 index 1a2ea9f..0000000 --- a/tools/rest/response.go +++ /dev/null @@ -1,41 +0,0 @@ -package rest - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - - "github.com/Roma7-7-7/shared-clipboard/internal/domain" - "github.com/Roma7-7-7/shared-clipboard/tools/log" -) - -const ( - ContentTypeHeader = "Content-Type" - LastModifiedHeader = "Last-Modified" - IfModifiedSinceHeader = "If-Modified-Since" - - ContentTypeJSON = "application/json" -) - -func ToJSON(data any) ([]byte, error) { - marshal, err := json.Marshal(data) - if err != nil { - return nil, fmt.Errorf("marshal body: %w", err) - } - return marshal, err -} - -func Send(ctx context.Context, rw http.ResponseWriter, status int, contentType string, body []byte, log log.TracedLogger) { - if contentType != "" { - rw.Header().Set(ContentTypeHeader, contentType) - } - rw.WriteHeader(status) - if _, err := rw.Write(body); err != nil { - log.Errorw(domain.TraceIDFromContext(ctx), "Failed to write response", err) - } -} - -func SendNoContent(ctx context.Context, rw http.ResponseWriter, log log.TracedLogger) { - Send(ctx, rw, http.StatusNoContent, "", nil, log) -} diff --git a/tools/common.go b/tools/tools.go similarity index 100% rename from tools/common.go rename to tools/tools.go