diff --git a/.buildkite/scripts/steps/integration_tests.sh b/.buildkite/scripts/steps/integration_tests.sh index b747d73b5a8..ea28f1f503d 100755 --- a/.buildkite/scripts/steps/integration_tests.sh +++ b/.buildkite/scripts/steps/integration_tests.sh @@ -13,7 +13,13 @@ TESTS_EXIT_STATUS=$? set -e # HTML report -go install github.com/alexec/junit2html@latest -junit2html < build/TEST-go-integration.xml > build/TEST-report.html +outputXML="build/TEST-go-integration.xml" + +if [ -f "$outputXML" ]; then + go install github.com/alexec/junit2html@latest + junit2html < "$outputXML" > build/TEST-report.html +else + echo "Cannot generate HTML test report: $outputXML not found" +fi exit $TESTS_EXIT_STATUS diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_policy_change_test.go b/internal/pkg/agent/application/actions/handlers/handler_action_policy_change_test.go index f17bf1b8fa0..cdfc2a7110a 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_policy_change_test.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_policy_change_test.go @@ -33,7 +33,11 @@ import ( func TestPolicyChange(t *testing.T) { log, _ := logger.New("", false) ack := noopacker.New() - agentInfo, _ := info.NewAgentInfo(true) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + agentInfo, _ := info.NewAgentInfo(ctx, true) nullStore := &storage.NullStore{} t.Run("Receive a config change and successfully emits a raw configuration", func(t *testing.T) { @@ -59,7 +63,10 @@ func TestPolicyChange(t *testing.T) { func TestPolicyAcked(t *testing.T) { log, _ := logger.New("", false) - agentInfo, _ := info.NewAgentInfo(true) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + agentInfo, _ := info.NewAgentInfo(ctx, true) nullStore := &storage.NullStore{} t.Run("Config change should ACK", func(t *testing.T) { diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_settings.go b/internal/pkg/agent/application/actions/handlers/handler_action_settings.go index e6c22bd4e5f..6fd348baca6 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_settings.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_settings.go @@ -55,7 +55,7 @@ func (h *Settings) Handle(ctx context.Context, a fleetapi.Action, acker acker.Ac return fmt.Errorf("failed to unpack log level: %w", err) } - if err := h.agentInfo.SetLogLevel(action.LogLevel); err != nil { + if err := h.agentInfo.SetLogLevel(ctx, action.LogLevel); err != nil { return fmt.Errorf("failed to update log level: %w", err) } diff --git a/internal/pkg/agent/application/actions/handlers/handler_action_upgrade_test.go b/internal/pkg/agent/application/actions/handlers/handler_action_upgrade_test.go index f767018adb7..17de63af699 100644 --- a/internal/pkg/agent/application/actions/handlers/handler_action_upgrade_test.go +++ b/internal/pkg/agent/application/actions/handlers/handler_action_upgrade_test.go @@ -57,7 +57,7 @@ func TestUpgradeHandler(t *testing.T) { defer cancel() log, _ := logger.New("", false) - agentInfo, _ := info.NewAgentInfo(true) + agentInfo, _ := info.NewAgentInfo(ctx, true) msgChan := make(chan string) // Create and start the coordinator @@ -89,7 +89,7 @@ func TestUpgradeHandlerSameVersion(t *testing.T) { defer cancel() log, _ := logger.New("", false) - agentInfo, _ := info.NewAgentInfo(true) + agentInfo, _ := info.NewAgentInfo(ctx, true) msgChan := make(chan string) // Create and start the Coordinator @@ -123,7 +123,7 @@ func TestUpgradeHandlerNewVersion(t *testing.T) { defer cancel() log, _ := logger.New("", false) - agentInfo, _ := info.NewAgentInfo(true) + agentInfo, _ := info.NewAgentInfo(ctx, true) msgChan := make(chan string) // Create and start the Coordinator diff --git a/internal/pkg/agent/application/application.go b/internal/pkg/agent/application/application.go index 1980e3c9e52..445bba8434f 100644 --- a/internal/pkg/agent/application/application.go +++ b/internal/pkg/agent/application/application.go @@ -5,6 +5,7 @@ package application import ( + "context" "fmt" "time" @@ -34,6 +35,7 @@ import ( // New creates a new Agent and bootstrap the required subsystem. func New( + ctx context.Context, log *logger.Logger, baseLogger *logger.Logger, logLevel logp.Level, @@ -139,7 +141,7 @@ func New( } else { isManaged = true var store storage.Store - store, cfg, err = mergeFleetConfig(rawConfig) + store, cfg, err = mergeFleetConfig(ctx, rawConfig) if err != nil { return nil, nil, nil, err } @@ -158,7 +160,7 @@ func New( EndpointSignedComponentModifier(), ) - managed, err = newManagedConfigManager(log, agentInfo, cfg, store, runtime, fleetInitTimeout) + managed, err = newManagedConfigManager(ctx, log, agentInfo, cfg, store, runtime, fleetInitTimeout) if err != nil { return nil, nil, nil, err } @@ -196,9 +198,9 @@ func New( return coord, configMgr, composable, nil } -func mergeFleetConfig(rawConfig *config.Config) (storage.Store, *configuration.Configuration, error) { +func mergeFleetConfig(ctx context.Context, rawConfig *config.Config) (storage.Store, *configuration.Configuration, error) { path := paths.AgentConfigFile() - store := storage.NewEncryptedDiskStore(path) + store := storage.NewEncryptedDiskStore(ctx, path) reader, err := store.Load() if err != nil { diff --git a/internal/pkg/agent/application/application_test.go b/internal/pkg/agent/application/application_test.go index b831105be21..cf67f19c6cf 100644 --- a/internal/pkg/agent/application/application_test.go +++ b/internal/pkg/agent/application/application_test.go @@ -5,6 +5,7 @@ package application import ( + "context" "fmt" "testing" "time" @@ -37,7 +38,7 @@ func TestMergeFleetConfig(t *testing.T) { } rawConfig := config.MustNewConfigFrom(cfg) - storage, conf, err := mergeFleetConfig(rawConfig) + storage, conf, err := mergeFleetConfig(context.Background(), rawConfig) require.NoError(t, err) assert.NotNil(t, storage) assert.NotNil(t, conf) @@ -48,7 +49,11 @@ func TestMergeFleetConfig(t *testing.T) { func TestLimitsLog(t *testing.T) { log, obs := logger.NewTesting("TestLimitsLog") + ctx, cn := context.WithCancel(context.Background()) + defer cn() + _, _, _, err := New( + ctx, log, log, logp.DebugLevel, diff --git a/internal/pkg/agent/application/coordinator/coordinator_test.go b/internal/pkg/agent/application/coordinator/coordinator_test.go index a22c65a1eee..53d4524a1cb 100644 --- a/internal/pkg/agent/application/coordinator/coordinator_test.go +++ b/internal/pkg/agent/application/coordinator/coordinator_test.go @@ -100,7 +100,7 @@ func TestCoordinator_State_Starting(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - coord, cfgMgr, varsMgr := createCoordinator(t) + coord, cfgMgr, varsMgr := createCoordinator(t, ctx) stateChan := coord.StateSubscribe(ctx, 32) go func() { err := coord.Run(ctx) @@ -147,7 +147,7 @@ func TestCoordinator_State_ConfigError_NotManaged(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - coord, cfgMgr, varsMgr := createCoordinator(t) + coord, cfgMgr, varsMgr := createCoordinator(t, ctx) go func() { err := coord.Run(ctx) if errors.Is(err, context.Canceled) { @@ -190,7 +190,7 @@ func TestCoordinator_State_ConfigError_Managed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - coord, cfgMgr, varsMgr := createCoordinator(t, ManagedCoordinator(true)) + coord, cfgMgr, varsMgr := createCoordinator(t, ctx, ManagedCoordinator(true)) go func() { err := coord.Run(ctx) if errors.Is(err, context.Canceled) { @@ -232,7 +232,7 @@ func TestCoordinator_StateSubscribe(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - coord, cfgMgr, varsMgr := createCoordinator(t) + coord, cfgMgr, varsMgr := createCoordinator(t, ctx) go func() { err := coord.Run(ctx) if errors.Is(err, context.Canceled) { @@ -392,7 +392,7 @@ func TestCoordinator_ReExec(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - coord, cfgMgr, varsMgr := createCoordinator(t) + coord, cfgMgr, varsMgr := createCoordinator(t, ctx) go func() { err := coord.Run(ctx) if errors.Is(err, context.Canceled) { @@ -431,7 +431,7 @@ func TestCoordinator_Upgrade(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - coord, cfgMgr, varsMgr := createCoordinator(t) + coord, cfgMgr, varsMgr := createCoordinator(t, ctx) go func() { err := coord.Run(ctx) if errors.Is(err, context.Canceled) { @@ -472,7 +472,7 @@ func ManagedCoordinator(managed bool) CoordinatorOpt { // createCoordinator creates a coordinator that using a fake config manager and a fake vars manager. // // The runtime specifications is set up to use both the fake component and fake shipper. -func createCoordinator(t *testing.T, opts ...CoordinatorOpt) (*Coordinator, *fakeConfigManager, *fakeVarsManager) { +func createCoordinator(t *testing.T, ctx context.Context, opts ...CoordinatorOpt) (*Coordinator, *fakeConfigManager, *fakeVarsManager) { t.Helper() o := &createCoordinatorOpts{} @@ -482,7 +482,7 @@ func createCoordinator(t *testing.T, opts ...CoordinatorOpt) (*Coordinator, *fak l := newErrorLogger(t) - ai, err := info.NewAgentInfo(false) + ai, err := info.NewAgentInfo(ctx, false) require.NoError(t, err) componentSpec := component.InputRuntimeSpec{ diff --git a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go index b844a188e0b..363df4bdbe5 100644 --- a/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go +++ b/internal/pkg/agent/application/gateway/fleet/fleet_gateway.go @@ -308,7 +308,7 @@ func (f *fleetGateway) convertToCheckinComponents(components []runtime.Component } func (f *fleetGateway) execute(ctx context.Context) (*fleetapi.CheckinResponse, time.Duration, error) { - ecsMeta, err := info.Metadata(f.log) + ecsMeta, err := info.Metadata(ctx, f.log) if err != nil { f.log.Error(errors.New("failed to load metadata", err)) } diff --git a/internal/pkg/agent/application/info/agent_id.go b/internal/pkg/agent/application/info/agent_id.go index 8056fd0cce1..34ba79ea0c0 100644 --- a/internal/pkg/agent/application/info/agent_id.go +++ b/internal/pkg/agent/application/info/agent_id.go @@ -6,6 +6,7 @@ package info import ( "bytes" + "context" "fmt" "io" "time" @@ -40,8 +41,8 @@ type ioStore interface { } // updateLogLevel updates log level and persists it to disk. -func updateLogLevel(level string) error { - ai, err := loadAgentInfoWithBackoff(false, defaultLogLevel, false) +func updateLogLevel(ctx context.Context, level string) error { + ai, err := loadAgentInfoWithBackoff(ctx, false, defaultLogLevel, false) if err != nil { return err } @@ -52,7 +53,7 @@ func updateLogLevel(level string) error { } agentConfigFile := paths.AgentConfigFile() - diskStore := storage.NewEncryptedDiskStore(agentConfigFile) + diskStore := storage.NewEncryptedDiskStore(ctx, agentConfigFile) ai.LogLevel = level return updateAgentInfo(diskStore, ai) @@ -173,7 +174,7 @@ func yamlToReader(in interface{}) (io.Reader, error) { return bytes.NewReader(data), nil } -func loadAgentInfoWithBackoff(forceUpdate bool, logLevel string, createAgentID bool) (*persistentAgentInfo, error) { +func loadAgentInfoWithBackoff(ctx context.Context, forceUpdate bool, logLevel string, createAgentID bool) (*persistentAgentInfo, error) { var err error var ai *persistentAgentInfo @@ -182,7 +183,7 @@ func loadAgentInfoWithBackoff(forceUpdate bool, logLevel string, createAgentID b for i := 0; i <= maxRetriesloadAgentInfo; i++ { backExp.Wait() - ai, err = loadAgentInfo(forceUpdate, logLevel, createAgentID) + ai, err = loadAgentInfo(ctx, forceUpdate, logLevel, createAgentID) if !errors.Is(err, filelock.ErrAppAlreadyRunning) { break } @@ -192,7 +193,7 @@ func loadAgentInfoWithBackoff(forceUpdate bool, logLevel string, createAgentID b return ai, err } -func loadAgentInfo(forceUpdate bool, logLevel string, createAgentID bool) (*persistentAgentInfo, error) { +func loadAgentInfo(ctx context.Context, forceUpdate bool, logLevel string, createAgentID bool) (*persistentAgentInfo, error) { idLock := paths.AgentConfigFileLock() if err := idLock.TryLock(); err != nil { return nil, err @@ -201,7 +202,7 @@ func loadAgentInfo(forceUpdate bool, logLevel string, createAgentID bool) (*pers defer idLock.Unlock() agentConfigFile := paths.AgentConfigFile() - diskStore := storage.NewEncryptedDiskStore(agentConfigFile) + diskStore := storage.NewEncryptedDiskStore(ctx, agentConfigFile) agentInfo, err := getInfoFromStore(diskStore, logLevel) if err != nil { diff --git a/internal/pkg/agent/application/info/agent_info.go b/internal/pkg/agent/application/info/agent_info.go index 649c6d5eacb..f335b2eb76e 100644 --- a/internal/pkg/agent/application/info/agent_info.go +++ b/internal/pkg/agent/application/info/agent_info.go @@ -5,6 +5,8 @@ package info import ( + "context" + "github.com/elastic/elastic-agent/internal/pkg/release" "github.com/elastic/elastic-agent/pkg/core/logger" ) @@ -25,8 +27,8 @@ type AgentInfo struct { // new unique identifier for agent. // If agent config file does not exist it gets created. // Initiates log level to predefined value. -func NewAgentInfoWithLog(level string, createAgentID bool) (*AgentInfo, error) { - agentInfo, err := loadAgentInfoWithBackoff(false, level, createAgentID) +func NewAgentInfoWithLog(ctx context.Context, level string, createAgentID bool) (*AgentInfo, error) { + agentInfo, err := loadAgentInfoWithBackoff(ctx, false, level, createAgentID) if err != nil { return nil, err } @@ -43,8 +45,8 @@ func NewAgentInfoWithLog(level string, createAgentID bool) (*AgentInfo, error) { // this created ID otherwise it generates // new unique identifier for agent. // If agent config file does not exist it gets created. -func NewAgentInfo(createAgentID bool) (*AgentInfo, error) { - return NewAgentInfoWithLog(defaultLogLevel, createAgentID) +func NewAgentInfo(ctx context.Context, createAgentID bool) (*AgentInfo, error) { + return NewAgentInfoWithLog(ctx, defaultLogLevel, createAgentID) } // LogLevel retrieves a log level. @@ -56,8 +58,8 @@ func (i *AgentInfo) LogLevel() string { } // SetLogLevel updates log level of agent. -func (i *AgentInfo) SetLogLevel(level string) error { - if err := updateLogLevel(level); err != nil { +func (i *AgentInfo) SetLogLevel(ctx context.Context, level string) error { + if err := updateLogLevel(ctx, level); err != nil { return err } @@ -66,8 +68,8 @@ func (i *AgentInfo) SetLogLevel(level string) error { } // ReloadID reloads agent info ID from configuration file. -func (i *AgentInfo) ReloadID() error { - newInfo, err := NewAgentInfoWithLog(i.logLevel, false) +func (i *AgentInfo) ReloadID(ctx context.Context) error { + newInfo, err := NewAgentInfoWithLog(ctx, i.logLevel, false) if err != nil { return err } diff --git a/internal/pkg/agent/application/info/agent_metadata.go b/internal/pkg/agent/application/info/agent_metadata.go index b298fdc817c..81ad78a834d 100644 --- a/internal/pkg/agent/application/info/agent_metadata.go +++ b/internal/pkg/agent/application/info/agent_metadata.go @@ -5,6 +5,7 @@ package info import ( + "context" "fmt" "runtime" "strings" @@ -125,8 +126,8 @@ const ( ) // Metadata loads metadata from disk. -func Metadata(l *logger.Logger) (*ECSMeta, error) { - agentInfo, err := NewAgentInfo(false) +func Metadata(ctx context.Context, l *logger.Logger) (*ECSMeta, error) { + agentInfo, err := NewAgentInfo(ctx, false) if err != nil { return nil, fmt.Errorf("failed to create new agent info: %w", err) } diff --git a/internal/pkg/agent/application/managed_mode.go b/internal/pkg/agent/application/managed_mode.go index 1357f3dfeff..6895003c38e 100644 --- a/internal/pkg/agent/application/managed_mode.go +++ b/internal/pkg/agent/application/managed_mode.go @@ -56,6 +56,7 @@ type managedConfigManager struct { } func newManagedConfigManager( + ctx context.Context, log *logger.Logger, agentInfo *info.AgentInfo, cfg *configuration.Configuration, @@ -72,7 +73,7 @@ func newManagedConfigManager( } // Create the state store that will persist the last good policy change on disk. - stateStore, err := store.NewStateStoreWithMigration(log, paths.AgentActionStoreFile(), paths.AgentStateStoreFile()) + stateStore, err := store.NewStateStoreWithMigration(ctx, log, paths.AgentActionStoreFile(), paths.AgentStateStoreFile()) if err != nil { return nil, errors.New(err, fmt.Sprintf("fail to read action store '%s'", paths.AgentActionStoreFile())) } @@ -116,7 +117,7 @@ func (m *managedConfigManager) Run(ctx context.Context) error { } // Reload ID because of win7 sync issue - if err := m.agentInfo.ReloadID(); err != nil { + if err := m.agentInfo.ReloadID(ctx); err != nil { return err } diff --git a/internal/pkg/agent/application/secret/secret.go b/internal/pkg/agent/application/secret/secret.go index bd2ee546454..84d1d3f0b39 100644 --- a/internal/pkg/agent/application/secret/secret.go +++ b/internal/pkg/agent/application/secret/secret.go @@ -6,6 +6,7 @@ package secret import ( + "context" "encoding/json" "fmt" "runtime" @@ -14,6 +15,7 @@ import ( "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" "github.com/elastic/elastic-agent/internal/pkg/agent/vault" + "github.com/elastic/elastic-agent/internal/pkg/agent/vault/aesgcm" ) const agentSecretKey = "secret" @@ -45,14 +47,14 @@ func WithVaultPath(vaultPath string) OptionFunc { } // CreateAgentSecret creates agent secret key if it doesn't exist -func CreateAgentSecret(opts ...OptionFunc) error { - return Create(agentSecretKey, opts...) +func CreateAgentSecret(ctx context.Context, opts ...OptionFunc) error { + return Create(ctx, agentSecretKey, opts...) } // Create creates secret and stores it in the vault under given key -func Create(key string, opts ...OptionFunc) error { +func Create(ctx context.Context, key string, opts ...OptionFunc) error { options := applyOptions(opts...) - v, err := vault.New(options.vaultPath) + v, err := vault.New(ctx, options.vaultPath) if err != nil { return fmt.Errorf("could not create new vault: %w", err) } @@ -63,7 +65,7 @@ func Create(key string, opts ...OptionFunc) error { defer mxCreate.Unlock() // Check if the key exists - exists, err := v.Exists(key) + exists, err := v.Exists(ctx, key) if err != nil { return err } @@ -72,7 +74,7 @@ func Create(key string, opts ...OptionFunc) error { } // Create new AES256 key - k, err := vault.NewKey(vault.AES256) + k, err := aesgcm.NewKey(aesgcm.AES256) if err != nil { return err } @@ -82,31 +84,31 @@ func Create(key string, opts ...OptionFunc) error { CreatedOn: time.Now().UTC(), } - return set(v, key, secret) + return set(ctx, v, key, secret) } // GetAgentSecret read the agent secret from the vault -func GetAgentSecret(opts ...OptionFunc) (secret Secret, err error) { - return Get(agentSecretKey, opts...) +func GetAgentSecret(ctx context.Context, opts ...OptionFunc) (secret Secret, err error) { + return Get(ctx, agentSecretKey, opts...) } // SetAgentSecret saves the agent secret from the vault // This is needed for migration from 8.3.0-8.3.2 to higher versions -func SetAgentSecret(secret Secret, opts ...OptionFunc) error { - return Set(agentSecretKey, secret, opts...) +func SetAgentSecret(ctx context.Context, secret Secret, opts ...OptionFunc) error { + return Set(ctx, agentSecretKey, secret, opts...) } // Get reads the secret key from the vault -func Get(key string, opts ...OptionFunc) (secret Secret, err error) { +func Get(ctx context.Context, key string, opts ...OptionFunc) (secret Secret, err error) { options := applyOptions(opts...) // open vault readonly, will not create the vault directory or the seed it was not created before - v, err := vault.New(options.vaultPath, vault.WithReadonly(true)) + v, err := vault.New(ctx, options.vaultPath, vault.WithReadonly(true)) if err != nil { return secret, err } defer v.Close() - b, err := v.Get(key) + b, err := v.Get(ctx, key) if err != nil { return secret, err } @@ -116,35 +118,35 @@ func Get(key string, opts ...OptionFunc) (secret Secret, err error) { } // Set saves the secret key to the vault -func Set(key string, secret Secret, opts ...OptionFunc) error { +func Set(ctx context.Context, key string, secret Secret, opts ...OptionFunc) error { options := applyOptions(opts...) - v, err := vault.New(options.vaultPath) + v, err := vault.New(ctx, options.vaultPath) if err != nil { return fmt.Errorf("could not create new vault: %w", err) } defer v.Close() - return set(v, key, secret) + return set(ctx, v, key, secret) } -func set(v *vault.Vault, key string, secret Secret) error { +func set(ctx context.Context, v *vault.Vault, key string, secret Secret) error { b, err := json.Marshal(secret) if err != nil { return fmt.Errorf("could not marshal secret: %w", err) } - return v.Set(key, b) + return v.Set(ctx, key, b) } // Remove removes the secret key from the vault -func Remove(key string, opts ...OptionFunc) error { +func Remove(ctx context.Context, key string, opts ...OptionFunc) error { options := applyOptions(opts...) - v, err := vault.New(options.vaultPath) + v, err := vault.New(ctx, options.vaultPath) if err != nil { return fmt.Errorf("could not create new vault: %w", err) } defer v.Close() - return v.Remove(key) + return v.Remove(ctx, key) } func applyOptions(opts ...OptionFunc) options { diff --git a/internal/pkg/agent/application/secret/secret_test.go b/internal/pkg/agent/application/secret/secret_test.go index 7433e8c9c9a..ec90dff8f42 100644 --- a/internal/pkg/agent/application/secret/secret_test.go +++ b/internal/pkg/agent/application/secret/secret_test.go @@ -7,6 +7,7 @@ package secret import ( + "context" "os" "path/filepath" "testing" @@ -14,7 +15,7 @@ import ( "github.com/google/go-cmp/cmp" - "github.com/elastic/elastic-agent/internal/pkg/agent/vault" + "github.com/elastic/elastic-agent/internal/pkg/agent/vault/aesgcm" ) func getTestVaultPath(t *testing.T) string { @@ -31,10 +32,13 @@ func getTestOptions(t *testing.T) []OptionFunc { func TestCreate(t *testing.T) { opts := getTestOptions(t) + ctx, cn := context.WithCancel(context.Background()) + defer cn() + start := time.Now().UTC() keys := []string{"secret1", "secret2", "secret3"} for _, key := range keys { - err := Create(key, opts...) + err := Create(ctx, key, opts...) if err != nil { t.Fatal(err) } @@ -42,7 +46,7 @@ func TestCreate(t *testing.T) { end := time.Now().UTC() for _, key := range keys { - secret, err := Get(key, opts...) + secret, err := Get(ctx, key, opts...) if err != nil { t.Error(err) } @@ -51,14 +55,14 @@ func TestCreate(t *testing.T) { t.Errorf("invalid created on date/time: %v", secret.CreatedOn) } - diff := cmp.Diff(int(vault.AES256), len(secret.Value)) + diff := cmp.Diff(int(aesgcm.AES256), len(secret.Value)) if diff != "" { t.Error(diff) } } for _, key := range keys { - err := Remove(key, opts...) + err := Remove(ctx, key, opts...) if err != nil { t.Fatal(err) } diff --git a/internal/pkg/agent/cmd/enroll.go b/internal/pkg/agent/cmd/enroll.go index ccb85c075ab..1bce5f7e547 100644 --- a/internal/pkg/agent/cmd/enroll.go +++ b/internal/pkg/agent/cmd/enroll.go @@ -393,6 +393,7 @@ func enroll(streams *cli.IOStreams, cmd *cobra.Command) error { } c, err := newEnrollCmd( + ctx, logger, &options, pathConfigFile, diff --git a/internal/pkg/agent/cmd/enroll_cmd.go b/internal/pkg/agent/cmd/enroll_cmd.go index 3f9145435c0..d092aa61492 100644 --- a/internal/pkg/agent/cmd/enroll_cmd.go +++ b/internal/pkg/agent/cmd/enroll_cmd.go @@ -152,6 +152,7 @@ func (e *enrollCmdOption) remoteConfig() (remote.Config, error) { // newEnrollCmd creates a new enroll command that will registers the current beats to the remote // system. func newEnrollCmd( + ctx context.Context, log *logger.Logger, options *enrollCmdOption, configPath string, @@ -160,7 +161,7 @@ func newEnrollCmd( store := storage.NewReplaceOnSuccessStore( configPath, application.DefaultAgentFleetConfig, - storage.NewEncryptedDiskStore(paths.AgentConfigFile()), + storage.NewEncryptedDiskStore(ctx, paths.AgentConfigFile()), ) return newEnrollCmdWithStore( @@ -198,7 +199,7 @@ func (c *enrollCmd) Execute(ctx context.Context, streams *cli.IOStreams) error { // Create encryption key from the agent before touching configuration if !c.options.SkipCreateSecret { - err = secret.CreateAgentSecret() + err = secret.CreateAgentSecret(ctx) if err != nil { return err } @@ -504,7 +505,7 @@ func (c *enrollCmd) enrollWithBackoff(ctx context.Context, persistentConfig map[ func (c *enrollCmd) enroll(ctx context.Context, persistentConfig map[string]interface{}) error { cmd := fleetapi.NewEnrollCmd(c.client) - metadata, err := info.Metadata(c.log) + metadata, err := info.Metadata(ctx, c.log) if err != nil { return errors.New(err, "acquiring metadata failed") } diff --git a/internal/pkg/agent/cmd/inspect.go b/internal/pkg/agent/cmd/inspect.go index 5a593e01ac6..094d4584a06 100644 --- a/internal/pkg/agent/cmd/inspect.go +++ b/internal/pkg/agent/cmd/inspect.go @@ -134,7 +134,7 @@ func inspectConfig(ctx context.Context, cfgPath string, opts inspectConfigOpts, } if !opts.variables && !opts.includeMonitoring { - fullCfg, err := operations.LoadFullAgentConfig(l, cfgPath, true) + fullCfg, err := operations.LoadFullAgentConfig(ctx, l, cfgPath, true) if err != nil { return err } @@ -146,7 +146,7 @@ func inspectConfig(ctx context.Context, cfgPath string, opts inspectConfigOpts, return err } - agentInfo, err := info.NewAgentInfoWithLog("error", false) + agentInfo, err := info.NewAgentInfoWithLog(ctx, "error", false) if err != nil { return fmt.Errorf("could not load agent info: %w", err) } @@ -163,7 +163,7 @@ func inspectConfig(ctx context.Context, cfgPath string, opts inspectConfigOpts, return fmt.Errorf("failed to detect inputs and outputs: %w", err) } - monitorFn, err := getMonitoringFn(cfg) + monitorFn, err := getMonitoringFn(ctx, cfg) if err != nil { return fmt.Errorf("failed to get monitoring: %w", err) } @@ -254,12 +254,12 @@ func inspectComponents(ctx context.Context, cfgPath string, opts inspectComponen return err } - monitorFn, err := getMonitoringFn(m) + monitorFn, err := getMonitoringFn(ctx, m) if err != nil { return fmt.Errorf("failed to get monitoring: %w", err) } - agentInfo, err := info.NewAgentInfoWithLog("error", false) + agentInfo, err := info.NewAgentInfoWithLog(ctx, "error", false) if err != nil { return fmt.Errorf("could not load agent info: %w", err) } @@ -330,7 +330,7 @@ func inspectComponents(ctx context.Context, cfgPath string, opts inspectComponen return printComponents(allowed, blocked, streams) } -func getMonitoringFn(cfg map[string]interface{}) (component.GenerateMonitoringCfgFn, error) { +func getMonitoringFn(ctx context.Context, cfg map[string]interface{}) (component.GenerateMonitoringCfgFn, error) { config, err := config.NewConfigFrom(cfg) if err != nil { return nil, err @@ -341,7 +341,7 @@ func getMonitoringFn(cfg map[string]interface{}) (component.GenerateMonitoringCf return nil, err } - agentInfo, err := info.NewAgentInfoWithLog("error", false) + agentInfo, err := info.NewAgentInfoWithLog(ctx, "error", false) if err != nil { return nil, fmt.Errorf("could not load agent info: %w", err) } @@ -352,7 +352,7 @@ func getMonitoringFn(cfg map[string]interface{}) (component.GenerateMonitoringCf func getConfigWithVariables(ctx context.Context, l *logger.Logger, cfgPath string, timeout time.Duration) (map[string]interface{}, logp.Level, error) { - cfg, err := operations.LoadFullAgentConfig(l, cfgPath, true) + cfg, err := operations.LoadFullAgentConfig(ctx, l, cfgPath, true) if err != nil { return nil, logp.InfoLevel, err } diff --git a/internal/pkg/agent/cmd/run.go b/internal/pkg/agent/cmd/run.go index 9f31c18c4a3..961322a81ca 100644 --- a/internal/pkg/agent/cmd/run.go +++ b/internal/pkg/agent/cmd/run.go @@ -130,7 +130,7 @@ func run(override cfgOverrider, testingMode bool, fleetInitTimeout time.Duration defer cancel() go service.ProcessWindowsControlEvents(stopBeat) - cfg, err := loadConfig(override) + cfg, err := loadConfig(ctx, override) if err != nil { return err } @@ -166,7 +166,7 @@ func run(override cfgOverrider, testingMode bool, fleetInitTimeout time.Duration // The secret is not created here if it exists already from the previous enrollment. // This is needed for compatibility with agent running in standalone mode, // that writes the agentID into fleet.enc (encrypted fleet.yml) before even loading the configuration. - err = secret.CreateAgentSecret() + err = secret.CreateAgentSecret(ctx) if err != nil { return fmt.Errorf("failed to read/write secrets: %w", err) } @@ -174,18 +174,18 @@ func run(override cfgOverrider, testingMode bool, fleetInitTimeout time.Duration // Migrate .yml files if the corresponding .enc does not exist // the encrypted config does not exist but the unencrypted file does - err = migration.MigrateToEncryptedConfig(l, paths.AgentConfigYmlFile(), paths.AgentConfigFile()) + err = migration.MigrateToEncryptedConfig(ctx, l, paths.AgentConfigYmlFile(), paths.AgentConfigFile()) if err != nil { return errors.New(err, "error migrating fleet config") } // the encrypted state does not exist but the unencrypted file does - err = migration.MigrateToEncryptedConfig(l, paths.AgentStateStoreYmlFile(), paths.AgentStateStoreFile()) + err = migration.MigrateToEncryptedConfig(ctx, l, paths.AgentStateStoreYmlFile(), paths.AgentStateStoreFile()) if err != nil { return errors.New(err, "error migrating agent state") } - agentInfo, err := info.NewAgentInfoWithLog(defaultLogLevel(cfg, logLvl.String()), createAgentID) + agentInfo, err := info.NewAgentInfoWithLog(ctx, defaultLogLevel(cfg, logLvl.String()), createAgentID) if err != nil { return errors.New(err, "could not load agent info", @@ -236,7 +236,7 @@ func run(override cfgOverrider, testingMode bool, fleetInitTimeout time.Duration l.Info("APM instrumentation disabled") } - coord, configMgr, composable, err := application.New(l, baseLogger, logLvl, agentInfo, rex, tracer, testingMode, fleetInitTimeout, configuration.IsFleetServerBootstrap(cfg.Fleet), modifiers...) + coord, configMgr, composable, err := application.New(ctx, l, baseLogger, logLvl, agentInfo, rex, tracer, testingMode, fleetInitTimeout, configuration.IsFleetServerBootstrap(cfg.Fleet), modifiers...) if err != nil { return err } @@ -324,7 +324,7 @@ LOOP: return err } -func loadConfig(override cfgOverrider) (*configuration.Configuration, error) { +func loadConfig(ctx context.Context, override cfgOverrider) (*configuration.Configuration, error) { pathConfigFile := paths.ConfigFile() rawConfig, err := config.LoadFile(pathConfigFile) if err != nil { @@ -334,7 +334,7 @@ func loadConfig(override cfgOverrider) (*configuration.Configuration, error) { errors.M(errors.MetaKeyPath, pathConfigFile)) } - if err := getOverwrites(rawConfig); err != nil { + if err := getOverwrites(ctx, rawConfig); err != nil { return nil, errors.New(err, "could not read overwrites") } @@ -366,7 +366,7 @@ func reexecPath() (string, error) { return potentialReexec, nil } -func getOverwrites(rawConfig *config.Config) error { +func getOverwrites(ctx context.Context, rawConfig *config.Config) error { cfg, err := configuration.NewFromConfig(rawConfig) if err != nil { return err @@ -377,7 +377,7 @@ func getOverwrites(rawConfig *config.Config) error { return nil } path := paths.AgentConfigFile() - store := storage.NewEncryptedDiskStore(path) + store := storage.NewEncryptedDiskStore(ctx, path) reader, err := store.Load() if err != nil && errors.Is(err, os.ErrNotExist) { @@ -450,6 +450,7 @@ func tryDelayEnroll(ctx context.Context, logger *logger.Logger, cfg *configurati options.DelayEnroll = false options.FleetServer.SpawnAgent = false c, err := newEnrollCmd( + ctx, logger, &options, paths.ConfigFile(), @@ -470,7 +471,7 @@ func tryDelayEnroll(ctx context.Context, logger *logger.Logger, cfg *configurati errors.M("path", enrollPath))) } logger.Info("Successfully performed delayed enrollment of this Elastic Agent.") - return loadConfig(override) + return loadConfig(ctx, override) } func initTracer(agentName, version string, mcfg *monitoringCfg.MonitoringConfig) (*apm.Tracer, error) { diff --git a/internal/pkg/agent/install/uninstall.go b/internal/pkg/agent/install/uninstall.go index ab96a354fe5..a7977a172b5 100644 --- a/internal/pkg/agent/install/uninstall.go +++ b/internal/pkg/agent/install/uninstall.go @@ -198,7 +198,7 @@ func uninstallComponents(ctx context.Context, cfgFile string, uninstallToken str return fmt.Errorf("failed to detect inputs and outputs: %w", err) } - cfg, err := operations.LoadFullAgentConfig(log, cfgFile, false) + cfg, err := operations.LoadFullAgentConfig(ctx, log, cfgFile, false) if err != nil { return err } diff --git a/internal/pkg/agent/migration/migrate_config.go b/internal/pkg/agent/migration/migrate_config.go index 305ccd762f5..ad489315810 100644 --- a/internal/pkg/agent/migration/migrate_config.go +++ b/internal/pkg/agent/migration/migrate_config.go @@ -5,6 +5,7 @@ package migration import ( + "context" "fmt" "io/fs" "os" @@ -22,7 +23,7 @@ import ( // - The contents from the unencrypted file will be copied as a byte stream without any transformation. // - The function will not perform any operation if the encryptedConfigPath already exists and it's not empty to avoid overwrites. // - If neither the encrypted file nor the unencrypted file exist this call is a no-op -func MigrateToEncryptedConfig(l *logp.Logger, unencryptedConfigPath string, encryptedConfigPath string) error { +func MigrateToEncryptedConfig(ctx context.Context, l *logp.Logger, unencryptedConfigPath string, encryptedConfigPath string) error { encStat, encFileErr := os.Stat(encryptedConfigPath) if encFileErr != nil && !errors.Is(encFileErr, fs.ErrNotExist) { @@ -53,7 +54,7 @@ func MigrateToEncryptedConfig(l *logp.Logger, unencryptedConfigPath string, encr l.Errorf("Error closing unencrypted store reader for %q: %v", unencryptedConfigPath, err) } }() - store := storage.NewEncryptedDiskStore(encryptedConfigPath) + store := storage.NewEncryptedDiskStore(ctx, encryptedConfigPath) err = store.Save(reader) if err != nil { return errors.New(err, fmt.Sprintf("error writing encrypted config from file %q to file %q", unencryptedConfigPath, encryptedConfigPath)) diff --git a/internal/pkg/agent/migration/migrate_config_test.go b/internal/pkg/agent/migration/migrate_config_test.go index 95e3445f2a4..da4eafd84fd 100644 --- a/internal/pkg/agent/migration/migrate_config_test.go +++ b/internal/pkg/agent/migration/migrate_config_test.go @@ -8,6 +8,7 @@ package migration import ( "bytes" + "context" "io" "io/fs" "os" @@ -32,6 +33,9 @@ type configfile struct { } func TestMigrateToEncryptedConfig(t *testing.T) { + ctx, cn := context.WithCancel(context.Background()) + defer cn() + testcases := []struct { name string unencryptedConfig configfile @@ -107,12 +111,12 @@ func TestMigrateToEncryptedConfig(t *testing.T) { paths.SetTop(top) vaultPath := paths.AgentVaultPath() - err := secret.CreateAgentSecret(secret.WithVaultPath(vaultPath)) + err := secret.CreateAgentSecret(ctx, secret.WithVaultPath(vaultPath)) require.NoError(t, err) - createAndPersistStore(t, top, tc.unencryptedConfig, false) - encryptedStore := createAndPersistStore(t, top, tc.encryptedConfig, true) + createAndPersistStore(t, ctx, top, tc.unencryptedConfig, false) + encryptedStore := createAndPersistStore(t, ctx, top, tc.encryptedConfig, true) absUnencryptedFile := path.Join(top, tc.unencryptedConfig.name) absEncryptedFile := path.Join(top, tc.encryptedConfig.name) @@ -132,7 +136,7 @@ func TestMigrateToEncryptedConfig(t *testing.T) { log := logp.NewLogger("test_migrate_config") // setup end - err = MigrateToEncryptedConfig(log, absUnencryptedFile, absEncryptedFile) + err = MigrateToEncryptedConfig(ctx, log, absUnencryptedFile, absEncryptedFile) assert.NoError(t, err) if len(tc.expectedEncryptedContent) > 0 { @@ -159,6 +163,9 @@ func TestErrorMigrateToEncryptedConfig(t *testing.T) { t.Skip("cannot reliably reproduce permission errors on windows") } + ctx, cn := context.WithCancel(context.Background()) + defer cn() + testcases := []struct { name string unencryptedConfig configfile @@ -199,12 +206,12 @@ func TestErrorMigrateToEncryptedConfig(t *testing.T) { paths.SetTop(top) vaultPath := paths.AgentVaultPath() - err := secret.CreateAgentSecret(secret.WithVaultPath(vaultPath)) + err := secret.CreateAgentSecret(ctx, secret.WithVaultPath(vaultPath)) require.NoError(t, err) - createAndPersistStore(t, top, tc.unencryptedConfig, false) - createAndPersistStore(t, top, tc.encryptedConfig, true) + createAndPersistStore(t, ctx, top, tc.unencryptedConfig, false) + createAndPersistStore(t, ctx, top, tc.encryptedConfig, true) err = os.Chmod(top, 0555&os.ModePerm) require.NoError(t, err) @@ -231,7 +238,7 @@ func TestErrorMigrateToEncryptedConfig(t *testing.T) { log := logp.NewLogger("test_migrate_config") // setup end - err = MigrateToEncryptedConfig(log, absUnencryptedFile, absEncryptedFile) + err = MigrateToEncryptedConfig(ctx, log, absUnencryptedFile, absEncryptedFile) assert.Error(t, err) }) @@ -239,13 +246,13 @@ func TestErrorMigrateToEncryptedConfig(t *testing.T) { } -func createAndPersistStore(t *testing.T, baseDir string, cf configfile, encrypted bool) storage.Storage { +func createAndPersistStore(t *testing.T, ctx context.Context, baseDir string, cf configfile, encrypted bool) storage.Storage { var store storage.Storage asbFilePath := path.Join(baseDir, cf.name) if encrypted { - store = storage.NewEncryptedDiskStore(asbFilePath) + store = storage.NewEncryptedDiskStore(ctx, asbFilePath) } else { store = storage.NewDiskStore(asbFilePath) } diff --git a/internal/pkg/agent/migration/migrate_secret.go b/internal/pkg/agent/migration/migrate_secret.go deleted file mode 100644 index 08cfc3e5eb1..00000000000 --- a/internal/pkg/agent/migration/migrate_secret.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License; -// you may not use this file except in compliance with the Elastic License. - -package migration - -import ( - "errors" - "fmt" - "io" - "io/fs" - "io/ioutil" - "os" - "path/filepath" - "runtime" - "strings" - - "github.com/elastic/elastic-agent-libs/logp" - "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" - "github.com/elastic/elastic-agent/internal/pkg/agent/application/secret" - "github.com/elastic/elastic-agent/internal/pkg/agent/storage" - "github.com/elastic/elastic-agent/internal/pkg/fileutil" -) - -const ( - darwin = "darwin" -) - -// MigrateAgentSecret migrates agent secret if the secret doesn't exists agent upgrade from 8.3.0 - 8.3.2 to 8.x and above on Linux and Windows platforms. -func MigrateAgentSecret(log *logp.Logger) error { - // Nothing to migrate for darwin - if runtime.GOOS == darwin { - return nil - } - - // Check if the secret already exists - log.Debug("migrate agent secret, check if secret already exists") - _, err := secret.GetAgentSecret() - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - // The secret doesn't exists, perform migration below - log.Debug("agent secret doesn't exists, perform migration") - } else { - err = fmt.Errorf("failed read the agent secret: %w", err) - log.Error(err) - return err - } - } else { - // The secret already exists, nothing to migrate - log.Debug("secret already exists nothing to migrate") - return nil - } - - // Check if the secret was copied by the fleet upgrade handler to the legacy location - log.Debug("check if secret was copied over by 8.3.0-8.3.2 version of the agent") - sec, err := getAgentSecretFromHomePath(paths.Home()) - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - // The secret is not found in this instance of the vault, continue with migration - log.Debug("agent secret copied from 8.3.0-8.3.2 doesn't exists, continue with migration") - } else { - err = fmt.Errorf("failed agent 8.3.0-8.3.2 secret check: %w", err) - log.Error(err) - return err - } - } else { - // The secret is found, save in the new agent vault - log.Debug("agent secret from 8.3.0-8.3.2 is found, migrate to the new vault") - return secret.SetAgentSecret(sec) - } - - // Scan other agent data directories, find the latest agent secret - log.Debug("search for possible latest agent 8.3.0-8.3.2 secret") - dataDir := paths.Data() - - sec, err = findPreviousAgentSecret(dataDir) - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - // The secret is not found - log.Debug("no previous agent 8.3.0-8.3.2 secrets found, nothing to migrate") - return nil - } - err = fmt.Errorf("search for possible latest agent 8.3.0-8.3.2 secret failed: %w", err) - log.Error(err) - return err - } - log.Debug("found previous agent 8.3.0-8.3.2 secret, migrate to the new vault") - return secret.SetAgentSecret(sec) -} - -func findPreviousAgentSecret(dataDir string) (secret.Secret, error) { - found := false - var sec secret.Secret - fileSystem := os.DirFS(dataDir) - _ = fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - if strings.HasPrefix(d.Name(), "elastic-agent-") { - vaultPath := getLegacyVaultPathFromPath(filepath.Join(dataDir, path)) - s, err := secret.GetAgentSecret(secret.WithVaultPath(vaultPath)) - if err != nil { - // Ignore if fs.ErrNotExist error, keep scanning - if errors.Is(err, fs.ErrNotExist) { - return nil - } - return err - } - - // Check that the configuration can be decrypted with the found agent secret - exists, _ := fileutil.FileExists(paths.AgentConfigFile()) - if exists { - store := storage.NewEncryptedDiskStore(paths.AgentConfigFile(), storage.WithVaultPath(vaultPath)) - r, err := store.Load() - if err != nil { - //nolint:nilerr // ignore the error keep scanning - return nil - } - - defer r.Close() - _, err = ioutil.ReadAll(r) - if err != nil { - //nolint:nilerr // ignore the error keep scanning - return nil - } - - sec = s - found = true - return io.EOF - } - } else if d.Name() != "." { - return fs.SkipDir - } - } - return nil - }) - if !found { - return sec, fs.ErrNotExist - } - return sec, nil -} - -func getAgentSecretFromHomePath(homePath string) (sec secret.Secret, err error) { - vaultPath := getLegacyVaultPathFromPath(homePath) - fi, err := os.Stat(vaultPath) - if err != nil { - return - } - - if !fi.IsDir() { - return sec, fs.ErrNotExist - } - return secret.GetAgentSecret(secret.WithVaultPath(vaultPath)) -} - -func getLegacyVaultPath() string { - return getLegacyVaultPathFromPath(paths.Home()) -} - -func getLegacyVaultPathFromPath(path string) string { - return filepath.Join(path, "vault") -} diff --git a/internal/pkg/agent/migration/migrate_secret_test.go b/internal/pkg/agent/migration/migrate_secret_test.go deleted file mode 100644 index 11b4790e64e..00000000000 --- a/internal/pkg/agent/migration/migrate_secret_test.go +++ /dev/null @@ -1,386 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License; -// you may not use this file except in compliance with the Elastic License. - -//go:build linux || windows - -package migration - -import ( - "errors" - "io/fs" - "io/ioutil" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/gofrs/uuid" - "github.com/google/go-cmp/cmp" - - "github.com/elastic/elastic-agent-libs/logp" - "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" - "github.com/elastic/elastic-agent/internal/pkg/agent/application/secret" - "github.com/elastic/elastic-agent/internal/pkg/agent/storage" - "github.com/elastic/elastic-agent/internal/pkg/agent/vault" -) - -func TestFindAgentSecretFromHomePath(t *testing.T) { - - tests := []struct { - name string - setupFn func(homePath string) error - wantErr error - }{ - { - name: "no data dir", - wantErr: fs.ErrNotExist, - }, - { - name: "no vault dir", - setupFn: func(homePath string) error { - return os.MkdirAll(homePath, 0750) - }, - wantErr: fs.ErrNotExist, - }, - { - name: "vault file instead of directory", - setupFn: func(homePath string) error { - err := os.MkdirAll(homePath, 0750) - if err != nil { - return err - } - return ioutil.WriteFile(getLegacyVaultPathFromPath(homePath), []byte{}, 0600) - }, - wantErr: fs.ErrNotExist, - }, - { - name: "empty vault directory", - setupFn: func(homePath string) error { - return os.MkdirAll(getLegacyVaultPathFromPath(homePath), 0750) - }, - wantErr: fs.ErrNotExist, - }, - { - name: "empty vault", - setupFn: func(homePath string) error { - v, err := vault.New(getLegacyVaultPathFromPath(homePath)) - if err != nil { - return err - } - defer v.Close() - return nil - }, - wantErr: fs.ErrNotExist, - }, - { - name: "vault dir with no seed", - setupFn: func(homePath string) error { - vaultPath := getLegacyVaultPathFromPath(homePath) - v, err := vault.New(vaultPath) - if err != nil { - return err - } - defer v.Close() - return os.Remove(filepath.Join(vaultPath, ".seed")) - }, - wantErr: fs.ErrNotExist, - }, - { - name: "vault with secret and misplaced seed vault", - setupFn: func(homePath string) error { - vaultPath := getLegacyVaultPathFromPath(homePath) - err := secret.CreateAgentSecret(secret.WithVaultPath(vaultPath)) - if err != nil { - return err - } - return os.Remove(filepath.Join(vaultPath, ".seed")) - }, - wantErr: fs.ErrNotExist, - }, - { - name: "vault with valid secret", - setupFn: func(homePath string) error { - vaultPath := getLegacyVaultPathFromPath(homePath) - err := secret.CreateAgentSecret(secret.WithVaultPath(vaultPath)) - if err != nil { - return err - } - return generateTestConfig(vaultPath) - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - top := t.TempDir() - paths.SetTop(top) - homePath := paths.Home() - - if tc.setupFn != nil { - if err := tc.setupFn(homePath); err != nil { - t.Fatal(err) - } - } - - sec, err := getAgentSecretFromHomePath(homePath) - if !errors.Is(err, tc.wantErr) { - t.Fatalf("want err: %v, got err: %v", tc.wantErr, err) - } - - foundSec, err := findPreviousAgentSecret(filepath.Dir(homePath)) - if !errors.Is(err, tc.wantErr) { - t.Fatalf("want err: %v, got err: %v", tc.wantErr, err) - } - diff := cmp.Diff(sec, foundSec) - if diff != "" { - t.Fatal(diff) - } - - }) - } -} - -type configType int - -const ( - NoConfig configType = iota - MatchingConfig - NonMatchingConfig -) - -func TestFindNewestAgentSecret(t *testing.T) { - - tests := []struct { - name string - cfgType configType - wantErr error - }{ - { - name: "missing config", - cfgType: NoConfig, - wantErr: fs.ErrNotExist, - }, - { - name: "matching config", - cfgType: MatchingConfig, - }, - { - name: "non-matching config", - cfgType: NonMatchingConfig, - wantErr: fs.ErrNotExist, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - top := t.TempDir() - paths.SetTop(top) - paths.SetConfig(top) - dataDir := paths.Data() - - wantSecret, err := generateTestSecrets(dataDir, 3, tc.cfgType) - if err != nil { - t.Fatal(err) - } - sec, err := findPreviousAgentSecret(dataDir) - - if !errors.Is(err, tc.wantErr) { - t.Fatalf("want err: %v, got err: %v", tc.wantErr, err) - } - diff := cmp.Diff(sec, wantSecret) - if diff != "" { - t.Fatal(diff) - } - }) - } -} - -func TestMigrateAgentSecret(t *testing.T) { - top := t.TempDir() - paths.SetTop(top) - paths.SetConfig(top) - dataDir := paths.Data() - - // No vault home path - homePath := generateTestHomePath(dataDir) - if err := os.MkdirAll(homePath, 0750); err != nil { - t.Fatal(err) - } - - // Empty vault home path - homePath = generateTestHomePath(dataDir) - vaultPath := getLegacyVaultPathFromPath(homePath) - if err := os.MkdirAll(vaultPath, 0750); err != nil { - t.Fatal(err) - } - - // Vault with missing seed - homePath = generateTestHomePath(dataDir) - vaultPath = getLegacyVaultPathFromPath(homePath) - v, err := vault.New(vaultPath) - if err != nil { - t.Fatal(err) - } - defer v.Close() - - if err = os.Remove(filepath.Join(vaultPath, ".seed")); err != nil { - t.Fatal(err) - } - - // Generate few valid secrets to scan for - wantSecret, err := generateTestSecrets(dataDir, 5, MatchingConfig) - if err != nil { - t.Fatal(err) - } - - // Expect no agent secret found - _, err = secret.GetAgentSecret(secret.WithVaultPath(paths.AgentVaultPath())) - if !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("expected err: %v", fs.ErrNotExist) - } - - // Perform migration - log := logp.NewLogger("test_agent_secret") - err = MigrateAgentSecret(log) - if err != nil { - t.Fatal(err) - } - - // Expect the agent secret is migrated now - sec, err := secret.GetAgentSecret(secret.WithVaultPath(paths.AgentVaultPath())) - if err != nil { - t.Fatal(err) - } - - // Compare the migrated secret with the expected newest one - diff := cmp.Diff(sec, wantSecret) - if diff != "" { - t.Fatal(diff) - } -} - -func TestMigrateAgentSecretAlreadyExists(t *testing.T) { - top := t.TempDir() - paths.SetTop(top) - err := secret.CreateAgentSecret(secret.WithVaultPath(paths.AgentVaultPath())) - if err != nil { - t.Fatal(err) - } - - // Expect agent secret created - wantSecret, err := secret.GetAgentSecret(secret.WithVaultPath(paths.AgentVaultPath())) - if err != nil { - t.Fatal(err) - } - - // Perform migration - log := logp.NewLogger("test_agent_secret") - err = MigrateAgentSecret(log) - if err != nil { - t.Fatal(err) - } - - sec, err := secret.GetAgentSecret(secret.WithVaultPath(paths.AgentVaultPath())) - if err != nil { - t.Fatal(err) - } - - // Compare, should be the same secret - diff := cmp.Diff(sec, wantSecret) - if diff != "" { - t.Fatal(diff) - } -} - -func TestMigrateAgentSecretFromLegacyLocation(t *testing.T) { - top := t.TempDir() - paths.SetTop(top) - paths.SetConfig(top) - vaultPath := getLegacyVaultPath() - err := secret.CreateAgentSecret(secret.WithVaultPath(vaultPath)) - if err != nil { - t.Fatal(err) - } - - // Expect agent secret created - wantSecret, err := secret.GetAgentSecret(secret.WithVaultPath(vaultPath)) - if err != nil { - t.Fatal(err) - } - - // Perform migration - log := logp.NewLogger("test_agent_secret") - err = MigrateAgentSecret(log) - if err != nil { - t.Fatal(err) - } - - sec, err := secret.GetAgentSecret(secret.WithVaultPath(paths.AgentVaultPath())) - if err != nil { - t.Fatal(err) - } - - // Compare, should be the same secret - diff := cmp.Diff(sec, wantSecret) - if diff != "" { - t.Fatal(diff) - } -} - -func generateTestHomePath(dataDir string) string { - suffix := uuid.Must(uuid.NewV4()).String()[:6] - return filepath.Join(dataDir, "elastic-agent-"+suffix) -} - -func generateTestConfig(vaultPath string) error { - fleetEncConfigFile := paths.AgentConfigFile() - store := storage.NewEncryptedDiskStore(fleetEncConfigFile, storage.WithVaultPath(vaultPath)) - return store.Save(strings.NewReader("foo")) -} - -func generateTestSecrets(dataDir string, count int, cfgType configType) (wantSecret secret.Secret, err error) { - now := time.Now() - - // Generate multiple home paths - //homePaths := make([]string, count) - for i := 0; i < count; i++ { - homePath := generateTestHomePath(dataDir) - k, err := vault.NewKey(vault.AES256) - if err != nil { - return wantSecret, err - } - - sec := secret.Secret{ - Value: k, - CreatedOn: now.Add(-time.Duration(i+1) * time.Minute), - } - - vaultPath := getLegacyVaultPathFromPath(homePath) - err = secret.SetAgentSecret(sec, secret.WithVaultPath(vaultPath)) - if err != nil { - return wantSecret, err - } - - switch cfgType { - case NoConfig: - case MatchingConfig, NonMatchingConfig: - if i == 0 { - wantSecret = sec - // Create matching encrypted config file, the content of the file doesn't matter for this test - err = generateTestConfig(vaultPath) - if err != nil { - return wantSecret, err - } - } - } - // Delete - if cfgType == NonMatchingConfig && i == 0 { - _ = os.RemoveAll(vaultPath) - wantSecret = secret.Secret{} - } - } - - return wantSecret, nil -} diff --git a/internal/pkg/agent/storage/encrypted_disk_storage_windows_linux_test.go b/internal/pkg/agent/storage/encrypted_disk_storage_windows_linux_test.go index fc1df378150..c29c7ec8fd9 100644 --- a/internal/pkg/agent/storage/encrypted_disk_storage_windows_linux_test.go +++ b/internal/pkg/agent/storage/encrypted_disk_storage_windows_linux_test.go @@ -8,6 +8,7 @@ package storage import ( "bytes" + "context" "errors" "io/fs" "io/ioutil" @@ -28,8 +29,11 @@ const ( func TestEncryptedDiskStorageWindowsLinuxLoad(t *testing.T) { dir := t.TempDir() + ctx, cn := context.WithCancel(context.Background()) + defer cn() + fp := filepath.Join(dir, testConfigFile) - s := NewEncryptedDiskStore(fp, WithVaultPath(dir)) + s := NewEncryptedDiskStore(ctx, fp, WithVaultPath(dir)) // Test that the file loads and doesn't create vault r, err := s.Load() @@ -67,7 +71,7 @@ func TestEncryptedDiskStorageWindowsLinuxLoad(t *testing.T) { } // Create agent secret - err = secret.CreateAgentSecret(secret.WithVaultPath(dir)) + err = secret.CreateAgentSecret(ctx, secret.WithVaultPath(dir)) if err != nil { t.Fatal(err) } diff --git a/internal/pkg/agent/storage/encrypted_disk_store.go b/internal/pkg/agent/storage/encrypted_disk_store.go index be78e4235df..edad97f19ef 100644 --- a/internal/pkg/agent/storage/encrypted_disk_store.go +++ b/internal/pkg/agent/storage/encrypted_disk_store.go @@ -6,6 +6,7 @@ package storage import ( "bytes" + "context" "fmt" "io" "io/fs" @@ -39,11 +40,12 @@ type OptionFunc func(s *EncryptedDiskStore) // NewEncryptedDiskStore creates an encrypted disk store. // Drop-in replacement for NewDiskStorage -func NewEncryptedDiskStore(target string, opts ...OptionFunc) Storage { +func NewEncryptedDiskStore(ctx context.Context, target string, opts ...OptionFunc) Storage { if encryptionDisabled { return NewDiskStore(target) } s := &EncryptedDiskStore{ + ctx: ctx, target: target, vaultPath: paths.AgentVaultPath(), } @@ -75,9 +77,9 @@ func (d *EncryptedDiskStore) Exists() (bool, error) { return true, nil } -func (d *EncryptedDiskStore) ensureKey() error { +func (d *EncryptedDiskStore) ensureKey(ctx context.Context) error { if d.key == nil { - key, err := secret.GetAgentSecret(secret.WithVaultPath(d.vaultPath)) + key, err := secret.GetAgentSecret(ctx, secret.WithVaultPath(d.vaultPath)) if err != nil { return fmt.Errorf("could not get agent key: %w", err) } @@ -90,7 +92,7 @@ func (d *EncryptedDiskStore) ensureKey() error { // Specifically it will write to a .tmp file then rotate the file to the target name to ensure that an error does not corrupt the previously written file. func (d *EncryptedDiskStore) Save(in io.Reader) error { // Ensure has agent key - err := d.ensureKey() + err := d.ensureKey(d.ctx) if err != nil { return errors.New(err, "failed to ensure key") } @@ -179,7 +181,7 @@ func (d *EncryptedDiskStore) Load() (rc io.ReadCloser, err error) { }() // Ensure has agent key - err = d.ensureKey() + err = d.ensureKey(d.ctx) if err != nil { return nil, errors.New(err, "failed to ensure key during encrypted disk store Load") } diff --git a/internal/pkg/agent/storage/storage.go b/internal/pkg/agent/storage/storage.go index db434ed7226..952f82c5883 100644 --- a/internal/pkg/agent/storage/storage.go +++ b/internal/pkg/agent/storage/storage.go @@ -5,6 +5,7 @@ package storage import ( + "context" "io" "os" ) @@ -36,6 +37,7 @@ type DiskStore struct { // EncryptedDiskStore encrypts config when saving to disk. // When saving it will save to a temporary file then replace the target file. type EncryptedDiskStore struct { + ctx context.Context target string vaultPath string key []byte diff --git a/internal/pkg/agent/storage/store/state_store.go b/internal/pkg/agent/storage/store/state_store.go index dd78a195ce9..6f64f1184bf 100644 --- a/internal/pkg/agent/storage/store/state_store.go +++ b/internal/pkg/agent/storage/store/state_store.go @@ -75,13 +75,13 @@ type stateSerializer struct { } // NewStateStoreWithMigration creates a new state store and migrates the old one. -func NewStateStoreWithMigration(log *logger.Logger, actionStorePath, stateStorePath string) (*StateStore, error) { - err := migrateStateStore(log, actionStorePath, stateStorePath) +func NewStateStoreWithMigration(ctx context.Context, log *logger.Logger, actionStorePath, stateStorePath string) (*StateStore, error) { + err := migrateStateStore(ctx, log, actionStorePath, stateStorePath) if err != nil { return nil, err } - return NewStateStore(log, storage.NewEncryptedDiskStore(stateStorePath)) + return NewStateStore(log, storage.NewEncryptedDiskStore(ctx, stateStorePath)) } // NewStateStoreActionAcker creates a new state store backed action acker. @@ -143,10 +143,10 @@ func NewStateStore(log *logger.Logger, store storeLoad) (*StateStore, error) { }, nil } -func migrateStateStore(log *logger.Logger, actionStorePath, stateStorePath string) (err error) { +func migrateStateStore(ctx context.Context, log *logger.Logger, actionStorePath, stateStorePath string) (err error) { log = log.Named("state_migration") actionDiskStore := storage.NewDiskStore(actionStorePath) - stateDiskStore := storage.NewEncryptedDiskStore(stateStorePath) + stateDiskStore := storage.NewEncryptedDiskStore(ctx, stateStorePath) stateStoreExits, err := stateDiskStore.Exists() if err != nil { diff --git a/internal/pkg/agent/storage/store/state_store_test.go b/internal/pkg/agent/storage/store/state_store_test.go index 83b6ebfcefc..0e969a7525e 100644 --- a/internal/pkg/agent/storage/store/state_store_test.go +++ b/internal/pkg/agent/storage/store/state_store_test.go @@ -31,6 +31,10 @@ func TestStateStore(t *testing.T) { func runTestStateStore(t *testing.T, ackToken string) { log, _ := logger.New("state_store", false) + + ctx, cn := context.WithCancel(context.Background()) + defer cn() + withFile := func(fn func(t *testing.T, file string)) func(*testing.T) { return func(t *testing.T) { dir := t.TempDir() @@ -244,7 +248,7 @@ func runTestStateStore(t *testing.T, ackToken string) { t.Run("migrate actions file does not exists", withFile(func(t *testing.T, actionStorePath string) { withFile(func(t *testing.T, stateStorePath string) { - err := migrateStateStore(log, actionStorePath, stateStorePath) + err := migrateStateStore(ctx, log, actionStorePath, stateStorePath) require.NoError(t, err) stateStore, err := NewStateStore(log, storage.NewDiskStore(stateStorePath)) require.NoError(t, err) @@ -275,7 +279,7 @@ func runTestStateStore(t *testing.T, ackToken string) { require.Len(t, actionStore.actions(), 1) withFile(func(t *testing.T, stateStorePath string) { - err = migrateStateStore(log, actionStorePath, stateStorePath) + err = migrateStateStore(ctx, log, actionStorePath, stateStorePath) require.NoError(t, err) stateStore, err := NewStateStore(log, storage.NewDiskStore(stateStorePath)) diff --git a/internal/pkg/agent/transpiler/vars_test.go b/internal/pkg/agent/transpiler/vars_test.go index 76a1bbfd9d2..0843e1a73a1 100644 --- a/internal/pkg/agent/transpiler/vars_test.go +++ b/internal/pkg/agent/transpiler/vars_test.go @@ -5,6 +5,7 @@ package transpiler import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -361,6 +362,6 @@ func (p *contextProviderMock) Fetch(key string) (string, bool) { return "mockedFetchContent", true } -func (p *contextProviderMock) Run(comm corecomp.ContextProviderComm) error { +func (p *contextProviderMock) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { return nil } diff --git a/internal/pkg/agent/vault/aesgcm.go b/internal/pkg/agent/vault/aesgcm/aesgcm.go similarity index 99% rename from internal/pkg/agent/vault/aesgcm.go rename to internal/pkg/agent/vault/aesgcm/aesgcm.go index aa209f994c8..ea9e9de184c 100644 --- a/internal/pkg/agent/vault/aesgcm.go +++ b/internal/pkg/agent/vault/aesgcm/aesgcm.go @@ -2,7 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -package vault +package aesgcm import ( "crypto/aes" diff --git a/internal/pkg/agent/vault/aesgcm_test.go b/internal/pkg/agent/vault/aesgcm/aesgcm_test.go similarity index 99% rename from internal/pkg/agent/vault/aesgcm_test.go rename to internal/pkg/agent/vault/aesgcm/aesgcm_test.go index 0c17ad4374f..7edcbae4001 100644 --- a/internal/pkg/agent/vault/aesgcm_test.go +++ b/internal/pkg/agent/vault/aesgcm/aesgcm_test.go @@ -2,7 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -package vault +package aesgcm import ( "crypto/aes" diff --git a/internal/pkg/agent/vault/seed.go b/internal/pkg/agent/vault/seed.go index ecc369b0918..6112e48313e 100644 --- a/internal/pkg/agent/vault/seed.go +++ b/internal/pkg/agent/vault/seed.go @@ -2,7 +2,7 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -//go:build linux || windows +//go:build !darwin package vault @@ -14,6 +14,8 @@ import ( "os" "path/filepath" "sync" + + "github.com/elastic/elastic-agent/internal/pkg/agent/vault/aesgcm" ) const ( @@ -36,8 +38,8 @@ func getSeed(path string) ([]byte, error) { } // return fs.ErrNotExists if invalid length of bytes returned - if len(b) != int(AES256) { - return nil, fmt.Errorf("invalid seed length, expected: %v, got: %v: %w", int(AES256), len(b), fs.ErrNotExist) + if len(b) != int(aesgcm.AES256) { + return nil, fmt.Errorf("invalid seed length, expected: %v, got: %v: %w", int(aesgcm.AES256), len(b), fs.ErrNotExist) } return b, nil } @@ -59,7 +61,7 @@ func createSeedIfNotExists(path string) ([]byte, error) { return b, nil } - seed, err := NewKey(AES256) + seed, err := aesgcm.NewKey(aesgcm.AES256) if err != nil { return nil, err } diff --git a/internal/pkg/agent/vault/seed_test.go b/internal/pkg/agent/vault/seed_test.go index 3ef63678128..93a88877ec8 100644 --- a/internal/pkg/agent/vault/seed_test.go +++ b/internal/pkg/agent/vault/seed_test.go @@ -2,22 +2,23 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -//go:build linux || windows +//go:build !darwin package vault import ( "context" "encoding/hex" - "io/fs" + "errors" + "os" "path/filepath" "sync" "testing" "github.com/google/go-cmp/cmp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + + "github.com/elastic/elastic-agent/internal/pkg/agent/vault/aesgcm" ) func TestGetSeed(t *testing.T) { @@ -25,27 +26,41 @@ func TestGetSeed(t *testing.T) { fp := filepath.Join(dir, seedFile) - require.NoFileExists(t, fp) + // check the test prerequisites + if _, err := os.Stat(fp); !errors.Is(err, os.ErrNotExist) { + t.Fatal(err) + } // seed is not yet created - _, err := getSeed(dir) + if _, err := getSeed(dir); !errors.Is(err, os.ErrNotExist) { + t.Fatal(err) + } // should be not found - require.ErrorIs(t, err, fs.ErrNotExist) + if _, err := os.Stat(fp); !errors.Is(err, os.ErrNotExist) { + t.Fatal(err) + } b, err := createSeedIfNotExists(dir) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } - require.FileExists(t, fp) + // file should exist + if _, err := os.Stat(fp); err != nil { + t.Fatal(err) + } - diff := cmp.Diff(int(AES256), len(b)) + diff := cmp.Diff(int(aesgcm.AES256), len(b)) if diff != "" { t.Error(diff) } // try get seed gotSeed, err := getSeed(dir) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } diff = cmp.Diff(b, gotSeed) if diff != "" { @@ -58,14 +73,21 @@ func TestCreateSeedIfNotExists(t *testing.T) { fp := filepath.Join(dir, seedFile) - assert.NoFileExists(t, fp) + if _, err := os.Stat(fp); !errors.Is(err, os.ErrNotExist) { + t.Fatal(err) + } b, err := createSeedIfNotExists(dir) - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } - require.FileExists(t, fp) + // file should exist + if _, err := os.Stat(fp); err != nil { + t.Fatal(err) + } - diff := cmp.Diff(int(AES256), len(b)) + diff := cmp.Diff(int(aesgcm.AES256), len(b)) if diff != "" { t.Error(diff) } @@ -95,7 +117,9 @@ func TestCreateSeedIfNotExistsRace(t *testing.T) { } err = g.Wait() - assert.NoError(t, err) + if err != nil { + t.Fatal(err) + } set := make(map[string]struct{}) diff --git a/internal/pkg/agent/vault/vault_darwin.go b/internal/pkg/agent/vault/vault_darwin.go index d0cde61ab18..026583efd14 100644 --- a/internal/pkg/agent/vault/vault_darwin.go +++ b/internal/pkg/agent/vault/vault_darwin.go @@ -23,6 +23,7 @@ extern char* GetOSStatusMessage(OSStatus status); */ import "C" import ( + "context" "fmt" "sync" "unsafe" @@ -37,7 +38,7 @@ type Vault struct { // New initializes the vault store // Call Close when done to release the resources -func New(name string, opts ...OptionFunc) (*Vault, error) { +func New(ctx context.Context, name string, opts ...OptionFunc) (*Vault, error) { var keychain C.SecKeychainRef err := statusToError(C.OpenKeychain(keychain)) @@ -64,7 +65,7 @@ func (v *Vault) Close() error { } // Set sets the key in the vault store -func (v *Vault) Set(key string, data []byte) error { +func (v *Vault) Set(ctx context.Context, key string, data []byte) error { v.mx.Lock() defer v.mx.Unlock() @@ -81,7 +82,7 @@ func (v *Vault) Set(key string, data []byte) error { } // Get retrieves the key from the vault store -func (v *Vault) Get(key string) ([]byte, error) { +func (v *Vault) Get(ctx context.Context, key string) ([]byte, error) { var ( data unsafe.Pointer len C.size_t @@ -106,7 +107,7 @@ func (v *Vault) Get(key string) ([]byte, error) { } // Exists checks if the key exists -func (v *Vault) Exists(key string) (bool, error) { +func (v *Vault) Exists(ctx context.Context, key string) (bool, error) { v.mx.Lock() defer v.mx.Unlock() @@ -128,7 +129,7 @@ func (v *Vault) Exists(key string) (bool, error) { } // Remove will remove a key from the keychain. -func (v *Vault) Remove(key string) error { +func (v *Vault) Remove(ctx context.Context, key string) error { v.mx.Lock() defer v.mx.Unlock() diff --git a/internal/pkg/agent/vault/vault_key.go b/internal/pkg/agent/vault/vault_key.go deleted file mode 100644 index 648cecf73cb..00000000000 --- a/internal/pkg/agent/vault/vault_key.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License; -// you may not use this file except in compliance with the Elastic License. - -//go:build linux || windows - -package vault - -import ( - "crypto/sha256" - "encoding/hex" -) - -// fileNameFromKey returns the filename as a hash of the vault seed combined with the key -// this ties the key with the vault seed eliminating the change of attempting -// to decrypt the key for the wrong vault seed value. -func fileNameFromKey(seed []byte, key string) string { - hash := sha256.Sum256(append(seed, []byte(key)...)) - return hex.EncodeToString(hash[:]) -} diff --git a/internal/pkg/agent/vault/vault_linux.go b/internal/pkg/agent/vault/vault_linux.go index 93f813138d7..aed423dadc1 100644 --- a/internal/pkg/agent/vault/vault_linux.go +++ b/internal/pkg/agent/vault/vault_linux.go @@ -11,121 +11,25 @@ import ( "crypto/sha256" "errors" "fmt" - "io/fs" - "io/ioutil" "os" "path/filepath" - "sync" "syscall" "golang.org/x/crypto/pbkdf2" -) - -const saltSize = 8 - -type Vault struct { - path string - key []byte - mx sync.Mutex -} - -// New creates the vault store -func New(path string, opts ...OptionFunc) (v *Vault, err error) { - options := applyOptions(opts...) - dir := filepath.Dir(path) - - // If there is no specific path then get the executable directory - if dir == "." { - exefp, err := os.Executable() - if err != nil { - return nil, fmt.Errorf("could not get executable path: %w", err) - } - dir = filepath.Dir(exefp) - path = filepath.Join(dir, path) - } - - if options.readonly { - fi, err := os.Stat(path) - if err != nil { - return nil, err - } - if !fi.IsDir() { - return nil, fs.ErrNotExist - } - } else { - err := os.MkdirAll(path, 0750) - if err != nil { - return nil, fmt.Errorf("failed to create vault path: %v, err: %w", path, err) - } - } - - key, err := getOrCreateSeed(path, options.readonly) - if err != nil { - return nil, fmt.Errorf("could not get seed to create new valt: %w", err) - } - - return &Vault{ - path: path, - key: key, - }, nil -} - -// Close closes the valut store -// Noop on linux -func (v *Vault) Close() error { - return nil -} - -// Set stores the key in the vault store -func (v *Vault) Set(key string, data []byte) error { - enc, err := v.encrypt(data) - if err != nil { - return err - } - - v.mx.Lock() - defer v.mx.Unlock() - - return ioutil.WriteFile(v.filepathFromKey(key), enc, 0600) -} - -// Get retrieves the key from the vault store -func (v *Vault) Get(key string) ([]byte, error) { - v.mx.Lock() - defer v.mx.Unlock() - - enc, err := ioutil.ReadFile(v.filepathFromKey(key)) - if err != nil { - return nil, err - } - - return v.decrypt(enc) -} - -// Exists checks if the key exists -func (v *Vault) Exists(key string) (ok bool, err error) { - v.mx.Lock() - defer v.mx.Unlock() - if _, err = os.Stat(v.filepathFromKey(key)); err == nil { - ok = true - } else if errors.Is(err, fs.ErrNotExist) { - err = nil - } - return ok, err -} + "github.com/elastic/elastic-agent/internal/pkg/agent/vault/aesgcm" +) -// Remove removes the key -func (v *Vault) Remove(key string) error { - return os.RemoveAll(v.filepathFromKey(key)) -} +const ( + saltSize = 8 +) func (v *Vault) encrypt(data []byte) ([]byte, error) { - key, salt, err := deriveKey(v.key, nil) + key, salt, err := deriveKey(v.seed, nil) if err != nil { return nil, err } - enc, err := Encrypt(key, data) + enc, err := aesgcm.Encrypt(key, data) if err != nil { return nil, err } @@ -137,11 +41,11 @@ func (v *Vault) decrypt(data []byte) ([]byte, error) { return nil, syscall.EINVAL } salt, data := data[:saltSize], data[saltSize:] - key, _, err := deriveKey(v.key, salt) + key, _, err := deriveKey(v.seed, salt) if err != nil { return nil, err } - return Decrypt(key, data) + return aesgcm.Decrypt(key, data) } func deriveKey(pw []byte, salt []byte) ([]byte, []byte, error) { @@ -154,6 +58,44 @@ func deriveKey(pw []byte, salt []byte) ([]byte, []byte, error) { return pbkdf2.Key(pw, salt, 12022, 32, sha256.New), salt, nil } -func (v *Vault) filepathFromKey(key string) string { - return filepath.Join(v.path, fileNameFromKey(v.key, key)) +func tightenPermissions(path string) error { + // Noop for linx + return nil +} + +// writeFile "atomic" file write, utilizes temp file and replace +func writeFile(fp string, data []byte) (err error) { + dir, fn := filepath.Split(fp) + if dir == "" { + dir = "." + } + + f, err := os.CreateTemp(dir, fn) + if err != nil { + return fmt.Errorf("failed creating temp file: %w", err) + } + defer func() { + rerr := os.Remove(f.Name()) + if rerr != nil && !errors.Is(rerr, os.ErrNotExist) { + err = errors.Join(err, fmt.Errorf("cleanup failed, could not remove temp file: %w", rerr)) + } + }() + defer f.Close() + + _, err = f.Write(data) + if err != nil { + return fmt.Errorf("failed writing temp file: %w", err) + } + + err = f.Sync() + if err != nil { + return fmt.Errorf("failed syncing temp file: %w", err) + } + + err = f.Close() + if err != nil { + return fmt.Errorf("failed closing temp file: %w", err) + } + + return os.Rename(f.Name(), fp) } diff --git a/internal/pkg/agent/vault/vault_notdarwin.go b/internal/pkg/agent/vault/vault_notdarwin.go new file mode 100644 index 00000000000..43f65bd4197 --- /dev/null +++ b/internal/pkg/agent/vault/vault_notdarwin.go @@ -0,0 +1,215 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build !darwin + +package vault + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "time" + + "github.com/gofrs/flock" +) + +const ( + // defaultFlockRetryDelay default file lock retry delay + defaultFlockRetryDelay = 10 * time.Millisecond + + // lock file name + lockFile = `.lock` +) + +type Vault struct { + path string + seed []byte + + lockRetryDelay time.Duration + lock *flock.Flock +} + +// New creates the vault store +func New(ctx context.Context, path string, opts ...OptionFunc) (v *Vault, err error) { + options := applyOptions(opts...) + dir := filepath.Dir(path) + + // If there is no specific path then get the executable directory + if dir == "." { + exefp, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("could not get executable path: %w", err) + } + dir = filepath.Dir(exefp) + path = filepath.Join(dir, path) + } + + if options.readonly { + fi, err := os.Stat(path) + if err != nil { + return nil, err + } + if !fi.IsDir() { + return nil, fs.ErrNotExist + } + } else { + err := os.MkdirAll(path, 0750) + if err != nil { + return nil, fmt.Errorf("failed to create vault path: %v, err: %w", path, err) + } + err = tightenPermissions(path) + if err != nil { + return nil, err + } + } + + r := &Vault{ + path: path, + lockRetryDelay: options.lockRetryDelay, + lock: flock.New(filepath.Join(path, lockFile)), + } + + err = r.tryLock(ctx) + if err != nil { + return nil, err + } + defer func() { + err = r.unlockAndJoinErrors(err) + }() + + r.seed, err = getOrCreateSeed(path, options.readonly) + if err != nil { + return nil, fmt.Errorf("could not get or create seed for the vault at %s: %w", path, err) + } + + return r, nil +} + +// Set stores the key in the vault store +func (v *Vault) Set(ctx context.Context, key string, data []byte) (err error) { + enc, err := v.encrypt(data) + if err != nil { + return err + } + + err = v.tryLock(ctx) + if err != nil { + return err + } + defer func() { + err = v.unlockAndJoinErrors(err) + }() + + return writeFile(v.filepathFromKey(key), enc) +} + +// Get retrieves the key from the vault store +func (v *Vault) Get(ctx context.Context, key string) (dec []byte, err error) { + err = v.tryRLock(ctx) + if err != nil { + return nil, err + } + defer func() { + err = v.unlockAndJoinErrors(err) + }() + + enc, err := os.ReadFile(v.filepathFromKey(key)) + if err != nil { + return nil, err + } + + return v.decrypt(enc) +} + +// Exists checks if the key exists +func (v *Vault) Exists(ctx context.Context, key string) (ok bool, err error) { + err = v.tryRLock(ctx) + if err != nil { + return false, err + } + defer func() { + err = v.unlockAndJoinErrors(err) + }() + + if _, err = os.Stat(v.filepathFromKey(key)); err != nil { + if errors.Is(err, fs.ErrNotExist) { + return false, nil + } + return false, err + } + return true, nil +} + +// Remove removes the key +func (v *Vault) Remove(ctx context.Context, key string) (err error) { + err = v.tryLock(ctx) + if err != nil { + return err + } + defer func() { + err = v.unlockAndJoinErrors(err) + }() + + return os.RemoveAll(v.filepathFromKey(key)) +} + +// Close closes the vault store +// Noop for non-darwin implementation +func (v *Vault) Close() error { + return nil +} + +// applyOptions applies options for windows and linux, not used for darwin implementation +func applyOptions(opts ...OptionFunc) Options { + options := Options{ + lockRetryDelay: defaultFlockRetryDelay, + } + + for _, opt := range opts { + opt(&options) + } + + return options +} + +// fileNameFromKey returns the filename as a hash of the vault seed combined with the key +// This ties the key with the vault seed eliminating the chance of attempting +// to decrypt the key for the wrong vault seed value. +func fileNameFromKey(seed []byte, key string) string { + hash := sha256.Sum256(append(seed, []byte(key)...)) + return hex.EncodeToString(hash[:]) +} + +func (v *Vault) filepathFromKey(key string) string { + return filepath.Join(v.path, fileNameFromKey(v.seed, key)) +} + +// try to acquire exclusive lock +func (v *Vault) tryLock(ctx context.Context) error { + _, err := v.lock.TryLockContext(ctx, v.lockRetryDelay) + if err != nil { + return fmt.Errorf("failed to acquire exclusive lock: %v, err: %w", v.lock.Path(), err) + } + return nil +} + +// try to acquire shared lock +func (v *Vault) tryRLock(ctx context.Context) error { + _, err := v.lock.TryRLockContext(ctx, v.lockRetryDelay) + if err != nil { + return fmt.Errorf("failed to acquire shared lock: %v, err: %w", v.lock.Path(), err) + } + return nil +} + +// unlockAndJoinErrors Helper function that unlocks the file lock and returns joined error +func (v *Vault) unlockAndJoinErrors(err error) error { + return errors.Join(err, v.lock.Unlock()) +} diff --git a/internal/pkg/agent/vault/vault_test.go b/internal/pkg/agent/vault/vault_notdarwin_test.go similarity index 54% rename from internal/pkg/agent/vault/vault_test.go rename to internal/pkg/agent/vault/vault_notdarwin_test.go index 8356d5b5999..7d75c9c10d9 100644 --- a/internal/pkg/agent/vault/vault_test.go +++ b/internal/pkg/agent/vault/vault_notdarwin_test.go @@ -2,19 +2,25 @@ // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. -//go:build linux || windows +//go:build !darwin package vault import ( + "context" + "encoding/json" "errors" - "io/ioutil" + "fmt" "os" "path/filepath" + "time" "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/sync/errgroup" + + "github.com/elastic/elastic-agent/internal/pkg/agent/vault/aesgcm" ) func getTestVaultPath(t *testing.T) string { @@ -25,26 +31,29 @@ func getTestVaultPath(t *testing.T) string { func TestVaultRekey(t *testing.T) { const key = "foo" + ctx, cn := context.WithCancel(context.Background()) + defer cn() + vaultPath := getTestVaultPath(t) - v, err := New(vaultPath) + v, err := New(ctx, vaultPath) if err != nil { t.Fatal(err) } defer v.Close() - err = v.Set(key, []byte("bar")) + err = v.Set(ctx, key, []byte("bar")) if err != nil { t.Fatal(err) } // Read seed file value seedPath := filepath.Join(vaultPath, ".seed") - seedBytes, err := ioutil.ReadFile(seedPath) + seedBytes, err := os.ReadFile(seedPath) if err != nil { t.Fatal(err) } - diff := cmp.Diff(int(AES256), len(seedBytes)) + diff := cmp.Diff(int(aesgcm.AES256), len(seedBytes)) if diff != "" { t.Fatal(diff) } @@ -57,14 +66,14 @@ func TestVaultRekey(t *testing.T) { } // The vault with the new seed - v2, err := New(vaultPath) + v2, err := New(ctx, vaultPath) if err != nil { t.Fatal(err) } defer v2.Close() // The key should be not found - _, err = v2.Get(key) + _, err = v2.Get(ctx, key) if !errors.Is(err, os.ErrNotExist) { t.Fatal(err) } @@ -73,7 +82,10 @@ func TestVaultRekey(t *testing.T) { func TestVault(t *testing.T) { vaultPath := getTestVaultPath(t) - v, err := New(vaultPath) + ctx, cn := context.WithCancel(context.Background()) + defer cn() + + v, err := New(ctx, vaultPath) if err != nil { t.Fatal(err) } @@ -95,7 +107,7 @@ func TestVault(t *testing.T) { // Test that keys do not exists for _, key := range keys { - exists, err := v.Exists(key) + exists, err := v.Exists(ctx, key) if err != nil { t.Fatal(err) } @@ -107,7 +119,7 @@ func TestVault(t *testing.T) { // Create keys, except the last one for i := 0; i < len(keys)-1; i++ { - err := v.Set(keys[i], []byte(vals[i])) + err := v.Set(ctx, keys[i], []byte(vals[i])) if err != nil { t.Fatal(err) } @@ -115,7 +127,7 @@ func TestVault(t *testing.T) { // Verify the keys that were created now exist for i := 0; i < len(keys)-1; i++ { - exists, err := v.Exists(keys[i]) + exists, err := v.Exists(ctx, keys[i]) if err != nil { t.Fatal(err) } @@ -127,7 +139,7 @@ func TestVault(t *testing.T) { // Verify the keys values for i := 0; i < len(keys)-1; i++ { - b, err := v.Get(keys[i]) + b, err := v.Get(ctx, keys[i]) if err != nil { t.Fatal(err) } @@ -138,7 +150,7 @@ func TestVault(t *testing.T) { } // Verify that the last key that was not creates still doesn't exists - exists, err := v.Exists(keys[len(keys)-1]) + exists, err := v.Exists(ctx, keys[len(keys)-1]) if err != nil { t.Fatal(err) } @@ -148,13 +160,13 @@ func TestVault(t *testing.T) { } // Delete the first key - err = v.Remove(keys[0]) + err = v.Remove(ctx, keys[0]) if err != nil { t.Fatal(err) } // Verify that just deleted key doesn't exist anymore - exists, err = v.Exists(keys[0]) + exists, err = v.Exists(ctx, keys[0]) if err != nil { t.Fatal(err) } @@ -164,3 +176,71 @@ func TestVault(t *testing.T) { t.Fatal(diff) } } + +type secret struct { + Value []byte `json:"v"` // binary value + CreatedOn time.Time `json:"t"` // date/time the secret was created on +} + +func TestVaultConcurrent(t *testing.T) { + const ( + parallel = 15 + iterations = 7 + + key = `secret` + ) + + vaultPath := getTestVaultPath(t) + + ctx, cn := context.WithCancel(context.Background()) + defer cn() + + for i := 0; i < iterations; i++ { + g, _ := errgroup.WithContext(context.Background()) + for j := 0; j < parallel; j++ { + g.Go(func() error { + return doCrud(t, ctx, vaultPath, key) + }) + } + err := g.Wait() + if err != nil { + t.Fatal(err) + } + } +} + +func doCrud(t *testing.T, ctx context.Context, vaultPath, key string) error { + v, err := New(ctx, vaultPath) + if err != nil { + return fmt.Errorf("could not create new vault: %w", err) + } + defer v.Close() + + // Create new AES256 key + k, err := aesgcm.NewKey(aesgcm.AES256) + if err != nil { + return err + } + + sec := secret{ + Value: k, + CreatedOn: time.Now().UTC(), + } + + b, err := json.Marshal(sec) + if err != nil { + return fmt.Errorf("could not marshal secret: %w", err) + } + + err = v.Set(ctx, key, b) + if err != nil { + return fmt.Errorf("failed to set secret: %w", err) + } + + _, err = v.Get(ctx, key) + if err != nil { + return fmt.Errorf("failed to get secret: %w", err) + } + + return nil +} diff --git a/internal/pkg/agent/vault/vault_options.go b/internal/pkg/agent/vault/vault_options.go index 2673ae6aa53..7570cff8bcf 100644 --- a/internal/pkg/agent/vault/vault_options.go +++ b/internal/pkg/agent/vault/vault_options.go @@ -4,25 +4,11 @@ package vault -type Options struct { - readonly bool -} - type OptionFunc func(o *Options) +// WithReadonly opens storage for read-only access only, noop for Darwin func WithReadonly(readonly bool) OptionFunc { return func(o *Options) { o.readonly = readonly } } - -//nolint:unused // not used on darwin -func applyOptions(opts ...OptionFunc) Options { - var options Options - - for _, opt := range opts { - opt(&options) - } - - return options -} diff --git a/internal/pkg/agent/vault/vault_options_darwin.go b/internal/pkg/agent/vault/vault_options_darwin.go new file mode 100644 index 00000000000..0600d02dfa2 --- /dev/null +++ b/internal/pkg/agent/vault/vault_options_darwin.go @@ -0,0 +1,11 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build darwin + +package vault + +type Options struct { + readonly bool +} diff --git a/internal/pkg/agent/vault/vault_options_nondarwin.go b/internal/pkg/agent/vault/vault_options_nondarwin.go new file mode 100644 index 00000000000..cc967d13d75 --- /dev/null +++ b/internal/pkg/agent/vault/vault_options_nondarwin.go @@ -0,0 +1,14 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +//go:build !darwin + +package vault + +import "time" + +type Options struct { + readonly bool + lockRetryDelay time.Duration +} diff --git a/internal/pkg/agent/vault/vault_windows.go b/internal/pkg/agent/vault/vault_windows.go index f09a4a55701..afd484d351a 100644 --- a/internal/pkg/agent/vault/vault_windows.go +++ b/internal/pkg/agent/vault/vault_windows.go @@ -7,132 +7,23 @@ package vault import ( - "errors" - "io/fs" - "io/ioutil" "os" - "path/filepath" - "sync" "github.com/billgraziano/dpapi" "github.com/hectane/go-acl" "golang.org/x/sys/windows" ) -type Vault struct { - path string - entropy []byte - mx sync.Mutex -} - -// Open initializes the vault store -func New(path string, opts ...OptionFunc) (v *Vault, err error) { - options := applyOptions(opts...) - dir := filepath.Dir(path) - - // If there is no specific path then get the executable directory - if dir == "." { - exefp, err := os.Executable() - if err != nil { - return nil, err - } - dir = filepath.Dir(exefp) - path = filepath.Join(dir, path) - } - - if options.readonly { - fi, err := os.Stat(path) - if err != nil { - return nil, err - } - if !fi.IsDir() { - return nil, fs.ErrNotExist - } - } else { - err := os.MkdirAll(path, 0750) - if err != nil { - return nil, err - } - err = systemAdministratorsOnly(path, false) - if err != nil { - return nil, err - } - } - - entropy, err := getOrCreateSeed(path, options.readonly) - if err != nil { - return nil, err - } - - return &Vault{ - path: path, - entropy: entropy, - }, nil -} - -// Close closes the valut store -// Noop on windows -func (v *Vault) Close() error { - return nil -} - -// Set stores the key in the vault store -func (v *Vault) Set(key string, data []byte) error { - enc, err := v.encrypt(data) - if err != nil { - return err - } - - v.mx.Lock() - defer v.mx.Unlock() - - return ioutil.WriteFile(v.filepathFromKey(key), enc, 0600) -} - -// Get retrieves the key from the vault store -func (v *Vault) Get(key string) ([]byte, error) { - v.mx.Lock() - defer v.mx.Unlock() - - enc, err := ioutil.ReadFile(v.filepathFromKey(key)) - if err != nil { - return nil, err - } - - return v.decrypt(enc) -} - -// Exists checks if the key exists -func (v *Vault) Exists(key string) (ok bool, err error) { - v.mx.Lock() - defer v.mx.Unlock() - - if _, err = os.Stat(v.filepathFromKey(key)); err == nil { - ok = true - } else if errors.Is(err, fs.ErrNotExist) { - err = nil - } - return ok, err -} - -// Remove removes the key -func (v *Vault) Remove(key string) error { - v.mx.Lock() - defer v.mx.Unlock() - - return os.RemoveAll(v.filepathFromKey(key)) -} - func (v *Vault) encrypt(data []byte) ([]byte, error) { - return dpapi.EncryptBytesMachineLocalEntropy(data, v.entropy) + return dpapi.EncryptBytesMachineLocalEntropy(data, v.seed) } func (v *Vault) decrypt(data []byte) ([]byte, error) { - return dpapi.DecryptBytesEntropy(data, v.entropy) + return dpapi.DecryptBytesEntropy(data, v.seed) } -func (v *Vault) filepathFromKey(key string) string { - return filepath.Join(v.path, fileNameFromKey(v.entropy, key)) +func tightenPermissions(path string) error { + return systemAdministratorsOnly(path, false) } func systemAdministratorsOnly(path string, inherit bool) error { @@ -152,3 +43,7 @@ func systemAdministratorsOnly(path string, inherit bool) error { acl.GrantSid(0xF10F0000, systemSID), // full control of all acl's acl.GrantSid(0xF10F0000, administratorsSID)) } + +func writeFile(fp string, data []byte) error { + return os.WriteFile(fp, data, 0600) +} diff --git a/internal/pkg/composable/controller.go b/internal/pkg/composable/controller.go index 5f042628293..6ae3220e59c 100644 --- a/internal/pkg/composable/controller.go +++ b/internal/pkg/composable/controller.go @@ -127,7 +127,7 @@ func (c *controller) Run(ctx context.Context) error { state.signal = stateChangedChan go func(name string, state *contextProviderState) { defer wg.Done() - err := state.provider.Run(state) + err := state.provider.Run(ctx, state) if err != nil && !errors.Is(err, context.Canceled) { err = errors.New(err, fmt.Sprintf("failed to run provider '%s'", name), errors.TypeConfig, errors.M("provider", name)) c.logger.Errorf("%s", err) diff --git a/internal/pkg/composable/providers/agent/agent.go b/internal/pkg/composable/providers/agent/agent.go index 2fb5bb284e5..a560d4cac5f 100644 --- a/internal/pkg/composable/providers/agent/agent.go +++ b/internal/pkg/composable/providers/agent/agent.go @@ -5,6 +5,8 @@ package agent import ( + "context" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/info" "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/composable" @@ -21,8 +23,8 @@ func init() { type contextProvider struct{} // Run runs the Agent context provider. -func (*contextProvider) Run(comm corecomp.ContextProviderComm) error { - a, err := info.NewAgentInfo(false) +func (*contextProvider) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { + a, err := info.NewAgentInfo(ctx, false) if err != nil { return err } diff --git a/internal/pkg/composable/providers/agent/agent_test.go b/internal/pkg/composable/providers/agent/agent_test.go index cd15e8058ea..5a4e9e2bd6a 100644 --- a/internal/pkg/composable/providers/agent/agent_test.go +++ b/internal/pkg/composable/providers/agent/agent_test.go @@ -23,8 +23,11 @@ func TestContextProvider(t *testing.T) { provider, err := builder(nil, nil, true) require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + comm := ctesting.NewContextComm(context.Background()) - err = provider.Run(comm) + err = provider.Run(ctx, comm) require.NoError(t, err) current := comm.Current() diff --git a/internal/pkg/composable/providers/env/env.go b/internal/pkg/composable/providers/env/env.go index ac6ef4be446..905f868712a 100644 --- a/internal/pkg/composable/providers/env/env.go +++ b/internal/pkg/composable/providers/env/env.go @@ -5,6 +5,7 @@ package env import ( + "context" "os" "strings" @@ -22,7 +23,7 @@ func init() { type contextProvider struct{} // Run runs the environment context provider. -func (*contextProvider) Run(comm corecomp.ContextProviderComm) error { +func (*contextProvider) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { err := comm.Set(getEnvMapping()) if err != nil { return errors.New(err, "failed to set mapping", errors.TypeUnexpected) diff --git a/internal/pkg/composable/providers/env/env_test.go b/internal/pkg/composable/providers/env/env_test.go index a03f37ee577..e4eafdf90cc 100644 --- a/internal/pkg/composable/providers/env/env_test.go +++ b/internal/pkg/composable/providers/env/env_test.go @@ -20,8 +20,11 @@ func TestContextProvider(t *testing.T) { provider, err := builder(nil, nil, true) require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + comm := ctesting.NewContextComm(context.Background()) - err = provider.Run(comm) + err = provider.Run(ctx, comm) require.NoError(t, err) assert.Equal(t, getEnvMapping(), comm.Current()) diff --git a/internal/pkg/composable/providers/host/host.go b/internal/pkg/composable/providers/host/host.go index 343d8d04488..d63189bf71a 100644 --- a/internal/pkg/composable/providers/host/host.go +++ b/internal/pkg/composable/providers/host/host.go @@ -5,6 +5,7 @@ package host import ( + "context" "fmt" "reflect" "runtime" @@ -47,7 +48,7 @@ type contextProvider struct { } // Run runs the environment context provider. -func (c *contextProvider) Run(comm corecomp.ContextProviderComm) error { +func (c *contextProvider) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { current, err := c.fetcher() if err != nil { return err diff --git a/internal/pkg/composable/providers/host/host_test.go b/internal/pkg/composable/providers/host/host_test.go index 319186f4e44..cfb74bd288c 100644 --- a/internal/pkg/composable/providers/host/host_test.go +++ b/internal/pkg/composable/providers/host/host_test.go @@ -56,7 +56,7 @@ func TestContextProvider(t *testing.T) { }) go func() { - _ = provider.Run(comm) + _ = provider.Run(ctx, comm) }() // wait for it to be called once @@ -114,9 +114,7 @@ func TestFQDNFeatureFlagToggle(t *testing.T) { }() ctx, cancel := context.WithCancel(context.Background()) - defer func() { - cancel() - }() + defer cancel() comm := ctesting.NewContextComm(ctx) calledChan := make(chan struct{}) @@ -130,7 +128,7 @@ func TestFQDNFeatureFlagToggle(t *testing.T) { // Run the provider go func() { - err = hostProvider.Run(comm) + err = hostProvider.Run(ctx, comm) }() // Trigger the FQDN feature flag callback by diff --git a/internal/pkg/composable/providers/kubernetesleaderelection/kubernetes_leaderelection.go b/internal/pkg/composable/providers/kubernetesleaderelection/kubernetes_leaderelection.go index 1fc6c7e958d..b1388005ab7 100644 --- a/internal/pkg/composable/providers/kubernetesleaderelection/kubernetes_leaderelection.go +++ b/internal/pkg/composable/providers/kubernetesleaderelection/kubernetes_leaderelection.go @@ -46,7 +46,7 @@ func ContextProviderBuilder(logger *logger.Logger, c *config.Config, managed boo } // Run runs the leaderelection provider. -func (p *contextProvider) Run(comm corecomp.ContextProviderComm) error { +func (p *contextProvider) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { client, err := kubernetes.GetKubernetesClient(p.config.KubeConfig, p.config.KubeClientOptions) if err != nil { // info only; return nil (do nothing) @@ -54,7 +54,7 @@ func (p *contextProvider) Run(comm corecomp.ContextProviderComm) error { return nil } - agentInfo, err := info.NewAgentInfo(false) + agentInfo, err := info.NewAgentInfo(ctx, false) if err != nil { return err } diff --git a/internal/pkg/composable/providers/kubernetessecrets/kubernetes_secrets.go b/internal/pkg/composable/providers/kubernetessecrets/kubernetes_secrets.go index 543d0cd6b28..1537a232dd1 100644 --- a/internal/pkg/composable/providers/kubernetessecrets/kubernetes_secrets.go +++ b/internal/pkg/composable/providers/kubernetessecrets/kubernetes_secrets.go @@ -91,7 +91,7 @@ func (p *contextProviderK8sSecrets) Fetch(key string) (string, bool) { } // Run initializes the k8s secrets context provider. -func (p *contextProviderK8sSecrets) Run(comm corecomp.ContextProviderComm) error { +func (p *contextProviderK8sSecrets) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { client, err := getK8sClientFunc(p.config.KubeConfig, p.config.KubeClientOptions) if err != nil { p.logger.Debugf("Kubernetes_secrets provider skipped, unable to connect: %s", err) diff --git a/internal/pkg/composable/providers/kubernetessecrets/kubernetes_secrets_test.go b/internal/pkg/composable/providers/kubernetessecrets/kubernetes_secrets_test.go index f633a9f062e..9924c84e6bc 100644 --- a/internal/pkg/composable/providers/kubernetessecrets/kubernetes_secrets_test.go +++ b/internal/pkg/composable/providers/kubernetessecrets/kubernetes_secrets_test.go @@ -66,7 +66,7 @@ func Test_K8sSecretsProvider_Fetch(t *testing.T) { comm := ctesting.NewContextComm(ctx) go func() { - _ = fp.Run(comm) + _ = fp.Run(ctx, comm) }() for { @@ -121,7 +121,7 @@ func Test_K8sSecretsProvider_FetchWrongSecret(t *testing.T) { comm := ctesting.NewContextComm(ctx) go func() { - _ = fp.Run(comm) + _ = fp.Run(ctx, comm) }() for { diff --git a/internal/pkg/composable/providers/local/local.go b/internal/pkg/composable/providers/local/local.go index b54e6142ee0..ac22e710f92 100644 --- a/internal/pkg/composable/providers/local/local.go +++ b/internal/pkg/composable/providers/local/local.go @@ -5,6 +5,7 @@ package local import ( + "context" "fmt" "github.com/elastic/elastic-agent/internal/pkg/agent/errors" @@ -23,7 +24,7 @@ type contextProvider struct { } // Run runs the environment context provider. -func (c *contextProvider) Run(comm corecomp.ContextProviderComm) error { +func (c *contextProvider) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { err := comm.Set(c.Mapping) if err != nil { return errors.New(err, "failed to set mapping", errors.TypeUnexpected) diff --git a/internal/pkg/composable/providers/local/local_test.go b/internal/pkg/composable/providers/local/local_test.go index dfec629b88a..1219b2864f8 100644 --- a/internal/pkg/composable/providers/local/local_test.go +++ b/internal/pkg/composable/providers/local/local_test.go @@ -29,8 +29,11 @@ func TestContextProvider(t *testing.T) { provider, err := builder(nil, cfg, true) require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + comm := ctesting.NewContextComm(context.Background()) - err = provider.Run(comm) + err = provider.Run(ctx, comm) require.NoError(t, err) assert.Equal(t, mapping, comm.Current()) diff --git a/internal/pkg/composable/providers/path/path.go b/internal/pkg/composable/providers/path/path.go index 389a21fe6bc..1900bd6dd92 100644 --- a/internal/pkg/composable/providers/path/path.go +++ b/internal/pkg/composable/providers/path/path.go @@ -5,6 +5,8 @@ package path import ( + "context" + "github.com/elastic/elastic-agent/internal/pkg/agent/application/paths" "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/composable" @@ -14,13 +16,13 @@ import ( ) func init() { - composable.Providers.AddContextProvider("path", ContextProviderBuilder) + _ = composable.Providers.AddContextProvider("path", ContextProviderBuilder) } type contextProvider struct{} // Run runs the Agent context provider. -func (*contextProvider) Run(comm corecomp.ContextProviderComm) error { +func (*contextProvider) Run(ctx context.Context, comm corecomp.ContextProviderComm) error { err := comm.Set(map[string]interface{}{ "home": paths.Home(), "data": paths.Data(), diff --git a/internal/pkg/composable/providers/path/path_test.go b/internal/pkg/composable/providers/path/path_test.go index 094865d3fbd..df9cac07eaf 100644 --- a/internal/pkg/composable/providers/path/path_test.go +++ b/internal/pkg/composable/providers/path/path_test.go @@ -21,8 +21,10 @@ func TestContextProvider(t *testing.T) { provider, err := builder(nil, nil, true) require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() comm := ctesting.NewContextComm(context.Background()) - err = provider.Run(comm) + err = provider.Run(ctx, comm) require.NoError(t, err) current := comm.Current() diff --git a/internal/pkg/config/operations/inspector.go b/internal/pkg/config/operations/inspector.go index a5f45d27c6b..f1066c22531 100644 --- a/internal/pkg/config/operations/inspector.go +++ b/internal/pkg/config/operations/inspector.go @@ -5,6 +5,7 @@ package operations import ( + "context" "fmt" "github.com/elastic/elastic-agent/internal/pkg/agent/application/info" @@ -25,8 +26,8 @@ var ( // LoadFullAgentConfig load agent config based on provided paths and defined capabilities. // In case fleet is used, config from policy action is returned. -func LoadFullAgentConfig(logger *logger.Logger, cfgPath string, failOnFleetMissing bool) (*config.Config, error) { - rawConfig, err := loadConfig(cfgPath) +func LoadFullAgentConfig(ctx context.Context, logger *logger.Logger, cfgPath string, failOnFleetMissing bool) (*config.Config, error) { + rawConfig, err := loadConfig(ctx, cfgPath) if err != nil { return nil, err } @@ -54,7 +55,7 @@ func LoadFullAgentConfig(logger *logger.Logger, cfgPath string, failOnFleetMissi return c, nil } - fleetConfig, err := loadFleetConfig(logger) + fleetConfig, err := loadFleetConfig(ctx, logger) if err != nil { return nil, err } else if fleetConfig == nil { @@ -74,7 +75,7 @@ func LoadFullAgentConfig(logger *logger.Logger, cfgPath string, failOnFleetMissi return rawConfig, nil } -func loadConfig(configPath string) (*config.Config, error) { +func loadConfig(ctx context.Context, configPath string) (*config.Config, error) { rawConfig, err := config.LoadFile(configPath) if err != nil { return nil, err @@ -82,7 +83,7 @@ func loadConfig(configPath string) (*config.Config, error) { path := paths.AgentConfigFile() - store := storage.NewEncryptedDiskStore(path) + store := storage.NewEncryptedDiskStore(ctx, path) reader, err := store.Load() if err != nil { return nil, errors.New(err, "could not initialize config store", @@ -108,8 +109,8 @@ func loadConfig(configPath string) (*config.Config, error) { return rawConfig, nil } -func loadFleetConfig(l *logger.Logger) (map[string]interface{}, error) { - stateStore, err := store.NewStateStoreWithMigration(l, paths.AgentActionStoreFile(), paths.AgentStateStoreFile()) +func loadFleetConfig(ctx context.Context, l *logger.Logger) (map[string]interface{}, error) { + stateStore, err := store.NewStateStoreWithMigration(ctx, l, paths.AgentActionStoreFile(), paths.AgentStateStoreFile()) if err != nil { return nil, err } diff --git a/internal/pkg/core/composable/providers.go b/internal/pkg/core/composable/providers.go index f6d2a8f3e26..1c4018fd479 100644 --- a/internal/pkg/core/composable/providers.go +++ b/internal/pkg/core/composable/providers.go @@ -26,7 +26,7 @@ type ContextProviderComm interface { // ContextProvider is the interface that a context provider must implement. type ContextProvider interface { // Run runs the context provider. - Run(ContextProviderComm) error + Run(context.Context, ContextProviderComm) error } // CloseableProvider is an interface that providers may choose to implement diff --git a/internal/pkg/testutils/testutils.go b/internal/pkg/testutils/testutils.go index fcd7cbbe2b6..4c3c0781cca 100644 --- a/internal/pkg/testutils/testutils.go +++ b/internal/pkg/testutils/testutils.go @@ -5,6 +5,7 @@ package testutils import ( + "context" "runtime" "testing" @@ -22,7 +23,7 @@ import ( func InitStorage(t *testing.T) { storage.DisableEncryptionDarwin() if runtime.GOOS != "darwin" { - err := secret.CreateAgentSecret() + err := secret.CreateAgentSecret(context.Background()) if err != nil { t.Fatal(err) } diff --git a/pkg/component/runtime/manager_test.go b/pkg/component/runtime/manager_test.go index ab81806e9a5..e489ad71056 100644 --- a/pkg/component/runtime/manager_test.go +++ b/pkg/component/runtime/manager_test.go @@ -68,7 +68,7 @@ func TestManager_SimpleComponentErr(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager( newDebugLogger(t), newDebugLogger(t), @@ -172,7 +172,7 @@ func TestManager_FakeInput_StartStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -291,7 +291,7 @@ func TestManager_FakeInput_Features(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - agentInfo, _ := info.NewAgentInfo(true) + agentInfo, _ := info.NewAgentInfo(ctx, true) m, err := NewManager( newDebugLogger(t), newDebugLogger(t), @@ -488,7 +488,7 @@ func TestManager_FakeInput_Limits(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - agentInfo, _ := info.NewAgentInfo(true) + agentInfo, _ := info.NewAgentInfo(ctx, true) m, err := NewManager( newDebugLogger(t), newDebugLogger(t), @@ -649,7 +649,7 @@ func TestManager_FakeShipper_Limits(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - agentInfo, _ := info.NewAgentInfo(true) + agentInfo, _ := info.NewAgentInfo(ctx, true) m, err := NewManager( newDebugLogger(t), newDebugLogger(t), @@ -810,7 +810,7 @@ func TestManager_FakeInput_BadUnitToGood(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -976,7 +976,7 @@ func TestManager_FakeInput_GoodUnitToBad(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -1132,7 +1132,7 @@ func TestManager_FakeInput_NoDeadlock(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -1266,7 +1266,7 @@ func TestManager_FakeInput_Configure(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -1386,7 +1386,7 @@ func TestManager_FakeInput_RemoveUnit(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -1539,7 +1539,7 @@ func TestManager_FakeInput_ActionState(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -1663,7 +1663,7 @@ func TestManager_FakeInput_Restarts(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -1798,7 +1798,7 @@ func TestManager_FakeInput_Restarts_ConfigKill(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -1940,7 +1940,7 @@ func TestManager_FakeInput_KeepsRestarting(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -2082,7 +2082,7 @@ func TestManager_FakeInput_RestartsOnMissedCheckins(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -2197,7 +2197,7 @@ func TestManager_FakeInput_InvalidAction(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -2314,7 +2314,7 @@ func TestManager_FakeInput_MultiComponent(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - agentInfo, _ := info.NewAgentInfo(true) + agentInfo, _ := info.NewAgentInfo(ctx, true) m, err := NewManager( newDebugLogger(t), newDebugLogger(t), @@ -2527,7 +2527,7 @@ func TestManager_FakeInput_LogLevel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager( newDebugLogger(t), newDebugLogger(t), @@ -2679,7 +2679,7 @@ func TestManager_FakeShipper(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager(newDebugLogger(t), newDebugLogger(t), "localhost:0", ai, apmtest.DiscardTracer, newTestMonitoringMgr(), configuration.DefaultGRPCConfig()) require.NoError(t, err) errCh := make(chan error) @@ -2971,7 +2971,7 @@ func TestManager_FakeInput_OutputChange(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ai, _ := info.NewAgentInfo(true) + ai, _ := info.NewAgentInfo(ctx, true) m, err := NewManager( newDebugLogger(t), newDebugLogger(t),