Skip to content

Commit

Permalink
Fix: Secondary fallback is not used for agent with local agent (#3518)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalpristas authored Oct 5, 2023
1 parent 6636206 commit 7f842a3
Show file tree
Hide file tree
Showing 11 changed files with 285 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Kind can be one of:
# - breaking-change: a change to previously-documented behavior
# - deprecation: functionality that is being removed in a later release
# - bug-fix: fixes a problem in a previous version
# - enhancement: extends functionality but does not break or fix existing behavior
# - feature: new functionality
# - known-issue: problems that we are aware of in a given version
# - security: impacts on the security of a product or a user’s deployment.
# - upgrade: important information for someone upgrading from a prior version
# - other: does not fit into any of the other categories
kind: feature

# Change summary; a 80ish characters long description of the change.
summary: Secondary fallback for package signature verification

# Long description; in case the summary is not enough to describe the change
# this field accommodate a description without length limits.
description: Ability to upgrade securely in Air gapped environment where fleet server is the only reachable URI.

# Affected component; a word indicating the component this changeset affects.
component: elastic-agent

# PR number; optional; the PR number that added the changeset.
# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added.
# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number.
# Please provide it if you are adding a fragment for a different PR.
pr: https://github.com/elastic/elastic-agent/pull/3453

# Issue number; optional; the GitHub issue related to this changeset (either closes or is part of).
# If not present is automatically filled by the tooling with the issue linked to the PR number.
issue: https://github.com/elastic/elastic-agent/issues/3264
2 changes: 1 addition & 1 deletion internal/pkg/agent/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func New(
EndpointSignedComponentModifier(),
)

managed, err = newManagedConfigManager(ctx, log, agentInfo, cfg, store, runtime, fleetInitTimeout)
managed, err = newManagedConfigManager(ctx, log, agentInfo, cfg, store, runtime, fleetInitTimeout, upgrader)
if err != nil {
return nil, nil, nil, err
}
Expand Down
60 changes: 37 additions & 23 deletions internal/pkg/agent/application/managed_mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/elastic/elastic-agent-client/v7/pkg/client"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/actions"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/actions/handlers"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/coordinator"
"github.com/elastic/elastic-agent/internal/pkg/agent/application/dispatcher"
Expand Down Expand Up @@ -38,17 +39,18 @@ import (
const dispatchFlushInterval = time.Minute * 5

type managedConfigManager struct {
log *logger.Logger
agentInfo *info.AgentInfo
cfg *configuration.Configuration
client *remote.Client
store storage.Store
stateStore *store.StateStore
actionQueue *queue.ActionQueue
dispatcher *dispatcher.ActionDispatcher
runtime *runtime.Manager
coord *coordinator.Coordinator
fleetInitTimeout time.Duration
log *logger.Logger
agentInfo *info.AgentInfo
cfg *configuration.Configuration
client *remote.Client
store storage.Store
stateStore *store.StateStore
actionQueue *queue.ActionQueue
dispatcher *dispatcher.ActionDispatcher
runtime *runtime.Manager
coord *coordinator.Coordinator
fleetInitTimeout time.Duration
initialClientSetters []actions.ClientSetter

ch chan coordinator.ConfigChange
errCh chan error
Expand All @@ -62,6 +64,7 @@ func newManagedConfigManager(
storeSaver storage.Store,
runtime *runtime.Manager,
fleetInitTimeout time.Duration,
clientSetters ...actions.ClientSetter,
) (*managedConfigManager, error) {
client, err := fleetclient.NewAuthWithConfig(log, cfg.Fleet.AccessAPIKey, cfg.Fleet.Client)
if err != nil {
Expand All @@ -88,18 +91,19 @@ func newManagedConfigManager(
}

return &managedConfigManager{
log: log,
agentInfo: agentInfo,
cfg: cfg,
client: client,
store: storeSaver,
stateStore: stateStore,
actionQueue: actionQueue,
dispatcher: actionDispatcher,
runtime: runtime,
fleetInitTimeout: fleetInitTimeout,
ch: make(chan coordinator.ConfigChange),
errCh: make(chan error),
log: log,
agentInfo: agentInfo,
cfg: cfg,
client: client,
store: storeSaver,
stateStore: stateStore,
actionQueue: actionQueue,
dispatcher: actionDispatcher,
runtime: runtime,
fleetInitTimeout: fleetInitTimeout,
ch: make(chan coordinator.ConfigChange),
errCh: make(chan error),
initialClientSetters: clientSetters,
}, nil
}

Expand Down Expand Up @@ -195,6 +199,16 @@ func (m *managedConfigManager) Run(ctx context.Context) error {
if m.cfg.Fleet.Server == nil {
policyChanger.AddSetter(gateway)
policyChanger.AddSetter(ack)

for _, cs := range m.initialClientSetters {
policyChanger.AddSetter(cs)
}
} else {
// locally managed fleet server
// init with local address
for _, cs := range m.initialClientSetters {
cs.SetClient(m.client)
}
}

// Proxy errors from the gateway to our own channel.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (v *Verifier) verifyAsc(fullPath string, skipDefaultPgp bool, pgpSources ..
if len(check) == 0 {
continue
}
raw, err := download.PgpBytesFromSource(v.log, check, v.client)
raw, err := download.PgpBytesFromSource(v.log, check, &v.client)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (v *Verifier) verifyAsc(a artifact.Artifact, version string, skipDefaultPgp
if len(check) == 0 {
continue
}
raw, err := download.PgpBytesFromSource(v.log, check, v.client)
raw, err := download.PgpBytesFromSource(v.log, check, &v.client)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
var (
ErrRemotePGPDownloadFailed = errors.New("Remote PGP download failed")
ErrInvalidLocation = errors.New("Remote PGP location is invalid")
ErrUnknownPGPSource = errors.New("unknown pgp source")
)

// warnLogger is a logger that only needs to implement Warnf, as that is the only functions
Expand Down Expand Up @@ -180,7 +181,7 @@ func VerifyGPGSignature(file string, asciiArmorSignature, publicKey []byte) erro
return nil
}

func PgpBytesFromSource(log warnLogger, source string, client http.Client) ([]byte, error) {
func PgpBytesFromSource(log warnLogger, source string, client HTTPClient) ([]byte, error) {
if strings.HasPrefix(source, PgpSourceRawPrefix) {
return []byte(strings.TrimPrefix(source, PgpSourceRawPrefix)), nil
}
Expand All @@ -189,11 +190,14 @@ func PgpBytesFromSource(log warnLogger, source string, client http.Client) ([]by
pgpBytes, err := fetchPgpFromURI(strings.TrimPrefix(source, PgpSourceURIPrefix), client)
if errors.Is(err, ErrRemotePGPDownloadFailed) || errors.Is(err, ErrInvalidLocation) {
log.Warnf("Skipped remote PGP located at %q because it's unavailable: %v", strings.TrimPrefix(source, PgpSourceURIPrefix), err)
} else if err != nil {
log.Warnf("Failed to fetch remote PGP")
}

return pgpBytes, nil
}

return nil, errors.New("unknown pgp source")
return nil, ErrUnknownPGPSource
}

func CheckValidDownloadUri(rawURI string) error {
Expand All @@ -209,7 +213,7 @@ func CheckValidDownloadUri(rawURI string) error {
return nil
}

func fetchPgpFromURI(uri string, client http.Client) ([]byte, error) {
func fetchPgpFromURI(uri string, client HTTPClient) ([]byte, error) {
if err := CheckValidDownloadUri(uri); err != nil {
return nil, err
}
Expand All @@ -221,7 +225,7 @@ func fetchPgpFromURI(uri string, client http.Client) ([]byte, error) {
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return nil, multierror.Append(err, ErrRemotePGPDownloadFailed)
}
Expand All @@ -233,3 +237,7 @@ func fetchPgpFromURI(uri string, client http.Client) ([]byte, error) {

return ioutil.ReadAll(resp.Body)
}

type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package download

import (
"bytes"
"io"
"net/http"
"testing"

"github.com/stretchr/testify/require"

"github.com/elastic/elastic-agent/internal/pkg/agent/errors"
"github.com/elastic/elastic-agent/pkg/core/logger"
)

func TestPgpBytesFromSource(t *testing.T) {
testCases := []struct {
Name string
Source string
ClientDoErr error
ClientBody []byte
ClientStatus int

ExpectedPGP []byte
ExpectedErr error
ExpectedLogMessage string
}{
{
"successful call",
PgpSourceURIPrefix + "https://location/path",
nil,
[]byte("pgp-body"),
200,
[]byte("pgp-body"),
nil,
"",
},
{
"unknown source call",
"https://location/path",
nil,
[]byte("pgp-body"),
200,
nil,
ErrUnknownPGPSource,
"",
},
{
"invalid location is filtered call",
PgpSourceURIPrefix + "http://location/path",
nil,
[]byte("pgp-body"),
200,
nil,
nil,
"Skipped remote PGP located ",
},
{
"do error is filtered",
PgpSourceURIPrefix + "https://location/path",
errors.New("error"),
[]byte("pgp-body"),
200,
nil,
nil,
"Skipped remote PGP located",
},
{
"invalid status code is filtered out",
PgpSourceURIPrefix + "https://location/path",
nil,
[]byte("pgp-body"),
500,
nil,
nil,
"Failed to fetch remote PGP",
},
{
"invalid status code is filtered out",
PgpSourceURIPrefix + "https://location/path",
nil,
[]byte("pgp-body"),
404,
nil,
nil,
"Failed to fetch remote PGP",
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
log, obs := logger.NewTesting(tc.Name)
mockClient := &MockClient{
DoFunc: func(req *http.Request) (*http.Response, error) {
if tc.ClientDoErr != nil {
return nil, tc.ClientDoErr
}

return &http.Response{
StatusCode: tc.ClientStatus,
Body: io.NopCloser(bytes.NewReader(tc.ClientBody)),
}, nil
},
}

resPgp, resErr := PgpBytesFromSource(log, tc.Source, mockClient)
require.Equal(t, tc.ExpectedErr, resErr)
require.Equal(t, tc.ExpectedPGP, resPgp)
if tc.ExpectedLogMessage != "" {
logs := obs.FilterMessageSnippet(tc.ExpectedLogMessage)
require.NotEqual(t, 0, logs.Len())
}

})
}
}

type MockClient struct {
DoFunc func(req *http.Request) (*http.Response, error)
}

func (m *MockClient) Do(req *http.Request) (*http.Response, error) {
return m.DoFunc(req)
}
Loading

0 comments on commit 7f842a3

Please sign in to comment.