Skip to content

Commit

Permalink
add pkiSign function to use vault pki sign api endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: n-marton <[email protected]>
  • Loading branch information
n-marton committed Jun 17, 2024
1 parent 53a5f43 commit bc92044
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 15 deletions.
21 changes: 15 additions & 6 deletions dependency/vault_pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ type VaultPKIQuery struct {
pkiPath string
data map[string]interface{}
filePath string
// we have a var here for passing the private key to the class functions
// for the cases when we intend to use sign instead of issue
privateKey *string
}

// NewVaultReadQuery creates a new datacenter dependency.
func NewVaultPKIQuery(urlpath, filepath string, data map[string]interface{}) (*VaultPKIQuery, error) {
func NewVaultPKIQuery(urlpath, filepath string, data map[string]interface{}, privateKey *string) (*VaultPKIQuery, error) {
urlpath = strings.TrimSpace(urlpath)
urlpath = strings.Trim(urlpath, "/")
if urlpath == "" {
Expand All @@ -52,11 +55,12 @@ func NewVaultPKIQuery(urlpath, filepath string, data map[string]interface{}) (*V
}

return &VaultPKIQuery{
stopCh: make(chan struct{}, 1),
sleepCh: make(chan time.Duration, 1),
pkiPath: secretURL.Path,
data: data,
filePath: filepath,
stopCh: make(chan struct{}, 1),
sleepCh: make(chan time.Duration, 1),
pkiPath: secretURL.Path,
data: data,
filePath: filepath,
privateKey: privateKey,
}, nil
}

Expand Down Expand Up @@ -107,6 +111,11 @@ func (d *VaultPKIQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interface
default:
return PemEncoded{}, nil, err
}
// In the case that we are using sign vault endpoint we wont have an private key in the response
// so we should pass the one the we generated
if encPems.Key == "" && d.privateKey != nil {
encPems.Key = *d.privateKey
}
return respWithMetadata(encPems)
}

Expand Down
15 changes: 10 additions & 5 deletions dependency/vault_pki_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ func init() {
}

func Test_VaultPKI_uniqueID(t *testing.T) {
d1, _ := NewVaultPKIQuery("pki/issue/example-dot-com", "/unique_1", nil)
d1, _ := NewVaultPKIQuery("pki/issue/example-dot-com", "/unique_1", nil, nil)
id1 := d1.String()
d2, _ := NewVaultPKIQuery("pki/issue/example-dot-com", "/unique_2", nil)
d2, _ := NewVaultPKIQuery("pki/issue/example-dot-com", "/unique_2", nil, nil)
id2 := d2.String()
if id1 == id2 {
t.Errorf("IDs should be unique.\n%s\n%s", id1, id2)
}
d3, _ := NewVaultPKIQuery("pki/sign/example-dot-com", "/unique_1", nil, nil)
id3 := d3.String()
if id1 == id3 {
t.Errorf("IDs should be unique.\n%s\n%s", id1, id3)
}
}

func Test_VaultPKI_notGoodFor(t *testing.T) {
Expand Down Expand Up @@ -149,7 +154,7 @@ func Test_VaultPKI_fetchPEM(t *testing.T) {
"ttl": "2h",
"ip_sans": "127.0.0.1,192.168.2.2",
}
d, err := NewVaultPKIQuery("pki/issue/example-dot-com", "/dev/null", data)
d, err := NewVaultPKIQuery("pki/issue/example-dot-com", "/dev/null", data, nil)
if err != nil {
t.Error(err)
}
Expand All @@ -161,7 +166,7 @@ func Test_VaultPKI_fetchPEM(t *testing.T) {
t.Errorf("pemsificate not fetched, got: %s", string(encPEM))
}
// test path error
d, err = NewVaultPKIQuery("pki/issue/does-not-exist", "/dev/null", data)
d, err = NewVaultPKIQuery("pki/issue/does-not-exist", "/dev/null", data, nil)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -195,7 +200,7 @@ func Test_VaultPKI_refetch(t *testing.T) {
"ttl": TTL,
"ip_sans": "127.0.0.1,192.168.2.2",
}
d, err := NewVaultPKIQuery("pki/issue/example-dot-com", f.Name(), data)
d, err := NewVaultPKIQuery("pki/issue/example-dot-com", f.Name(), data, nil)
if err != nil {
t.Fatal(err)
}
Expand Down
52 changes: 52 additions & 0 deletions docs/templating-language.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ provides the following functions:
+ [Write (and Read back)](#write-and-read-back)
* [`secrets`](#secrets)
* [`pkiCert`](#pkicert)
* [`pkiSign`](#pkisign)
* [`service`](#service)
* [`services`](#services)
* [`tree`](#tree)
Expand Down Expand Up @@ -765,6 +766,57 @@ to separate files from a template.
{{- end -}}
```

### pkiSign

Query [Vault][vault] for a PKI certificate. This is pretty similar to `pkiCert`
however, instead of using the `issue` api endpoint it uses the `sign`. This also
means, the private key generation is happening on the consul template side, this
can be quite useful if one generates a high number of certificates with low ttl
which can put high load on the vault servers.

The templating behaviour is the same as we have in `pkiCert` with a few special
attributes. You also need to pass `key_type=rsa|ec|ed25519` in alignment with your
role on vault server. If you have `use_csr_common_name` and/or `use_csr_sans` as true
on in your role, you should also pass them here, so the CSR will be appended with those
values (they can have any value `use_csr_sans=value` or `use_csr_common_name=value` the
code only check for the key).


```golang
{{ with pkiSign "pki/sign/my-domain-dot-com" "common_name=foo.example.com" }}
Certificate: {{ .Cert }}
Private Key: {{ .Key }}
Cert Authority: {{ .CA }}
{{ end }}
```

If the pki role has use_csr_common_name=true and use_csr_sans=true
```golang
{{ with pkiSign "pki/sign/my-domain-dot-com" "common_name=foo.example.com" "use_csr_common_name=some" "use_csr_sans=thing" }}
Certificate: {{ .Cert }}
Private Key: {{ .Key }}
Cert Authority: {{ .CA }}
{{ end }}
```

If the pki role has `ec` key
```golang
{{ with pkiSign "pki/sign/my-domain-dot-com" "common_name=foo.example.com" key_type="ec" key_bits="521" }}
Certificate: {{ .Cert }}
Private Key: {{ .Key }}
Cert Authority: {{ .CA }}
{{ end }}
```

If the pki role has `ed25519` key
```golang
{{ with pkiSign "pki/sign/my-domain-dot-com" "common_name=foo.example.com" key_type="ed25519" }}
Certificate: {{ .Cert }}
Private Key: {{ .Key }}
Cert Authority: {{ .CA }}
{{ end }}
```

### `service`

Query [Consul][consul] for services based on their health.
Expand Down
196 changes: 195 additions & 1 deletion template/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@ package template

import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net"
"os"
"os/exec"
"os/user"
Expand Down Expand Up @@ -456,7 +464,193 @@ func pkiCertFunc(b *Brain, used, missing *dep.Set, destPath string) func(...stri
data[k] = v
}

d, err := dep.NewVaultPKIQuery(path, destPath, data)
d, err := dep.NewVaultPKIQuery(path, destPath, data, nil)
if err != nil {
return nil, err
}

used.Add(d)
if value, ok := b.Recall(d); ok {
return value, nil
}
missing.Add(d)

return nil, nil
}
}

// pkiSignFunc generates a privatekey and csr, sends the latter to Vault to sign
func pkiSignFunc(b *Brain, used, missing *dep.Set, destPath string) func(...string) (interface{}, error) {
return func(s ...string) (interface{}, error) {
if len(s) == 0 {
return nil, nil
}

keyType := "rsa"
keyBits := 2048

var privateKey any
var rawKey string
var useCSRCommonName bool
var useCSRSans bool

path, rest := s[0], s[1:]
data := make(map[string]interface{})
for _, str := range rest {
if len(str) == 0 {
continue
}
parts := strings.SplitN(str, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("not k=v pair %q", str)
}

k, v := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
// since we are generating the private key on our end we should not send
// key_type and key_bits to Vault
// use_csr_common_name and use_csr_sans here are meant to mirror the settings on the
// vault role side, so we can configure on our end accordingly
if k != "key_type" && k != "key_bits" && k != "use_csr_common_name" && k != "use_csr_sans" {
data[k] = v
}
// if we passed a key_type and the value is either rsa, ec or ed25519 we override the default value
if k == "key_type" && (v == "rsa" || v == "ed25519" || v == "ec") {
keyType = v
}
// if we passed key_bits we override the default value
if k == "key_bits" {
keyBit, err := strconv.Atoi(v)
if err != nil {
return nil, err
}
keyBits = keyBit
}
// check if we passed use_csr_common_name for later usage
if k == "use_csr_common_name" {
useCSRCommonName = true
}
// check if we passed use_csr_sans for later usage
if k == "use_csr_sans" {
useCSRSans = true
}
}

var csrTemplate x509.CertificateRequest

// if we passed use_csr_common_name, that means serverside will expect the commonname from the csr
// so besides adding that param to the csr template, we also remove it from the map we pass later
// to vault, this way we spare a warning from the server side
if useCSRCommonName {
commonName, ok := data["common_name"]
if ok {
csrTemplate.Subject.CommonName = commonName.(string)
}
delete(data, "common_name")
}
// if we passed use_csr_sans, that means serverside will expect the subject alternate names from the csr
// so besides adding that param to the csr template, we also remove it from the map we pass later
if useCSRSans {
subjectAltNames, ok := data["uri_sans"]
if ok {
csrTemplate.DNSNames = strings.Split(subjectAltNames.(string), ",")
}
subjectAltIPs, ok := data["ip_sans"]
if ok {
for _, ip := range strings.Split(subjectAltIPs.(string), ",") {
parsedIP := net.ParseIP(ip)
csrTemplate.IPAddresses = append(csrTemplate.IPAddresses, parsedIP)
}
}
delete(data, "uri_sans")
delete(data, "ip_sans")
}

// generating private keys and also pem encode them for later usage
if keyType == "rsa" {
key, err := rsa.GenerateKey(rand.Reader, keyBits)
if err != nil {
return nil, err
}
privateKey = key
csrTemplate.SignatureAlgorithm = x509.SHA512WithRSA
marshaledKey := x509.MarshalPKCS1PrivateKey(key)
if err != nil {
return nil, err
}
keyPEMBlock := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: marshaledKey,
}
rawKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
}
if keyType == "ed25519" {
_, key, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
privateKey = key
csrTemplate.SignatureAlgorithm = x509.PureEd25519
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, err
}
keyPEMBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: marshaledKey,
}
rawKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
}
if keyType == "ec" {
if keyBits == 2048 {
keyBits = 256
}
var err error
var key *ecdsa.PrivateKey
switch keyBits {
case 224:
key, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
case 256:
key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case 384:
key, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case 521:
key, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
default:
err = errors.New("Got unknown ec< key bits: " + fmt.Sprintf("%d", keyBits))
}
if err != nil {
return nil, err
}
privateKey = key
csrTemplate.SignatureAlgorithm = x509.ECDSAWithSHA512
marshaledKey, err := x509.MarshalECPrivateKey(key)
if err != nil {
return nil, err
}
keyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledKey,
}
rawKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
}

csr, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, privateKey)
if err != nil {
return nil, err
}

pemBlock := &pem.Block{
Type: "CERTIFICATE REQUEST",
Headers: nil,
Bytes: csr,
}
pemCsr := string(pem.EncodeToMemory(pemBlock))

// we need to pass the actual csr to the sign endpoint
data["csr"] = pemCsr

// we pass also the private key that we have generated
d, err := dep.NewVaultPKIQuery(path, destPath, data, &rawKey)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func funcMap(i *funcMapInput) template.FuncMap {
"caRoots": connectCARootsFunc(i.brain, i.used, i.missing),
"caLeaf": connectLeafFunc(i.brain, i.used, i.missing),
"pkiCert": pkiCertFunc(i.brain, i.used, i.missing, i.destination),

"pkiSign": pkiSignFunc(i.brain, i.used, i.missing, i.destination),
// Nomad Functions.
"nomadServices": nomadServicesFunc(i.brain, i.used, i.missing),
"nomadService": nomadServiceFunc(i.brain, i.used, i.missing),
Expand Down
Loading

0 comments on commit bc92044

Please sign in to comment.