From fab036986c7e91657302044b60d01fe1794469ff Mon Sep 17 00:00:00 2001 From: Yuri Shkuro Date: Sat, 15 Jul 2023 21:04:17 -0400 Subject: [PATCH] [hotrod] Validate user input to avoid security warnings from code scanning (#4583) --- examples/hotrod/services/customer/client.go | 6 +++--- examples/hotrod/services/customer/database.go | 19 ++++++++++--------- .../hotrod/services/customer/interface.go | 2 +- examples/hotrod/services/customer/server.go | 10 ++++++++-- examples/hotrod/services/frontend/best_eta.go | 2 +- examples/hotrod/services/frontend/server.go | 10 ++++++++-- 6 files changed, 31 insertions(+), 18 deletions(-) diff --git a/examples/hotrod/services/customer/client.go b/examples/hotrod/services/customer/client.go index 59787771519..56f0cb96e5f 100644 --- a/examples/hotrod/services/customer/client.go +++ b/examples/hotrod/services/customer/client.go @@ -43,10 +43,10 @@ func NewClient(tracer trace.TracerProvider, logger log.Factory, hostPort string) } // Get implements customer.Interface#Get as an RPC -func (c *Client) Get(ctx context.Context, customerID string) (*Customer, error) { - c.logger.For(ctx).Info("Getting customer", zap.String("customer_id", customerID)) +func (c *Client) Get(ctx context.Context, customerID int) (*Customer, error) { + c.logger.For(ctx).Info("Getting customer", zap.Int("customer_id", customerID)) - url := fmt.Sprintf("http://"+c.hostPort+"/customer?customer=%s", customerID) + url := fmt.Sprintf("http://"+c.hostPort+"/customer?customer=%d", customerID) var customer Customer if err := c.client.GetJSON(ctx, "/customer", url, &customer); err != nil { return nil, err diff --git a/examples/hotrod/services/customer/database.go b/examples/hotrod/services/customer/database.go index 57cc09f4c57..f20e1a59fa7 100644 --- a/examples/hotrod/services/customer/database.go +++ b/examples/hotrod/services/customer/database.go @@ -18,6 +18,7 @@ package customer import ( "context" "errors" + "fmt" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.20.0" @@ -34,7 +35,7 @@ import ( type database struct { tracer trace.Tracer logger log.Factory - customers map[string]*Customer + customers map[int]*Customer lock *tracing.Mutex } @@ -46,23 +47,23 @@ func newDatabase(tracer trace.Tracer, logger log.Factory) *database { SessionBaggageKey: "request", LogFactory: logger, }, - customers: map[string]*Customer{ - "123": { + customers: map[int]*Customer{ + 123: { ID: "123", Name: "Rachel's_Floral_Designs", Location: "115,277", }, - "567": { + 567: { ID: "567", Name: "Amazing_Coffee_Roasters", Location: "211,653", }, - "392": { + 392: { ID: "392", Name: "Trom_Chocolatier", Location: "577,322", }, - "731": { + 731: { ID: "731", Name: "Japanese_Desserts", Location: "728,326", @@ -71,15 +72,15 @@ func newDatabase(tracer trace.Tracer, logger log.Factory) *database { } } -func (d *database) Get(ctx context.Context, customerID string) (*Customer, error) { - d.logger.For(ctx).Info("Loading customer", zap.String("customer_id", customerID)) +func (d *database) Get(ctx context.Context, customerID int) (*Customer, error) { + d.logger.For(ctx).Info("Loading customer", zap.Int("customer_id", customerID)) // simulate opentracing instrumentation of an SQL query ctx, span := d.tracer.Start(ctx, "SQL SELECT", trace.WithSpanKind(trace.SpanKindClient)) // #nosec span.SetAttributes( semconv.PeerServiceKey.String("mysql"), - attribute.Key("sql.query").String("SELECT * FROM customer WHERE customer_id="+customerID), + attribute.Key("sql.query").String(fmt.Sprintf("SELECT * FROM customer WHERE customer_id=%d", customerID)), ) defer span.End() diff --git a/examples/hotrod/services/customer/interface.go b/examples/hotrod/services/customer/interface.go index 4a0bc3f86e0..e144832f81a 100644 --- a/examples/hotrod/services/customer/interface.go +++ b/examples/hotrod/services/customer/interface.go @@ -28,5 +28,5 @@ type Customer struct { // Interface exposed by the Customer service. type Interface interface { - Get(ctx context.Context, customerID string) (*Customer, error) + Get(ctx context.Context, customerID int) (*Customer, error) } diff --git a/examples/hotrod/services/customer/server.go b/examples/hotrod/services/customer/server.go index 0b6acf2815a..222d0ce47d0 100644 --- a/examples/hotrod/services/customer/server.go +++ b/examples/hotrod/services/customer/server.go @@ -18,6 +18,7 @@ package customer import ( "encoding/json" "net/http" + "strconv" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" @@ -70,11 +71,16 @@ func (s *Server) customer(w http.ResponseWriter, r *http.Request) { return } - customerID := r.Form.Get("customer") - if customerID == "" { + customer := r.Form.Get("customer") + if customer == "" { http.Error(w, "Missing required 'customer' parameter", http.StatusBadRequest) return } + customerID, err := strconv.Atoi(customer) + if err != nil { + http.Error(w, "Parameter 'customer' is not an integer", http.StatusBadRequest) + return + } response, err := s.database.Get(ctx, customerID) if httperr.HandleError(w, err, http.StatusInternalServerError) { diff --git a/examples/hotrod/services/frontend/best_eta.go b/examples/hotrod/services/frontend/best_eta.go index 3a1765a8e47..58c6bdbcc55 100644 --- a/examples/hotrod/services/frontend/best_eta.go +++ b/examples/hotrod/services/frontend/best_eta.go @@ -70,7 +70,7 @@ func newBestETA(tracer trace.TracerProvider, logger log.Factory, options ConfigO } } -func (eta *bestETA) Get(ctx context.Context, customerID string) (*Response, error) { +func (eta *bestETA) Get(ctx context.Context, customerID int) (*Response, error) { customer, err := eta.customer.Get(ctx, customerID) if err != nil { return nil, err diff --git a/examples/hotrod/services/frontend/server.go b/examples/hotrod/services/frontend/server.go index 45527569f85..828eae2fba1 100644 --- a/examples/hotrod/services/frontend/server.go +++ b/examples/hotrod/services/frontend/server.go @@ -20,6 +20,7 @@ import ( "encoding/json" "net/http" "path" + "strconv" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" @@ -99,11 +100,16 @@ func (s *Server) dispatch(w http.ResponseWriter, r *http.Request) { return } - customerID := r.Form.Get("customer") - if customerID == "" { + customer := r.Form.Get("customer") + if customer == "" { http.Error(w, "Missing required 'customer' parameter", http.StatusBadRequest) return } + customerID, err := strconv.Atoi(customer) + if err != nil { + http.Error(w, "Parameter 'customer' is not an integer", http.StatusBadRequest) + return + } // TODO distinguish between user errors (such as invalid customer ID) and server failures response, err := s.bestETA.Get(ctx, customerID)