From 61a7f074c730922c0239b4b8520b2e8a62003547 Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Wed, 12 Jul 2023 18:14:13 +0800 Subject: [PATCH] swagger: block swagger url if disbale swagger server (#6785) (#6790) close tikv/pd#6786 /swagger/* return 404 Signed-off-by: bufferflies <1045931706@qq.com> Co-authored-by: bufferflies <1045931706@qq.com> --- cmd/pd-server/main.go | 5 +++- pkg/swaggerserver/swagger_handler.go | 29 ------------------- pkg/swaggerserver/swaggerserver.go | 12 +++++++- ...ty_handler.go => swaggerserver_disable.go} | 19 ++++++++---- tests/cluster.go | 5 +++- tests/server/api/api_test.go | 9 ++++++ 6 files changed, 41 insertions(+), 38 deletions(-) delete mode 100644 pkg/swaggerserver/swagger_handler.go rename pkg/swaggerserver/{empty_handler.go => swaggerserver_disable.go} (61%) diff --git a/cmd/pd-server/main.go b/cmd/pd-server/main.go index 672ca15b77d..7f42c5b4adc 100644 --- a/cmd/pd-server/main.go +++ b/cmd/pd-server/main.go @@ -208,7 +208,10 @@ func start(cmd *cobra.Command, args []string, services ...string) { // Creates server. ctx, cancel := context.WithCancel(context.Background()) - serviceBuilders := []server.HandlerBuilder{api.NewHandler, apiv2.NewV2Handler, swaggerserver.NewHandler, autoscaling.NewHandler} + serviceBuilders := []server.HandlerBuilder{api.NewHandler, apiv2.NewV2Handler, autoscaling.NewHandler} + if swaggerserver.Enabled() { + serviceBuilders = append(serviceBuilders, swaggerserver.NewHandler) + } serviceBuilders = append(serviceBuilders, dashboard.GetServiceBuilders()...) svr, err := server.CreateServer(ctx, cfg, services, serviceBuilders...) if err != nil { diff --git a/pkg/swaggerserver/swagger_handler.go b/pkg/swaggerserver/swagger_handler.go deleted file mode 100644 index 69cff3d2751..00000000000 --- a/pkg/swaggerserver/swagger_handler.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2020 TiKV Project Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build swagger_server -// +build swagger_server - -package swaggerserver - -import ( - "net/http" - - httpSwagger "github.com/swaggo/http-swagger" - _ "github.com/tikv/pd/docs/swagger" -) - -func handler() http.Handler { - return httpSwagger.Handler() -} diff --git a/pkg/swaggerserver/swaggerserver.go b/pkg/swaggerserver/swaggerserver.go index 778844adbde..d68bab06eb2 100644 --- a/pkg/swaggerserver/swaggerserver.go +++ b/pkg/swaggerserver/swaggerserver.go @@ -12,12 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build swagger_server +// +build swagger_server + package swaggerserver import ( "context" "net/http" + httpSwagger "github.com/swaggo/http-swagger" + _ "github.com/tikv/pd/docs/swagger" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/server" ) @@ -33,9 +38,14 @@ var ( } ) +// Enabled return true if swagger server is disabled. +func Enabled() bool { + return true +} + // NewHandler creates a HTTP handler for Swagger. func NewHandler(context.Context, *server.Server) (http.Handler, apiutil.APIServiceGroup, error) { swaggerHandler := http.NewServeMux() - swaggerHandler.Handle(swaggerPrefix, handler()) + swaggerHandler.Handle(swaggerPrefix, httpSwagger.Handler()) return swaggerHandler, swaggerServiceGroup, nil } diff --git a/pkg/swaggerserver/empty_handler.go b/pkg/swaggerserver/swaggerserver_disable.go similarity index 61% rename from pkg/swaggerserver/empty_handler.go rename to pkg/swaggerserver/swaggerserver_disable.go index 79f33a9af6b..c3b861b3b6c 100644 --- a/pkg/swaggerserver/empty_handler.go +++ b/pkg/swaggerserver/swaggerserver_disable.go @@ -1,4 +1,4 @@ -// Copyright 2020 TiKV Project Authors. +// Copyright 2023 TiKV Project Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,12 +18,19 @@ package swaggerserver import ( - "io" + "context" "net/http" + + "github.com/tikv/pd/pkg/utils/apiutil" + "github.com/tikv/pd/server" ) -func handler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = io.WriteString(w, "Swagger UI is not built. Try `make` with `SWAGGER=1`.\n") - }) +// Enabled return false if swagger server is disabled. +func Enabled() bool { + return false +} + +// NewHandler creates a HTTP handler for Swagger. +func NewHandler(context.Context, *server.Server) (http.Handler, apiutil.APIServiceGroup, error) { + return nil, apiutil.APIServiceGroup{}, nil } diff --git a/tests/cluster.go b/tests/cluster.go index b0a5d529998..d4cb6f4da14 100644 --- a/tests/cluster.go +++ b/tests/cluster.go @@ -95,7 +95,10 @@ func createTestServer(ctx context.Context, cfg *config.Config, services []string if err != nil { return nil, err } - serviceBuilders := []server.HandlerBuilder{api.NewHandler, apiv2.NewV2Handler, swaggerserver.NewHandler, autoscaling.NewHandler} + serviceBuilders := []server.HandlerBuilder{api.NewHandler, apiv2.NewV2Handler, autoscaling.NewHandler} + if swaggerserver.Enabled() { + serviceBuilders = append(serviceBuilders, swaggerserver.NewHandler) + } serviceBuilders = append(serviceBuilders, dashboard.GetServiceBuilders()...) svr, err := server.CreateServer(ctx, cfg, services, serviceBuilders...) if err != nil { diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 7fb7e7d1236..243bdad399c 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -377,6 +377,15 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { } } +func (suite *middlewareTestSuite) TestSwaggerUrl() { + leader := suite.cluster.GetServer(suite.cluster.GetLeader()) + req, _ := http.NewRequest(http.MethodGet, leader.GetAddr()+"/swagger/ui/index", nil) + resp, err := dialClient.Do(req) + suite.NoError(err) + suite.True(resp.StatusCode == http.StatusNotFound) + resp.Body.Close() +} + func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { leader := suite.cluster.GetServer(suite.cluster.GetLeader()) input := map[string]interface{}{