Skip to content

Commit

Permalink
feat: refactor vault unsealing
Browse files Browse the repository at this point in the history
Move data related to vault unsealing to separated Unsealing struct and
decouple the relevant code from Service for simpler testing.
  • Loading branch information
ibukanov committed Jul 1, 2024
1 parent 665038d commit de0b2a1
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 177 deletions.
244 changes: 135 additions & 109 deletions services/payments/secrets.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
"errors"
"fmt"
"io"
"os"
"strings"
"time"

"filippo.io/age"
"filippo.io/age/agessh"
Expand All @@ -31,6 +31,7 @@ import (
nitroawsutils "github.com/brave-intl/bat-go/libs/nitro/aws"
paymentLib "github.com/brave-intl/bat-go/libs/payments"
"github.com/hashicorp/vault/shamir"
"github.com/rs/zerolog"
)

// ChainAddress represents an on-chain address used for payouts. It needs to be persisted
Expand All @@ -53,6 +54,23 @@ type Vault struct {
shares paymentLib.CreateVaultResponse
}

type OperatorKey = *age.X25519Identity

// State of vault unsealing
type Unsealing struct {
// id for AWS KMS key to encrypt/decrypt operator shares
kmsDecryptKeyArn string
getChainAddress func(ctx context.Context, address string) (*ChainAddress, error)

// private key reconstructed from the operator shares
operatorKey OperatorKey

keyShares [][]byte
secretsCiphertext []byte
solanaPrivCiphertext []byte
secrets map[string]string
}

