diff --git a/libs/nitro/aws/sdk.go b/libs/nitro/aws/sdk.go index 9afc7bca5..69fc7d313 100644 --- a/libs/nitro/aws/sdk.go +++ b/libs/nitro/aws/sdk.go @@ -114,8 +114,7 @@ func NewAWSConfig(ctx context.Context, proxyAddr string, region string) (aws.Con Str("region", region). Msg("setting up new aws config") - var client http.Client - tr := nitro.NewProxyRoundTripper(ctx, proxyAddr).(*http.Transport) + tr := nitro.NewProxyTransport(ctx, proxyAddr) certs := x509.NewCertPool() certs.AppendCertsFromPEM([]byte(amazonRoots)) @@ -132,7 +131,7 @@ func NewAWSConfig(ctx context.Context, proxyAddr string, region string) (aws.Con return aws.Config{}, fmt.Errorf("failed to configure transport for HTTP/2, %v", err) } - client = http.Client{ + client := &http.Client{ Transport: tr, } @@ -143,7 +142,7 @@ func NewAWSConfig(ctx context.Context, proxyAddr string, region string) (aws.Con } cfg, err := config.LoadDefaultConfig(context.TODO(), - config.WithHTTPClient(&client), + config.WithHTTPClient(client), config.WithRegion("us-west-2"), config.WithLogger(applicationLogger), ) diff --git a/libs/nitro/log.go b/libs/nitro/log.go index 8a58ccae0..635cc8a5e 100644 --- a/libs/nitro/log.go +++ b/libs/nitro/log.go @@ -2,6 +2,7 @@ package nitro import ( "context" + "io" "net" "os" @@ -17,7 +18,10 @@ type VsockWriter struct { } // NewVsockWriter - create a new vsock writer -func NewVsockWriter(addr string) *VsockWriter { +func NewVsockWriter(addr string) io.Writer { + if EnclaveMocking() { + return os.Stderr + } return &VsockWriter{ socket: nil, addr: addr, @@ -27,7 +31,7 @@ func NewVsockWriter(addr string) *VsockWriter { // Connect - interface implementation for connect method for VsockWriter func (w *VsockWriter) Connect() error { if w.socket == nil { - s, err := DialContext(context.Background(), "tcp", w.addr) + s, err := dialVsockContext(context.Background(), "tcp", w.addr) if err != nil { return err } diff --git a/libs/nitro/vsock.go b/libs/nitro/vsock.go index 90880eb71..307dae35d 100644 --- a/libs/nitro/vsock.go +++ b/libs/nitro/vsock.go @@ -55,12 +55,12 @@ func parseVsockAddr(addr string) (uint32, uint32, error) { } // DialContext is a net.Dial wrapper which additionally allows connecting to vsock networks -func DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - logger := logging.Logger(ctx, "nitro.DialContext") +func dialVsockContext(ctx context.Context, network, addr string) (net.Conn, error) { + logger := logging.Logger(ctx, "nitro.dialVsockContext") logger.Debug(). Str("network", fmt.Sprintf("%v", network)). Str("addr", fmt.Sprintf("%v", addr)). - Msg("DialContext") + Msg("dialVsockContext") cid, port, err := parseVsockAddr(addr) if err != nil { @@ -102,12 +102,16 @@ func (p *proxyClientConfig) Proxy(*http.Request) (*url.URL, error) { return v, err } -// NewProxyRoundTripper returns an http.RoundTripper which routes outgoing requests through the proxy addr -func NewProxyRoundTripper(ctx context.Context, addr string) http.RoundTripper { +// NewProxyTransport returns an http.Transport which routes outgoing requests +// through the proxy addr. +func NewProxyTransport(ctx context.Context, addr string) *http.Transport { + if enclaveMocking { + return &http.Transport{} + } config := proxyClientConfig{ctx, addr} return &http.Transport{ Proxy: config.Proxy, - DialContext: DialContext, + DialContext: dialVsockContext, } } @@ -122,7 +126,7 @@ func NewReverseProxyServer( } proxy := httputil.NewSingleHostReverseProxy(proxyURL) proxy.Transport = &http.Transport{ - DialContext: DialContext, + DialContext: dialVsockContext, } proxy.Director = func(req *http.Request) { req.Header.Add("X-Forwarded-Host", req.Host) @@ -240,3 +244,17 @@ func syncCopy(wg *sync.WaitGroup, dst io.WriteCloser, src io.ReadCloser) { defer wg.Done() _, _ = io.Copy(dst, src) } + +func Listen(ctx context.Context, address string) (net.Listener, error) { + if enclaveMocking { + return net.Listen("tcp", "address") + } + + // TODO: share with parseVsockAddr + port, err := strconv.ParseUint(strings.Split(address, ":")[1], 10, 32) + if err != nil { + return nil, fmt.Errorf("Failed to parse vsock address - %w", err) + } + + return vsock.Listen(uint32(port), &vsock.Config{}) +} diff --git a/services/nitro/nitro.go b/services/nitro/nitro.go index 05e5f2296..aa7996645 100644 --- a/services/nitro/nitro.go +++ b/services/nitro/nitro.go @@ -17,7 +17,6 @@ import ( "github.com/brave-intl/bat-go/services/payments" "github.com/go-chi/chi" - "github.com/mdlayher/vsock" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -81,6 +80,8 @@ func init() { viper.BindPFlag("enclave-decrypt-key-template-secret", NitroServeCmd.PersistentFlags().Lookup("enclave-decrypt-key-template-secret")) viper.BindEnv("enclave-decrypt-key-template-secret", "ENCLAVE_DECRYPT_KEY_TEMPLATE_SECRET") + rootcmd.Must(viper.BindEnv("enclave-mocking", "ENCLAVE_MOCKING")) + NitroServeCmd.AddCommand(OutsideNitroServeCmd) NitroServeCmd.AddCommand(InsideNitroServeCmd) srvcmd.ServeCmd.AddCommand(NitroServeCmd) @@ -110,10 +111,14 @@ var NitroServeCmd = &cobra.Command{ func RunNitroServerInEnclave(cmd *cobra.Command, args []string) error { ctx := cmd.Context() + if viper.GetString("enclave-mocking") != "" { + nitro.MockEnclave() + } + logaddr := viper.GetString("log-address") - writer := nitro.NewVsockWriter(logaddr) + logWriter := nitro.NewVsockWriter(logaddr) - ctx = context.WithValue(ctx, appctx.LogWriterCTXKey, writer) + ctx = context.WithValue(ctx, appctx.LogWriterCTXKey, logWriter) ctx = context.WithValue(ctx, appctx.EgressProxyAddrCTXKey, viper.GetString("egress-address")) ctx = context.WithValue(ctx, appctx.AWSRegionCTXKey, viper.GetString("aws-region")) ctx = context.WithValue(ctx, appctx.PaymentsQLDBRoleArnCTXKey, viper.GetString("qldb-role-arn")) @@ -152,19 +157,15 @@ func RunNitroServerInEnclave(cmd *cobra.Command, args []string) error { logger.Info().Msg("payments routes setup") // setup listener - addr := viper.GetString("address") - port, err := strconv.ParseUint(strings.Split(addr, ":")[1], 10, 32) - if err != nil || port == 0 { - // panic if there is an error, or if the port is too large to fit in uint32 - logger.Panic().Err(err).Msg("invalid --address") - } + listenAddress := viper.GetString("address") // setup vsock listener - l, err := vsock.Listen(uint32(port), &vsock.Config{}) + httpListener, err := nitro.Listen(ctx, listenAddress) if err != nil { logger.Panic().Err(err).Msg("listening on vsock port failed") } logger.Info().Msg("vsock listener setup") + // setup server srv := http.Server{ Handler: chi.ServerBaseContext(ctx, r), @@ -173,7 +174,7 @@ func RunNitroServerInEnclave(cmd *cobra.Command, args []string) error { } logger.Info().Msg("starting server") // run the server in another routine - logger.Fatal().Err(srv.Serve(l)).Msg("server shutdown") + logger.Fatal().Err(srv.Serve(httpListener)).Msg("server shutdown") return nil } diff --git a/services/payments/secrets.go b/services/payments/secrets.go index 6ccfbe9ac..0bff609cb 100644 --- a/services/payments/secrets.go +++ b/services/payments/secrets.go @@ -14,7 +14,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "os" "strings" @@ -56,7 +55,7 @@ type Vault struct { // createAttestationDocument will create an attestation document and return the private key and // attestation document which is attesting over the userData supplied -func createAttestationDocument(ctx context.Context, userData []byte) (crypto.PrivateKey, []byte, error) { +func createAttestationDocument(ctx context.Context) (crypto.PrivateKey, []byte, error) { // create a one time use nonce nonce, err := createAttestationNonce(ctx) if err != nil { @@ -75,7 +74,7 @@ func createAttestationDocument(ctx context.Context, userData []byte) (crypto.Pri } // attest to the document with passed in user data - document, err := nitro.Attest(ctx, nonce, userData, publicKeyMarshaled) + document, err := nitro.Attest(ctx, nonce, nil, publicKeyMarshaled) if err != nil { return nil, nil, fmt.Errorf("failed to create attestation document: %w", err) } @@ -439,12 +438,12 @@ func (s *Service) fetchOperatorShares(ctx context.Context, bucket string) error return fmt.Errorf("failed to get operator share from s3: %w", err) } - data, err := ioutil.ReadAll(shareResponse.Body) + data, err := io.ReadAll(shareResponse.Body) if err != nil { return fmt.Errorf("failed to read operator share from s3 response: %w", err) } - privateKey, document, err := createAttestationDocument(ctx, nil) + privateKey, document, err := createAttestationDocument(ctx) if err != nil { return fmt.Errorf("failed to create attestation document: %w", err) } diff --git a/services/payments/statemachine.go b/services/payments/statemachine.go index 293e59330..122cb9e96 100644 --- a/services/payments/statemachine.go +++ b/services/payments/statemachine.go @@ -98,7 +98,7 @@ func (s *Service) StateMachineFromTransaction( var machine TxStateMachine client := http.Client{ - Transport: nitro.NewProxyRoundTripper(ctx, s.egressAddr).(*http.Transport), + Transport: nitro.NewProxyTransport(ctx, s.egressAddr), } switch authenticatedState.PaymentDetails.Custodian {