Skip to content

Commit

Permalink
Allow multiple JWT audiences to be configured (closes #108)
Browse files Browse the repository at this point in the history
  • Loading branch information
keeganwitt committed Dec 13, 2023
1 parent c86aaa3 commit 38cb421
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 29 deletions.
32 changes: 26 additions & 6 deletions pkg/sidecar/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,22 @@ type Config struct {
RenewSignalDeprecated string `hcl:"renewSignal"`

// JWT configuration
JWTAudience string `hcl:"jwt_audience"`
JWTSvidFilename string `hcl:"jwt_svid_file_name"`
JWTBundleFilename string `hcl:"jwt_bundle_file_name"`
JwtSvids []JwtConfig `hcl:"jwt_svids"`
JWTAudienceDeprecated string `hcl:"jwt_audience"`
JWTSvidFilenameDeprecated string `hcl:"jwt_svid_file_name"`
JWTBundleFilename string `hcl:"jwt_bundle_file_name"`

// TODO: is there a reason for this to be exposed? and inside of config?
ReloadExternalProcess func() error
// TODO: is there a reason for this to be exposed? and inside of config?
Log logrus.FieldLogger
}

type JwtConfig struct {
JWTAudience string `hcl:"jwt_audience"`
JWTSvidFilename string `hcl:"jwt_svid_file_name"`
}

// ParseConfig parses the given HCL file into a SidecarConfig struct
func ParseConfig(file string) (*Config, error) {
sidecarConfig := new(Config)
Expand Down Expand Up @@ -120,11 +126,17 @@ func ValidateConfig(c *Config) error {
c.RenewSignal = c.RenewSignalDeprecated
}

for _, jwtConfig := range c.JwtSvids {
if countEmpty(jwtConfig.JWTSvidFilename, jwtConfig.JWTAudience) > 0 {
return errors.New("both 'jwt_file_name' and 'jwt_audience' are required in 'jwt_svids'")
}
}

x509EmptyCount := countEmpty(c.SvidFileName, c.SvidBundleFileName, c.SvidKeyFileName)
jwtSVIDEmptyCount := countEmpty(c.JWTSvidFilename, c.JWTAudience)
jwtSVIDEmptyCount := countEmpty(c.JWTSvidFilenameDeprecated, c.JWTAudienceDeprecated)
jwtBundleEmptyCount := countEmpty(c.SvidBundleFileName)
if x509EmptyCount == 3 && jwtSVIDEmptyCount == 2 && jwtBundleEmptyCount == 1 {
return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified")
if x509EmptyCount == 3 && jwtSVIDEmptyCount == 2 && c.JwtSvids == nil && jwtBundleEmptyCount == 1 {
return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), 'jwt_svids', or ('jwt_bundle_file_name') must be fully specified")
}

if x509EmptyCount != 0 && x509EmptyCount != 3 {
Expand All @@ -135,6 +147,14 @@ func ValidateConfig(c *Config) error {
return errors.New("all or none of 'jwt_file_name', 'jwt_audience' must be specified")
}

if jwtSVIDEmptyCount == 0 {
c.Log.Warn(getWarning("jwt_file_name and jwt_audience", "jwt_svids"))
}

if jwtSVIDEmptyCount != 0 && c.JwtSvids == nil {
return errors.New("must not specify deprecated JWT configs ('jwt_file_name' and 'jwt_audience') and new JWT config ('jwt_svids')")
}

return nil
}

Expand Down
33 changes: 25 additions & 8 deletions pkg/sidecar/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ func TestParseConfig(t *testing.T) {
assert.Equal(t, expectedSvidFileName, c.SvidFileName)
assert.Equal(t, expectedKeyFileName, c.SvidKeyFileName)
assert.Equal(t, expectedSvidBundleFileName, c.SvidBundleFileName)
assert.Equal(t, expectedJWTSVIDFileName, c.JWTSvidFilename)
assert.Equal(t, expectedJWTSVIDFileName, c.JWTSvidFilenameDeprecated)
assert.Equal(t, expectedJWTBundleFileName, c.JWTBundleFilename)
assert.Equal(t, expectedJWTAudience, c.JWTAudience)
assert.Equal(t, expectedJWTAudience, c.JWTAudienceDeprecated)
assert.True(t, c.AddIntermediatesToBundle)
}

Expand All @@ -56,12 +56,29 @@ func TestValidateConfig(t *testing.T) {
SvidBundleFileName: "bundle.pem",
},
},
{
name: "warns on deprecated jwt configs",
config: &Config{
AgentAddress: "path",
JWTAudienceDeprecated: "your-audience",
JWTSvidFilenameDeprecated: "jwt.token",
JWTBundleFilename: "bundle.json",
},
expectLogs: []shortEntry{
{
Level: logrus.WarnLevel,
Message: "jwt_file_name and jwt_audience will be deprecated, should be used as jwt_svids",
},
},
},
{
name: "no error",
config: &Config{
AgentAddress: "path",
JWTAudience: "your-audience",
JWTSvidFilename: "jwt.token",
AgentAddress: "path",
JwtSvids: []JwtConfig{{
JWTSvidFilename: "jwt.token",
JWTAudience: "your-audience",
}},
JWTBundleFilename: "bundle.json",
},
},
Expand All @@ -70,7 +87,7 @@ func TestValidateConfig(t *testing.T) {
config: &Config{
AgentAddress: "path",
},
expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified",
expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), 'jwt_svids', or ('jwt_bundle_file_name') must be fully specified",
},
{
name: "missing svid config",
Expand All @@ -83,8 +100,8 @@ func TestValidateConfig(t *testing.T) {
{
name: "missing jwt config",
config: &Config{
AgentAddress: "path",
JWTSvidFilename: "cert.pem",
AgentAddress: "path",
JWTSvidFilenameDeprecated: "cert.pem",
},
expectError: "all or none of 'jwt_file_name', 'jwt_audience' must be specified",
},
Expand Down
41 changes: 26 additions & 15 deletions pkg/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,30 @@ func (s *Sidecar) RunDaemon(ctx context.Context) error {
}()
}

if s.config.JWTSvidFilename != "" && s.config.JWTAudience != "" {
if s.config.JWTSvidFilenameDeprecated != "" && s.config.JWTAudienceDeprecated != "" {
jwtSource, err := workloadapi.NewJWTSource(ctx, workloadapi.WithClientOptions(s.getWorkloadAPIAdress()))
if err != nil {
s.config.Log.Fatalf("Error watching JWT svid updates: %v", err)
}
s.jwtSource = jwtSource
defer s.jwtSource.Close()

wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx)
}()
if s.config.JwtSvids != nil {
for _, jwtConfig := range s.config.JwtSvids {
jwtConfig := jwtConfig
wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx, jwtConfig.JWTAudience, jwtConfig.JWTSvidFilename)
}()
}
} else {
wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx, s.config.JWTAudienceDeprecated, s.config.JWTSvidFilenameDeprecated)
}()
}
}

