Skip to content

Commit

Permalink
chore: initial nitro mocking support
Browse files Browse the repository at this point in the history
  • Loading branch information
ibukanov committed Jun 22, 2024
1 parent 77542e3 commit 47fa552
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 30 deletions.
7 changes: 3 additions & 4 deletions libs/nitro/aws/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ func NewAWSConfig(ctx context.Context, proxyAddr string, region string) (aws.Con
Str("region", region).
Msg("setting up new aws config")

var client http.Client
tr := nitro.NewProxyRoundTripper(ctx, proxyAddr).(*http.Transport)
tr := nitro.NewProxyTransport(ctx, proxyAddr)

certs := x509.NewCertPool()
certs.AppendCertsFromPEM([]byte(amazonRoots))
Expand All @@ -132,7 +131,7 @@ func NewAWSConfig(ctx context.Context, proxyAddr string, region string) (aws.Con
return aws.Config{}, fmt.Errorf("failed to configure transport for HTTP/2, %v", err)
}

client = http.Client{
client := &http.Client{
Transport: tr,
}

Expand All @@ -143,7 +142,7 @@ func NewAWSConfig(ctx context.Context, proxyAddr string, region string) (aws.Con
}

cfg, err := config.LoadDefaultConfig(context.TODO(),
config.WithHTTPClient(&client),
config.WithHTTPClient(client),
config.WithRegion("us-west-2"),
config.WithLogger(applicationLogger),
)
Expand Down
8 changes: 6 additions & 2 deletions libs/nitro/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nitro

import (
"context"
"io"
"net"
"os"

Expand All @@ -17,7 +18,10 @@ type VsockWriter struct {
}

// NewVsockWriter - create a new vsock writer
func NewVsockWriter(addr string) *VsockWriter {
func NewVsockWriter(addr string) io.Writer {
if EnclaveMocking() {
return os.Stderr
}
return &VsockWriter{
socket: nil,
addr: addr,
Expand All @@ -27,7 +31,7 @@ func NewVsockWriter(addr string) *VsockWriter {
// Connect - interface implementation for connect method for VsockWriter
func (w *VsockWriter) Connect() error {
if w.socket == nil {
s, err := DialContext(context.Background(), "tcp", w.addr)
s, err := dialVsockContext(context.Background(), "tcp", w.addr)
if err != nil {
return err
}
Expand Down
32 changes: 25 additions & 7 deletions libs/nitro/vsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ func parseVsockAddr(addr string) (uint32, uint32, error) {
}

// DialContext is a net.Dial wrapper which additionally allows connecting to vsock networks
func DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logger := logging.Logger(ctx, "nitro.DialContext")
func dialVsockContext(ctx context.Context, network, addr string) (net.Conn, error) {
logger := logging.Logger(ctx, "nitro.dialVsockContext")
logger.Debug().
Str("network", fmt.Sprintf("%v", network)).
Str("addr", fmt.Sprintf("%v", addr)).
Msg("DialContext")
Msg("dialVsockContext")

cid, port, err := parseVsockAddr(addr)
if err != nil {
Expand Down Expand Up @@ -102,12 +102,16 @@ func (p *proxyClientConfig) Proxy(*http.Request) (*url.URL, error) {
return v, err
}

// NewProxyRoundTripper returns an http.RoundTripper which routes outgoing requests through the proxy addr
func NewProxyRoundTripper(ctx context.Context, addr string) http.RoundTripper {
// NewProxyTransport returns an http.Transport which routes outgoing requests
// through the proxy addr.
func NewProxyTransport(ctx context.Context, addr string) *http.Transport {
if enclaveMocking {
return &http.Transport{}
}
config := proxyClientConfig{ctx, addr}
return &http.Transport{
Proxy: config.Proxy,
DialContext: DialContext,
DialContext: dialVsockContext,
}
}

Expand All @@ -122,7 +126,7 @@ func NewReverseProxyServer(
}
proxy := httputil.NewSingleHostReverseProxy(proxyURL)
proxy.Transport = &http.Transport{
DialContext: DialContext,
DialContext: dialVsockContext,
}
proxy.Director = func(req *http.Request) {
req.Header.Add("X-Forwarded-Host", req.Host)
Expand Down Expand Up @@ -240,3 +244,17 @@ func syncCopy(wg *sync.WaitGroup, dst io.WriteCloser, src io.ReadCloser) {
defer wg.Done()
_, _ = io.Copy(dst, src)
}

func Listen(ctx context.Context, address string) (net.Listener, error) {
if enclaveMocking {
return net.Listen("tcp", "address")
}

// TODO: share with parseVsockAddr
port, err := strconv.ParseUint(strings.Split(address, ":")[1], 10, 32)
if err != nil {
return nil, fmt.Errorf("Failed to parse vsock address - %w", err)
}

return vsock.Listen(uint32(port), &vsock.Config{})
}
23 changes: 12 additions & 11 deletions services/nitro/nitro.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/brave-intl/bat-go/services/payments"

"github.com/go-chi/chi"
"github.com/mdlayher/vsock"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
Expand Down Expand Up @@ -81,6 +80,8 @@ func init() {
viper.BindPFlag("enclave-decrypt-key-template-secret", NitroServeCmd.PersistentFlags().Lookup("enclave-decrypt-key-template-secret"))
viper.BindEnv("enclave-decrypt-key-template-secret", "ENCLAVE_DECRYPT_KEY_TEMPLATE_SECRET")

rootcmd.Must(viper.BindEnv("enclave-mocking", "ENCLAVE_MOCKING"))

NitroServeCmd.AddCommand(OutsideNitroServeCmd)
NitroServeCmd.AddCommand(InsideNitroServeCmd)
srvcmd.ServeCmd.AddCommand(NitroServeCmd)
Expand Down Expand Up @@ -110,10 +111,14 @@ var NitroServeCmd = &cobra.Command{
func RunNitroServerInEnclave(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

if viper.GetString("enclave-mocking") != "" {
nitro.MockEnclave()
}

logaddr := viper.GetString("log-address")
writer := nitro.NewVsockWriter(logaddr)
logWriter := nitro.NewVsockWriter(logaddr)

ctx = context.WithValue(ctx, appctx.LogWriterCTXKey, writer)
ctx = context.WithValue(ctx, appctx.LogWriterCTXKey, logWriter)
ctx = context.WithValue(ctx, appctx.EgressProxyAddrCTXKey, viper.GetString("egress-address"))
ctx = context.WithValue(ctx, appctx.AWSRegionCTXKey, viper.GetString("aws-region"))
ctx = context.WithValue(ctx, appctx.PaymentsQLDBRoleArnCTXKey, viper.GetString("qldb-role-arn"))
Expand Down Expand Up @@ -152,19 +157,15 @@ func RunNitroServerInEnclave(cmd *cobra.Command, args []string) error {
logger.Info().Msg("payments routes setup")

// setup listener
addr := viper.GetString("address")
port, err := strconv.ParseUint(strings.Split(addr, ":")[1], 10, 32)
if err != nil || port == 0 {
// panic if there is an error, or if the port is too large to fit in uint32
logger.Panic().Err(err).Msg("invalid --address")
}
listenAddress := viper.GetString("address")

// setup vsock listener
l, err := vsock.Listen(uint32(port), &vsock.Config{})
httpListener, err := nitro.Listen(ctx, listenAddress)
if err != nil {
logger.Panic().Err(err).Msg("listening on vsock port failed")
}
logger.Info().Msg("vsock listener setup")

// setup server
srv := http.Server{
Handler: chi.ServerBaseContext(ctx, r),
Expand All @@ -173,7 +174,7 @@ func RunNitroServerInEnclave(cmd *cobra.Command, args []string) error {
}
logger.Info().Msg("starting server")
// run the server in another routine
logger.Fatal().Err(srv.Serve(l)).Msg("server shutdown")
logger.Fatal().Err(srv.Serve(httpListener)).Msg("server shutdown")
return nil
}

Expand Down
9 changes: 4 additions & 5 deletions services/payments/secrets.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"strings"

Expand Down Expand Up @@ -56,7 +55,7 @@ type Vault struct {

// createAttestationDocument will create an attestation document and return the private key and
// attestation document which is attesting over the userData supplied
func createAttestationDocument(ctx context.Context, userData []byte) (crypto.PrivateKey, []byte, error) {
func createAttestationDocument(ctx context.Context) (crypto.PrivateKey, []byte, error) {
// create a one time use nonce
nonce, err := createAttestationNonce(ctx)
if err != nil {
Expand All @@ -75,7 +74,7 @@ func createAttestationDocument(ctx context.Context, userData []byte) (crypto.Pri
}

// attest to the document with passed in user data
document, err := nitro.Attest(ctx, nonce, userData, publicKeyMarshaled)
document, err := nitro.Attest(ctx, nonce, nil, publicKeyMarshaled)
if err != nil {
return nil, nil, fmt.Errorf("failed to create attestation document: %w", err)
}
Expand Down Expand Up @@ -439,12 +438,12 @@ func (s *Service) fetchOperatorShares(ctx context.Context, bucket string) error
return fmt.Errorf("failed to get operator share from s3: %w", err)
}

data, err := ioutil.ReadAll(shareResponse.Body)
data, err := io.ReadAll(shareResponse.Body)
if err != nil {
return fmt.Errorf("failed to read operator share from s3 response: %w", err)
}

privateKey, document, err := createAttestationDocument(ctx, nil)
privateKey, document, err := createAttestationDocument(ctx)
if err != nil {
return fmt.Errorf("failed to create attestation document: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion services/payments/statemachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (s *Service) StateMachineFromTransaction(
var machine TxStateMachine

client := http.Client{
Transport: nitro.NewProxyRoundTripper(ctx, s.egressAddr).(*http.Transport),
Transport: nitro.NewProxyTransport(ctx, s.egressAddr),
}

switch authenticatedState.PaymentDetails.Custodian {
Expand Down

0 comments on commit 47fa552

Please sign in to comment.