diff --git a/pkg/backup/constants/constants.go b/pkg/backup/constants/constants.go index 51a7112cc9..5506083dff 100644 --- a/pkg/backup/constants/constants.go +++ b/pkg/backup/constants/constants.go @@ -100,4 +100,7 @@ const ( ClusterBackupMeta = "clustermeta" ClusterRestoreMeta = "restoremeta" MetaFile = "backupmeta" + + // AWSRegionEnv is the aws region environment variable + AWSRegionEnv = "AWS_REGION" ) diff --git a/pkg/backup/restore/restore_manager_test.go b/pkg/backup/restore/restore_manager_test.go index 10e4d98a50..a36f8c6602 100644 --- a/pkg/backup/restore/restore_manager_test.go +++ b/pkg/backup/restore/restore_manager_test.go @@ -24,6 +24,7 @@ import ( "github.com/onsi/gomega" . "github.com/onsi/gomega" "github.com/pingcap/tidb-operator/pkg/apis/pingcap/v1alpha1" + "github.com/pingcap/tidb-operator/pkg/backup/constants" "github.com/pingcap/tidb-operator/pkg/backup/testutils" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -436,6 +437,8 @@ func TestBRRestoreByEBS(t *testing.T) { g.Expect(err).To(Succeed()) }() + os.Setenv(constants.AWSRegionEnv, "us-west-1") + for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/backup/util/aws_ebs.go b/pkg/backup/util/aws_ebs.go index 8656987ab6..65feb49dbc 100644 --- a/pkg/backup/util/aws_ebs.go +++ b/pkg/backup/util/aws_ebs.go @@ -14,7 +14,10 @@ package util import ( + "os" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ebs" "github.com/aws/aws-sdk-go/service/ebs/ebsiface" @@ -22,6 +25,7 @@ import ( "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/pingcap/errors" "github.com/pingcap/tidb-operator/pkg/apis/pingcap/v1alpha1" + "github.com/pingcap/tidb-operator/pkg/backup/constants" "go.uber.org/atomic" "golang.org/x/sync/errgroup" corev1 "k8s.io/api/core/v1" @@ -119,7 +123,17 @@ func NewEC2Session(concurrency uint) (*EC2Session, error) { if err != nil { return nil, errors.Trace(err) } - ec2Session := ec2.New(sess) + + region := os.Getenv(constants.AWSRegionEnv) + if region == "" { + ec2Metadata := ec2metadata.New(sess) + region, err = ec2Metadata.Region() + if err != nil { + return nil, errors.Annotate(err, "get ec2 region") + } + } + + ec2Session := ec2.New(sess, aws.NewConfig().WithRegion(region)) return &EC2Session{EC2: ec2Session, concurrency: concurrency}, nil } @@ -201,6 +215,15 @@ func NewEBSSession(concurrency uint) (*EBSSession, error) { if err != nil { return nil, errors.Trace(err) } - ebsSession := ebs.New(sess) + region := os.Getenv(constants.AWSRegionEnv) + if region == "" { + ec2Metadata := ec2metadata.New(sess) + region, err = ec2Metadata.Region() + if err != nil { + return nil, errors.Annotate(err, "get ec2 region") + } + } + + ebsSession := ebs.New(sess, aws.NewConfig().WithRegion(region)) return &EBSSession{EBS: ebsSession, concurrency: concurrency}, nil }