diff --git a/test/helpers.go b/test/helpers.go index 0967e6ddaf98..7747bebdd903 100644 --- a/test/helpers.go +++ b/test/helpers.go @@ -512,31 +512,36 @@ func generateCertificateBundleFiles(td string, genIntermediate bool, outputSuffi err = fmt.Errorf("error generating certificate bundle: %w", err) return } - err = os.WriteFile(filepath.Join(td, fmt.Sprintf("caCert%s.pem", outputSuffix)), caCertBuf.Bytes(), 0600) + caCertFile = filepath.Join(td, fmt.Sprintf("caCert%s.pem", outputSuffix)) + err = os.WriteFile(caCertFile, caCertBuf.Bytes(), 0600) if err != nil { - err = fmt.Errorf("error writing caCert to file: %w", err) + err = fmt.Errorf("error writing caCert to file %s: %w", caCertFile, err) return } - err = os.WriteFile(filepath.Join(td, fmt.Sprintf("caPrivKey%s.pem", outputSuffix)), caPrivKeyBuf.Bytes(), 0600) + caPrivKeyFile = filepath.Join(td, fmt.Sprintf("caPrivKey%s.pem", outputSuffix)) + err = os.WriteFile(caPrivKeyFile, caPrivKeyBuf.Bytes(), 0600) if err != nil { - err = fmt.Errorf("error writing caPrivKey to file: %w", err) + err = fmt.Errorf("error writing caPrivKey to file %s: %w", caPrivKeyFile, err) return } if genIntermediate { - err = os.WriteFile(filepath.Join(td, fmt.Sprintf("caIntermediateCert%s.pem", outputSuffix)), caIntermediateCertBuf.Bytes(), 0600) + caIntermediateCertFile = filepath.Join(td, fmt.Sprintf("caIntermediateCert%s.pem", outputSuffix)) + err = os.WriteFile(caIntermediateCertFile, caIntermediateCertBuf.Bytes(), 0600) if err != nil { - err = fmt.Errorf("error writing caIntermediateCert to file: %w", err) + err = fmt.Errorf("error writing caIntermediateCert to file %s: %w", caIntermediateCertFile, err) return } - err = os.WriteFile(filepath.Join(td, fmt.Sprintf("caIntermediatePrivKey%s.pem", outputSuffix)), caIntermediatePrivKeyBuf.Bytes(), 0600) + caIntermediatePrivKeyFile = filepath.Join(td, fmt.Sprintf("caIntermediatePrivKey%s.pem", outputSuffix)) + err = os.WriteFile(caIntermediatePrivKeyFile, caIntermediatePrivKeyBuf.Bytes(), 0600) if err != nil { - err = fmt.Errorf("error writing caIntermediatePrivKey to file: %w", err) + err = fmt.Errorf("error writing caIntermediatePrivKey to file %s: %w", caIntermediatePrivKeyFile, err) return } } - err = os.WriteFile(filepath.Join(td, fmt.Sprintf("cert%s.pem", outputSuffix)), certBuf.Bytes(), 0600) + certFile = filepath.Join(td, fmt.Sprintf("cert%s.pem", outputSuffix)) + err = os.WriteFile(certFile, certBuf.Bytes(), 0600) if err != nil { - err = fmt.Errorf("error writing cert to file: %w", err) + err = fmt.Errorf("error writing cert to file %s: %w", certFile, err) return } @@ -544,7 +549,7 @@ func generateCertificateBundleFiles(td string, genIntermediate bool, outputSuffi certChainFile = filepath.Join(td, fmt.Sprintf("certchain%s.pem", outputSuffix)) err = os.WriteFile(certChainFile, certChainBuf.Bytes(), 0600) if err != nil { - err = fmt.Errorf("error writing certificate chain to file: %w", err) + err = fmt.Errorf("error writing certificate chain to file %s: %w", certFile, err) return } return diff --git a/test/helpers_test.go b/test/helpers_test.go index 95b8872ddd88..ae04371422a7 100644 --- a/test/helpers_test.go +++ b/test/helpers_test.go @@ -2,10 +2,16 @@ package test -import "testing" +import ( + "crypto/x509" + "encoding/pem" + "io/ioutil" + "log" + "testing" +) -func TestGenerateCertificateBundle(t *testing.T) { - for _, test := range []struct { +func TestGenerateCertificateBundleFiles(t *testing.T) { + for _, tt := range []struct { name string genIntermediate bool }{ @@ -18,11 +24,144 @@ func TestGenerateCertificateBundle(t *testing.T) { genIntermediate: true, }, } { - t.Run(test.name, func(t *testing.T) { - _, _, _, _, _, _, err := generateCertificateBundle(true) + t.Run(tt.name, func(t *testing.T) { + td := t.TempDir() + suffix := "foobar" + caCertFile, caPrivKeyFile, caIntermediateCertFile, caIntermediatePrivKeyFile, + certFile, certChainFile, err := generateCertificateBundleFiles(td, true, suffix) if err != nil { t.Fatalf("Error generating certificate bundle: %v", err) } + verifyCertificate(t, caCertFile) + if tt.genIntermediate { + verifyCertificate(t, caIntermediateCertFile) + } + verifyCertificate(t, certFile) + + verifyPrivateKey(t, caPrivKeyFile) + if tt.genIntermediate { + verifyPrivateKey(t, caIntermediatePrivKeyFile) + verifyCertificateChain(t, certChainFile) + } }) } } + +func verifyCertificate(t *testing.T, certFile string) { + t.Helper() + // open and parse certFile, ensure it is a TLS certificate + data, err := ioutil.ReadFile(certFile) + if err != nil { + t.Fatalf("Error reading certificate file %s: %v\n", certFile, err) + return + } + + // Check if the file contents are a PEM-encoded TLS certificate + if !isPEMEncodedCert(data) { + t.Fatalf("file %s doesn't contain a valid PEM-encoded TLS certificate", certFile) + } +} + +func verifyCertificateChain(t *testing.T, certChainFile string) { + t.Helper() + // open and parse certChainFile, ensure it is a TLS certificate chain + data, err := ioutil.ReadFile(certChainFile) + if err != nil { + t.Fatalf("Error reading certificate file %s: %v\n", certChainFile, err) + } + + // Check if the file contents are a PEM-encoded TLS certificate + t.Logf("DMDEBUG 76 before isPEMEncodedCertChain") + if !isPEMEncodedCertChain(data) { + t.Fatalf("file %s doesn't contain a valid PEM-encoded TLS certificate chain", certChainFile) + } +} + +// isPEMEncodedCert checks if the provided data is a PEM-encoded certificate +func isPEMEncodedCert(data []byte) bool { + // Decode the PEM data + block, _ := pem.Decode(data) + if block == nil || block.Type != "CERTIFICATE" { + return false + } + + // Parse the certificate to ensure it is valid + _, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false + } + + return true +} + +func verifyPrivateKey(t *testing.T, privKeyFile string) { + t.Helper() + // open and parse certFile, ensure it is a TLS certificate + data, err := ioutil.ReadFile(privKeyFile) + if err != nil { + t.Fatalf("Error reading private key file %s: %v\n", privKeyFile, err) + return + } + + // Check if the file contents are a PEM-encoded private key + if !isPEMEncodedPrivateKey(data) { + t.Fatalf("file %s doesn't contain a valid PEM-encoded private key", privKeyFile) + } +} + +// isPEMEncodedPrivateKey checks if the provided data is a PEM-encoded private key +func isPEMEncodedPrivateKey(data []byte) bool { + // Decode the PEM data + block, _ := pem.Decode(data) + if block == nil { + return false + } + var err error + + switch block.Type { + case "PRIVATE KEY": + _, err = x509.ParsePKCS8PrivateKey(block.Bytes) + case "RSA PRIVATE KEY": + _, err = x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + _, err = x509.ParseECPrivateKey(block.Bytes) + default: + return false + } + if err != nil { + log.Printf("isPEMEncodedPrivateKey: %v", err) + return false + } + + return true +} + +// isPEMEncodedCertChain checks if the provided data is a concatenation of a PEM-encoded +// intermediate certificate followed by a root certificate +func isPEMEncodedCertChain(data []byte) bool { + // Decode the PEM blocks one by one + blockCnt := 0 + for len(data) > 0 { + var block *pem.Block + block, data = pem.Decode(data) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { + return false + } + + // Parse the certificate to ensure it is valid + _, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false + } + + blockCnt++ + } + // we want exactly two blocks in the certificate chain - intermediate and root + if blockCnt != 2 { + return false + } + return true +}