diff --git a/README.md b/README.md index 3d4f6cdc..975421b9 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,7 @@ The configuration file is an [HCL](https://github.com/hashicorp/hcl) formatted f | `svid_file_name` | File name to be used to store the X.509 SVID public certificate in PEM format. | `"svid.pem"` | | `svid_key_file_name` | File name to be used to store the X.509 SVID private key and public certificate in PEM format. | `"svid_key.pem"` | | `svid_bundle_file_name` | File name to be used to store the X.509 SVID Bundle in PEM format. | `"svid_bundle.pem"` | - | `jwt_audience` | JWT SVID audience. | `"your-audience"` | - | `jwt_svid_file_name` | File name to be used to store JWT SVID in Base64-encoded string. | `"jwt_svid.token"` | + | `jwt_svids` | An array with the audience and file name to store the JWT SVIDs. File is Base64-encoded string). | `[{jwt_audience="your-audience", jwt_svid_file_name="jwt_svid.token"}]` | | `jwt_bundle_file_name` | File name to be used to store JWT Bundle in JSON format. | `"jwt_bundle.json"` | ### Configuration example diff --git a/pkg/sidecar/config.go b/pkg/sidecar/config.go index c6af3202..90457493 100644 --- a/pkg/sidecar/config.go +++ b/pkg/sidecar/config.go @@ -33,9 +33,8 @@ type Config struct { RenewSignalDeprecated string `hcl:"renewSignal"` // JWT configuration - JWTAudience string `hcl:"jwt_audience"` - JWTSvidFilename string `hcl:"jwt_svid_file_name"` - JWTBundleFilename string `hcl:"jwt_bundle_file_name"` + JwtSvids []JwtConfig `hcl:"jwt_svids"` + JWTBundleFilename string `hcl:"jwt_bundle_file_name"` // TODO: is there a reason for this to be exposed? and inside of config? ReloadExternalProcess func() error @@ -43,6 +42,11 @@ type Config struct { Log logrus.FieldLogger } +type JwtConfig struct { + JWTAudience string `hcl:"jwt_audience"` + JWTSvidFilename string `hcl:"jwt_svid_file_name"` +} + // ParseConfig parses the given HCL file into a SidecarConfig struct func ParseConfig(file string) (*Config, error) { sidecarConfig := new(Config) @@ -121,21 +125,25 @@ func ValidateConfig(c *Config) error { c.RenewSignal = c.RenewSignalDeprecated } + for _, jwtConfig := range c.JwtSvids { + if jwtConfig.JWTSvidFilename == "" { + return errors.New("'jwt_file_name' is required in 'jwt_svids'") + } + if jwtConfig.JWTAudience == "" { + return errors.New("'jwt_audience' is required in 'jwt_svids'") + } + } + x509EmptyCount := countEmpty(c.SvidFileName, c.SvidBundleFileName, c.SvidKeyFileName) - jwtSVIDEmptyCount := countEmpty(c.JWTSvidFilename, c.JWTAudience) jwtBundleEmptyCount := countEmpty(c.SvidBundleFileName) - if x509EmptyCount == 3 && jwtSVIDEmptyCount == 2 && jwtBundleEmptyCount == 1 { - return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified") + if x509EmptyCount == 3 && len(c.JwtSvids) == 0 && jwtBundleEmptyCount == 1 { + return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), 'jwt_svids', or 'jwt_bundle_file_name' must be fully specified") } if x509EmptyCount != 0 && x509EmptyCount != 3 { return errors.New("all or none of 'svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name' must be specified") } - if jwtSVIDEmptyCount != 0 && jwtSVIDEmptyCount != 2 { - return errors.New("all or none of 'jwt_file_name', 'jwt_audience' must be specified") - } - return nil } diff --git a/pkg/sidecar/config_test.go b/pkg/sidecar/config_test.go index 34e1580d..8d3f0570 100644 --- a/pkg/sidecar/config_test.go +++ b/pkg/sidecar/config_test.go @@ -34,9 +34,9 @@ func TestParseConfig(t *testing.T) { assert.Equal(t, expectedSvidFileName, c.SvidFileName) assert.Equal(t, expectedKeyFileName, c.SvidKeyFileName) assert.Equal(t, expectedSvidBundleFileName, c.SvidBundleFileName) - assert.Equal(t, expectedJWTSVIDFileName, c.JWTSvidFilename) + assert.Equal(t, expectedJWTSVIDFileName, c.JwtSvids[0].JWTSvidFilename) assert.Equal(t, expectedJWTBundleFileName, c.JWTBundleFilename) - assert.Equal(t, expectedJWTAudience, c.JWTAudience) + assert.Equal(t, expectedJWTAudience, c.JwtSvids[0].JWTAudience) assert.True(t, c.AddIntermediatesToBundle) } @@ -59,9 +59,11 @@ func TestValidateConfig(t *testing.T) { { name: "no error", config: &Config{ - AgentAddress: "path", - JWTAudience: "your-audience", - JWTSvidFilename: "jwt.token", + AgentAddress: "path", + JwtSvids: []JwtConfig{{ + JWTSvidFilename: "jwt.token", + JWTAudience: "your-audience", + }}, JWTBundleFilename: "bundle.json", }, }, @@ -70,7 +72,7 @@ func TestValidateConfig(t *testing.T) { config: &Config{ AgentAddress: "path", }, - expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified", + expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), 'jwt_svids', or 'jwt_bundle_file_name' must be fully specified", }, { name: "missing svid config", @@ -81,12 +83,24 @@ func TestValidateConfig(t *testing.T) { expectError: "all or none of 'svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name' must be specified", }, { - name: "missing jwt config", + name: "missing jwt audience", config: &Config{ - AgentAddress: "path", - JWTSvidFilename: "cert.pem", + AgentAddress: "path", + JwtSvids: []JwtConfig{{ + JWTSvidFilename: "jwt.token", + }}, + }, + expectError: "'jwt_audience' is required in 'jwt_svids'", + }, + { + name: "missing jwt path", + config: &Config{ + AgentAddress: "path", + JwtSvids: []JwtConfig{{ + JWTAudience: "my-audience", + }}, }, - expectError: "all or none of 'jwt_file_name', 'jwt_audience' must be specified", + expectError: "'jwt_file_name' is required in 'jwt_svids'", }, // Duplicated field error: { diff --git a/pkg/sidecar/sidecar.go b/pkg/sidecar/sidecar.go index 54d7b70f..67bd5e91 100644 --- a/pkg/sidecar/sidecar.go +++ b/pkg/sidecar/sidecar.go @@ -103,7 +103,7 @@ func (s *Sidecar) RunDaemon(ctx context.Context) error { }() } - if s.config.JWTSvidFilename != "" && s.config.JWTAudience != "" { + if len(s.config.JwtSvids) > 0 { jwtSource, err := workloadapi.NewJWTSource(ctx, workloadapi.WithClientOptions(s.getWorkloadAPIAdress())) if err != nil { s.config.Log.Fatalf("Error watching JWT svid updates: %v", err) @@ -111,11 +111,14 @@ func (s *Sidecar) RunDaemon(ctx context.Context) error { s.jwtSource = jwtSource defer s.jwtSource.Close() - wg.Add(1) - go func() { - defer wg.Done() - s.updateJWTSVID(ctx) - }() + for _, jwtConfig := range s.config.JwtSvids { + jwtConfig := jwtConfig + wg.Add(1) + go func() { + defer wg.Done() + s.updateJWTSVID(ctx, jwtConfig.JWTAudience, jwtConfig.JWTSvidFilename) + }() + } } wg.Wait() @@ -278,14 +281,14 @@ func (s *Sidecar) updateJWTBundle(jwkSet *jwtbundle.Set) { } } -func (s *Sidecar) fetchJWTSVID(ctx context.Context) (*jwtsvid.SVID, error) { - jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: s.config.JWTAudience}) +func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string) (*jwtsvid.SVID, error) { + jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: jwtAudience}) if err != nil { s.config.Log.Errorf("Unable to fetch JWT SVID: %v", err) return nil, err } - _, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{s.config.JWTAudience}) + _, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{jwtAudience}) if err != nil { s.config.Log.Errorf("Unable to parse or validate token: %v", err) return nil, err @@ -316,16 +319,16 @@ func getRefreshInterval(svid *jwtsvid.SVID) time.Duration { return time.Until(svid.Expiry)/2 + time.Second } -func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, error) { +func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtSvidFilename string) (*jwtsvid.SVID, error) { s.config.Log.Debug("Updating JWT SVID") - jwtSVID, err := s.fetchJWTSVID(ctx) + jwtSVID, err := s.fetchJWTSVIDs(ctx, jwtAudience) if err != nil { s.config.Log.Errorf("Unable to update JWT SVID: %v", err) return nil, err } - filePath := path.Join(s.config.CertDir, s.config.JWTSvidFilename) + filePath := path.Join(s.config.CertDir, jwtSvidFilename) if err = os.WriteFile(filePath, []byte(jwtSVID.Marshal()), os.ModePerm); err != nil { s.config.Log.Errorf("Unable to update JWT SVID: %v", err) return nil, err @@ -335,10 +338,10 @@ func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, erro return jwtSVID, nil } -func (s *Sidecar) updateJWTSVID(ctx context.Context) { +func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtSvidFilename string) { retryInterval := createRetryIntervalFunc() var initialInterval time.Duration - jwtSVID, err := s.performJWTSVIDUpdate(ctx) + jwtSVID, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename) if err != nil { // If the first update fails, use the retry interval initialInterval = retryInterval() @@ -354,7 +357,7 @@ func (s *Sidecar) updateJWTSVID(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - jwtSVID, err = s.performJWTSVIDUpdate(ctx) + jwtSVID, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename) if err == nil { retryInterval = createRetryIntervalFunc() ticker.Reset(getRefreshInterval(jwtSVID)) diff --git a/test/fixture/config/helper.conf b/test/fixture/config/helper.conf index 87b1b567..ccd57742 100644 --- a/test/fixture/config/helper.conf +++ b/test/fixture/config/helper.conf @@ -6,8 +6,12 @@ renew_signal = "SIGHUP" svid_file_name = "svid.pem" svid_key_file_name = "svid_key.pem" svid_bundle_file_name = "svid_bundle.pem" -jwt_svid_file_name = "jwt_svid.token" jwt_bundle_file_name = "jwt_bundle.json" -jwt_audience = "your-audience" +jwt_svids = [ + { + jwt_svid_file_name = "jwt_svid.token" + jwt_audience = "your-audience" + } +] timeout = "10s" add_intermediates_to_bundle = true