diff --git a/rai/client_test.go b/rai/client_test.go index 116e5b1..bd076df 100644 --- a/rai/client_test.go +++ b/rai/client_test.go @@ -18,6 +18,8 @@ import ( "context" "fmt" "net/http" + "os" + "path/filepath" "strings" "testing" "time" @@ -63,6 +65,34 @@ func findModel(models []Model, name string) *Model { return nil } +// deleteTokenCacheDir deletes the token file cache and the "~/.rai" directory if empty +func deleteTokenCacheDir(t *testing.T) { + fname, err := cachePath() + if err != nil { + t.Error("Failed to get token cache file name") + } + err = os.Remove(fname) + if err != nil { + t.Errorf("Failed to delete token cache file %s\n", fname) + } + err = os.Remove(filepath.Dir(fname)) + if err != nil { + fmt.Printf("Could not delete token cache file directory %s - err: %s\n", filepath.Dir(fname), err) + } +} + +// assertTokenCacheFileCreated asserts that the token file has been created +func assertTokenCacheFileCreated(t *testing.T) { + fpath, err := cachePath() + if err != nil { + t.Error("Failed to get token cache file name") + } + + if _, err := os.Stat(fpath); err != nil { + t.Errorf("Failed to stat token cache file %s, err: %s\n", fpath, err) + } +} + func TestNewClient(t *testing.T) { var testClient *Client var cfg Config @@ -95,6 +125,30 @@ func TestNewClient(t *testing.T) { assert.NotNil(t, err) } +// Test token cache file creation +func TestTokenCacheFile(t *testing.T) { + deleteTokenCacheDir(t) + + var testClient *Client + var cfg Config + + err := getConfig(&cfg) + assert.Nil(t, err) + + opts := ClientOptions{Config: cfg} + testClient = NewClient(context.Background(), &opts) + + token, err := testClient.accessTokenHandler.GetAccessToken() + assert.Nil(t, err) + assert.NotNil(t, token) + + tokenCached, _ := testClient.accessTokenHandler.GetAccessToken() + + assert.Equal(t, token, tokenCached) + + assertTokenCacheFileCreated(t) +} + // Test database management APIs. func TestDatabase(t *testing.T) { client := test.client diff --git a/rai/handlers.go b/rai/handlers.go index aec2b76..48ee8fe 100644 --- a/rai/handlers.go +++ b/rai/handlers.go @@ -22,6 +22,7 @@ import ( "os" "os/user" "path" + "path/filepath" "github.com/pkg/errors" ) @@ -56,8 +57,8 @@ func NewClientCredentialsHandler( return &ClientCredentialsHandler{client: c, creds: creds} } -// Returns the name of the token cache file. -func cacheName() (string, error) { +// Returns the path of the token cache file. +func cachePath() (string, error) { usr, err := user.Current() if err != nil { return "", err @@ -65,15 +66,6 @@ func cacheName() (string, error) { return path.Join(usr.HomeDir, ".rai", "tokens.json"), nil } -// Returns the directory of the token cache file. -func cacheDir() (string, error) { - usr, err := user.Current() - if err != nil { - return "", err - } - return path.Join(usr.HomeDir, ".rai"), nil -} - // Read the access token corresponding to the given ClientID from the local // token cache, returns nil if the token does not exist. func readAccessToken(creds *ClientCredentials) (*AccessToken, error) { @@ -88,7 +80,7 @@ func readAccessToken(creds *ClientCredentials) (*AccessToken, error) { } func readTokenCache() (map[string]*AccessToken, error) { - fname, err := cacheName() + fname, err := cachePath() if err != nil { return nil, err } @@ -116,19 +108,17 @@ func writeAccessToken(clientID string, token *AccessToken) { } func writeTokenCache(cache map[string]*AccessToken) { - dname, err := cacheDir() + fname, err := cachePath() if err != nil { - fmt.Println(errors.Wrapf(err, "failed to find token directory")) - } else { - err = os.MkdirAll(dname, 0775) - if err != nil { - fmt.Println(errors.Wrapf(err, "failed to create token directory")) - } + return } - fname, err := cacheName() + + dirName := filepath.Dir(fname) + err = os.MkdirAll(dirName, 0775) if err != nil { - return + fmt.Println(errors.Wrapf(err, "failed to create token directory")) } + f, err := os.OpenFile(fname, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { fmt.Println(errors.Wrapf(err, "failed to open token file"))