Skip to content

Commit

Permalink
Allow multiple JWT tokens to be configured (closes #108) (#109)
Browse files Browse the repository at this point in the history
* Allow multiple JWT SVIDs to be configured (closes #108)

Signed-off-by: Keegan Witt <[email protected]>
  • Loading branch information
keeganwitt committed Jan 11, 2024
1 parent 1aafd98 commit 8d4d3ab
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 39 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ The configuration file is an [HCL](https://github.com/hashicorp/hcl) formatted f
| `svid_file_name` | File name to be used to store the X.509 SVID public certificate in PEM format. | `"svid.pem"` |
| `svid_key_file_name` | File name to be used to store the X.509 SVID private key and public certificate in PEM format. | `"svid_key.pem"` |
| `svid_bundle_file_name` | File name to be used to store the X.509 SVID Bundle in PEM format. | `"svid_bundle.pem"` |
| `jwt_audience` | JWT SVID audience. | `"your-audience"` |
| `jwt_svid_file_name` | File name to be used to store JWT SVID in Base64-encoded string. | `"jwt_svid.token"` |
| `jwt_svids` | An array with the audience and file name to store the JWT SVIDs. File is Base64-encoded string). | `[{jwt_audience="your-audience", jwt_svid_file_name="jwt_svid.token"}]` |
| `jwt_bundle_file_name` | File name to be used to store JWT Bundle in JSON format. | `"jwt_bundle.json"` |

### Configuration example
Expand Down
28 changes: 18 additions & 10 deletions pkg/sidecar/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,20 @@ type Config struct {
RenewSignalDeprecated string `hcl:"renewSignal"`

// JWT configuration
JWTAudience string `hcl:"jwt_audience"`
JWTSvidFilename string `hcl:"jwt_svid_file_name"`
JWTBundleFilename string `hcl:"jwt_bundle_file_name"`
JwtSvids []JwtConfig `hcl:"jwt_svids"`
JWTBundleFilename string `hcl:"jwt_bundle_file_name"`

// TODO: is there a reason for this to be exposed? and inside of config?
ReloadExternalProcess func() error
// TODO: is there a reason for this to be exposed? and inside of config?
Log logrus.FieldLogger
}

type JwtConfig struct {
JWTAudience string `hcl:"jwt_audience"`
JWTSvidFilename string `hcl:"jwt_svid_file_name"`
}

// ParseConfig parses the given HCL file into a SidecarConfig struct
func ParseConfig(file string) (*Config, error) {
sidecarConfig := new(Config)
Expand Down Expand Up @@ -121,21 +125,25 @@ func ValidateConfig(c *Config) error {
c.RenewSignal = c.RenewSignalDeprecated
}

for _, jwtConfig := range c.JwtSvids {
if jwtConfig.JWTSvidFilename == "" {
return errors.New("'jwt_file_name' is required in 'jwt_svids'")
}
if jwtConfig.JWTAudience == "" {
return errors.New("'jwt_audience' is required in 'jwt_svids'")
}
}

x509EmptyCount := countEmpty(c.SvidFileName, c.SvidBundleFileName, c.SvidKeyFileName)
jwtSVIDEmptyCount := countEmpty(c.JWTSvidFilename, c.JWTAudience)
jwtBundleEmptyCount := countEmpty(c.SvidBundleFileName)
if x509EmptyCount == 3 && jwtSVIDEmptyCount == 2 && jwtBundleEmptyCount == 1 {
return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified")
if x509EmptyCount == 3 && len(c.JwtSvids) == 0 && jwtBundleEmptyCount == 1 {
return errors.New("at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), 'jwt_svids', or 'jwt_bundle_file_name' must be fully specified")
}

if x509EmptyCount != 0 && x509EmptyCount != 3 {
return errors.New("all or none of 'svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name' must be specified")
}

if jwtSVIDEmptyCount != 0 && jwtSVIDEmptyCount != 2 {
return errors.New("all or none of 'jwt_file_name', 'jwt_audience' must be specified")
}

return nil
}

Expand Down
34 changes: 24 additions & 10 deletions pkg/sidecar/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ func TestParseConfig(t *testing.T) {
assert.Equal(t, expectedSvidFileName, c.SvidFileName)
assert.Equal(t, expectedKeyFileName, c.SvidKeyFileName)
assert.Equal(t, expectedSvidBundleFileName, c.SvidBundleFileName)
assert.Equal(t, expectedJWTSVIDFileName, c.JWTSvidFilename)
assert.Equal(t, expectedJWTSVIDFileName, c.JwtSvids[0].JWTSvidFilename)
assert.Equal(t, expectedJWTBundleFileName, c.JWTBundleFilename)
assert.Equal(t, expectedJWTAudience, c.JWTAudience)
assert.Equal(t, expectedJWTAudience, c.JwtSvids[0].JWTAudience)
assert.True(t, c.AddIntermediatesToBundle)
}

Expand All @@ -59,9 +59,11 @@ func TestValidateConfig(t *testing.T) {
{
name: "no error",
config: &Config{
AgentAddress: "path",
JWTAudience: "your-audience",
JWTSvidFilename: "jwt.token",
AgentAddress: "path",
JwtSvids: []JwtConfig{{
JWTSvidFilename: "jwt.token",
JWTAudience: "your-audience",
}},
JWTBundleFilename: "bundle.json",
},
},
Expand All @@ -70,7 +72,7 @@ func TestValidateConfig(t *testing.T) {
config: &Config{
AgentAddress: "path",
},
expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), ('jwt_file_name', 'jwt_audience'), or ('jwt_bundle_file_name') must be fully specified",
expectError: "at least one of the sets ('svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name'), 'jwt_svids', or 'jwt_bundle_file_name' must be fully specified",
},
{
name: "missing svid config",
Expand All @@ -81,12 +83,24 @@ func TestValidateConfig(t *testing.T) {
expectError: "all or none of 'svid_file_name', 'svid_key_file_name', 'svid_bundle_file_name' must be specified",
},
{
name: "missing jwt config",
name: "missing jwt audience",
config: &Config{
AgentAddress: "path",
JWTSvidFilename: "cert.pem",
AgentAddress: "path",
JwtSvids: []JwtConfig{{
JWTSvidFilename: "jwt.token",
}},
},
expectError: "'jwt_audience' is required in 'jwt_svids'",
},
{
name: "missing jwt path",
config: &Config{
AgentAddress: "path",
JwtSvids: []JwtConfig{{
JWTAudience: "my-audience",
}},
},
expectError: "all or none of 'jwt_file_name', 'jwt_audience' must be specified",
expectError: "'jwt_file_name' is required in 'jwt_svids'",
},
// Duplicated field error:
{
Expand Down
33 changes: 18 additions & 15 deletions pkg/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,22 @@ func (s *Sidecar) RunDaemon(ctx context.Context) error {
}()
}

if s.config.JWTSvidFilename != "" && s.config.JWTAudience != "" {
if len(s.config.JwtSvids) > 0 {
jwtSource, err := workloadapi.NewJWTSource(ctx, workloadapi.WithClientOptions(s.getWorkloadAPIAdress()))
if err != nil {
s.config.Log.Fatalf("Error watching JWT svid updates: %v", err)
}
s.jwtSource = jwtSource
defer s.jwtSource.Close()

wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx)
}()
for _, jwtConfig := range s.config.JwtSvids {
jwtConfig := jwtConfig
wg.Add(1)
go func() {
defer wg.Done()
s.updateJWTSVID(ctx, jwtConfig.JWTAudience, jwtConfig.JWTSvidFilename)
}()
}
}

wg.Wait()
Expand Down Expand Up @@ -278,14 +281,14 @@ func (s *Sidecar) updateJWTBundle(jwkSet *jwtbundle.Set) {
}
}

func (s *Sidecar) fetchJWTSVID(ctx context.Context) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: s.config.JWTAudience})
func (s *Sidecar) fetchJWTSVIDs(ctx context.Context, jwtAudience string) (*jwtsvid.SVID, error) {
jwtSVID, err := s.jwtSource.FetchJWTSVID(ctx, jwtsvid.Params{Audience: jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to fetch JWT SVID: %v", err)
return nil, err
}

_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{s.config.JWTAudience})
_, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), s.jwtSource, []string{jwtAudience})
if err != nil {
s.config.Log.Errorf("Unable to parse or validate token: %v", err)
return nil, err
Expand Down Expand Up @@ -316,16 +319,16 @@ func getRefreshInterval(svid *jwtsvid.SVID) time.Duration {
return time.Until(svid.Expiry)/2 + time.Second
}

func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, error) {
func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context, jwtAudience string, jwtSvidFilename string) (*jwtsvid.SVID, error) {
s.config.Log.Debug("Updating JWT SVID")

jwtSVID, err := s.fetchJWTSVID(ctx)
jwtSVID, err := s.fetchJWTSVIDs(ctx, jwtAudience)
if err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
}

