Skip to content

Commit

Permalink
Add msal cache 1.0 -> 1.1 upgrade (#2664)
Browse files Browse the repository at this point in the history
When retrieving cache entries, automatically upgrade cache entries that were saved in msal cache v1.0 format to v1.1 format to avoid stale cache entries causing random auth failures.

For users that run both old and newer versions of azd (and consequently the msal libraries), both versions of the cli will work with the same cache file. In some corner cases where the old CLI has been most recently used to login, the refresh token may be ignored by the new cli with requires reauthentication. Otherwise, authentication is expected to work as normal without any unintentional side-effects.

Related to AzureAD/microsoft-authentication-library-for-go#453

Fixes #2659
  • Loading branch information
weikanglim authored Aug 29, 2023
1 parent 019d75b commit b8f37f7
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 0 deletions.
1 change: 1 addition & 0 deletions cli/azd/.vscode/cspell-azd-dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ pyproject
pyvenv
reauthentication
relogin
remarshal
restoreapp
retriable
rzip
Expand Down
85 changes: 85 additions & 0 deletions cli/azd/pkg/auth/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,24 @@ package auth

import (
"context"
"encoding/json"
"errors"
"log"
"strings"
"unicode"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
)

// Known entries from msal cache contract. This is not an exhaustive list.
var contractFields = []string{
"AccessToken",
"RefreshToken",
"IdToken",
"Account",
"AppMetadata",
}

// The MSAL cache key for the current user. The stored MSAL cached data contains
// all accounts with stored credentials, across all tenants.
// Currently, the underlying MSAL cache data is represented as [Contract] inside the library.
Expand All @@ -36,6 +49,46 @@ func (a *msalCacheAdapter) Replace(ctx context.Context, cache cache.Unmarshaler,
return err
}

// In msal v1.0, keys were stored with mixed casing; in v1.1., it was changed to lower case.
// This handles upgrades where we have a v1.0 cache, and we need to convert it to v1.1,
// by normalizing the appropriate key entries.
c := map[string]json.RawMessage{}
if err = json.Unmarshal(val, &c); err == nil {
for _, contractKey := range contractFields {
if _, found := c[contractKey]; found {
msg := []byte(c[contractKey])
inner := map[string]json.RawMessage{}

err := json.Unmarshal(msg, &inner)
if err != nil {
log.Printf("msal-upgrade: failed to unmarshal inner: %v", err)
continue
}

updated := normalizeKeys(inner)
if !updated {
continue
}

newMsg, err := json.Marshal(inner)
if err != nil {
log.Printf("msal-upgrade: failed to remarshal inner: %v", err)
continue
}

c[contractKey] = json.RawMessage(newMsg)
}
}

if newVal, err := json.Marshal(c); err == nil {
val = newVal
} else {
log.Printf("msal-upgrade: failed to remarshal msal cache: %v", err)
}
} else {
log.Printf("msal-upgrade: failed to unmarshal msal cache: %v", err)
}

// Replace the msal cache contents with the new value retrieved.
if err := cache.Unmarshal(val); err != nil {
return err
Expand All @@ -52,6 +105,38 @@ func (a *msalCacheAdapter) Export(ctx context.Context, cache cache.Marshaler, _
return a.cache.Set(cCurrentUserCacheKey, val)
}

// Normalize keys by removing upper-case keys and replacing them with lower-case keys.
// In the case where a lower-case key and upper-case key exists, the lower-case key entry
// takes precedence.
func normalizeKeys(m map[string]json.RawMessage) (normalized bool) {
for k, v := range m {
if hasUpper(k) {
// An upper-case key entry exists. Delete it as it is no longer allowed.
delete(m, k)

// If a lower-case key entry exists, that supersedes it and we are done.
// Otherwise, we can safely upgrade the cache entry by re-adding it with lower case.
lower := strings.ToLower(k)
if _, isLower := m[lower]; !isLower {
m[lower] = v
}

normalized = true
}
}

return normalized
}

func hasUpper(s string) bool {
for _, r := range s {
if unicode.IsUpper(r) && unicode.IsLetter(r) {
return true
}
}
return false
}

type Cache interface {
Read(key string) ([]byte, error)
Set(key string, value []byte) error
Expand Down
106 changes: 106 additions & 0 deletions cli/azd/pkg/auth/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package auth

import (
"context"
"encoding/json"
"math/rand"
"testing"

Expand Down Expand Up @@ -113,3 +114,108 @@ func TestCredentialCache(t *testing.T) {
_, err = c.Read("nonExist")
require.ErrorIs(t, err, errCacheKeyNotFound)
}

type mockContractHolder struct {
contract *mockContract
}

// Marshal implements cache.Marshaler in msal/apps/cache.
func (c *mockContractHolder) Marshal() ([]byte, error) {
return json.Marshal(c.contract)
}

// Unmarshal implements cache.Unmarshaler in msal/apps/cache.
func (c *mockContractHolder) Unmarshal(b []byte) error {
contract := &mockContract{}

err := json.Unmarshal(b, contract)
if err != nil {
return err
}

c.contract = contract
return nil
}

type val struct {
Value string `json:"value"`
}

// mockContract that simulates the MSAL cache contract.
type mockContract struct {
AccessTokens map[string]val `json:"AccessToken,omitempty"`
RefreshTokens map[string]val `json:"RefreshToken,omitempty"`
IDTokens map[string]val `json:"IdToken,omitempty"`
Accounts map[string]val `json:"Account,omitempty"`
AppMetaData map[string]val `json:"AppMetadata,omitempty"`

// mock remainder fields
Remainder map[string]val `json:"Remainder,omitempty"`
}

func TestKeyNormalization(t *testing.T) {
entries := map[string]val{
"Upper": {"Upper"},
"lower": {"lower"},
"Upper-And-Lower": {"Upper-And-Lower"},
"upper-and-lower": {"upper-and-lower"},
}
orig := mockContract{
AccessTokens: entries,
RefreshTokens: entries,
IDTokens: entries,
Accounts: entries,
AppMetaData: entries,
Remainder: map[string]val{
"remainder": {"remainder"},
},
}

normalizedEntries := map[string]val{
"upper": {"Upper"},
"lower": {"lower"},
"upper-and-lower": {"upper-and-lower"},
}
normalized := mockContract{
AccessTokens: normalizedEntries,
RefreshTokens: normalizedEntries,
IDTokens: normalizedEntries,
Accounts: normalizedEntries,
AppMetaData: normalizedEntries,
Remainder: map[string]val{
"remainder": {"remainder"},
},
}

ctx := context.Background()
c := msalCacheAdapter{&memoryCache{
cache: map[string][]byte{},
inner: nil,
}}

// Replace (retrieve) when cache is empty, expect nil
h := mockContractHolder{}
err := c.Replace(ctx, &h, cache.ReplaceHints{})
require.NoError(t, err)
require.Nil(t, h.contract)

// Export (save) with original entry
h.contract = &orig
err = c.Export(ctx, &h, cache.ExportHints{})
require.NoError(t, err)
require.JSONEq(t, mustJson(orig), mustJson(h.contract))

// Replace (retrieve) that will normalize the keys
err = c.Replace(ctx, &h, cache.ReplaceHints{})
require.NoError(t, err)
require.JSONEq(t, mustJson(normalized), mustJson(h.contract))
}

func mustJson(v any) string {
b, err := json.Marshal(v)
if err != nil {
panic(err)
}

return string(b)
}

0 comments on commit b8f37f7

Please sign in to comment.