Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple JWT tokens to be configured (closes #108) #109

Merged
merged 14 commits into from
Jan 11, 2024
Merged
27 changes: 13 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@ If `-config` is not specified, the default value `helper.conf` is assumed.
## Configuration
The configuration file is an [HCL](https://github.com/hashicorp/hcl) formatted file that defines the following configurations:

| Configuration | Description | Example Value |
|-----------------------------|----------------------------------------------------------------------------------------------------------------| -------------------- |
|`agent_address` | Socket address of SPIRE Agent. | `"/tmp/agent.sock"` |
|`cmd` | The path to the process to launch. | `"ghostunnel"` |
|`cmd_args` | The arguments of the process to launch. | `"server --listen localhost:8002 --target localhost:8001--keystore certs/svid_key.pem --cacert certs/svid_bundle.pem --allow-uri-san spiffe://example.org/Database"` |
|`cert_dir` | Directory name to store the fetched certificates. This directory must be created previously. | `"certs"` |
|`add_intermediates_to_bundle`| Add intermediate certificates into Bundle file instead of SVID file. | `true` |
|`renew_signal` | The signal that the process to be launched expects to reload the certificates. It is not supported on Windows. | `"SIGUSR1"` |
|`svid_file_name` | File name to be used to store the X.509 SVID public certificate in PEM format. | `"svid.pem"` |
|`svid_key_file_name` | File name to be used to store the X.509 SVID private key and public certificate in PEM format. | `"svid_key.pem"` |
|`svid_bundle_file_name` | File name to be used to store the X.509 SVID Bundle in PEM format. | `"svid_bundle.pem"` |
|`jwt_audience` | JWT SVID audience. | `"your-audience"` |
|`jwt_svid_file_name` | File name to be used to store JWT SVID in Base64-encoded string. | `"jwt_svid.token"` |
|`jwt_bundle_file_name` | File name to be used to store JWT Bundle in JSON format. | `"jwt_bundle.json"` |
| Configuration | Description | Example Value |
|-------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `agent_address` | Socket address of SPIRE Agent. | `"/tmp/agent.sock"` |
| `cmd` | The path to the process to launch. | `"ghostunnel"` |
| `cmd_args` | The arguments of the process to launch. | `"server --listen localhost:8002 --target localhost:8001--keystore certs/svid_key.pem --cacert certs/svid_bundle.pem --allow-uri-san spiffe://example.org/Database"` |
| `cert_dir` | Directory name to store the fetched certificates. This directory must be created previously. | `"certs"` |
| `add_intermediates_to_bundle` | Add intermediate certificates into Bundle file instead of SVID file. | `true` |
| `renew_signal` | The signal that the process to be launched expects to reload the certificates. It is not supported on Windows. | `"SIGUSR1"` |
| `svid_file_name` | File name to be used to store the X.509 SVID public certificate in PEM format. | `"svid.pem"` |
| `svid_key_file_name` | File name to be used to store the X.509 SVID private key and public certificate in PEM format. | `"svid_key.pem"` |
| `svid_bundle_file_name` | File name to be used to store the X.509 SVID Bundle in PEM format. | `"svid_bundle.pem"` |
| `jwt_svids` | An array of objects containing `jwt_audience` (which is the JWT SVID audience) and `jwt_svid_file_name` (which is the file name to be used to store JWT SVID in Base64-encoded string). | `[{jwt_audience="your-audience", jwt_svid_file_name="jwt_svid.token"}]` |
keeganwitt marked this conversation as resolved.
Show resolved Hide resolved
| `jwt_bundle_file_name` | File name to be used to store JWT Bundle in JSON format. | `"jwt_bundle.json"` |

### Configuration example
```
Expand Down
28 changes: 18 additions & 10 deletions pkg/sidecar/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,20 @@ type Config struct {
RenewSignalDeprecated string `hcl:"renewSignal"`

// JWT configuration
JWTAudience string `hcl:"jwt_audience"`
JWTSvidFilename string `hcl:"jwt_svid_file_name"`
JWTBundleFilename string `hcl:"jwt_bundle_file_name"`
JwtSvids []JwtConfig `hcl:"jwt_svids"`
JWTBundleFilename string `hcl:"jwt_bundle_file_name"`

// TODO: is there a reason for this to be exposed? and inside of config?
ReloadExternalProcess func() error
// TODO: is there a reason for this to be exposed? and inside of config?
Log logrus.FieldLogger
}

type JwtConfig struct {
JWTAudience string `hcl:"jwt_audience"`
JWTSvidFilename string `hcl:"jwt_svid_file_name"`
}

// ParseConfig parses the given HCL file into a SidecarConfig struct
func ParseConfig(file string) (*Config, error) {
sidecarConfig := new(Config)
Expand Down Expand Up @@ -120,21 +124,25 @@ func ValidateConfig(c *Config) error {
c.RenewSignal = c.RenewSignalDeprecated
}

for _, jwtConfig := range c.JwtSvids {
if countEmpty(jwtConfig.JWTSvidFilename) > 0 {
keeganwitt marked this conversation as resolved.
Show resolved Hide resolved
return errors.New("'jwt_file_name' is required in 'jwt_svids'")
}
if countEmpty(jwtConfig.JWTAudience) > 0 {
keeganwitt marked this conversation as resolved.
Show resolved Hide resolved
return errors.New("'jwt_audience' is required in 'jwt_svids'")
}
}

x509EmptyCount := countEmpty(c.SvidFileName, c.SvidBundleFileName, c.SvidKeyFileName)
jwtSVIDEmptyCount := countEmpty(c.JWTSvidFilename, c.JWTAudience)
jwtBundleEmptyCount := countEmpty(c.SvidBundleFileName)
if x509EmptyCount == 3 && jwtSVIDEmptyCount == 2 && jwtBundleEmptyCount == 1 {
return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified")
if x509EmptyCount == 3 && len(c.JwtSvids) == 0 && jwtBundleEmptyCount == 1 {
return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), 'jwt_svids', or 'jwt_bundle_file_name' must be fully specified")
}

if x509EmptyCount != 0 && x509EmptyCount != 3 {
return errors.New("all or none of 'svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name' must be specified")
}

if jwtSVIDEmptyCount != 0 && jwtSVIDEmptyCount != 2 {
return errors.New("all or none of 'jwt_file_name', 'jwt_audience' must be specified")
}

return nil
}

Expand Down
34 changes: 24 additions & 10 deletions pkg/sidecar/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ func TestParseConfig(t *testing.T) {
assert.Equal(t, expectedSvidFileName, c.SvidFileName)
assert.Equal(t, expectedKeyFileName, c.SvidKeyFileName)
assert.Equal(t, expectedSvidBundleFileName, c.SvidBundleFileName)
assert.Equal(t, expectedJWTSVIDFileName, c.JWTSvidFilename)
assert.Equal(t, expectedJWTSVIDFileName, c.JwtSvids[0].JWTSvidFilename)
assert.Equal(t, expectedJWTBundleFileName, c.JWTBundleFilename)
assert.Equal(t, expectedJWTAudience, c.JWTAudience)
assert.Equal(t, expectedJWTAudience, c.JwtSvids[0].JWTAudience)
assert.True(t, c.AddIntermediatesToBundle)
}

Expand All @@ -59,9 +59,11 @@ func TestValidateConfig(t *testing.T) {
{
name: "no error",
config: &Config{
AgentAddress: "path",
JWTAudience: "your-audience",
JWTSvidFilename: "jwt.token",
AgentAddress: "path",
JwtSvids: []JwtConfig{{
JWTSvidFilename: "jwt.token",
JWTAudience: "your-audience",
}},
JWTBundleFilename: "bundle.json",
},
},
Expand All @@ -70,7 +72,7 @@ func TestValidateConfig(t *testing.T) {
config: &Config{
AgentAddress: "path",
},
expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified",
expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), 'jwt_svids', or 'jwt_bundle_file_name' must be fully specified",
},
{
name: "missing svid config",
Expand All @@ -81,12 +83,24 @@ func TestValidateConfig(t *testing.T) {
expectError: "all or none of 'svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name' must be specified",
},
{
name: "missing jwt config",
name: "missing jwt audience",
config: &Config{
AgentAddress: "path",
JWTSvidFilename: "cert.pem",
AgentAddress: "path",
JwtSvids: []JwtConfig{{
JWTSvidFilename: "jwt.token",
}},
},
expectError: "'jwt_audience' is required in 'jwt_svids'",
},
{
name: "missing jwt path",
config: &Config{
AgentAddress: "path",
JwtSvids: []JwtConfig{{
JWTAudience: "my-audience",
}},
},
expectError: "all or none of 'jwt_file_name', 'jwt_audience' must be specified",
expectError: "'jwt_file_name' is required in 'jwt_svids'",
},
// Duplicated field error:
{
Expand Down
33 changes: 18 additions & 15 deletions pkg/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,22 @@ func (s *Sidecar) RunDaemon(ctx context.Context) error {
}()
}

if s.config.JWTSvidFilename != "" && s.config.JWTAudience != "" {
if len(s.config.JwtSvids) > 0 {
jwtSource, err := workloadapi.NewJWTSource(ctx, workloadapi.WithClientOptions(s.getWorkloadAPIAdress()))
if err != nil {
s.config.Log.Fatalf("Error watching JWT svid updates: %v", err)
}
s.jwtSource = jwtSource
defer s.jwtSource.Close()

wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx)
}()
for _, jwtConfig := range s.config.JwtSvids {
jwtConfig := jwtConfig
wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx, jwtConfig.JWTAudience, jwtConfig.JWTSvidFilename)
}()
}
}

wg.Wait()
Expand Down Expand Up @@ -274,14 +277,14 @@ func (s *Sidecar) updateJWTBundle(jwkSet *jwtbundle.Set) {
}
}

func (s *Sidecar) fetchJWTSVID(ctx context.Context) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: s.config.JWTAudience})
func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to fetch JWT SVID: %v", err)
return nil, err
}

_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{s.config.JWTAudience})
_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to parse or validate token: %v", err)
return nil, err
Expand Down Expand Up @@ -312,16 +315,16 @@ func getRefreshInterval(svid *jwtsvid.SVID) time.Duration {
return time.Until(svid.Expiry)/2 + time.Second
}

func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, error) {
func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtSvidFilename string) (*jwtsvid.SVID, error) {
s.config.Log.Debug("Updating JWT SVID")

jwtSVID, err := s.fetchJWTSVID(ctx)
jwtSVID, err := s.fetchJWTSVIDs(ctx, jwtAudience)
if err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
}

filePath := path.Join(s.config.CertDir, s.config.JWTSvidFilename)
filePath := path.Join(s.config.CertDir, jwtSvidFilename)
if err = os.WriteFile(filePath, []byte(jwtSVID.Marshal()), os.ModePerm); err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
Expand All @@ -331,10 +334,10 @@ func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, erro
return jwtSVID, nil
}

func (s *Sidecar) updateJWTSVID(ctx context.Context) {
func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtSvidFilename string) {
retryInterval := createRetryIntervalFunc()
var initialInterval time.Duration
jwtSVID, err := s.performJWTSVIDUpdate(ctx)
jwtSVID, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename)
if err != nil {
// If the first update fails, use the retry interval
initialInterval = retryInterval()
Expand All @@ -350,7 +353,7 @@ func (s *Sidecar) updateJWTSVID(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
jwtSVID, err = s.performJWTSVIDUpdate(ctx)
jwtSVID, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename)
if err == nil {
retryInterval = createRetryIntervalFunc()
ticker.Reset(getRefreshInterval(jwtSVID))
Expand Down
8 changes: 6 additions & 2 deletions test/fixture/config/helper.conf
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ renew_signal = "SIGHUP"
svid_file_name = "svid.pem"
svid_key_file_name = "svid_key.pem"
svid_bundle_file_name = "svid_bundle.pem"
jwt_svid_file_name = "jwt_svid.token"
jwt_bundle_file_name = "jwt_bundle.json"
jwt_audience = "your-audience"
jwt_svids = [
faisal-memon marked this conversation as resolved.
Show resolved Hide resolved
{
jwt_svid_file_name = "jwt_svid.token"
jwt_audience = "your-audience"
}
]
timeout = "10s"
add_intermediates_to_bundle = true
Loading