Skip to content

Commit

Permalink
refactor: use new logger util to replace logrus (#1023)
Browse files Browse the repository at this point in the history
  • Loading branch information
binbin-li authored Aug 25, 2023
1 parent 49adf7b commit 1032c6f
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 58 deletions.
10 changes: 10 additions & 0 deletions internal/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ const (
Server componentType = "server"
// ReferrerStore is the component type for the referrer store.
ReferrerStore componentType = "referrerStore"
// Cache is the component type for the cache.
Cache componentType = "cache"
// CertProvider is the component type for certificate provider.
CertProvider componentType = "certificateProvider"
// AuthProvider is the component type for auth provider.
AuthProvider componentType = "authProvider"
// PolicyProvider is the component type for policy provider.
PolicyProvider componentType = "policyProvider"
// Verifier is the component type for verifier.
Verifier componentType = "verifier"

traceIDHeaderName = "traceIDHeaderName"
)
Expand Down
16 changes: 10 additions & 6 deletions pkg/cache/dapr/dapr.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ import (
"time"

"github.com/dapr/go-sdk/client"
"github.com/deislabs/ratify/internal/logger"
"github.com/deislabs/ratify/pkg/cache"
"github.com/deislabs/ratify/pkg/featureflag"
"github.com/sirupsen/logrus"
)

const DaprCacheType = "dapr"

var logOpt = logger.Option{
ComponentType: logger.Cache,
}

type factory struct{}

type daprCache struct {
Expand Down Expand Up @@ -67,11 +71,11 @@ func (d *daprCache) Get(ctx context.Context, key string) (string, bool) {
func (d *daprCache) Set(ctx context.Context, key string, value interface{}) bool {
bytes, err := json.Marshal(value)
if err != nil {
logrus.Error("Error marshalling value for redis: ", err)
logger.GetLogger(ctx, logOpt).Error("Error marshalling value for redis: ", err)
return false
}
if err := d.daprClient.SaveState(ctx, d.cacheName, key, bytes, nil); err != nil {
logrus.Error("Error saving value to redis: ", err)
logger.GetLogger(ctx, logOpt).Error("Error saving value to redis: ", err)
return false
}
return true
Expand All @@ -80,21 +84,21 @@ func (d *daprCache) Set(ctx context.Context, key string, value interface{}) bool
func (d *daprCache) SetWithTTL(ctx context.Context, key string, value interface{}, ttl time.Duration) bool {
bytes, err := json.Marshal(value)
if err != nil {
logrus.Error("Error marshalling value for redis: ", err)
logger.GetLogger(ctx, logOpt).Error("Error marshalling value for redis: ", err)
return false
}
ttlString := strconv.Itoa(int(ttl.Seconds()))
md := map[string]string{"ttlInSeconds": ttlString}
if err := d.daprClient.SaveState(ctx, d.cacheName, key, bytes, md); err != nil {
logrus.Error("Error saving value to redis: ", err)
logger.GetLogger(ctx, logOpt).Error("Error saving value to redis: ", err)
return false
}
return true
}

func (d *daprCache) Delete(ctx context.Context, key string) bool {
if err := d.daprClient.DeleteState(ctx, d.cacheName, key, nil); err != nil {
logrus.Error("Error deleting value from redis: ", err)
logger.GetLogger(ctx, logOpt).Error("Error deleting value from redis: ", err)
return false
}
return true
Expand Down
18 changes: 11 additions & 7 deletions pkg/cache/ristretto/ristretto.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ import (
"time"

"github.com/cespare/xxhash/v2"
"github.com/deislabs/ratify/internal/logger"
"github.com/deislabs/ratify/pkg/cache"
"github.com/dgraph-io/ristretto"
"github.com/dgraph-io/ristretto/z"
"github.com/sirupsen/logrus"
)

const RistrettoCacheType = "ristretto"

var logOpt = logger.Option{
ComponentType: logger.Cache,
}

type factory struct {
once sync.Once
}
Expand All @@ -42,7 +46,7 @@ func init() {
cache.Register(RistrettoCacheType, &factory{})
}

func (f *factory) Create(_ context.Context, _ string, cacheSize int) (cache.CacheProvider, error) {
func (f *factory) Create(ctx context.Context, _ string, cacheSize int) (cache.CacheProvider, error) {
var err error
var memoryCache *ristretto.Cache
f.once.Do(func() {
Expand All @@ -54,7 +58,7 @@ func (f *factory) Create(_ context.Context, _ string, cacheSize int) (cache.Cach
})
})
if err != nil {
logrus.Errorf("could not create cache, err: %v", err)
logger.GetLogger(ctx, logOpt).Errorf("could not create cache, err: %v", err)
return &ristrettoCache{}, err
}

Expand All @@ -72,19 +76,19 @@ func (r *ristrettoCache) Get(_ context.Context, key string) (string, bool) {
return returnValue, ok
}

func (r *ristrettoCache) Set(_ context.Context, key string, value interface{}) bool {
func (r *ristrettoCache) Set(ctx context.Context, key string, value interface{}) bool {
bytes, err := json.Marshal(value)
if err != nil {
logrus.Error("Error marshalling value for ristretto: ", err)
logger.GetLogger(ctx, logOpt).Error("Error marshalling value for ristretto: ", err)
return false
}
return r.memoryCache.Set(key, string(bytes), 1)
}

func (r *ristrettoCache) SetWithTTL(_ context.Context, key string, value interface{}, ttl time.Duration) bool {
func (r *ristrettoCache) SetWithTTL(ctx context.Context, key string, value interface{}, ttl time.Duration) bool {
bytes, err := json.Marshal(value)
if err != nil {
logrus.Error("Error marshalling value for ristretto: ", err)
logger.GetLogger(ctx, logOpt).Error("Error marshalling value for ristretto: ", err)
return false
}
return r.memoryCache.SetWithTTL(key, string(bytes), 1, ttl)
Expand Down
30 changes: 17 additions & 13 deletions pkg/certificateprovider/azurekeyvault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import (
"time"

re "github.com/deislabs/ratify/errors"
"github.com/deislabs/ratify/internal/logger"
"github.com/deislabs/ratify/pkg/certificateprovider"
"github.com/deislabs/ratify/pkg/certificateprovider/azurekeyvault/types"
"github.com/deislabs/ratify/pkg/metrics"
"golang.org/x/crypto/pkcs12"

kv "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)

Expand All @@ -45,6 +45,10 @@ const (
PEMContentType string = "application/x-pem-file"
)

var logOpt = logger.Option{
ComponentType: logger.CertProvider,
}

type akvCertProvider struct{}

// init calls to register the provider
Expand Down Expand Up @@ -80,7 +84,7 @@ func (s *akvCertProvider) GetCertificates(ctx context.Context, attrib map[string
return nil, nil, re.ErrorCodeConfigInvalid.NewError(re.CertProvider, providerName, re.EmptyLink, nil, fmt.Sprintf("cloudName %s is not valid", cloudName), re.HideStackTrace)
}

keyVaultCerts, err := getKeyvaultRequestObj(attrib)
keyVaultCerts, err := getKeyvaultRequestObj(ctx, attrib)
if err != nil {
return nil, nil, re.ErrorCodeConfigInvalid.NewError(re.CertProvider, providerName, re.AKVLink, err, "failed to get keyvault request object from provider attributes", re.HideStackTrace)
}
Expand All @@ -89,7 +93,7 @@ func (s *akvCertProvider) GetCertificates(ctx context.Context, attrib map[string
return nil, nil, re.ErrorCodeConfigInvalid.NewError(re.CertProvider, providerName, re.EmptyLink, nil, "no keyvault certificate configured", re.PrintStackTrace)
}

logrus.Debugf("vaultURI %s", keyvaultURI)
logger.GetLogger(ctx, logOpt).Debugf("vaultURI %s", keyvaultURI)

kvClient, err := initializeKvClient(ctx, azureCloudEnv.KeyVaultEndpoint, tenantID, workloadIdentityClientID)
if err != nil {
Expand All @@ -99,7 +103,7 @@ func (s *akvCertProvider) GetCertificates(ctx context.Context, attrib map[string
certs := []*x509.Certificate{}
certsStatus := []map[string]string{}
for _, keyVaultCert := range keyVaultCerts {
logrus.Debugf("fetching secret from key vault, certName %v, keyvault %v", keyVaultCert.CertificateName, keyvaultURI)
logger.GetLogger(ctx, logOpt).Debugf("fetching secret from key vault, certName %v, keyvault %v", keyVaultCert.CertificateName, keyvaultURI)

// fetch the object from Key Vault
// GetSecret is required so we can fetch the entire cert chain. See issue https://github.com/deislabs/ratify/issues/695 for details
Expand All @@ -110,7 +114,7 @@ func (s *akvCertProvider) GetCertificates(ctx context.Context, attrib map[string
return nil, nil, fmt.Errorf("failed to get secret objectName:%s, objectVersion:%s, error: %w", keyVaultCert.CertificateName, keyVaultCert.CertificateVersion, err)
}

certResult, certProperty, err := getCertsFromSecretBundle(secretBundle, keyVaultCert.CertificateName)
certResult, certProperty, err := getCertsFromSecretBundle(ctx, secretBundle, keyVaultCert.CertificateName)

if err != nil {
return nil, nil, fmt.Errorf("failed to get certificates from secret bundle:%w", err)
Expand All @@ -132,21 +136,21 @@ func getCertStatusMap(certsStatus []map[string]string) certificateprovider.Certi
}

// parse the requested keyvault cert object from the input attributes
func getKeyvaultRequestObj(attrib map[string]string) ([]types.KeyVaultCertificate, error) {
func getKeyvaultRequestObj(ctx context.Context, attrib map[string]string) ([]types.KeyVaultCertificate, error) {
keyVaultCerts := []types.KeyVaultCertificate{}

certificatesStrings := types.GetCertificates(attrib)
if certificatesStrings == "" {
return nil, re.ErrorCodeConfigInvalid.NewError(re.CertProvider, providerName, re.EmptyLink, nil, "certificates is not set", re.HideStackTrace)
}

logrus.Debugf("certificates string defined in ratify certStore class, certificates %v", certificatesStrings)
logger.GetLogger(ctx, logOpt).Debugf("certificates string defined in ratify certStore class, certificates %v", certificatesStrings)

objects, err := types.GetCertificatesArray(certificatesStrings)
if err != nil {
return nil, re.ErrorCodeDataDecodingFailure.NewError(re.CertProvider, providerName, re.EmptyLink, err, "failed to yaml unmarshal objects", re.HideStackTrace)
}
logrus.Debugf("unmarshaled objects yaml, objectsArray %v", objects.Array)
logger.GetLogger(ctx, logOpt).Debugf("unmarshaled objects yaml, objectsArray %v", objects.Array)

for i, object := range objects.Array {
var keyVaultCert types.KeyVaultCertificate
Expand All @@ -159,7 +163,7 @@ func getKeyvaultRequestObj(attrib map[string]string) ([]types.KeyVaultCertificat
keyVaultCerts = append(keyVaultCerts, keyVaultCert)
}

logrus.Debugf("unmarshaled %v key vault objects, keyVaultObjects: %v", len(keyVaultCerts), keyVaultCerts)
logger.GetLogger(ctx, logOpt).Debugf("unmarshaled %v key vault objects, keyVaultObjects: %v", len(keyVaultCerts), keyVaultCerts)
return keyVaultCerts, nil
}

Expand Down Expand Up @@ -221,7 +225,7 @@ func initializeKvClient(ctx context.Context, keyVaultEndpoint, tenantID, clientI

// Parse the secret bundle and return an array of certificates
// In a certificate chain scenario, all certificates from root to leaf will be returned
func getCertsFromSecretBundle(secretBundle kv.SecretBundle, certName string) ([]*x509.Certificate, []map[string]string, error) {
func getCertsFromSecretBundle(ctx context.Context, secretBundle kv.SecretBundle, certName string) ([]*x509.Certificate, []map[string]string, error) {
if secretBundle.ContentType == nil || secretBundle.Value == nil || secretBundle.ID == nil {
return nil, nil, re.ErrorCodeCertInvalid.NewError(re.CertProvider, providerName, re.EmptyLink, nil, "found invalid secret bundle for certificate %s, contentType, value, and id must not be nil", re.HideStackTrace)
}
Expand Down Expand Up @@ -264,7 +268,7 @@ func getCertsFromSecretBundle(secretBundle kv.SecretBundle, certName string) ([]
for block != nil {
switch block.Type {
case "PRIVATE KEY":
logrus.Warnf("azure keyvault certificate provider: certificate %s, version %s private key skipped. Please see doc to learn how to create a new certificate in keyvault with non exportable keys. https://learn.microsoft.com/en-us/azure/key-vault/certificates/how-to-export-certificate?tabs=azure-cli#exportable-and-non-exportable-keys", certName, version)
logger.GetLogger(ctx, logOpt).Warnf("azure keyvault certificate provider: certificate %s, version %s private key skipped. Please see doc to learn how to create a new certificate in keyvault with non exportable keys. https://learn.microsoft.com/en-us/azure/key-vault/certificates/how-to-export-certificate?tabs=azure-cli#exportable-and-non-exportable-keys", certName, version)
case "CERTIFICATE":
var pemData []byte
pemData = append(pemData, pem.EncodeToMemory(block)...)
Expand All @@ -278,15 +282,15 @@ func getCertsFromSecretBundle(secretBundle kv.SecretBundle, certName string) ([]
certsStatus = append(certsStatus, certProperty)
}
default:
logrus.Warnf("certificate '%s', version '%s': azure keyvault certificate provider detected unknown block type %s", certName, version, block.Type)
logger.GetLogger(ctx, logOpt).Warnf("certificate '%s', version '%s': azure keyvault certificate provider detected unknown block type %s", certName, version, block.Type)
}

block, rest = pem.Decode(rest)
if block == nil && len(rest) > 0 {
return nil, nil, re.ErrorCodeCertInvalid.NewError(re.CertProvider, providerName, re.EmptyLink, nil, fmt.Sprintf("certificate '%s', version '%s': azure keyvault certificate provider error, block is nil and remaining block to parse > 0", certName, version), re.HideStackTrace)
}
}
logrus.Debugf("azurekeyvault certprovider getCertsFromSecretBundle: %v certificates parsed, Certificate '%s', version '%s'", len(results), certName, version)
logger.GetLogger(ctx, logOpt).Debugf("azurekeyvault certprovider getCertsFromSecretBundle: %v certificates parsed, Certificate '%s', version '%s'", len(results), certName, version)
return results, certsStatus, nil
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/certificateprovider/azurekeyvault/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func TestGetKeyvaultRequestObj(t *testing.T) {
attrib["tenantID"] = "TestIDABC"
attrib["certificates"] = "array:\n- |\n certificateName: wabbit-networks-io \n certificateVersion: \"testversion\"\n"

result, err := getKeyvaultRequestObj(attrib)
result, err := getKeyvaultRequestObj(context.Background(), attrib)

if err != nil {
logrus.Infof("err %v", err)
Expand Down Expand Up @@ -327,7 +327,7 @@ func Test(t *testing.T) {
ContentType: &tc.contentType,
}

certs, status, err := getCertsFromSecretBundle(testdata, "certName")
certs, status, err := getCertsFromSecretBundle(context.Background(), testdata, "certName")
if tc.expectedErr {
assert.NotNil(t, err)
assert.Nil(t, certs)
Expand Down Expand Up @@ -363,7 +363,7 @@ func TestGetKeyvaultRequestObj_error(t *testing.T) {

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
_, err := getKeyvaultRequestObj(tc.attrib)
_, err := getKeyvaultRequestObj(context.Background(), tc.attrib)
if tc.expectedErr {
assert.NotNil(t, err)
} else {
Expand Down
4 changes: 2 additions & 2 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
"time"

re "github.com/deislabs/ratify/errors"
"github.com/deislabs/ratify/internal/logger"
provider "github.com/deislabs/ratify/pkg/common/oras/authprovider"
"github.com/sirupsen/logrus"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand Down Expand Up @@ -129,7 +129,7 @@ func (d *azureManagedIdentityAuthProvider) Provide(ctx context.Context, artifact
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "could not refresh azure managed identity token", re.HideStackTrace)
}
d.identityToken = newToken
logrus.Info("successfully refreshed azure managed identity token")
logger.GetLogger(ctx, logOpt).Info("successfully refreshed azure managed identity token")
}
// add protocol to generate complete URI
serverURL := "https://" + artifactHostName
Expand Down
4 changes: 2 additions & 2 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import (
"time"

re "github.com/deislabs/ratify/errors"
"github.com/deislabs/ratify/internal/logger"
provider "github.com/deislabs/ratify/pkg/common/oras/authprovider"
"github.com/deislabs/ratify/pkg/metrics"
"github.com/deislabs/ratify/pkg/utils/azureauth"

"github.com/Azure/azure-sdk-for-go/services/preview/containerregistry/runtime/2019-08-15-preview/containerregistry"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
"github.com/sirupsen/logrus"
)

type AzureWIProviderFactory struct{} //nolint:revive // ignore linter to have unique type name
Expand Down Expand Up @@ -123,7 +123,7 @@ func (d *azureWIAuthProvider) Provide(ctx context.Context, artifact string) (pro
return provider.AuthConfig{}, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureWorkloadIdentityLink, nil, "could not refresh AAD token", re.HideStackTrace)
}
d.aadToken = newToken
logrus.Info("successfully refreshed AAD token")
logger.GetLogger(ctx, logOpt).Info("successfully refreshed AAD token")
}

// add protocol to generate complete URI
Expand Down
10 changes: 9 additions & 1 deletion pkg/common/oras/authprovider/azure/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@ limitations under the License.

package azure

import "time"
import (
"time"

"github.com/deislabs/ratify/internal/logger"
)

const (
dockerTokenLoginUsernameGUID = "00000000-0000-0000-0000-000000000000"
AADResource = "https://containerregistry.azure.net/.default"
defaultACRExpiryDuration time.Duration = 3 * time.Hour
)

var logOpt = logger.Option{
ComponentType: logger.AuthProvider,
}
8 changes: 6 additions & 2 deletions pkg/policyprovider/regopolicy/regopolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"os"

re "github.com/deislabs/ratify/errors"
"github.com/deislabs/ratify/internal/logger"
"github.com/deislabs/ratify/pkg/common"
"github.com/deislabs/ratify/pkg/executor/types"
"github.com/deislabs/ratify/pkg/ocispecs"
Expand All @@ -32,7 +33,6 @@ import (
opa "github.com/deislabs/ratify/pkg/policyprovider/policyengine/opaengine"
query "github.com/deislabs/ratify/pkg/policyprovider/policyquery/rego"
policyTypes "github.com/deislabs/ratify/pkg/policyprovider/types"
"github.com/sirupsen/logrus"
)

type policyEnforcer struct {
Expand All @@ -51,6 +51,10 @@ type policyEnforcerConf struct {
// Factory is a factory for creating rego policy enforcers.
type Factory struct{}

var logOpt = logger.Option{
ComponentType: logger.PolicyProvider,
}

// init calls Register for our rego policy provider.
func init() {
pf.Register(policyTypes.RegoPolicy, &Factory{})
Expand Down Expand Up @@ -121,7 +125,7 @@ func (e *policyEnforcer) OverallVerifyResult(ctx context.Context, verifierReport
nestedReports["verifierReports"] = verifierReports
result, err := e.OpaEngine.Evaluate(ctx, nestedReports)
if err != nil {
logrus.Errorf("failed to evaluate policy: %v", err)
logger.GetLogger(ctx, logOpt).Errorf("failed to evaluate policy: %v", err)
return false
}
return result
Expand Down
Loading

0 comments on commit 1032c6f

Please sign in to comment.