diff --git a/server/tenant_test.go b/server/tenant_test.go index 66bed4b..e735fb0 100644 --- a/server/tenant_test.go +++ b/server/tenant_test.go @@ -49,18 +49,10 @@ func (ts *TestSuite) Test_tenantsCreateHandler() { for _, tt := range tests { ts.T().Run(tt.name, func(t *testing.T) { input := app.TenantCreateInput{Name: "new tenant"} - j, _ := json.Marshal(&input) - req := httptest.NewRequest(http.MethodPost, "/api/tenants", bytes.NewReader(j)) - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - req.Header.Set(echo.HeaderAuthorization, "Bearer "+tt.token) - - res := httptest.NewRecorder() - ts.server.ServeHTTP(res, req) - body, err := io.ReadAll(res.Body) - ts.NoError(err) + body, status := ts.request(http.MethodPost, "/api/tenants", tt.token, input) // Assertions - ts.Equal(tt.wantStatus, res.Code, "incorrect http status, body: \n%s", body) + ts.Equal(tt.wantStatus, status, "incorrect http status, body: \n%s", body) if tt.wantStatus != http.StatusOK { return diff --git a/server/user_test.go b/server/user_test.go index 196c787..53feada 100644 --- a/server/user_test.go +++ b/server/user_test.go @@ -2,13 +2,9 @@ package server_test import ( "encoding/json" - "io" "net/http" - "net/http/httptest" "testing" - "github.com/labstack/echo/v4" - "github.com/briskt/keygo/app" "github.com/briskt/keygo/db" ) @@ -64,17 +60,10 @@ func (ts *TestSuite) Test_GetUser() { for _, tt := range tests { ts.T().Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/api/users/"+tt.userID, nil) - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - req.Header.Set(echo.HeaderAuthorization, "Bearer "+tt.token) - - res := httptest.NewRecorder() - ts.server.ServeHTTP(res, req) - body, err := io.ReadAll(res.Body) - ts.NoError(err) + body, status := ts.request(http.MethodGet, "/api/users/"+tt.userID, tt.token, nil) // Assertions - ts.Equal(tt.wantStatus, res.Code, "incorrect http status, body: \n%s", body) + ts.Equal(tt.wantStatus, status, "incorrect http status, body: \n%s", body) if tt.wantStatus != http.StatusOK { return @@ -124,17 +113,10 @@ func (ts *TestSuite) Test_GetUserList() { for _, tt := range tests { ts.T().Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/api/users", nil) - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - req.Header.Set(echo.HeaderAuthorization, "Bearer "+tt.token) - - res := httptest.NewRecorder() - ts.server.ServeHTTP(res, req) - body, err := io.ReadAll(res.Body) - ts.NoError(err) + body, status := ts.request(http.MethodGet, "/api/users", tt.token, nil) // Assertions - ts.Equal(tt.wantStatus, res.Code, "incorrect http status, body: \n%s", body) + ts.Equal(tt.wantStatus, status, "incorrect http status, body: \n%s", body) if tt.wantStatus != http.StatusOK { return