// createAttestationDocument will create an attestation document and return the private key and
// attestation document which is attesting over the userData supplied
func createAttestationDocument(ctx context.Context) (crypto.PrivateKey, []byte, error) {
Expand Down Expand Up @@ -216,18 +234,11 @@ func encryptShares(shares [][]byte, operatorKeys []string) ([]paymentLib.Operato
return shareResult, nil
}

// AreSecretsLoaded will tell you if we have successfully loaded secrets on the service
func (s *Service) AreSecretsLoaded(ctx context.Context) bool {
if len(s.secrets) > 0 {
return true
}
return false
}

func (s *Service) createSolanaAddress(ctx context.Context, bucket, creatorKey string) (*ChainAddress, error) {
solAccount := solTypes.NewAccount()
b58PubKey := solAccount.PublicKey.ToBase58()
encSeed, err := s.encryptWithShares(ctx, solAccount.PrivateKey.Seed())
encBuf := &bytes.Buffer{}
err := encryptToWriter(ctx, s.operatorKey, solAccount.PrivateKey.Seed(), encBuf)
if err != nil {
return nil, fmt.Errorf("failed to encrypt seed: %w", err)
}
Expand All @@ -239,15 +250,11 @@ func (s *Service) createSolanaAddress(ctx context.Context, bucket, creatorKey st
}
s3Client := s3.NewFromConfig(awsCfg)

encSeedBytes, err := io.ReadAll(encSeed)
if err != nil {
return nil, fmt.Errorf("failed to seed to bytes: %w", err)
}
h := md5.New()
h.Write(encSeedBytes)
h.Write(encBuf.Bytes())

input := &s3.PutObjectInput{
Body: bytes.NewBuffer(encSeedBytes),
Body: encBuf,
Bucket: aws.String(bucket),
Key: aws.String("solana-address-" + b58PubKey),
ContentMD5: aws.String(base64.StdEncoding.EncodeToString(h.Sum(nil))),
Expand Down Expand Up @@ -298,7 +305,7 @@ func (s *Service) approveSolanaAddress(ctx context.Context, address, approverKey

// fetchSecrets will take an s3 bucket/object and fetch the configuration and store the
// ciphertext on the service for decryption later
func (s *Service) fetchSecrets(ctx context.Context, bucket, secretsObject string, solanaPubAddr string) error {
func (u *Unsealing) tryFetchSecrets(ctx context.Context, bucket, secretsObject string, solanaPubAddr string) error {
logger := logging.Logger(ctx, "payments.secrets")
awsCfg, err := nitroAwsCfg(ctx)
if err != nil {
Expand All @@ -315,14 +322,14 @@ func (s *Service) fetchSecrets(ctx context.Context, bucket, secretsObject string
}

// we are not able to decrypt secretsCiphertext until all operator shares are available
s.secretsCiphertext, err = io.ReadAll(secretsResponse.Body)
u.secretsCiphertext, err = io.ReadAll(secretsResponse.Body)
if err != nil {
return fmt.Errorf("failed to read secrets bytes: %w", err)
}

if solanaPubAddr != "" {
logger.Debug().Str("solana public key", string(solanaPubAddr)).Msg("fetching solana key from s3")
chainAddress, err := s.datastore.GetChainAddress(ctx, solanaPubAddr)
chainAddress, err := u.getChainAddress(ctx, solanaPubAddr)
if err != nil {
return fmt.Errorf("failed to get solana address from QLDB: %w", err)
}
Expand All @@ -337,11 +344,11 @@ func (s *Service) fetchSecrets(ctx context.Context, bucket, secretsObject string
return fmt.Errorf("failed to get solana address from s3: %w", err)
}
logger.Debug().Msg("no error reading solana key from s3")
s.solanaPrivCiphertext, err = io.ReadAll(solanaAddressResponse.Body)
u.solanaPrivCiphertext, err = io.ReadAll(solanaAddressResponse.Body)
if err != nil {
return fmt.Errorf("failed to read solana address bytes: %w", err)
}
logger.Debug().Int("solana ciphertext length", len(s.solanaPrivCiphertext)).Msg("setting solana ciphertext to service")
logger.Debug().Int("solana ciphertext length", len(u.solanaPrivCiphertext)).Msg("setting solana ciphertext to service")
} else {
return fmt.Errorf("provided solana address has insufficient approvals")
}
Expand All @@ -350,60 +357,57 @@ func (s *Service) fetchSecrets(ctx context.Context, bucket, secretsObject string
return nil
}

// enoughOperatorShares informs the caller if there are enough operator shares present to attempt a decrypt
func (s *Service) enoughOperatorShares(ctx context.Context, required int) bool {
if len(s.keyShares) > required { // TODO: configurable in future, right now need two shares
return true
func (u *Unsealing) fetchSecretes(
ctx context.Context,
logger *zerolog.Logger,
) error {
// get the secrets object key and bucket name from environment
secretsBucketName, ok := ctx.Value(appctx.EnclaveSecretsBucketNameCTXKey).(string)
if !ok {
return errNoSecretsBucketConfigured
}
return false
}

var (
errNoSecretsCiphertext = errors.New("failed to get service configuration ciphertext")
)

// configureSecrets takes the ciphertext configuration from fetchSecrets, then decrypts it with the keyshares
// from fetchOperatorShares then stores the values in the configuration map
func (s *Service) configureSecrets(ctx context.Context) error {
logger := logging.Logger(ctx, "payments.secrets")
// do we have secrets downloaded?
if len(s.secretsCiphertext) < 1 {
return errNoSecretsCiphertext
// download the configuration file, kms decrypt the file
secretsObjectName, ok := ctx.Value(appctx.EnclaveSecretsObjectNameCTXKey).(string)
if !ok {
return errNoSecretsObjectConfigured
}

// decrypt configuration ciphertext
secrets, err := s.decryptSecrets(ctx)
if err != nil {
return fmt.Errorf("failed to decrypt secrets: %w", err)
solanaAddress, ok := ctx.Value(appctx.EnclaveSolanaAddressCTXKey).(string)
if !ok {
return errNoSolanaAddressConfigured
}
logger.Debug().Msg("decrypted secrets without error")
logger.Debug().Str("solana address:", solanaAddress).Msg("solana address configured")

// store conf on service
s.secrets = secrets
for {
// fetch the secrets, result will store the secrets (age ciphertext) on the service instance
if err := u.tryFetchSecrets(ctx, secretsBucketName, secretsObjectName, solanaAddress); err != nil {
// log the error, we will retry again
logger.Error().Err(err).Msg("failed to fetch secrets, will retry shortly")
<-time.After(30 * time.Second)
continue
}
break
}

s.setEnvFromSecrets(ctx, secrets)
logger.Debug().Msg("set env from secrets")
return nil
}

// setEnvFromSecrets takes a secrets map and loads the secrets as environment variables
func (s *Service) setEnvFromSecrets(ctx context.Context, secrets map[string]string) {
logger := logging.Logger(ctx, "payments.secrets")
os.Setenv("ZEBPAY_API_KEY", secrets["zebpayApiKey"])
os.Setenv("ZEBPAY_SIGNING_KEY", secrets["zebpayPrivateKey"])
os.Setenv("SOLANA_RPC_ENDPOINT", secrets["solanaRpcEndpoint"])

if solKey, ok := secrets["solanaPrivateKey"]; ok {
logger.Debug().Int("solana key length", len(secrets["solanaPrivateKey"])).Msg("setting solana key environment varialbe")
os.Setenv("SOLANA_SIGNING_KEY", solKey)
logger.Debug().Int("solana env var key length", len(os.Getenv("SOLANA_SIGNING_KEY"))).Msg("set solana key environment varialbe")
// enoughOperatorShares informs the caller if there are enough operator shares present to attempt a decrypt
func (u *Unsealing) enoughOperatorShares(ctx context.Context, required int) bool {
if len(u.keyShares) > required { // TODO: configurable in future, right now need two shares
return true
}
return false
}

var (
errNoSecretsCiphertext = errors.New("failed to get service configuration ciphertext")
)

// fetchOperatorShares will take an s3 bucket and fetch all of the operator shares and store them
func (s *Service) fetchOperatorShares(ctx context.Context, bucket string) error {
func (u *Unsealing) tryFetchOperatorShares(ctx context.Context, bucket string) error {
// clear out all keyshares and start over, we will be downloading ALL shares from the s3 bucket
s.keyShares = [][]byte{}
u.keyShares = [][]byte{}

// get the aws configuration
awsCfg, err := nitroAwsCfg(ctx)
Expand Down Expand Up @@ -452,7 +456,7 @@ func (s *Service) fetchOperatorShares(ctx context.Context, bucket string) error
decryptOutput, err := kms.NewFromConfig(awsCfg).Decrypt(ctx, &kms.DecryptInput{
CiphertextBlob: data,
EncryptionAlgorithm: kmsTypes.EncryptionAlgorithmSpecSymmetricDefault,
KeyId: aws.String(s.kmsDecryptKeyArn),
KeyId: aws.String(u.kmsDecryptKeyArn),
Recipient: &kmsTypes.RecipientInfo{
AttestationDocument: document, // attestation document
KeyEncryptionAlgorithm: kmsTypes.KeyEncryptionMechanismRsaesOaepSha256, // how to decrypt
Expand All @@ -473,34 +477,74 @@ func (s *Service) fetchOperatorShares(ctx context.Context, bucket string) error
return fmt.Errorf("failed to base64 decode operator key share: %w", err)
}

s.keyShares = append(s.keyShares, share)
u.keyShares = append(u.keyShares, share)
}

return nil
}

func (u *Unsealing) fetchOperatorShares(
ctx context.Context,
logger *zerolog.Logger,
) error {
// operator shares files
operatorSharesBucketName, ok := ctx.Value(appctx.EnclaveOperatorSharesBucketNameCTXKey).(string)
if !ok {
return errNoOperatorSharesBucketConfigured
}

for {
// do we have enough shares to attempt to reconstitute the key?
if err := u.tryFetchOperatorShares(ctx, operatorSharesBucketName); err != nil {
logger.Error().Err(err).Msg("failed to fetch operator shares, will retry shortly")
<-time.After(60 * time.Second)
continue
}
if ok := u.enoughOperatorShares(ctx, 1); ok { // 2 is the number of shares required
break
}
logger.Error().Msg("need more operator shares to decrypt secrets")
// no - poll for operator shares until we can attempt to decrypt the file
<-time.After(60 * time.Second) // wait a minute before attempting again to get operator shares
}
return nil
}

// decryptSecrets combines the shamir shares stored on the service instance and decrypts the ciphertext
// returning a map of secret values from the configuration
func (s *Service) decryptSecrets(ctx context.Context) (map[string]string, error) {
func (u *Unsealing) decryptSecrets(ctx context.Context) error {
// do we have secrets downloaded?
if len(u.secretsCiphertext) < 1 {
return errNoSecretsCiphertext
}

// combine the service configured key shares
privateKey, err := shamir.Combine(u.keyShares)
if err != nil {
return fmt.Errorf("failed to combine keyShares: %w", err)
}

u.operatorKey, err = age.ParseX25519Identity(string(privateKey))
if err != nil {
return fmt.Errorf("failed to parse private key bytes for secret decryption: %w", err)
}

logger := logging.Logger(ctx, "payments.secrets")
var output = map[string]string{}

secBuf := bytes.NewBuffer(s.secretsCiphertext)

sec, err := s.decryptWithShares(ctx, *secBuf)
sec, err := getDecryptReader(ctx, u.operatorKey, u.secretsCiphertext)
if err != nil {
return nil, fmt.Errorf("failed to decrypt secrets with shares: %w", err)
return fmt.Errorf("failed to decrypt secrets with shares: %w", err)
}
if err := json.NewDecoder(sec).Decode(&output); err != nil {
return nil, fmt.Errorf("failed to json decode the secrets: %w", err)
return fmt.Errorf("failed to json decode the secrets: %w", err)
}

if len(s.solanaPrivCiphertext) > 0 {
logger.Debug().Int("solana ciphertext length", len(s.solanaPrivCiphertext)).Msg("decrypting solana ciphertext")
solBuf := bytes.NewBuffer(s.solanaPrivCiphertext)
solReader, err := s.decryptWithShares(ctx, *solBuf)
if len(u.solanaPrivCiphertext) > 0 {
logger.Debug().Int("solana ciphertext length", len(u.solanaPrivCiphertext)).Msg("decrypting solana ciphertext")
solReader, err := getDecryptReader(ctx, u.operatorKey, u.solanaPrivCiphertext)
if err != nil {
return nil, fmt.Errorf("failed to decrypt solana address with shares: %w", err)
return fmt.Errorf("failed to decrypt solana address with shares: %w", err)
}
logger.Debug().Msg("decryptWithShares completed without error")
buf := new(bytes.Buffer)
Expand All @@ -510,47 +554,29 @@ func (s *Service) decryptSecrets(ctx context.Context) (map[string]string, error)
logger.Debug().Int("solana key length", len(output["solanaPrivateKey"])).Msg("set decrypted key to secret map")
}

return output, nil
u.secrets = output
return nil
}

func (s *Service) decryptWithShares(ctx context.Context, buf bytes.Buffer) (io.Reader, error) {
// combine the service configured key shares
privateKey, err := shamir.Combine(s.keyShares)
if err != nil {
return nil, fmt.Errorf("failed to combine keyShares: %w", err)
}

identity, err := age.ParseX25519Identity(string(privateKey))
if err != nil {
return nil, fmt.Errorf("failed to parse private key bytes for secret decryption: %w", err)
}

return age.Decrypt(bytes.NewReader(buf.Bytes()), identity)
func getDecryptReader(
ctx context.Context,
key OperatorKey,
cipherText []byte,
) (io.Reader, error) {
return age.Decrypt(bytes.NewReader(cipherText), key)
}

func (s *Service) encryptWithShares(ctx context.Context, data []byte) (io.Reader, error) {
// combine the service configured key shares
privateKey, err := shamir.Combine(s.keyShares)
if err != nil {
return nil, fmt.Errorf("failed to combine keyShares: %w", err)
}

identity, err := age.ParseX25519Identity(string(privateKey))
if err != nil {
return nil, fmt.Errorf("failed to parse private key bytes for secret decryption: %w", err)
}

out := &bytes.Buffer{}
func encryptToWriter(ctx context.Context, key OperatorKey, data []byte, destination io.Writer) error {

w, err := age.Encrypt(out, identity.Recipient())
w, err := age.Encrypt(destination, key.Recipient())
if err != nil {
return nil, fmt.Errorf("Failed to create encrypted file: %v", err)
return fmt.Errorf("Failed to create encryption stream: %v", err)
}
if _, err := io.WriteString(w, string(data)); err != nil {
return nil, fmt.Errorf("Failed to write to encrypted file: %v", err)
_, err = w.Write(data)
err2 := w.Close()
if err != nil || err2 != nil {
err = errors.Join(err, err2)
return fmt.Errorf("Failed to encrypt %d bytes: %w", len(data), err)
}
if err := w.Close(); err != nil {
return nil, fmt.Errorf("Failed to close encrypted file: %v", err)
}
return out, nil
return nil
}
Loading

0 comments on commit de0b2a1

Please sign in to comment.