wg.Wait()
Expand Down Expand Up @@ -274,14 +285,14 @@ func (s *Sidecar) updateJWTBundle(jwkSet *jwtbundle.Set) {
}
}

func (s *Sidecar) fetchJWTSVID(ctx context.Context) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: s.config.JWTAudience})
func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to fetch JWT SVID: %v", err)
return nil, err
}

_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{s.config.JWTAudience})
_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to parse or validate token: %v", err)
return nil, err
Expand Down Expand Up @@ -312,16 +323,16 @@ func getRefreshInterval(svid *jwtsvid.SVID) time.Duration {
return time.Until(svid.Expiry)/2 + time.Second
}

func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, error) {
func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtSvidFilename string) (*jwtsvid.SVID, error) {
s.config.Log.Debug("Updating JWT SVID")

jwtSVID, err := s.fetchJWTSVID(ctx)
jwtSVID, err := s.fetchJWTSVIDs(ctx, jwtAudience)
if err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
}

filePath := path.Join(s.config.CertDir, s.config.JWTSvidFilename)
filePath := path.Join(s.config.CertDir, jwtSvidFilename)
if err = os.WriteFile(filePath, []byte(jwtSVID.Marshal()), os.ModePerm); err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
Expand All @@ -331,10 +342,10 @@ func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, erro
return jwtSVID, nil
}

func (s *Sidecar) updateJWTSVID(ctx context.Context) {
func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtSvidFilename string) {
retryInterval := createRetryIntervalFunc()
var initialInterval time.Duration
jwtSVID, err := s.performJWTSVIDUpdate(ctx)
jwtSVID, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename)
if err != nil {
// If the first update fails, use the retry interval
initialInterval = retryInterval()
Expand All @@ -350,7 +361,7 @@ func (s *Sidecar) updateJWTSVID(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
jwtSVID, err = s.performJWTSVIDUpdate(ctx)
jwtSVID, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename)
if err == nil {
retryInterval = createRetryIntervalFunc()
ticker.Reset(getRefreshInterval(jwtSVID))
Expand Down

0 comments on commit 38cb421

Please sign in to comment.