Skip to content

Commit

Permalink
Update IMDS consumers to use helper with fixed retry set
Browse files Browse the repository at this point in the history
  • Loading branch information
ndbaker1 committed Sep 20, 2024
1 parent 4bf3646 commit 07f12b8
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 62 deletions.
16 changes: 10 additions & 6 deletions nodeadm/cmd/nodeadm/init/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/integrii/flaggy"
"go.uber.org/zap"
"k8s.io/utils/strings/slices"

"github.com/awslabs/amazon-eks-ami/nodeadm/internal/api"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/aws/ecr"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/aws/imds"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/cli"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/configprovider"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/containerd"
Expand Down Expand Up @@ -146,14 +146,18 @@ func (c *initCmd) Run(log *zap.Logger, opts *cli.GlobalOptions) error {
// perform in-place updates when allowed by the user
func enrichConfig(log *zap.Logger, cfg *api.NodeConfig) error {
log.Info("Fetching instance details..")
imdsClient := imds.New(imds.Options{})
awsConfig, err := config.LoadDefaultConfig(context.TODO(), config.WithClientLogMode(aws.LogRetries), config.WithEC2IMDSRegion(func(o *config.UseEC2IMDSRegion) {
o.Client = imdsClient
}))
awsConfig, err := config.LoadDefaultConfig(context.TODO(),
config.WithClientLogMode(aws.LogRetries),
config.WithEC2IMDSRegion(func(o *config.UseEC2IMDSRegion) {
// Use our pre-configured IMDS client to avoid hitting common retry
// issues with the default config.
o.Client = imds.Client
}),
)
if err != nil {
return err
}
instanceDetails, err := api.GetInstanceDetails(context.TODO(), cfg.Spec.FeatureGates, imdsClient, ec2.NewFromConfig(awsConfig))
instanceDetails, err := api.GetInstanceDetails(context.TODO(), cfg.Spec.FeatureGates, ec2.NewFromConfig(awsConfig))
if err != nil {
return err
}
Expand Down
15 changes: 5 additions & 10 deletions nodeadm/internal/api/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,29 @@ package api
import (
"context"
"fmt"
"io"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2extra "github.com/awslabs/amazon-eks-ami/nodeadm/internal/aws/ec2"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/aws/imds"
)

// Fetch information about the ec2 instance using IMDS data.
// This information is stored into the internal config to avoid redundant calls
// to IMDS when looking for instance metadata
func GetInstanceDetails(ctx context.Context, featureGates map[Feature]bool, imdsClient *imds.Client, ec2Client *ec2.Client) (*InstanceDetails, error) {
instanceIdenitityDocument, err := imdsClient.GetInstanceIdentityDocument(ctx, &imds.GetInstanceIdentityDocumentInput{})
func GetInstanceDetails(ctx context.Context, featureGates map[Feature]bool, ec2Client *ec2.Client) (*InstanceDetails, error) {
instanceIdenitityDocument, err := imds.GetInstanceIdentityDocument(ctx)
if err != nil {
return nil, err
}

macResponse, err := imdsClient.GetMetadata(ctx, &imds.GetMetadataInput{Path: "mac"})
if err != nil {
return nil, err
}
mac, err := io.ReadAll(macResponse.Content)
mac, err := imds.GetProperty(ctx, "mac")
if err != nil {
return nil, err
}

privateDNSName := ""
var privateDNSName string
if !IsFeatureEnabled(InstanceIdNodeName, featureGates) {
privateDNSName, err = getPrivateDNSName(ec2Client, instanceIdenitityDocument.InstanceID)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion nodeadm/internal/aws/ecr/ecr.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (r *ECRRegistry) GetSandboxImage() string {

func GetEKSRegistry(region string) (ECRRegistry, error) {
account, region := getEKSRegistryCoordinates(region)
servicesDomain, err := imds.GetProperty(imds.ServicesDomain)
servicesDomain, err := imds.GetProperty(context.TODO(), imds.ServicesDomain)
if err != nil {
return "", err
}
Expand Down
40 changes: 31 additions & 9 deletions nodeadm/internal/aws/imds/imds.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,70 @@ package imds
import (
"context"
"io"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
)

var client *imds.Client
var Client *imds.Client

func init() {
client = imds.New(imds.Options{
Client = imds.New(imds.Options{
DisableDefaultTimeout: true,
Retryer: retry.NewStandard(func(so *retry.StandardOptions) {
so.MaxAttempts = 15
so.MaxBackoff = 1 * time.Second
so.Retryables = append(so.Retryables,
&instanceCredential404{},
)
}),
})
}

type instanceCredential404 struct{}

func (e *instanceCredential404) IsErrorRetryable(err error) aws.Ternary {
// Theres not a specific error type that i can see, as this request failed
// due to missing instance credentials. The events are processed as follows:
// 1. https://github.com/aws/aws-sdk-go-v2/blob/06150d96305d6b6c19db0a2e5d1c1f4fa4a95612/feature/ec2/imds/request_middleware.go#L189
// 2. https://github.com/aws/aws-sdk-go-v2/blob/06150d96305d6b6c19db0a2e5d1c1f4fa4a95612/aws/retry/middleware.go#L238-L242
if strings.Contains(err.Error(), "request to EC2 IMDS failed") {
return aws.TrueTernary
}
return aws.UnknownTernary
}

type IMDSProperty string

const (
ServicesDomain IMDSProperty = "services/domain"
)

func GetUserData() ([]byte, error) {
resp, err := client.GetUserData(context.TODO(), &imds.GetUserDataInput{})
func GetInstanceIdentityDocument(ctx context.Context) (*imds.GetInstanceIdentityDocumentOutput, error) {
return Client.GetInstanceIdentityDocument(ctx, &imds.GetInstanceIdentityDocumentInput{})
}

func GetUserData(ctx context.Context) ([]byte, error) {
res, err := Client.GetUserData(ctx, &imds.GetUserDataInput{})
if err != nil {
return nil, err
}
return io.ReadAll(resp.Content)
return io.ReadAll(res.Content)
}

func GetProperty(prop IMDSProperty) (string, error) {
bytes, err := GetPropertyBytes(prop)
func GetProperty(ctx context.Context, prop IMDSProperty) (string, error) {
bytes, err := GetPropertyBytes(ctx, prop)
if err != nil {
return "", err
}
return string(bytes), nil
}

func GetPropertyBytes(prop IMDSProperty) ([]byte, error) {
res, err := client.GetMetadata(context.TODO(), &imds.GetMetadataInput{Path: string(prop)})
func GetPropertyBytes(ctx context.Context, prop IMDSProperty) ([]byte, error) {
res, err := Client.GetMetadata(ctx, &imds.GetMetadataInput{Path: string(prop)})
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion nodeadm/internal/configprovider/userdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package configprovider
import (
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"fmt"
"io"
Expand Down Expand Up @@ -31,7 +32,7 @@ type userDataProvider interface {
type imdsUserDataProvider struct{}

func (p *imdsUserDataProvider) GetUserData() ([]byte, error) {
return imds.GetUserData()
return imds.GetUserData(context.TODO())
}

type userDataConfigProvider struct {
Expand Down
35 changes: 11 additions & 24 deletions nodeadm/internal/kubelet/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
_ "embed"
"encoding/json"
"fmt"
"io"
"net"
"net/url"
"os"
Expand All @@ -20,10 +19,10 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
k8skubelet "k8s.io/kubelet/config/v1beta1"

"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/smithy-go/ptr"

"github.com/awslabs/amazon-eks-ami/nodeadm/internal/api"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/aws/imds"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/containerd"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/system"
"github.com/awslabs/amazon-eks-ami/nodeadm/internal/util"
Expand Down Expand Up @@ -203,7 +202,7 @@ func (ksc *kubeletConfig) withOutpostSetup(cfg *api.NodeConfig) error {
}

func (ksc *kubeletConfig) withNodeIp(cfg *api.NodeConfig, flags map[string]string) error {
nodeIp, err := getNodeIp(context.TODO(), imds.New(imds.Options{}), cfg)
nodeIp, err := getNodeIp(context.TODO(), cfg)
if err != nil {
return err
}
Expand Down Expand Up @@ -262,11 +261,11 @@ func (ksc *kubeletConfig) withCloudProvider(kubeletVersion string, cfg *api.Node
func (ksc *kubeletConfig) withDefaultReservedResources(cfg *api.NodeConfig) {
ksc.SystemReservedCgroup = ptr.String("/system")
ksc.KubeReservedCgroup = ptr.String("/runtime")
maxPods, ok := MaxPodsPerInstanceType[cfg.Status.Instance.Type]
if !ok {
ksc.MaxPods = CalcMaxPods(cfg.Status.Instance.Region, cfg.Status.Instance.Type)
} else {
if maxPods, ok := MaxPodsPerInstanceType[cfg.Status.Instance.Type]; ok {
// #nosec G115 // known source from ec2 apis within int32 range
ksc.MaxPods = int32(maxPods)
} else {
ksc.MaxPods = CalcMaxPods(cfg.Status.Instance.Region, cfg.Status.Instance.Type)
}
ksc.KubeReserved = map[string]string{
"cpu": fmt.Sprintf("%dm", getCPUMillicoresToReserve()),
Expand Down Expand Up @@ -407,36 +406,24 @@ func getProviderId(availabilityZone, instanceId string) string {
}

// Get the IP of the node depending on the ipFamily configured for the cluster
func getNodeIp(ctx context.Context, imdsClient *imds.Client, cfg *api.NodeConfig) (string, error) {
func getNodeIp(ctx context.Context, cfg *api.NodeConfig) (string, error) {
ipFamily, err := api.GetCIDRIpFamily(cfg.Spec.Cluster.CIDR)
if err != nil {
return "", err
}
switch ipFamily {
case api.IPFamilyIPv4:
ipv4Response, err := imdsClient.GetMetadata(ctx, &imds.GetMetadataInput{
Path: "local-ipv4",
})
if err != nil {
return "", err
}
ip, err := io.ReadAll(ipv4Response.Content)
ipv4, err := imds.GetProperty(ctx, "local-ipv4")
if err != nil {
return "", err
}
return string(ip), nil
return ipv4, nil
case api.IPFamilyIPv6:
ipv6Response, err := imdsClient.GetMetadata(ctx, &imds.GetMetadataInput{
Path: fmt.Sprintf("network/interfaces/macs/%s/ipv6s", cfg.Status.Instance.MAC),
})
if err != nil {
return "", err
}
ip, err := io.ReadAll(ipv6Response.Content)
ipv6, err := imds.GetProperty(ctx, imds.IMDSProperty(fmt.Sprintf("network/interfaces/macs/%s/ipv6s", cfg.Status.Instance.MAC)))
if err != nil {
return "", err
}
return string(ip), nil
return ipv6, nil
default:
return "", fmt.Errorf("invalid ip-family. %s is not one of %v", ipFamily, []api.IPFamily{api.IPFamilyIPv4, api.IPFamilyIPv6})
}
Expand Down
22 changes: 11 additions & 11 deletions nodeadm/internal/system/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ const (
)

type core struct {
Id int `json:"core_id"`
Threads []int `json:"thread_ids"`
SocketID int `json:"socket_id"`
Id int `json:"core_id"`
Threads []uint64 `json:"thread_ids"`
SocketID int `json:"socket_id"`
}

func init() {
Expand Down Expand Up @@ -168,7 +168,7 @@ func getCoresInfo(cpuDirs []string) ([]core, error) {
desiredCore.SocketID = physicalPackageID

if len(desiredCore.Threads) == 0 {
desiredCore.Threads = []int{cpuID}
desiredCore.Threads = []uint64{cpuID}
} else {
desiredCore.Threads = append(desiredCore.Threads, cpuID)
}
Expand All @@ -177,12 +177,12 @@ func getCoresInfo(cpuDirs []string) ([]core, error) {
return cores, nil
}

func getCPUID(str string) (int, error) {
func getCPUID(str string) (uint64, error) {
matches := cpuDirRegExp.FindStringSubmatch(str)
if len(matches) != 2 {
return 0, fmt.Errorf("failed to match regexp, str: %s", str)
}
valInt, err := strconv.Atoi(matches[1])
valInt, err := strconv.ParseUint(matches[1], 10, 16)
if err != nil {
return 0, err
}
Expand All @@ -199,7 +199,7 @@ func getCoreID(cpuPath string) (string, error) {
return strings.TrimSpace(string(coreID)), err
}

func IsCPUOnline(cpuID int) bool {
func IsCPUOnline(cpuID uint64) bool {
cpuOnlinePath, err := filepath.Abs(cpusPath + "/online")
if err != nil {
zap.L().Info("Unable to get absolute path", zap.String("absolutPath", cpusPath+"/online"))
Expand All @@ -217,15 +217,15 @@ func IsCPUOnline(cpuID int) bool {
zap.Error(err))
}

isOnline, err := isCpuOnline(cpuOnlinePath, uint16(cpuID))
isOnline, err := isCpuOnline(cpuOnlinePath, cpuID)
if err != nil {
zap.L().Error("Unable to get online CPUs list", zap.Error(err))
return false
}
return isOnline
}

func isCpuOnline(path string, cpuID uint16) (bool, error) {
func isCpuOnline(path string, cpuID uint64) (bool, error) {
// #nosec G304 // This path is cpuOnlinePath from isCPUOnline
fileContent, err := os.ReadFile(path)
if err != nil {
Expand Down Expand Up @@ -254,15 +254,15 @@ func isCpuOnline(path string, cpuID uint16) (bool, error) {
return false, fmt.Errorf("invalid values in %s", path)
}
// Return true, if the CPU under consideration is in the range of online CPUs.
if cpuID >= uint16(min) && cpuID <= uint16(max) {
if cpuID >= min && cpuID <= max {
return true, nil
}
case 1:
value, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return false, err
}
if uint16(value) == cpuID {
if value == cpuID {
return true, nil
}
}
Expand Down

0 comments on commit 07f12b8

Please sign in to comment.