Skip to content

Commit

Permalink
[hotrod] Validate user input to avoid security warnings from code sca…
Browse files Browse the repository at this point in the history
…nning (#4583)
  • Loading branch information
yurishkuro authored Jul 16, 2023
1 parent 96f7fdc commit fab0369
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 18 deletions.
6 changes: 3 additions & 3 deletions examples/hotrod/services/customer/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ func NewClient(tracer trace.TracerProvider, logger log.Factory, hostPort string)
}

// Get implements customer.Interface#Get as an RPC
func (c *Client) Get(ctx context.Context, customerID string) (*Customer, error) {
c.logger.For(ctx).Info("Getting customer", zap.String("customer_id", customerID))
func (c *Client) Get(ctx context.Context, customerID int) (*Customer, error) {
c.logger.For(ctx).Info("Getting customer", zap.Int("customer_id", customerID))

url := fmt.Sprintf("http://"+c.hostPort+"/customer?customer=%s", customerID)
url := fmt.Sprintf("http://"+c.hostPort+"/customer?customer=%d", customerID)
var customer Customer
if err := c.client.GetJSON(ctx, "/customer", url, &customer); err != nil {
return nil, err
Expand Down
19 changes: 10 additions & 9 deletions examples/hotrod/services/customer/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package customer
import (
"context"
"errors"
"fmt"

"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.20.0"
Expand All @@ -34,7 +35,7 @@ import (
type database struct {
tracer trace.Tracer
logger log.Factory
customers map[string]*Customer
customers map[int]*Customer
lock *tracing.Mutex
}

Expand All @@ -46,23 +47,23 @@ func newDatabase(tracer trace.Tracer, logger log.Factory) *database {
SessionBaggageKey: "request",
LogFactory: logger,
},
customers: map[string]*Customer{
"123": {
customers: map[int]*Customer{
123: {
ID: "123",
Name: "Rachel's_Floral_Designs",
Location: "115,277",
},
"567": {
567: {
ID: "567",
Name: "Amazing_Coffee_Roasters",
Location: "211,653",
},
"392": {
392: {
ID: "392",
Name: "Trom_Chocolatier",
Location: "577,322",
},
"731": {
731: {
ID: "731",
Name: "Japanese_Desserts",
Location: "728,326",
Expand All @@ -71,15 +72,15 @@ func newDatabase(tracer trace.Tracer, logger log.Factory) *database {
}
}

func (d *database) Get(ctx context.Context, customerID string) (*Customer, error) {
d.logger.For(ctx).Info("Loading customer", zap.String("customer_id", customerID))
func (d *database) Get(ctx context.Context, customerID int) (*Customer, error) {
d.logger.For(ctx).Info("Loading customer", zap.Int("customer_id", customerID))

// simulate opentracing instrumentation of an SQL query
ctx, span := d.tracer.Start(ctx, "SQL SELECT", trace.WithSpanKind(trace.SpanKindClient))
// #nosec
span.SetAttributes(
semconv.PeerServiceKey.String("mysql"),
attribute.Key("sql.query").String("SELECT * FROM customer WHERE customer_id="+customerID),
attribute.Key("sql.query").String(fmt.Sprintf("SELECT * FROM customer WHERE customer_id=%d", customerID)),
)
defer span.End()

Expand Down
2 changes: 1 addition & 1 deletion examples/hotrod/services/customer/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ type Customer struct {

// Interface exposed by the Customer service.
type Interface interface {
Get(ctx context.Context, customerID string) (*Customer, error)
Get(ctx context.Context, customerID int) (*Customer, error)
}
10 changes: 8 additions & 2 deletions examples/hotrod/services/customer/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package customer
import (
"encoding/json"
"net/http"
"strconv"

"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
Expand Down Expand Up @@ -70,11 +71,16 @@ func (s *Server) customer(w http.ResponseWriter, r *http.Request) {
return
}

customerID := r.Form.Get("customer")
if customerID == "" {
customer := r.Form.Get("customer")
if customer == "" {
http.Error(w, "Missing required 'customer' parameter", http.StatusBadRequest)
return
}
customerID, err := strconv.Atoi(customer)
if err != nil {
http.Error(w, "Parameter 'customer' is not an integer", http.StatusBadRequest)
return
}

response, err := s.database.Get(ctx, customerID)
if httperr.HandleError(w, err, http.StatusInternalServerError) {
Expand Down
2 changes: 1 addition & 1 deletion examples/hotrod/services/frontend/best_eta.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func newBestETA(tracer trace.TracerProvider, logger log.Factory, options ConfigO
}
}

func (eta *bestETA) Get(ctx context.Context, customerID string) (*Response, error) {
func (eta *bestETA) Get(ctx context.Context, customerID int) (*Response, error) {
customer, err := eta.customer.Get(ctx, customerID)
if err != nil {
return nil, err
Expand Down
10 changes: 8 additions & 2 deletions examples/hotrod/services/frontend/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/json"
"net/http"
"path"
"strconv"

"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
Expand Down Expand Up @@ -99,11 +100,16 @@ func (s *Server) dispatch(w http.ResponseWriter, r *http.Request) {
return
}

customerID := r.Form.Get("customer")
if customerID == "" {
customer := r.Form.Get("customer")
if customer == "" {
http.Error(w, "Missing required 'customer' parameter", http.StatusBadRequest)
return
}
customerID, err := strconv.Atoi(customer)
if err != nil {
http.Error(w, "Parameter 'customer' is not an integer", http.StatusBadRequest)
return
}

// TODO distinguish between user errors (such as invalid customer ID) and server failures
response, err := s.bestETA.Get(ctx, customerID)
Expand Down

0 comments on commit fab0369

Please sign in to comment.