Skip to content

Commit

Permalink
Add test for token file creation
Browse files Browse the repository at this point in the history
  • Loading branch information
ginal committed Jul 11, 2024
1 parent 3199e5e commit fa3420a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
52 changes: 52 additions & 0 deletions rai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ import (
"context"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -63,6 +66,31 @@ func findModel(models []Model, name string) *Model {
return nil
}

// deleteTokenCacheDir deletes the token file cache and the "~/.rai" directory if empty
func deleteTokenCacheDir(t *testing.T) {
fname, err := cachePath()
if err != nil {
t.Error("Failed to get token cache file name")
}
err = os.Remove(fname)
if err != nil {
t.Errorf("Failed to delete token cache file %s\n", fname)
}
_ = os.Remove(filepath.Dir(fname))
}

// assertTokenCacheFileCreated asserts that the token file has been created
func assertTokenCacheFileCreated(t *testing.T) {
fpath, err := cachePath()
if err != nil {
t.Error("Failed to get token cache file name")
}

if _, err := os.Stat(fpath); err != nil {
t.Error(errors.Wrapf(err, "Failed to stat token cache file %s", fpath))
}
}

func TestNewClient(t *testing.T) {
var testClient *Client
var cfg Config
Expand Down Expand Up @@ -95,6 +123,30 @@ func TestNewClient(t *testing.T) {
assert.NotNil(t, err)
}

// Test token cache file creation
func TestTokenCacheFile(t *testing.T) {
deleteTokenCacheDir(t)

var testClient *Client
var cfg Config

err := getConfig(&cfg)
assert.Nil(t, err)

opts := ClientOptions{Config: cfg}
testClient = NewClient(context.Background(), &opts)

token, err := testClient.accessTokenHandler.GetAccessToken()
assert.Nil(t, err)
assert.NotNil(t, token)

tokenCached, _ := testClient.accessTokenHandler.GetAccessToken()

assert.Equal(t, token, tokenCached)

assertTokenCacheFileCreated(t)
}

// Test database management APIs.
func TestDatabase(t *testing.T) {
client := test.client
Expand Down
32 changes: 11 additions & 21 deletions rai/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"os"
"os/user"
"path"
"path/filepath"

"github.com/pkg/errors"
)
Expand Down Expand Up @@ -56,24 +57,15 @@ func NewClientCredentialsHandler(
return &ClientCredentialsHandler{client: c, creds: creds}
}

// Returns the name of the token cache file.
func cacheName() (string, error) {
// Returns the path of the token cache file.
func cachePath() (string, error) {
usr, err := user.Current()
if err != nil {
return "", err
}
return path.Join(usr.HomeDir, ".rai", "tokens.json"), nil
}

// Returns the directory of the token cache file.
func cacheDir() (string, error) {
usr, err := user.Current()
if err != nil {
return "", err
}
return path.Join(usr.HomeDir, ".rai"), nil
}

// Read the access token corresponding to the given ClientID from the local
// token cache, returns nil if the token does not exist.
func readAccessToken(creds *ClientCredentials) (*AccessToken, error) {
Expand All @@ -88,7 +80,7 @@ func readAccessToken(creds *ClientCredentials) (*AccessToken, error) {
}

func readTokenCache() (map[string]*AccessToken, error) {
fname, err := cacheName()
fname, err := cachePath()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -116,19 +108,17 @@ func writeAccessToken(clientID string, token *AccessToken) {
}

func writeTokenCache(cache map[string]*AccessToken) {
dname, err := cacheDir()
fname, err := cachePath()
if err != nil {
fmt.Println(errors.Wrapf(err, "failed to find token directory"))
} else {
err = os.MkdirAll(dname, 0775)
if err != nil {
fmt.Println(errors.Wrapf(err, "failed to create token directory"))
}
return
}
fname, err := cacheName()

dirName := filepath.Dir(fname)
err = os.MkdirAll(dirName, 0775)
if err != nil {
return
fmt.Println(errors.Wrapf(err, "failed to create token directory"))
}

f, err := os.OpenFile(fname, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
fmt.Println(errors.Wrapf(err, "failed to open token file"))
Expand Down

0 comments on commit fa3420a

Please sign in to comment.