diff --git a/dependency/vault_pki.go b/dependency/vault_pki.go index 8bfdab311..37bc3d6a8 100644 --- a/dependency/vault_pki.go +++ b/dependency/vault_pki.go @@ -21,11 +21,40 @@ import ( var _ Dependency = (*VaultPKIQuery)(nil) // Return type containing PEMs as strings -type PemEncoded struct{ Cert, Key, CA string } +type PemEncoded struct { + Cert, Key, CA string + CAChain []string +} + +func (a PemEncoded) Equals(b PemEncoded) bool { + if a.CA != b.CA || a.Cert != b.Cert || a.Key != b.Key { + return false + } + + if len(a.CAChain) != len(b.CAChain) { + return false + } + + for i, v := range a.CAChain { + if v != b.CAChain[i] { + return false + } + } + return true +} + +func (a PemEncoded) CaChainContains(item string) bool { + for _, v := range a.CAChain { + if v == item { + return true + } + } + return false +} // a wrapper to mimic v2 secrets Data wrapper -func (p PemEncoded) Data() PemEncoded { - return p +func (a PemEncoded) Data() PemEncoded { + return a } // VaultPKIQuery is the dependency to Vault for a secret @@ -152,10 +181,12 @@ func pemsCert(encoded []byte) (PemEncoded, *x509.Certificate, error) { var cert *x509.Certificate var encPems PemEncoded var aPem []byte + for { aPem, encoded = nextPem(encoded) // scan, find and parse PEM blocks block, _ = pem.Decode(aPem) + switch { case block == nil: // end of scan, no more PEMs found return encPems, cert, nil @@ -170,7 +201,13 @@ func pemsCert(encoded []byte) (PemEncoded, *x509.Certificate, error) { case err != nil: return PemEncoded{}, nil, err case maybeCert.IsCA: - encPems.CA = string(pem.EncodeToMemory(block)) + if encPems.CA == "" { + // set the first CA found to CA to be backward compatible + encPems.CA = string(pem.EncodeToMemory(block)) + } + if !encPems.CaChainContains(string(pem.EncodeToMemory(block))) { + encPems.CAChain = append(encPems.CAChain, string(pem.EncodeToMemory(block))) + } default: // the certificate cert = maybeCert encPems.Cert = string(pem.EncodeToMemory(block)) @@ -205,10 +242,20 @@ func (d *VaultPKIQuery) fetchPEMs(clients *ClientSet) ([]byte, error) { } printVaultWarnings(d, vaultSecret.Warnings) pems := bytes.Buffer{} - for _, v := range vaultSecret.Data { + + for k, v := range vaultSecret.Data { switch v := v.(type) { case string: pems.WriteString(v + "\n") + case []interface{}: + if k == "ca_chain" { + for _, item := range v { + switch item := item.(type) { + case string: + pems.WriteString(item + "\n") + } + } + } } } diff --git a/dependency/vault_pki_test.go b/dependency/vault_pki_test.go index 3c1eca21d..0b6727cec 100644 --- a/dependency/vault_pki_test.go +++ b/dependency/vault_pki_test.go @@ -238,7 +238,8 @@ func Test_VaultPKI_refetch(t *testing.T) { t.Fatalf("expected a pems but found: %s", pems2) } // using cached copy, so should be a match - if pems1 != pems2 { + + if !pems1.Equals(pems2) { t.Errorf("pemss don't match and should.") } @@ -261,7 +262,7 @@ func Test_VaultPKI_refetch(t *testing.T) { t.Fatalf("expected a pems but found: %s", pems2) } - if pems2 == pems3 { + if pems2.Equals(pems3) { t.Errorf("pemss match and shouldn't.") } }