From b8f37f73077c1aaaccf2ea02ef04da67a804b653 Mon Sep 17 00:00:00 2001 From: Wei Lim Date: Tue, 29 Aug 2023 16:27:00 -0700 Subject: [PATCH] Add msal cache 1.0 -> 1.1 upgrade (#2664) When retrieving cache entries, automatically upgrade cache entries that were saved in msal cache v1.0 format to v1.1 format to avoid stale cache entries causing random auth failures. For users that run both old and newer versions of azd (and consequently the msal libraries), both versions of the cli will work with the same cache file. In some corner cases where the old CLI has been most recently used to login, the refresh token may be ignored by the new cli with requires reauthentication. Otherwise, authentication is expected to work as normal without any unintentional side-effects. Related to https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/453 Fixes #2659 --- cli/azd/.vscode/cspell-azd-dictionary.txt | 1 + cli/azd/pkg/auth/cache.go | 85 +++++++++++++++++ cli/azd/pkg/auth/cache_test.go | 106 ++++++++++++++++++++++ 3 files changed, 192 insertions(+) diff --git a/cli/azd/.vscode/cspell-azd-dictionary.txt b/cli/azd/.vscode/cspell-azd-dictionary.txt index a27619f0f60..39547d43957 100644 --- a/cli/azd/.vscode/cspell-azd-dictionary.txt +++ b/cli/azd/.vscode/cspell-azd-dictionary.txt @@ -123,6 +123,7 @@ pyproject pyvenv reauthentication relogin +remarshal restoreapp retriable rzip diff --git a/cli/azd/pkg/auth/cache.go b/cli/azd/pkg/auth/cache.go index 3ff4ec82ffe..4f08a7ce2ff 100644 --- a/cli/azd/pkg/auth/cache.go +++ b/cli/azd/pkg/auth/cache.go @@ -5,11 +5,24 @@ package auth import ( "context" + "encoding/json" "errors" + "log" + "strings" + "unicode" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" ) +// Known entries from msal cache contract. This is not an exhaustive list. +var contractFields = []string{ + "AccessToken", + "RefreshToken", + "IdToken", + "Account", + "AppMetadata", +} + // The MSAL cache key for the current user. The stored MSAL cached data contains // all accounts with stored credentials, across all tenants. // Currently, the underlying MSAL cache data is represented as [Contract] inside the library. @@ -36,6 +49,46 @@ func (a *msalCacheAdapter) Replace(ctx context.Context, cache cache.Unmarshaler, return err } + // In msal v1.0, keys were stored with mixed casing; in v1.1., it was changed to lower case. + // This handles upgrades where we have a v1.0 cache, and we need to convert it to v1.1, + // by normalizing the appropriate key entries. + c := map[string]json.RawMessage{} + if err = json.Unmarshal(val, &c); err == nil { + for _, contractKey := range contractFields { + if _, found := c[contractKey]; found { + msg := []byte(c[contractKey]) + inner := map[string]json.RawMessage{} + + err := json.Unmarshal(msg, &inner) + if err != nil { + log.Printf("msal-upgrade: failed to unmarshal inner: %v", err) + continue + } + + updated := normalizeKeys(inner) + if !updated { + continue + } + + newMsg, err := json.Marshal(inner) + if err != nil { + log.Printf("msal-upgrade: failed to remarshal inner: %v", err) + continue + } + + c[contractKey] = json.RawMessage(newMsg) + } + } + + if newVal, err := json.Marshal(c); err == nil { + val = newVal + } else { + log.Printf("msal-upgrade: failed to remarshal msal cache: %v", err) + } + } else { + log.Printf("msal-upgrade: failed to unmarshal msal cache: %v", err) + } + // Replace the msal cache contents with the new value retrieved. if err := cache.Unmarshal(val); err != nil { return err @@ -52,6 +105,38 @@ func (a *msalCacheAdapter) Export(ctx context.Context, cache cache.Marshaler, _ return a.cache.Set(cCurrentUserCacheKey, val) } +// Normalize keys by removing upper-case keys and replacing them with lower-case keys. +// In the case where a lower-case key and upper-case key exists, the lower-case key entry +// takes precedence. +func normalizeKeys(m map[string]json.RawMessage) (normalized bool) { + for k, v := range m { + if hasUpper(k) { + // An upper-case key entry exists. Delete it as it is no longer allowed. + delete(m, k) + + // If a lower-case key entry exists, that supersedes it and we are done. + // Otherwise, we can safely upgrade the cache entry by re-adding it with lower case. + lower := strings.ToLower(k) + if _, isLower := m[lower]; !isLower { + m[lower] = v + } + + normalized = true + } + } + + return normalized +} + +func hasUpper(s string) bool { + for _, r := range s { + if unicode.IsUpper(r) && unicode.IsLetter(r) { + return true + } + } + return false +} + type Cache interface { Read(key string) ([]byte, error) Set(key string, value []byte) error diff --git a/cli/azd/pkg/auth/cache_test.go b/cli/azd/pkg/auth/cache_test.go index 9e8f3871a34..cf89d6690de 100644 --- a/cli/azd/pkg/auth/cache_test.go +++ b/cli/azd/pkg/auth/cache_test.go @@ -5,6 +5,7 @@ package auth import ( "context" + "encoding/json" "math/rand" "testing" @@ -113,3 +114,108 @@ func TestCredentialCache(t *testing.T) { _, err = c.Read("nonExist") require.ErrorIs(t, err, errCacheKeyNotFound) } + +type mockContractHolder struct { + contract *mockContract +} + +// Marshal implements cache.Marshaler in msal/apps/cache. +func (c *mockContractHolder) Marshal() ([]byte, error) { + return json.Marshal(c.contract) +} + +// Unmarshal implements cache.Unmarshaler in msal/apps/cache. +func (c *mockContractHolder) Unmarshal(b []byte) error { + contract := &mockContract{} + + err := json.Unmarshal(b, contract) + if err != nil { + return err + } + + c.contract = contract + return nil +} + +type val struct { + Value string `json:"value"` +} + +// mockContract that simulates the MSAL cache contract. +type mockContract struct { + AccessTokens map[string]val `json:"AccessToken,omitempty"` + RefreshTokens map[string]val `json:"RefreshToken,omitempty"` + IDTokens map[string]val `json:"IdToken,omitempty"` + Accounts map[string]val `json:"Account,omitempty"` + AppMetaData map[string]val `json:"AppMetadata,omitempty"` + + // mock remainder fields + Remainder map[string]val `json:"Remainder,omitempty"` +} + +func TestKeyNormalization(t *testing.T) { + entries := map[string]val{ + "Upper": {"Upper"}, + "lower": {"lower"}, + "Upper-And-Lower": {"Upper-And-Lower"}, + "upper-and-lower": {"upper-and-lower"}, + } + orig := mockContract{ + AccessTokens: entries, + RefreshTokens: entries, + IDTokens: entries, + Accounts: entries, + AppMetaData: entries, + Remainder: map[string]val{ + "remainder": {"remainder"}, + }, + } + + normalizedEntries := map[string]val{ + "upper": {"Upper"}, + "lower": {"lower"}, + "upper-and-lower": {"upper-and-lower"}, + } + normalized := mockContract{ + AccessTokens: normalizedEntries, + RefreshTokens: normalizedEntries, + IDTokens: normalizedEntries, + Accounts: normalizedEntries, + AppMetaData: normalizedEntries, + Remainder: map[string]val{ + "remainder": {"remainder"}, + }, + } + + ctx := context.Background() + c := msalCacheAdapter{&memoryCache{ + cache: map[string][]byte{}, + inner: nil, + }} + + // Replace (retrieve) when cache is empty, expect nil + h := mockContractHolder{} + err := c.Replace(ctx, &h, cache.ReplaceHints{}) + require.NoError(t, err) + require.Nil(t, h.contract) + + // Export (save) with original entry + h.contract = &orig + err = c.Export(ctx, &h, cache.ExportHints{}) + require.NoError(t, err) + require.JSONEq(t, mustJson(orig), mustJson(h.contract)) + + // Replace (retrieve) that will normalize the keys + err = c.Replace(ctx, &h, cache.ReplaceHints{}) + require.NoError(t, err) + require.JSONEq(t, mustJson(normalized), mustJson(h.contract)) +} + +func mustJson(v any) string { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(b) +}