Skip to content

Commit

Permalink
feat: tls support for GRPC shield API
Browse files Browse the repository at this point in the history
Signed-off-by: Kush Sharma <[email protected]>
  • Loading branch information
kushsharma committed Jul 23, 2023
1 parent 7bbaf3c commit c9dcc6d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 14 deletions.
21 changes: 15 additions & 6 deletions cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,33 @@ import (
"context"
"time"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

shieldv1beta1 "github.com/raystack/shield/proto/v1beta1"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

func createConnection(ctx context.Context, host string) (*grpc.ClientConn, error) {
func createConnection(ctx context.Context, host string, caCertFile string) (*grpc.ClientConn, error) {
creds := insecure.NewCredentials()
if caCertFile != "" {
tlsCreds, err := credentials.NewClientTLSFromFile(caCertFile, "")
if err != nil {
return nil, err
}
creds = tlsCreds
}
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithTransportCredentials(creds),
grpc.WithBlock(),
}

return grpc.DialContext(ctx, host, opts...)
}

func createClient(ctx context.Context, host string) (shieldv1beta1.ShieldServiceClient, func(), error) {
dialTimeoutCtx, dialCancel := context.WithTimeout(ctx, time.Second*2)
conn, err := createConnection(dialTimeoutCtx, host)
conn, err := createConnection(dialTimeoutCtx, host, "")
if err != nil {
dialCancel()
return nil, nil, err
Expand All @@ -37,7 +46,7 @@ func createClient(ctx context.Context, host string) (shieldv1beta1.ShieldService

func createAdminClient(ctx context.Context, host string) (shieldv1beta1.AdminServiceClient, func(), error) {
dialTimeoutCtx, dialCancel := context.WithTimeout(ctx, time.Second*2)
conn, err := createConnection(dialTimeoutCtx, host)
conn, err := createConnection(dialTimeoutCtx, host, "")
if err != nil {
dialCancel()
return nil, nil, err
Expand Down
4 changes: 4 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ app:
port: 8000
grpc:
port: 8001
# optional tls config
# tls_cert_file: "temp/server-cert.pem"
# tls_key_file: "temp/server-key.pem"
# tls_client_ca_file: "temp/ca-cert.pem"
metrics_port: 9000
identity_proxy_header: X-Shield-Email
# full path prefixed with scheme where resources config yaml files are kept
Expand Down
4 changes: 4 additions & 0 deletions docs/docs/reference/configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ app:
port: 8000
grpc:
port: 8001
# optional tls configuration for grpc server
tls_cert_file: "temp/server-cert.pem"
tls_key_file: "temp/server-key.pem"
tls_client_ca_file: "temp/ca-cert.pem"
metrics_port: 9000
identity_proxy_header: X-Shield-Email
# full path prefixed with scheme where resources config yaml files are kept
Expand Down
9 changes: 6 additions & 3 deletions pkg/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import (
)

type GRPCConfig struct {
Port int `mapstructure:"port" default:"8081"`
MaxRecvMsgSize int `mapstructure:"max_recv_msg_size" default:"33554432"`
MaxSendMsgSize int `mapstructure:"max_send_msg_size" default:"33554432"`
Port int `mapstructure:"port" default:"8081"`
MaxRecvMsgSize int `mapstructure:"max_recv_msg_size" default:"33554432"`
MaxSendMsgSize int `mapstructure:"max_send_msg_size" default:"33554432"`
TLSCertFile string `mapstructure:"tls_cert_file" default:""`
TLSKeyFile string `mapstructure:"tls_key_file" default:""`
TLSClientCAFile string `mapstructure:"tls_client_ca_file" default:""`
}

func (cfg Config) grpcAddr() string { return fmt.Sprintf("%s:%d", cfg.Host, cfg.GRPC.Port) }
Expand Down
29 changes: 24 additions & 5 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"strings"
"time"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

"google.golang.org/protobuf/encoding/protojson"

"github.com/raystack/shield/pkg/server/consts"
Expand Down Expand Up @@ -35,7 +38,6 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
Expand All @@ -53,13 +55,22 @@ func Serve(
deps api.Deps,
) error {
httpMux := http.NewServeMux()

grpcDialCtx, grpcDialCancel := context.WithTimeout(ctx, grpcDialTimeout)
defer grpcDialCancel()

grpcGatewayClientCreds := insecure.NewCredentials()
if cfg.GRPC.TLSClientCAFile != "" {
tlsCreds, err := credentials.NewClientTLSFromFile(cfg.GRPC.TLSClientCAFile, "")
if err != nil {
return err
}
grpcGatewayClientCreds = tlsCreds
}
// initialize grpc gateway client
grpcConn, err := grpc.DialContext(
grpcDialCtx,
cfg.grpcAddr(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithTransportCredentials(grpcGatewayClientCreds),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(cfg.GRPC.MaxRecvMsgSize),
grpc.MaxCallSendMsgSize(cfg.GRPC.MaxSendMsgSize),
Expand Down Expand Up @@ -126,8 +137,16 @@ func Serve(
httpMux.Handle("/console/", http.StripPrefix("/console/", spaHandler))
}

grpcMiddlewares := getGRPCMiddleware(logger, cfg.IdentityProxyHeader, nrApp, sessionMiddleware, deps)
grpcServer := grpc.NewServer(grpcMiddlewares)
grpcMiddleware := getGRPCMiddleware(logger, cfg.IdentityProxyHeader, nrApp, sessionMiddleware, deps)
grpcServerOpts := []grpc.ServerOption{grpcMiddleware}
if cfg.GRPC.TLSCertFile != "" && cfg.GRPC.TLSKeyFile != "" {
creds, err := credentials.NewServerTLSFromFile(cfg.GRPC.TLSCertFile, cfg.GRPC.TLSKeyFile)
if err != nil {
return err
}
grpcServerOpts = append(grpcServerOpts, grpc.Creds(creds))
}
grpcServer := grpc.NewServer(grpcServerOpts...)
reflection.Register(grpcServer)
grpc_health_v1.RegisterHealthServer(grpcServer, health.NewHandler())

Expand Down

0 comments on commit c9dcc6d

Please sign in to comment.