Skip to content

Commit

Permalink
Add Managed Identity option to azure storage
Browse files Browse the repository at this point in the history
updates
  • Loading branch information
m7hm7t committed Jun 3, 2024
1 parent 6dd09cd commit 0b5da6e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 9 deletions.
44 changes: 38 additions & 6 deletions pkg/storages/azure/folder.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
BufferSizeSetting = "AZURE_BUFFER_SIZE"
MaxBuffersSetting = "AZURE_MAX_BUFFERS"
TryTimeoutSetting = "AZURE_TRY_TIMEOUT"
ClientIDSetting = "AZURE_CLIENT_ID"
minBufferSize = 1024
defaultBufferSize = 8 * 1024 * 1024
minBuffers = 1
Expand All @@ -40,8 +41,9 @@ const (
type AzureAuthType string

const (
AzureAccessKeyAuth AzureAuthType = "AzureAccessKeyAuth"
AzureSASTokenAuth AzureAuthType = "AzureSASTokenAuth"
AzureAccessKeyAuth AzureAuthType = "AzureAccessKeyAuth"
AzureSASTokenAuth AzureAuthType = "AzureSASTokenAuth"
AzureManagedIdentityAuth AzureAuthType = "AzureManagedIdentityAuth"
)

var SettingList = []string{
Expand All @@ -52,6 +54,7 @@ var SettingList = []string{
EndpointSuffix,
BufferSizeSetting,
MaxBuffersSetting,
ClientIDSetting,
}

func NewFolderError(err error, format string, args ...interface{}) storage.Error {
Expand All @@ -78,6 +81,31 @@ func NewFolder(
}
}

func getContainerClientWithManagedIndetity(
accountName string,
storageEndpointSuffix string,
containerName string,
timeout time.Duration,
clientID string) (*azblob.ContainerClient, error) {
cred, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(clientID),
})
if err != nil {
return nil, err
}

containerURLString := fmt.Sprintf("https://%s.blob.%s/%s", accountName, storageEndpointSuffix, containerName)
_, err = url.Parse(containerURLString)
if err != nil {
return nil, NewFolderError(err, "Unable to parse service URL")
}

containerClient, err := azblob.NewContainerClient(containerURLString, cred, &azblob.ClientOptions{
Retry: policy.RetryOptions{TryTimeout: timeout},
})
return containerClient, err
}

func getContainerClientWithSASToken(
accountName string,
storageEndpointSuffix string,
Expand Down Expand Up @@ -136,9 +164,9 @@ func getContainerClient(
return containerClient, err
}

func configureAuthType(settings map[string]string) (AzureAuthType, string, string) {
func configureAuthType(settings map[string]string) (AzureAuthType, string, string, string) {
var ok bool
var accountToken, accessKey string
var accountToken, accessKey, clientID string
var authType AzureAuthType

if accessKey, ok = settings[AccessKeySetting]; ok {
Expand All @@ -149,9 +177,11 @@ func configureAuthType(settings map[string]string) (AzureAuthType, string, strin
if !strings.HasPrefix(accountToken, "?") {
accountToken = "?" + accountToken
}
} else if clientID, ok = settings[ClientIDSetting]; ok {
authType = AzureManagedIdentityAuth
}

return authType, accountToken, accessKey
return authType, accountToken, accessKey, clientID
}

func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder, error) {
Expand All @@ -161,7 +191,7 @@ func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder,
return nil, NewCredentialError(AccountSetting)
}

authType, accountToken, accountKey := configureAuthType(settings)
authType, accountToken, accountKey, clientID := configureAuthType(settings)

var credential *azblob.SharedKeyCredential
var err error
Expand Down Expand Up @@ -199,6 +229,8 @@ func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder,
var containerClient *azblob.ContainerClient
if authType == AzureSASTokenAuth {
containerClient, err = getContainerClientWithSASToken(accountName, storageEndpointSuffix, containerName, timeout, accountToken)
} else if authType == AzureManagedIdentityAuth {
containerClient, err = getContainerClientWithManagedIndetity(accountName, storageEndpointSuffix, containerName, timeout, clientID)
} else if authType == AzureAccessKeyAuth {
containerClient, err = getContainerClientWithAccessKey(accountName, storageEndpointSuffix, containerName, timeout, credential)
} else {
Expand Down
30 changes: 27 additions & 3 deletions pkg/storages/azure/folder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package azure

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/wal-g/wal-g/pkg/storages/storage"
Expand All @@ -22,24 +23,47 @@ var ConfigureAuthType = configureAuthType

func TestConfigureAccessKeyAuthType(t *testing.T) {
settings := map[string]string{AccessKeySetting: "foo"}
authType, accountToken, accessKey := ConfigureAuthType(settings)
authType, accountToken, accessKey, clientID := ConfigureAuthType(settings)
assert.Equal(t, authType, AzureAccessKeyAuth)
assert.Empty(t, accountToken)
assert.Equal(t, accessKey, "foo")
assert.Empty(t, clientID)
}

func TestConfigureSASTokenAuth(t *testing.T) {
settings := map[string]string{SasTokenSetting: "foo"}
authType, accountToken, accessKey := ConfigureAuthType(settings)
authType, accountToken, accessKey, clientID := ConfigureAuthType(settings)
assert.Equal(t, authType, AzureSASTokenAuth)
assert.Equal(t, accountToken, "?foo")
assert.Empty(t, accessKey)
assert.Empty(t, clientID)
}

func TestConfigureDefaultAuth(t *testing.T) {
settings := make(map[string]string)
authType, accountToken, accessKey := ConfigureAuthType(settings)
authType, accountToken, accessKey, clientID := ConfigureAuthType(settings)
assert.Empty(t, authType)
assert.Empty(t, accountToken)
assert.Empty(t, accessKey)
assert.Empty(t, clientID)
}

func TestConfigureManagedIdentityAuth(t *testing.T) {
settings := map[string]string{ClientIDSetting: "foo"}
authType, accountToken, accessKey, clientID := ConfigureAuthType(settings)
assert.Equal(t, authType, AzureManagedIdentityAuth)
assert.Empty(t, accountToken)
assert.Empty(t, accessKey)
assert.Equal(t, clientID, "foo")
}
func TestGetContainerClientWithManagedIdentity(t *testing.T) {
accountName := "test-account"
storageEndpointSuffix := "test-endpoint"
containerName := "test-container"
timeout := time.Minute
clientID := "test-client-id"

containerClient, err := getContainerClientWithManagedIndetity(accountName, storageEndpointSuffix, containerName, timeout, clientID)
assert.NoError(t, err)
assert.NotNil(t, containerClient)
}

0 comments on commit 0b5da6e

Please sign in to comment.