filePath := path.Join(s.config.CertDir, s.config.JWTSvidFilename)
filePath := path.Join(s.config.CertDir, jwtSvidFilename)
if err = os.WriteFile(filePath, []byte(jwtSVID.Marshal()), os.ModePerm); err != nil {
s.config.Log.Errorf("Unable to update JWT SVID: %v", err)
return nil, err
Expand All @@ -335,10 +338,10 @@ func (s *Sidecar) performJWTSVIDUpdate(ctx context.Context) (*jwtsvid.SVID, erro
return jwtSVID, nil
}

func (s *Sidecar) updateJWTSVID(ctx context.Context) {
func (s *Sidecar) updateJWTSVID(ctx context.Context, jwtAudience string, jwtSvidFilename string) {
retryInterval := createRetryIntervalFunc()
var initialInterval time.Duration
jwtSVID, err := s.performJWTSVIDUpdate(ctx)
jwtSVID, err := s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename)
if err != nil {
// If the first update fails, use the retry interval
initialInterval = retryInterval()
Expand All @@ -354,7 +357,7 @@ func (s *Sidecar) updateJWTSVID(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
jwtSVID, err = s.performJWTSVIDUpdate(ctx)
jwtSVID, err = s.performJWTSVIDUpdate(ctx, jwtAudience, jwtSvidFilename)
if err == nil {
retryInterval = createRetryIntervalFunc()
ticker.Reset(getRefreshInterval(jwtSVID))
Expand Down
8 changes: 6 additions & 2 deletions test/fixture/config/helper.conf
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ renew_signal = "SIGHUP"
svid_file_name = "svid.pem"
svid_key_file_name = "svid_key.pem"
svid_bundle_file_name = "svid_bundle.pem"
jwt_svid_file_name = "jwt_svid.token"
jwt_bundle_file_name = "jwt_bundle.json"
jwt_audience = "your-audience"
jwt_svids = [
{
jwt_svid_file_name = "jwt_svid.token"
jwt_audience = "your-audience"
}
]
timeout = "10s"
add_intermediates_to_bundle = true

0 comments on commit 8d4d3ab

Please sign in to comment.