diff --git a/api/handler/v1beta1/appeal_test.go b/api/handler/v1beta1/appeal_test.go index a419a701c..527d46c91 100644 --- a/api/handler/v1beta1/appeal_test.go +++ b/api/handler/v1beta1/appeal_test.go @@ -658,7 +658,7 @@ func (s *GrpcHandlersSuite) TestGetAppeal() { UpdatedAt: timestamppb.New(timeNow), }, } - s.appealService.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), expectedID).Return(expectedAppeal, nil).Once() + s.appealService.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedID).Return(expectedAppeal, nil).Once() req := &guardianv1beta1.GetAppealRequest{ Id: expectedID, @@ -674,7 +674,7 @@ func (s *GrpcHandlersSuite) TestGetAppeal() { s.setup() expectedError := errors.New("random error") - s.appealService.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.appealService.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(nil, expectedError).Once() req := &guardianv1beta1.GetAppealRequest{ @@ -690,7 +690,7 @@ func (s *GrpcHandlersSuite) TestGetAppeal() { s.Run("should return not found error if appeal not found", func() { s.setup() - s.appealService.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.appealService.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(nil, nil).Once() req := &guardianv1beta1.GetAppealRequest{ @@ -711,7 +711,7 @@ func (s *GrpcHandlersSuite) TestGetAppeal() { "foo": make(chan int), }, } - s.appealService.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.appealService.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(invalidAppeal, nil).Once() req := &guardianv1beta1.GetAppealRequest{ @@ -726,6 +726,7 @@ func (s *GrpcHandlersSuite) TestGetAppeal() { } func (s *GrpcHandlersSuite) TestCancelAppeal() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should return appeal details on success", func() { s.setup() timeNow := time.Now() @@ -809,7 +810,7 @@ func (s *GrpcHandlersSuite) TestCancelAppeal() { UpdatedAt: timestamppb.New(timeNow), }, } - s.appealService.EXPECT().Cancel(mock.AnythingOfType("*context.emptyCtx"), expectedID).Return(expectedAppeal, nil).Once() + s.appealService.EXPECT().Cancel(mockCtx, expectedID).Return(expectedAppeal, nil).Once() req := &guardianv1beta1.CancelAppealRequest{ Id: expectedID, @@ -863,7 +864,7 @@ func (s *GrpcHandlersSuite) TestCancelAppeal() { s.Run(tc.name, func() { s.setup() - s.appealService.EXPECT().Cancel(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.appealService.EXPECT().Cancel(mockCtx, mock.Anything). Return(nil, tc.expectedError).Once() req := &guardianv1beta1.CancelAppealRequest{ @@ -886,7 +887,7 @@ func (s *GrpcHandlersSuite) TestCancelAppeal() { "foo": make(chan int), }, } - s.appealService.EXPECT().Cancel(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.appealService.EXPECT().Cancel(mockCtx, mock.Anything). Return(invalidAppeal, nil).Once() req := &guardianv1beta1.CancelAppealRequest{ diff --git a/api/handler/v1beta1/approval_test.go b/api/handler/v1beta1/approval_test.go index 01fb2fdad..2c42050d1 100644 --- a/api/handler/v1beta1/approval_test.go +++ b/api/handler/v1beta1/approval_test.go @@ -635,7 +635,7 @@ func (s *GrpcHandlersSuite) TestAddApprover() { CreatedAt: timeNow, UpdatedAt: timeNow, } - s.appealService.EXPECT().AddApprover(mock.AnythingOfType("*context.emptyCtx"), appealID, approvalID, email).Return(expectedAppeal, nil).Once() + s.appealService.EXPECT().AddApprover(mock.Anything, appealID, approvalID, email).Return(expectedAppeal, nil).Once() expectedResponse := &guardianv1beta1.AddApproverResponse{ Appeal: &guardianv1beta1.Appeal{ Id: expectedAppeal.ID, @@ -760,7 +760,7 @@ func (s *GrpcHandlersSuite) TestDeleteApprover() { CreatedAt: timeNow, UpdatedAt: timeNow, } - s.appealService.EXPECT().DeleteApprover(mock.AnythingOfType("*context.emptyCtx"), appealID, approvalID, email).Return(expectedAppeal, nil).Once() + s.appealService.EXPECT().DeleteApprover(mock.MatchedBy(func(ctx context.Context) bool { return true }), appealID, approvalID, email).Return(expectedAppeal, nil).Once() expectedResponse := &guardianv1beta1.DeleteApproverResponse{ Appeal: &guardianv1beta1.Appeal{ Id: expectedAppeal.ID, diff --git a/api/handler/v1beta1/grant_test.go b/api/handler/v1beta1/grant_test.go index e9eca3339..242842317 100644 --- a/api/handler/v1beta1/grant_test.go +++ b/api/handler/v1beta1/grant_test.go @@ -191,7 +191,7 @@ func (s *GrpcHandlersSuite) TestGetGrant() { }, } s.grantService.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), grantID). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), grantID). Return(dummyGrant, nil).Once() req := &guardianv1beta1.GetGrantRequest{Id: grantID} @@ -225,7 +225,7 @@ func (s *GrpcHandlersSuite) TestGetGrant() { s.setup() s.grantService.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(nil, tc.expectedError).Once() req := &guardianv1beta1.GetGrantRequest{Id: "test-id"} @@ -249,7 +249,7 @@ func (s *GrpcHandlersSuite) TestGetGrant() { }, } s.grantService.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(expectedGrant, nil).Once() req := &guardianv1beta1.GetGrantRequest{Id: "test-id"} @@ -287,7 +287,7 @@ func (s *GrpcHandlersSuite) TestListUserRoles() { s.setup() s.grantService.EXPECT(). - ListUserRoles(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + ListUserRoles(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(nil, nil).Once() req := &guardianv1beta1.ListUserRolesRequest{} @@ -324,7 +324,7 @@ func (s *GrpcHandlersSuite) TestUpdateGrant() { } now := time.Now() s.grantService.EXPECT(). - Update(mock.AnythingOfType("*context.emptyCtx"), expectedGrant). + Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedGrant). Run(func(_a0 context.Context, g *domain.Grant) { g.UpdatedAt = now }). @@ -370,7 +370,7 @@ func (s *GrpcHandlersSuite) TestUpdateGrant() { s.setup() s.grantService.EXPECT(). - Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Grant")). + Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Grant")). Return(tc.expectedError).Once() req := &guardianv1beta1.UpdateGrantRequest{ @@ -443,7 +443,7 @@ func (s *GrpcHandlersSuite) TestImportFromProvider() { }, } s.grantService.EXPECT(). - ImportFromProvider(mock.AnythingOfType("*context.emptyCtx"), grant.ImportFromProviderCriteria{ + ImportFromProvider(mock.MatchedBy(func(ctx context.Context) bool { return true }), grant.ImportFromProviderCriteria{ ProviderID: "test-provider-id", ResourceIDs: []string{"test-resource-id"}, ResourceTypes: []string{"test-resource-type"}, diff --git a/api/handler/v1beta1/grpc.go b/api/handler/v1beta1/grpc.go index 436a32282..fe5063488 100644 --- a/api/handler/v1beta1/grpc.go +++ b/api/handler/v1beta1/grpc.go @@ -4,6 +4,8 @@ import ( "context" "strings" + "github.com/goto/guardian/pkg/log" + "github.com/goto/guardian/core/appeal" "github.com/goto/guardian/core/grant" @@ -124,6 +126,7 @@ type GRPCServer struct { adapter ProtoAdapter authenticatedUserContextKey interface{} + logger log.Logger guardianv1beta1.UnimplementedGuardianServiceServer } @@ -138,6 +141,7 @@ func NewGRPCServer( grantService grantService, adapter ProtoAdapter, authenticatedUserContextKey interface{}, + logger log.Logger, ) *GRPCServer { return &GRPCServer{ resourceService: resourceService, @@ -149,6 +153,7 @@ func NewGRPCServer( grantService: grantService, adapter: adapter, authenticatedUserContextKey: authenticatedUserContextKey, + logger: logger, } } diff --git a/api/handler/v1beta1/grpc_test.go b/api/handler/v1beta1/grpc_test.go index d7476cbac..eebb5b67c 100644 --- a/api/handler/v1beta1/grpc_test.go +++ b/api/handler/v1beta1/grpc_test.go @@ -3,6 +3,8 @@ package v1beta1_test import ( "testing" + "github.com/goto/guardian/pkg/log" + "github.com/goto/guardian/api/handler/v1beta1" "github.com/goto/guardian/api/handler/v1beta1/mocks" "github.com/stretchr/testify/suite" @@ -21,6 +23,7 @@ type GrpcHandlersSuite struct { approvalService *mocks.ApprovalService grantService *mocks.GrantService grpcServer *v1beta1.GRPCServer + logger log.Logger } func TestGrpcHandler(t *testing.T) { @@ -35,6 +38,7 @@ func (s *GrpcHandlersSuite) setup() { s.appealService = new(mocks.AppealService) s.approvalService = new(mocks.ApprovalService) s.grantService = new(mocks.GrantService) + s.logger = log.NewNoop() s.grpcServer = v1beta1.NewGRPCServer( s.resourceService, s.activityService, @@ -45,5 +49,6 @@ func (s *GrpcHandlersSuite) setup() { s.grantService, v1beta1.NewAdapter(), authEmailTestContextKey{}, + s.logger, ) } diff --git a/api/handler/v1beta1/policy_test.go b/api/handler/v1beta1/policy_test.go index fa76bdaa7..7bed430e5 100644 --- a/api/handler/v1beta1/policy_test.go +++ b/api/handler/v1beta1/policy_test.go @@ -29,7 +29,7 @@ func (s *GrpcHandlersSuite) TestListPolicies() { dummyPolicies := []*domain.Policy{ {ID: "test-policy"}, } - s.policyService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(dummyPolicies, nil).Once() + s.policyService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(dummyPolicies, nil).Once() req := &guardianv1beta1.ListPoliciesRequest{} res, err := s.grpcServer.ListPolicies(context.Background(), req) @@ -43,7 +43,7 @@ func (s *GrpcHandlersSuite) TestListPolicies() { s.setup() expectedError := errors.New("random error") - s.policyService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(nil, expectedError).Once() + s.policyService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(nil, expectedError).Once() req := &guardianv1beta1.ListPoliciesRequest{} res, err := s.grpcServer.ListPolicies(context.Background(), req) @@ -64,7 +64,7 @@ func (s *GrpcHandlersSuite) TestListPolicies() { }, }, } - s.policyService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(dummyPolicies, nil).Once() + s.policyService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(dummyPolicies, nil).Once() req := &guardianv1beta1.ListPoliciesRequest{} res, err := s.grpcServer.ListPolicies(context.Background(), req) @@ -176,7 +176,7 @@ func (s *GrpcHandlersSuite) TestGetPolicy() { UpdatedAt: timestamppb.New(timeNow), }, } - s.policyService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), "test-policy", uint(1)). + s.policyService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), "test-policy", uint(1)). Return(dummyPolicy, nil).Once() req := &guardianv1beta1.GetPolicyRequest{ @@ -193,7 +193,7 @@ func (s *GrpcHandlersSuite) TestGetPolicy() { s.Run("should return not found error if policy not found", func() { s.setup() - s.policyService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("uint")). + s.policyService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string"), mock.AnythingOfType("uint")). Return(nil, policy.ErrPolicyNotFound).Once() req := &guardianv1beta1.GetPolicyRequest{} @@ -208,7 +208,7 @@ func (s *GrpcHandlersSuite) TestGetPolicy() { s.setup() expectedError := errors.New("random error") - s.policyService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("uint")). + s.policyService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string"), mock.AnythingOfType("uint")). Return(nil, expectedError).Once() req := &guardianv1beta1.GetPolicyRequest{} @@ -229,7 +229,7 @@ func (s *GrpcHandlersSuite) TestGetPolicy() { Config: make(chan int), // invalid json }, } - s.policyService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("uint")). + s.policyService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string"), mock.AnythingOfType("uint")). Return(dummyPolicy, nil).Once() req := &guardianv1beta1.GetPolicyRequest{} @@ -366,7 +366,7 @@ func (s *GrpcHandlersSuite) TestCreatePolicy() { UpdatedAt: timestamppb.New(timeNow), }, } - s.policyService.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), expectedPolicy). + s.policyService.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedPolicy). Run(func(_a0 context.Context, _a1 *domain.Policy) { _a1.CreatedAt = timeNow _a1.UpdatedAt = timeNow @@ -442,7 +442,7 @@ func (s *GrpcHandlersSuite) TestCreatePolicy() { s.setup() expectedError := errors.New("random error") - s.policyService.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Policy")).Return(expectedError).Once() + s.policyService.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Policy")).Return(expectedError).Once() req := &guardianv1beta1.CreatePolicyRequest{} res, err := s.grpcServer.CreatePolicy(context.Background(), req) @@ -460,7 +460,7 @@ func (s *GrpcHandlersSuite) TestCreatePolicy() { Config: make(chan int), // invalid json }, } - s.policyService.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Policy")).Return(nil). + s.policyService.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Policy")).Return(nil). Run(func(_a0 context.Context, _a1 *domain.Policy) { *_a1 = *invalidPolicy }).Once() @@ -573,7 +573,7 @@ func (s *GrpcHandlersSuite) TestUpdatePolicy() { UpdatedAt: timestamppb.New(timeNow), }, } - s.policyService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), expectedPolicy). + s.policyService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedPolicy). Run(func(_a0 context.Context, _a1 *domain.Policy) { _a1.CreatedAt = timeNow _a1.UpdatedAt = timeNow @@ -636,7 +636,7 @@ func (s *GrpcHandlersSuite) TestUpdatePolicy() { s.setup() expectedError := policy.ErrPolicyNotFound - s.policyService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Policy")).Return(expectedError).Once() + s.policyService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Policy")).Return(expectedError).Once() req := &guardianv1beta1.UpdatePolicyRequest{} res, err := s.grpcServer.UpdatePolicy(context.Background(), req) @@ -650,7 +650,7 @@ func (s *GrpcHandlersSuite) TestUpdatePolicy() { s.setup() expectedError := policy.ErrEmptyIDParam - s.policyService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Policy")).Return(expectedError).Once() + s.policyService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Policy")).Return(expectedError).Once() req := &guardianv1beta1.UpdatePolicyRequest{} res, err := s.grpcServer.UpdatePolicy(context.Background(), req) @@ -664,7 +664,7 @@ func (s *GrpcHandlersSuite) TestUpdatePolicy() { s.setup() expectedError := errors.New("random error") - s.policyService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Policy")).Return(expectedError).Once() + s.policyService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Policy")).Return(expectedError).Once() req := &guardianv1beta1.UpdatePolicyRequest{} res, err := s.grpcServer.UpdatePolicy(context.Background(), req) @@ -682,7 +682,7 @@ func (s *GrpcHandlersSuite) TestUpdatePolicy() { Config: make(chan int), // invalid json }, } - s.policyService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Policy")).Return(nil). + s.policyService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Policy")).Return(nil). Run(func(_a0 context.Context, _a1 *domain.Policy) { *_a1 = *invalidPolicy }).Once() diff --git a/api/handler/v1beta1/provider.go b/api/handler/v1beta1/provider.go index 9a49010ec..982f5bfe8 100644 --- a/api/handler/v1beta1/provider.go +++ b/api/handler/v1beta1/provider.go @@ -82,6 +82,7 @@ func (s *GRPCServer) CreateProvider(ctx context.Context, req *guardianv1beta1.Cr } if err := s.providerService.Create(ctx, p); err != nil { + s.logger.Error(ctx, "failed to create provider", "provider_urn", p.URN, "type", p.Type, "error", err) return nil, status.Errorf(codes.Internal, "failed to create provider: %v", err) } @@ -110,6 +111,7 @@ func (s *GRPCServer) UpdateProvider(ctx context.Context, req *guardianv1beta1.Up } if err := s.providerService.Update(ctx, p); err != nil { + s.logger.Error(ctx, "failed to update provider", "provider_id", id, "provider_urn", p.URN, "type", p.Type, "error", err) return nil, status.Errorf(codes.Internal, "failed to update provider: %v", err) } diff --git a/api/handler/v1beta1/provider_test.go b/api/handler/v1beta1/provider_test.go index a7229eb9c..5aeb7bf0c 100644 --- a/api/handler/v1beta1/provider_test.go +++ b/api/handler/v1beta1/provider_test.go @@ -88,7 +88,7 @@ func (s *GrpcHandlersSuite) TestListProvider() { }, }, } - s.providerService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")). + s.providerService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })). Return(dummyProviders, nil).Once() req := &guardianv1beta1.ListProvidersRequest{} @@ -103,7 +103,7 @@ func (s *GrpcHandlersSuite) TestListProvider() { s.setup() expectedError := errors.New("random error") - s.providerService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")). + s.providerService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })). Return(nil, expectedError).Once() req := &guardianv1beta1.ListProvidersRequest{} @@ -132,7 +132,7 @@ func (s *GrpcHandlersSuite) TestListProvider() { }, }, } - s.providerService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")). + s.providerService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })). Return(expectedProviders, nil).Once() req := &guardianv1beta1.ListProvidersRequest{} @@ -205,7 +205,7 @@ func (s *GrpcHandlersSuite) TestGetProvider() { UpdatedAt: timestamppb.New(timeNow), }, } - s.providerService.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), expectedProvider.ID).Return(expectedProvider, nil).Once() + s.providerService.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedProvider.ID).Return(expectedProvider, nil).Once() req := &guardianv1beta1.GetProviderRequest{ Id: expectedProvider.ID, @@ -221,7 +221,7 @@ func (s *GrpcHandlersSuite) TestGetProvider() { s.setup() expectedError := provider.ErrRecordNotFound - s.providerService.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + s.providerService.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(nil, expectedError).Once() req := &guardianv1beta1.GetProviderRequest{} @@ -236,7 +236,7 @@ func (s *GrpcHandlersSuite) TestGetProvider() { s.setup() expectedError := errors.New("random error") - s.providerService.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + s.providerService.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(nil, expectedError).Once() req := &guardianv1beta1.GetProviderRequest{} @@ -255,7 +255,7 @@ func (s *GrpcHandlersSuite) TestGetProvider() { Credentials: make(chan int), // invalid json }, } - s.providerService.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + s.providerService.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(expectedProvider, nil).Once() req := &guardianv1beta1.GetProviderRequest{} @@ -285,7 +285,7 @@ func (s *GrpcHandlersSuite) TestGetProviderTypes() { }, }, } - s.providerService.EXPECT().GetTypes(mock.AnythingOfType("*context.emptyCtx")). + s.providerService.EXPECT().GetTypes(mock.MatchedBy(func(ctx context.Context) bool { return true })). Return(expectedProviderTypes, nil).Once() req := &guardianv1beta1.GetProviderTypesRequest{} @@ -300,7 +300,7 @@ func (s *GrpcHandlersSuite) TestGetProviderTypes() { s.setup() expectedError := errors.New("random error") - s.providerService.EXPECT().GetTypes(mock.AnythingOfType("*context.emptyCtx")). + s.providerService.EXPECT().GetTypes(mock.MatchedBy(func(ctx context.Context) bool { return true })). Return(nil, expectedError).Once() req := &guardianv1beta1.GetProviderTypesRequest{} @@ -395,7 +395,7 @@ func (s *GrpcHandlersSuite) TestCreateProvider() { UpdatedAt: timestamppb.New(timeNow), }, } - s.providerService.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), expectedProvider).Return(nil). + s.providerService.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedProvider).Return(nil). Run(func(_a0 context.Context, _a1 *domain.Provider) { _a1.ID = expectedID _a1.CreatedAt = timeNow @@ -447,7 +447,7 @@ func (s *GrpcHandlersSuite) TestCreateProvider() { s.setup() expectedError := errors.New("random error") - s.providerService.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Provider")).Return(expectedError).Once() + s.providerService.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Provider")).Return(expectedError).Once() req := &guardianv1beta1.CreateProviderRequest{} res, err := s.grpcServer.CreateProvider(context.Background(), req) @@ -465,7 +465,7 @@ func (s *GrpcHandlersSuite) TestCreateProvider() { Credentials: make(chan int), // invalid json }, } - s.providerService.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Provider")).Return(nil). + s.providerService.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Provider")).Return(nil). Run(func(_a0 context.Context, _a1 *domain.Provider) { *_a1 = *expectedProvider }).Once() @@ -539,7 +539,7 @@ func (s *GrpcHandlersSuite) TestUpdatedProvider() { UpdatedAt: timestamppb.New(timeNow), }, } - s.providerService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), expectedProvider).Return(nil). + s.providerService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedProvider).Return(nil). Run(func(_a0 context.Context, _a1 *domain.Provider) { _a1.CreatedAt = timeNow _a1.UpdatedAt = timeNow @@ -579,7 +579,7 @@ func (s *GrpcHandlersSuite) TestUpdatedProvider() { s.setup() expectedError := errors.New("random error") - s.providerService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Provider")).Return(expectedError).Once() + s.providerService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Provider")).Return(expectedError).Once() req := &guardianv1beta1.UpdateProviderRequest{} res, err := s.grpcServer.UpdateProvider(context.Background(), req) @@ -597,7 +597,7 @@ func (s *GrpcHandlersSuite) TestUpdatedProvider() { Credentials: make(chan int), // invalid json }, } - s.providerService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Provider")).Return(nil). + s.providerService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Provider")).Return(nil). Run(func(_a0 context.Context, _a1 *domain.Provider) { *_a1 = *expectedProvider }).Once() @@ -617,7 +617,7 @@ func (s *GrpcHandlersSuite) TestDeleteProvider() { expectedResponse := &guardianv1beta1.DeleteProviderResponse{} expectedID := "test-id" - s.providerService.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), expectedID).Return(nil).Once() + s.providerService.EXPECT().Delete(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedID).Return(nil).Once() req := &guardianv1beta1.DeleteProviderRequest{ Id: expectedID, @@ -633,7 +633,7 @@ func (s *GrpcHandlersSuite) TestDeleteProvider() { s.setup() expectedError := provider.ErrRecordNotFound - s.providerService.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + s.providerService.EXPECT().Delete(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(expectedError).Once() req := &guardianv1beta1.DeleteProviderRequest{} @@ -648,7 +648,7 @@ func (s *GrpcHandlersSuite) TestDeleteProvider() { s.setup() expectedError := errors.New("random error") - s.providerService.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + s.providerService.EXPECT().Delete(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(expectedError).Once() req := &guardianv1beta1.DeleteProviderRequest{} @@ -686,7 +686,7 @@ func (s *GrpcHandlersSuite) TestListRoles() { }, }, } - s.providerService.EXPECT().GetRoles(mock.AnythingOfType("*context.emptyCtx"), expectedProviderID, expectedResourceType). + s.providerService.EXPECT().GetRoles(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedProviderID, expectedResourceType). Return(expectedRoles, nil).Once() req := &guardianv1beta1.ListRolesRequest{ @@ -704,7 +704,7 @@ func (s *GrpcHandlersSuite) TestListRoles() { s.setup() expectedError := errors.New("random error") - s.providerService.EXPECT().GetRoles(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")). + s.providerService.EXPECT().GetRoles(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string"), mock.AnythingOfType("string")). Return(nil, expectedError).Once() req := &guardianv1beta1.ListRolesRequest{} @@ -725,7 +725,7 @@ func (s *GrpcHandlersSuite) TestListRoles() { }, }, } - s.providerService.EXPECT().GetRoles(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string"), mock.AnythingOfType("string")). + s.providerService.EXPECT().GetRoles(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string"), mock.AnythingOfType("string")). Return(invalidRoles, nil).Once() req := &guardianv1beta1.ListRolesRequest{} diff --git a/api/handler/v1beta1/resource_test.go b/api/handler/v1beta1/resource_test.go index 5a72502d5..c5122d6f9 100644 --- a/api/handler/v1beta1/resource_test.go +++ b/api/handler/v1beta1/resource_test.go @@ -145,7 +145,7 @@ func (s *GrpcHandlersSuite) TestGetResource() { CreatedAt: timeNow, UpdatedAt: timeNow, } - s.resourceService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), expectedID).Return(expectedResource, nil).Once() + s.resourceService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedID).Return(expectedResource, nil).Once() expectedResponse := &guardianv1beta1.GetResourceResponse{ Resource: &guardianv1beta1.Resource{ Id: expectedID, @@ -165,7 +165,7 @@ func (s *GrpcHandlersSuite) TestGetResource() { s.Run("should return not found error if resource not found", func() { s.setup() - s.resourceService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(nil, resource.ErrRecordNotFound) + s.resourceService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil, resource.ErrRecordNotFound) req := &guardianv1beta1.GetResourceRequest{Id: "unknown-id"} res, err := s.grpcServer.GetResource(context.Background(), req) @@ -179,7 +179,7 @@ func (s *GrpcHandlersSuite) TestGetResource() { s.setup() expectedError := errors.New("randome error") - s.resourceService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(nil, expectedError) + s.resourceService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil, expectedError) req := &guardianv1beta1.GetResourceRequest{Id: "unknown-id"} res, err := s.grpcServer.GetResource(context.Background(), req) @@ -202,7 +202,7 @@ func (s *GrpcHandlersSuite) TestGetResource() { "key": make(chan int), // invalid json }, } - s.resourceService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), expectedID).Return(expectedResource, nil).Once() + s.resourceService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedID).Return(expectedResource, nil).Once() req := &guardianv1beta1.GetResourceRequest{Id: expectedID} res, err := s.grpcServer.GetResource(context.Background(), req) @@ -223,7 +223,7 @@ func (s *GrpcHandlersSuite) TestUpdateResource() { ID: expectedID, Name: "new-name", } - s.resourceService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), expectedResource).Return(nil). + s.resourceService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedResource).Return(nil). Run(func(_a0 context.Context, _a1 *domain.Resource) { _a1.CreatedAt = timeNow _a1.UpdatedAt = timeNow @@ -253,7 +253,7 @@ func (s *GrpcHandlersSuite) TestUpdateResource() { s.Run("should return not found error if resource not found", func() { s.setup() - s.resourceService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Resource")).Return(resource.ErrRecordNotFound) + s.resourceService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Resource")).Return(resource.ErrRecordNotFound) req := &guardianv1beta1.UpdateResourceRequest{Id: "unknown-id"} res, err := s.grpcServer.UpdateResource(context.Background(), req) @@ -267,7 +267,7 @@ func (s *GrpcHandlersSuite) TestUpdateResource() { s.setup() expectedError := errors.New("randome error") - s.resourceService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Resource")).Return(expectedError) + s.resourceService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Resource")).Return(expectedError) req := &guardianv1beta1.UpdateResourceRequest{Id: "unknown-id"} res, err := s.grpcServer.UpdateResource(context.Background(), req) @@ -280,7 +280,7 @@ func (s *GrpcHandlersSuite) TestUpdateResource() { s.Run("should return error if there is an error when parsing the resource", func() { s.setup() - s.resourceService.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Resource")).Return(nil). + s.resourceService.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Resource")).Return(nil). Run(func(_a0 context.Context, _a1 *domain.Resource) { _a1.Details = map[string]interface{}{ "key": make(chan int), // invalid json @@ -301,7 +301,7 @@ func (s *GrpcHandlersSuite) TestDeleteResource() { s.setup() expectedID := "test-id" - s.resourceService.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), expectedID).Return(nil) + s.resourceService.EXPECT().Delete(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedID).Return(nil) req := &guardianv1beta1.DeleteResourceRequest{Id: expectedID} res, err := s.grpcServer.DeleteResource(context.Background(), req) @@ -314,7 +314,7 @@ func (s *GrpcHandlersSuite) TestDeleteResource() { s.Run("should return not found error if resource not found", func() { s.setup() - s.resourceService.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")).Return(resource.ErrRecordNotFound) + s.resourceService.EXPECT().Delete(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")).Return(resource.ErrRecordNotFound) req := &guardianv1beta1.DeleteResourceRequest{Id: "unknown-id"} res, err := s.grpcServer.DeleteResource(context.Background(), req) @@ -328,7 +328,7 @@ func (s *GrpcHandlersSuite) TestDeleteResource() { s.setup() expectedError := errors.New("randome error") - s.resourceService.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")).Return(expectedError) + s.resourceService.EXPECT().Delete(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")).Return(expectedError) req := &guardianv1beta1.DeleteResourceRequest{Id: "unknown-id"} res, err := s.grpcServer.DeleteResource(context.Background(), req) diff --git a/cli/job.go b/cli/job.go index 3e34d7cd1..36e7da916 100644 --- a/cli/job.go +++ b/cli/job.go @@ -4,13 +4,14 @@ import ( "context" "fmt" + "github.com/goto/guardian/pkg/log" + "github.com/MakeNowJust/heredoc" "github.com/go-playground/validator/v10" "github.com/goto/guardian/internal/server" "github.com/goto/guardian/jobs" "github.com/goto/guardian/pkg/crypto" "github.com/goto/guardian/plugins/notifiers" - "github.com/goto/salt/log" "github.com/spf13/cobra" ) @@ -63,7 +64,7 @@ func runJobCmd() *cobra.Command { return fmt.Errorf("loading config: %w", err) } - logger := log.NewLogrus(log.LogrusWithLevel(config.LogLevel)) + logger := log.NewCtxLogger(config.LogLevel, []string{config.AuditLogTraceIDHeaderKey}) crypto := crypto.NewAES(config.EncryptionSecretKeyKey) validator := validator.New() notifier, err := notifiers.NewClient(&config.Notifier, logger) diff --git a/core/activity/service.go b/core/activity/service.go index cc30bcbdb..7568552ba 100644 --- a/core/activity/service.go +++ b/core/activity/service.go @@ -6,7 +6,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/goto/guardian/domain" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" ) //go:generate mockery --name=repository --exported --with-expecter diff --git a/core/appeal/mocks/notifier.go b/core/appeal/mocks/notifier.go index 9d8970869..ec2ae4834 100644 --- a/core/appeal/mocks/notifier.go +++ b/core/appeal/mocks/notifier.go @@ -1,8 +1,10 @@ -// Code generated by mockery v2.33.0. DO NOT EDIT. +// Code generated by mockery v2.32.0. DO NOT EDIT. package mocks import ( + context "context" + domain "github.com/goto/guardian/domain" mock "github.com/stretchr/testify/mock" ) @@ -20,13 +22,13 @@ func (_m *Notifier) EXPECT() *Notifier_Expecter { return &Notifier_Expecter{mock: &_m.Mock} } -// Notify provides a mock function with given fields: _a0 -func (_m *Notifier) Notify(_a0 []domain.Notification) []error { - ret := _m.Called(_a0) +// Notify provides a mock function with given fields: _a0, _a1 +func (_m *Notifier) Notify(_a0 context.Context, _a1 []domain.Notification) []error { + ret := _m.Called(_a0, _a1) var r0 []error - if rf, ok := ret.Get(0).(func([]domain.Notification) []error); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, []domain.Notification) []error); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]error) @@ -42,14 +44,15 @@ type Notifier_Notify_Call struct { } // Notify is a helper method to define mock.On call -// - _a0 []domain.Notification -func (_e *Notifier_Expecter) Notify(_a0 interface{}) *Notifier_Notify_Call { - return &Notifier_Notify_Call{Call: _e.mock.On("Notify", _a0)} +// - _a0 context.Context +// - _a1 []domain.Notification +func (_e *Notifier_Expecter) Notify(_a0 interface{}, _a1 interface{}) *Notifier_Notify_Call { + return &Notifier_Notify_Call{Call: _e.mock.On("Notify", _a0, _a1)} } -func (_c *Notifier_Notify_Call) Run(run func(_a0 []domain.Notification)) *Notifier_Notify_Call { +func (_c *Notifier_Notify_Call) Run(run func(_a0 context.Context, _a1 []domain.Notification)) *Notifier_Notify_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]domain.Notification)) + run(args[0].(context.Context), args[1].([]domain.Notification)) }) return _c } @@ -59,7 +62,7 @@ func (_c *Notifier_Notify_Call) Return(_a0 []error) *Notifier_Notify_Call { return _c } -func (_c *Notifier_Notify_Call) RunAndReturn(run func([]domain.Notification) []error) *Notifier_Notify_Call { +func (_c *Notifier_Notify_Call) RunAndReturn(run func(context.Context, []domain.Notification) []error) *Notifier_Notify_Call { _c.Call.Return(run) return _c } diff --git a/core/appeal/service.go b/core/appeal/service.go index 375b027a1..04b5413dd 100644 --- a/core/appeal/service.go +++ b/core/appeal/service.go @@ -11,9 +11,9 @@ import ( "github.com/goto/guardian/core/grant" "github.com/goto/guardian/domain" "github.com/goto/guardian/pkg/evaluator" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/plugins/notifiers" "github.com/goto/guardian/utils" - "github.com/goto/salt/log" "golang.org/x/sync/errgroup" ) @@ -262,7 +262,7 @@ func (s *Service) Create(ctx context.Context, appeals []*domain.Appeal, opts ... return fmt.Errorf("validating cross-individual appeal: %w", err) } - if err := s.addCreatorDetails(appeal, policy); err != nil { + if err := s.addCreatorDetails(ctx, appeal, policy); err != nil { return fmt.Errorf("retrieving creator details: %w", err) } @@ -323,7 +323,7 @@ func (s *Service) Create(ctx context.Context, appeals []*domain.Appeal, opts ... } if err := s.auditLogger.Log(ctx, AuditKeyBulkInsert, appeals); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } for _, a := range appeals { @@ -355,13 +355,13 @@ func (s *Service) Create(ctx context.Context, appeals []*domain.Appeal, opts ... }) } - notifications = append(notifications, s.getApprovalNotifications(a)...) + notifications = append(notifications, s.getApprovalNotifications(ctx, a)...) } if len(notifications) > 0 { - if errs := s.notifier.Notify(notifications); errs != nil { + if errs := s.notifier.Notify(ctx, notifications); errs != nil { for _, err1 := range errs { - s.logger.Error("failed to send notifications", "error", err1.Error()) + s.logger.Error(ctx, "failed to send notifications", "error", err1.Error()) } } } @@ -567,12 +567,12 @@ func (s *Service) UpdateApproval(ctx context.Context, approvalAction domain.Appr }, }) } else { - notifications = append(notifications, s.getApprovalNotifications(appeal)...) + notifications = append(notifications, s.getApprovalNotifications(ctx, appeal)...) } if len(notifications) > 0 { - if errs := s.notifier.Notify(notifications); errs != nil { + if errs := s.notifier.Notify(ctx, notifications); errs != nil { for _, err1 := range errs { - s.logger.Error("failed to send notifications", "error", err1.Error()) + s.logger.Error(ctx, "failed to send notifications", "error", err1.Error()) } } } @@ -585,7 +585,7 @@ func (s *Service) UpdateApproval(ctx context.Context, approvalAction domain.Appr } if auditKey != "" { if err := s.auditLogger.Log(ctx, auditKey, approvalAction); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } } @@ -627,7 +627,7 @@ func (s *Service) Cancel(ctx context.Context, id string) (*domain.Appeal, error) if err := s.auditLogger.Log(ctx, AuditKeyCancel, map[string]interface{}{ "appeal_id": id, }); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } return appeal, nil @@ -666,18 +666,18 @@ func (s *Service) AddApprover(ctx context.Context, appealID, approvalID, email s approval.Approvers = append(approval.Approvers, email) if err := s.auditLogger.Log(ctx, AuditKeyAddApprover, approval); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } duration := domain.PermanentDurationLabel if !appeal.IsDurationEmpty() { duration, err = utils.GetReadableDuration(appeal.Options.Duration) if err != nil { - s.logger.Error("failed to get readable duration", "error", err, "appeal_id", appeal.ID) + s.logger.Error(ctx, "failed to get readable duration", "error", err, "appeal_id", appeal.ID) } } - if errs := s.notifier.Notify([]domain.Notification{ + if errs := s.notifier.Notify(ctx, []domain.Notification{ { User: email, Labels: map[string]string{ @@ -705,7 +705,7 @@ func (s *Service) AddApprover(ctx context.Context, appealID, approvalID, email s }, }); errs != nil { for _, err1 := range errs { - s.logger.Error("failed to send notifications", "error", err1.Error()) + s.logger.Error(ctx, "failed to send notifications", "error", err1.Error()) } } @@ -756,7 +756,7 @@ func (s *Service) DeleteApprover(ctx context.Context, appealID, approvalID, emai approval.Approvers = newApprovers if err := s.auditLogger.Log(ctx, AuditKeyDeleteApprover, approval); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } return appeal, nil @@ -858,7 +858,7 @@ func (s *Service) getPoliciesMap(ctx context.Context) (map[string]map[uint]*doma return policiesMap, nil } -func (s *Service) getApprovalNotifications(appeal *domain.Appeal) []domain.Notification { +func (s *Service) getApprovalNotifications(ctx context.Context, appeal *domain.Appeal) []domain.Notification { notifications := []domain.Notification{} approval := appeal.GetNextPendingApproval() @@ -867,7 +867,7 @@ func (s *Service) getApprovalNotifications(appeal *domain.Appeal) []domain.Notif if !appeal.IsDurationEmpty() { duration, err = utils.GetReadableDuration(appeal.Options.Duration) if err != nil { - s.logger.Error("failed to get readable duration", "error", err, "appeal_id", appeal.ID) + s.logger.Error(ctx, "failed to get readable duration", "error", err, "appeal_id", appeal.ID) } } @@ -1091,7 +1091,7 @@ func getPolicy(a *domain.Appeal, p *domain.Provider, policiesMap map[string]map[ return policiesMap[policyConfig.ID][uint(policyConfig.Version)], nil } -func (s *Service) addCreatorDetails(a *domain.Appeal, p *domain.Policy) error { +func (s *Service) addCreatorDetails(ctx context.Context, a *domain.Appeal, p *domain.Policy) error { if p.IAM == nil { return nil } @@ -1108,7 +1108,7 @@ func (s *Service) addCreatorDetails(a *domain.Appeal, p *domain.Policy) error { userDetails, err := iamClient.GetUser(a.CreatedBy) if err != nil { if p.AppealConfig != nil && p.AppealConfig.AllowCreatorDetailsFailure { - s.logger.Warn("fetching creator's user iam", "error", err) + s.logger.Warn(ctx, "fetching creator's user iam", "error", err) return nil } return fmt.Errorf("fetching creator's user iam: %w", err) @@ -1141,6 +1141,8 @@ func (s *Service) addCreatorDetails(a *domain.Appeal, p *domain.Policy) error { } a.Creator = creator + s.logger.Debug(ctx, "added creator details", "creator", creator) + return nil } diff --git a/core/appeal/service_test.go b/core/appeal/service_test.go index 74b49a8b9..b705f868f 100644 --- a/core/appeal/service_test.go +++ b/core/appeal/service_test.go @@ -14,7 +14,7 @@ import ( "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" "github.com/goto/guardian/mocks" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -85,7 +85,7 @@ func (s *ServiceTestSuite) TestGetByID() { s.Run("should return error if got any from repository", func() { expectedError := errors.New("repository error") - s.mockRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(nil, expectedError).Once() + s.mockRepository.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil, expectedError).Once() id := uuid.New().String() actualResult, actualError := s.service.GetByID(context.Background(), id) @@ -99,7 +99,7 @@ func (s *ServiceTestSuite) TestGetByID() { expectedResult := &domain.Appeal{ ID: expectedID, } - s.mockRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), expectedID).Return(expectedResult, nil).Once() + s.mockRepository.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedID).Return(expectedResult, nil).Once() actualResult, actualError := s.service.GetByID(context.Background(), expectedID) @@ -111,7 +111,7 @@ func (s *ServiceTestSuite) TestGetByID() { func (s *ServiceTestSuite) TestFind() { s.Run("should return error if got any from repository", func() { expectedError := errors.New("unexpected repository error") - s.mockRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(nil, expectedError).Once() + s.mockRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil, expectedError).Once() actualResult, actualError := s.service.Find(context.Background(), &domain.ListAppealsFilter{}) @@ -131,7 +131,7 @@ func (s *ServiceTestSuite) TestFind() { Role: "viewer", }, } - s.mockRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx"), expectedFilters).Return(expectedResult, nil).Once() + s.mockRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedFilters).Return(expectedResult, nil).Once() actualResult, actualError := s.service.Find(context.Background(), expectedFilters) @@ -187,7 +187,7 @@ func (s *ServiceTestSuite) TestCreate() { s.mockPolicyService.On("Find", mock.Anything).Return(expectedPolicies, nil).Once() expectedError := errors.New("appeal repository error") s.mockRepository.EXPECT(). - Find(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(nil, expectedError).Once() actualError := s.service.Create(context.Background(), []*domain.Appeal{}) @@ -556,10 +556,10 @@ func (s *ServiceTestSuite) TestCreate() { s.mockProviderService.On("Find", mock.Anything).Return(tc.providers, nil).Once() s.mockPolicyService.On("Find", mock.Anything).Return(tc.policies, nil).Once() s.mockRepository.EXPECT(). - Find(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(tc.existingAppeals, nil).Once() s.mockGrantService.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.ListGrantsFilter")). + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.ListGrantsFilter")). Return(tc.activeGrants, nil) if tc.callMockValidateAppeal { s.mockProviderService.On("ValidateAppeal", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.expectedAppealValidationError).Once() @@ -588,14 +588,14 @@ func (s *ServiceTestSuite) TestCreate() { s.mockProviderService.On("Find", mock.Anything).Return(expectedProviders, nil).Once() s.mockPolicyService.On("Find", mock.Anything).Return(expectedPolicies, nil).Once() s.mockRepository.EXPECT(). - Find(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedPendingAppeals, nil).Once() s.mockGrantService.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.ListGrantsFilter")). + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.ListGrantsFilter")). Return(expectedActiveGrants, nil).Once() expectedError := errors.New("repository error") s.mockRepository.EXPECT(). - BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedError).Once() actualError := s.service.Create(context.Background(), []*domain.Appeal{}) @@ -909,10 +909,10 @@ func (s *ServiceTestSuite) TestCreate() { AccountIDs: []string{"test@email.com", "addOnBehalfApprovedNotification-user"}, } s.mockRepository.EXPECT(). - Find(mock.AnythingOfType("*context.emptyCtx"), expectedExistingAppealsFilters). + Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedExistingAppealsFilters). Return(expectedExistingAppeals, nil).Once() s.mockGrantService.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), domain.ListGrantsFilter{ + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), domain.ListGrantsFilter{ Statuses: []string{string(domain.GrantStatusActive)}, }). Return(expectedActiveGrants, nil).Once() @@ -932,7 +932,7 @@ func (s *ServiceTestSuite) TestCreate() { s.mockIAMClient.On("GetUser", accountID).Return(expectedCreatorResponse, nil).Once() s.mockIAMClient.On("GetUser", accountID).Return(nil, errors.New("404 not found")).Once() s.mockRepository.EXPECT(). - BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), expectedAppealsInsertionParam). + BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedAppealsInsertionParam). Return(nil). Run(func(_a0 context.Context, appeals []*domain.Appeal) { for i, a := range appeals { @@ -943,7 +943,7 @@ func (s *ServiceTestSuite) TestCreate() { } }). Once() - s.mockNotifier.On("Notify", mock.Anything).Return(nil).Once() + s.mockNotifier.On("Notify", mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil).Once() s.mockAuditLogger.On("Log", mock.Anything, appeal.AuditKeyBulkInsert, mock.Anything).Return(nil).Once() appeals := []*domain.Appeal{ @@ -1046,10 +1046,10 @@ func (s *ServiceTestSuite) TestCreate() { s.mockProviderService.On("Find", mock.Anything).Return([]*domain.Provider{dummyProvider}, nil).Once() s.mockPolicyService.On("Find", mock.Anything).Return([]*domain.Policy{dummyPolicy, overriddingPolicy}, nil).Once() s.mockRepository.EXPECT(). - Find(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return([]*domain.Appeal{}, nil).Once() s.mockGrantService.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.ListGrantsFilter")). + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.ListGrantsFilter")). Return([]domain.Grant{}, nil).Once() s.mockProviderService.On("ValidateAppeal", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) s.mockProviderService.On("GetPermissions", mock.Anything, dummyProvider.Config, dummyResource.Type, input.Role). @@ -1058,9 +1058,9 @@ func (s *ServiceTestSuite) TestCreate() { s.mockIAMManager.On("GetClient", mock.Anything, mock.Anything).Return(s.mockIAMClient, nil) s.mockIAMClient.On("GetUser", input.AccountID).Return(map[string]interface{}{}, nil) s.mockRepository.EXPECT(). - BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(nil).Once() - s.mockNotifier.On("Notify", mock.Anything).Return(nil).Once() + s.mockNotifier.On("Notify", mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil).Once() s.mockAuditLogger.On("Log", mock.Anything, appeal.AuditKeyBulkInsert, mock.Anything).Return(nil).Once() s.mockGrantService.On("List", mock.Anything, mock.Anything).Return([]domain.Grant{}, nil).Once() s.mockGrantService.On("Prepare", mock.Anything, mock.Anything).Return(&domain.Grant{}, nil).Once() @@ -1283,10 +1283,10 @@ func (s *ServiceTestSuite) TestCreateAppeal__WithExistingAppealAndWithAutoApprov AccountIDs: []string{accountID}, } s.mockRepository.EXPECT(). - Find(mock.AnythingOfType("*context.emptyCtx"), expectedExistingAppealsFilters). + Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedExistingAppealsFilters). Return(expectedExistingAppeals, nil).Once() s.mockGrantService.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), domain.ListGrantsFilter{ + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), domain.ListGrantsFilter{ Statuses: []string{string(domain.GrantStatusActive)}, AccountIDs: []string{appeals[0].AccountID}, ResourceIDs: []string{appeals[0].ResourceID}, @@ -1296,7 +1296,7 @@ func (s *ServiceTestSuite) TestCreateAppeal__WithExistingAppealAndWithAutoApprov Return(expectedExistingGrants, nil).Once() // duplicate call with slight change in filters but the code needs it in order to work. appeal create code needs to be refactored. s.mockGrantService.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), domain.ListGrantsFilter{ + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), domain.ListGrantsFilter{ Statuses: []string{string(domain.GrantStatusActive)}, AccountIDs: []string{appeals[0].AccountID}, ResourceIDs: []string{appeals[0].ResourceID}, @@ -1311,7 +1311,7 @@ func (s *ServiceTestSuite) TestCreateAppeal__WithExistingAppealAndWithAutoApprov s.mockIAMClient.On("GetUser", accountID).Return(expectedCreatorUser, nil) s.mockGrantService.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), domain.ListGrantsFilter{ + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), domain.ListGrantsFilter{ AccountIDs: []string{accountID}, ResourceIDs: []string{"1"}, Statuses: []string{string(domain.GrantStatusActive)}, @@ -1326,10 +1326,10 @@ func (s *ServiceTestSuite) TestCreateAppeal__WithExistingAppealAndWithAutoApprov Permissions: []string{"test-permission"}, } s.mockGrantService.EXPECT(). - Prepare(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.Appeal")). + Prepare(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.Appeal")). Return(preparedGrant, nil).Once() s.mockGrantService.EXPECT(). - Revoke(mock.AnythingOfType("*context.emptyCtx"), currentActiveGrant.ID, domain.SystemActorName, appeal.RevokeReasonForExtension, + Revoke(mock.MatchedBy(func(ctx context.Context) bool { return true }), currentActiveGrant.ID, domain.SystemActorName, appeal.RevokeReasonForExtension, mock.AnythingOfType("grant.Option"), mock.AnythingOfType("grant.Option"), ). Return(preparedGrant, nil).Once() @@ -1340,7 +1340,7 @@ func (s *ServiceTestSuite) TestCreateAppeal__WithExistingAppealAndWithAutoApprov s.mockProviderService.On("GrantAccess", mock.Anything, appeals[0]).Return(nil).Once() s.mockRepository.EXPECT(). - BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), expectedAppealsInsertionParam). + BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedAppealsInsertionParam). Return(nil). Run(func(_a0 context.Context, appeals []*domain.Appeal) { for i, a := range appeals { @@ -1350,7 +1350,7 @@ func (s *ServiceTestSuite) TestCreateAppeal__WithExistingAppealAndWithAutoApprov } } }).Once() - s.mockNotifier.On("Notify", mock.Anything).Return(nil).Once() + s.mockNotifier.On("Notify", mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil).Once() s.mockAuditLogger.On("Log", mock.Anything, appeal.AuditKeyBulkInsert, mock.Anything).Return(nil).Once() actualError := s.service.Create(context.Background(), appeals) @@ -1454,32 +1454,32 @@ func (s *ServiceTestSuite) TestCreateAppeal__WithAdditionalAppeals() { // 1.a main appeal creation expectedResourceFilters := domain.ListResourcesFilter{IDs: []string{appealsPayload[0].Resource.ID}} - s.mockResourceService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx"), expectedResourceFilters).Return([]*domain.Resource{resources[0]}, nil).Once() - s.mockProviderService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(providers, nil).Once() - s.mockPolicyService.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(policies, nil).Once() - s.mockGrantService.EXPECT().List(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.ListGrantsFilter")).Return([]domain.Grant{}, nil).Once().Run(func(args mock.Arguments) { + s.mockResourceService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedResourceFilters).Return([]*domain.Resource{resources[0]}, nil).Once() + s.mockProviderService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(providers, nil).Once() + s.mockPolicyService.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(policies, nil).Once() + s.mockGrantService.EXPECT().List(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.ListGrantsFilter")).Return([]domain.Grant{}, nil).Once().Run(func(args mock.Arguments) { filter := args.Get(1).(domain.ListGrantsFilter) s.Equal([]string{appealsPayload[0].AccountID}, filter.AccountIDs) s.Equal([]string{appealsPayload[0].Resource.ID}, filter.ResourceIDs) s.Equal([]string{appealsPayload[0].Role}, filter.Roles) }) - s.mockRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.ListAppealsFilter")).Return([]*domain.Appeal{}, nil).Once() - s.mockProviderService.EXPECT().ValidateAppeal(mock.AnythingOfType("*context.emptyCtx"), appealsPayload[0], providers[0], policies[0]).Return(nil).Once() - s.mockProviderService.EXPECT().GetPermissions(mock.AnythingOfType("*context.emptyCtx"), providers[0].Config, appealsPayload[0].Resource.Type, appealsPayload[0].Role).Return([]interface{}{"test-permission-1"}, nil).Once() - s.mockGrantService.EXPECT().List(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.ListGrantsFilter")).Return([]domain.Grant{}, nil).Once().Run(func(args mock.Arguments) { + s.mockRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.ListAppealsFilter")).Return([]*domain.Appeal{}, nil).Once() + s.mockProviderService.EXPECT().ValidateAppeal(mock.MatchedBy(func(ctx context.Context) bool { return true }), appealsPayload[0], providers[0], policies[0]).Return(nil).Once() + s.mockProviderService.EXPECT().GetPermissions(mock.MatchedBy(func(ctx context.Context) bool { return true }), providers[0].Config, appealsPayload[0].Resource.Type, appealsPayload[0].Role).Return([]interface{}{"test-permission-1"}, nil).Once() + s.mockGrantService.EXPECT().List(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.ListGrantsFilter")).Return([]domain.Grant{}, nil).Once().Run(func(args mock.Arguments) { filter := args.Get(1).(domain.ListGrantsFilter) s.Equal([]string{appealsPayload[0].AccountID}, filter.AccountIDs) s.Equal([]string{appealsPayload[0].Resource.ID}, filter.ResourceIDs) }) expectedGrant := &domain.Grant{ID: "main-grant"} - s.mockGrantService.EXPECT().Prepare(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.Appeal")).Return(expectedGrant, nil).Once().Run(func(args mock.Arguments) { + s.mockGrantService.EXPECT().Prepare(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.Appeal")).Return(expectedGrant, nil).Once().Run(func(args mock.Arguments) { appeal := args.Get(1).(domain.Appeal) s.Equal(appealsPayload[0].AccountID, appeal.AccountID) s.Equal(appealsPayload[0].Role, appeal.Role) s.Equal(appealsPayload[0].ResourceID, appeal.ResourceID) s.Equal(len(policies[0].Steps), len(appeal.Approvals)) }) - s.mockPolicyService.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), policies[0].ID, policies[0].Version).Return(policies[0], nil).Once() + s.mockPolicyService.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), policies[0].ID, policies[0].Version).Return(policies[0], nil).Once() // 2.a additional appeal creation s.mockResourceService.EXPECT().Get(mock.AnythingOfType("*context.cancelCtx"), targetResource).Return(resources[1], nil).Once() @@ -1527,20 +1527,20 @@ func (s *ServiceTestSuite) TestCreateAppeal__WithAdditionalAppeals() { s.Equal(targetResource.ID, appeal.Resource.ID) }) s.mockAuditLogger.EXPECT().Log(mock.AnythingOfType("*context.cancelCtx"), appeal.AuditKeyBulkInsert, mock.Anything).Return(nil).Once() - s.mockNotifier.EXPECT().Notify(mock.Anything).Return(nil).Once() + s.mockNotifier.EXPECT().Notify(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil).Once() // 1.b grant access for the main appeal - s.mockProviderService.EXPECT().GrantAccess(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.Grant")).Return(nil).Once().Run(func(args mock.Arguments) { + s.mockProviderService.EXPECT().GrantAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.Grant")).Return(nil).Once().Run(func(args mock.Arguments) { grant := args.Get(1).(domain.Grant) s.Equal(expectedGrant.ID, grant.ID) }) - s.mockRepository.EXPECT().BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("[]*domain.Appeal")).Return(nil).Once().Run(func(args mock.Arguments) { + s.mockRepository.EXPECT().BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("[]*domain.Appeal")).Return(nil).Once().Run(func(args mock.Arguments) { appeals := args.Get(1).([]*domain.Appeal) appeal := appeals[0] s.Equal(appealsPayload[0].Resource.ID, appeal.Resource.ID) }) - s.mockAuditLogger.EXPECT().Log(mock.AnythingOfType("*context.emptyCtx"), appeal.AuditKeyBulkInsert, mock.Anything).Return(nil).Once() - s.mockNotifier.EXPECT().Notify(mock.Anything).Return(nil).Once() + s.mockAuditLogger.EXPECT().Log(mock.MatchedBy(func(ctx context.Context) bool { return true }), appeal.AuditKeyBulkInsert, mock.Anything).Return(nil).Once() + s.mockNotifier.EXPECT().Notify(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil).Once() err := s.service.Create(context.Background(), appealsPayload) @@ -1600,7 +1600,7 @@ func (s *ServiceTestSuite) TestUpdateApproval() { s.Run("should return error if got any from repository while getting appeal details", func() { expectedError := errors.New("repository error") - s.mockRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(nil, expectedError).Once() + s.mockRepository.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil, expectedError).Once() actualResult, actualError := s.service.UpdateApproval(context.Background(), validApprovalActionParam) @@ -1611,7 +1611,7 @@ func (s *ServiceTestSuite) TestUpdateApproval() { s.Run("should return error if appeal not found", func() { expectedError := appeal.ErrAppealNotFound - s.mockRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(nil, expectedError).Once() + s.mockRepository.EXPECT().GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil, expectedError).Once() actualResult, actualError := s.service.UpdateApproval(context.Background(), validApprovalActionParam) @@ -1787,7 +1787,7 @@ func (s *ServiceTestSuite) TestUpdateApproval() { Approvals: tc.approvals, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), validApprovalActionParam.AppealID). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), validApprovalActionParam.AppealID). Return(expectedAppeal, nil).Once() actualResult, actualError := s.service.UpdateApproval(context.Background(), validApprovalActionParam) @@ -1822,7 +1822,7 @@ func (s *ServiceTestSuite) TestUpdateApproval() { s.Run("should return error if got any from approvalService.AdvanceApproval", func() { s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() expectedError := errors.New("unexpected error") @@ -1874,7 +1874,7 @@ func (s *ServiceTestSuite) TestUpdateApproval() { expectedRevokedGrant.Status = domain.GrantStatusInactive s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(appealDetails, nil).Once() s.mockPolicyService.EXPECT().GetOne(mock.Anything, mock.Anything, mock.Anything).Return(&domain.Policy{}, nil).Once() @@ -1891,8 +1891,8 @@ func (s *ServiceTestSuite) TestUpdateApproval() { Revoke(mock.Anything, expectedRevokedGrant.ID, domain.SystemActorName, appeal.RevokeReasonForExtension, mock.Anything, mock.Anything). Return(expectedNewGrant, nil).Once() - s.mockRepository.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), appealDetails).Return(nil).Once() - s.mockNotifier.EXPECT().Notify(mock.Anything).Return(nil).Once() + s.mockRepository.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), appealDetails).Return(nil).Once() + s.mockNotifier.EXPECT().Notify(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, mock.Anything, mock.Anything). Return(nil).Once() @@ -2270,7 +2270,7 @@ func (s *ServiceTestSuite) TestUpdateApproval() { s.setup() s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), validApprovalActionParam.AppealID). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), validApprovalActionParam.AppealID). Return(tc.expectedAppealDetails, nil).Once() if tc.expectedApprovalAction.Action == "approve" { @@ -2301,8 +2301,8 @@ func (s *ServiceTestSuite) TestUpdateApproval() { tc.expectedResult.Policy = mockPolicy } - s.mockRepository.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), tc.expectedResult).Return(nil).Once() - s.mockNotifier.EXPECT().Notify(mock.Anything).Return(nil).Once() + s.mockRepository.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), tc.expectedResult).Return(nil).Once() + s.mockNotifier.EXPECT().Notify(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, mock.Anything, mock.Anything). Return(nil).Once() @@ -2476,15 +2476,16 @@ func (s *ServiceTestSuite) TestAddApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), appealID). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), appealID). Return(expectedAppeal, nil).Once() s.mockApprovalService.EXPECT(). - AddApprover(mock.AnythingOfType("*context.emptyCtx"), approvalID, newApprover). + AddApprover(mock.MatchedBy(func(ctx context.Context) bool { return true }), approvalID, newApprover). Return(nil).Once() s.mockAuditLogger.EXPECT(). - Log(mock.AnythingOfType("*context.emptyCtx"), appeal.AuditKeyAddApprover, expectedApproval).Return(nil).Once() - s.mockNotifier.EXPECT().Notify(mock.Anything). - Run(func(notifications []domain.Notification) { + Log(mock.MatchedBy(func(ctx context.Context) bool { return true }), appeal.AuditKeyAddApprover, expectedApproval).Return(nil).Once() + s.mockNotifier.EXPECT(). + Notify(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). + Run(func(ctx context.Context, notifications []domain.Notification) { s.Len(notifications, 1) n := notifications[0] s.Equal(tc.newApprover, n.User) @@ -2542,7 +2543,7 @@ func (s *ServiceTestSuite) TestAddApprover() { s.Run("should return error if getting appeal details returns an error", func() { expectedError := errors.New("unexpected error") s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(nil, expectedError).Once() appeal, err := s.service.AddApprover(context.Background(), uuid.New().String(), uuid.New().String(), "user@example.com") @@ -2564,7 +2565,7 @@ func (s *ServiceTestSuite) TestAddApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() appeal, err := s.service.AddApprover(context.Background(), uuid.New().String(), approvalID, "user@example.com") @@ -2585,7 +2586,7 @@ func (s *ServiceTestSuite) TestAddApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() appeal, err := s.service.AddApprover(context.Background(), uuid.New().String(), uuid.New().String(), "user@example.com") @@ -2608,7 +2609,7 @@ func (s *ServiceTestSuite) TestAddApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() appeal, err := s.service.AddApprover(context.Background(), uuid.New().String(), approvalID, "user@example.com") @@ -2632,7 +2633,7 @@ func (s *ServiceTestSuite) TestAddApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() appeal, err := s.service.AddApprover(context.Background(), uuid.New().String(), approvalID, "user@example.com") @@ -2656,7 +2657,7 @@ func (s *ServiceTestSuite) TestAddApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() s.mockApprovalService.EXPECT().AddApprover(mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() @@ -2715,13 +2716,13 @@ func (s *ServiceTestSuite) TestDeleteApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), appealID). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), appealID). Return(expectedAppeal, nil).Once() s.mockApprovalService.EXPECT(). - DeleteApprover(mock.AnythingOfType("*context.emptyCtx"), approvalID, approverEmail). + DeleteApprover(mock.MatchedBy(func(ctx context.Context) bool { return true }), approvalID, approverEmail). Return(nil).Once() s.mockAuditLogger.EXPECT(). - Log(mock.AnythingOfType("*context.emptyCtx"), appeal.AuditKeyDeleteApprover, expectedApproval).Return(nil).Once() + Log(mock.MatchedBy(func(ctx context.Context) bool { return true }), appeal.AuditKeyDeleteApprover, expectedApproval).Return(nil).Once() actualAppeal, actualError := s.service.DeleteApprover(context.Background(), appealID, approvalID, approverEmail) @@ -2773,7 +2774,7 @@ func (s *ServiceTestSuite) TestDeleteApprover() { s.Run("should return error if getting appeal details returns an error", func() { expectedError := errors.New("unexpected error") s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(nil, expectedError).Once() appeal, err := s.service.DeleteApprover(context.Background(), uuid.New().String(), uuid.New().String(), "user@example.com") @@ -2795,7 +2796,7 @@ func (s *ServiceTestSuite) TestDeleteApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() appeal, err := s.service.DeleteApprover(context.Background(), uuid.New().String(), approvalID, "user@example.com") @@ -2818,7 +2819,7 @@ func (s *ServiceTestSuite) TestDeleteApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() appeal, err := s.service.DeleteApprover(context.Background(), uuid.New().String(), approvalID, "user@example.com") @@ -2842,7 +2843,7 @@ func (s *ServiceTestSuite) TestDeleteApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() appeal, err := s.service.DeleteApprover(context.Background(), uuid.New().String(), approvalID, "user@example.com") @@ -2866,7 +2867,7 @@ func (s *ServiceTestSuite) TestDeleteApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() appeal, err := s.service.DeleteApprover(context.Background(), uuid.New().String(), approvalID, "user@example.com") @@ -2891,7 +2892,7 @@ func (s *ServiceTestSuite) TestDeleteApprover() { }, } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedAppeal, nil).Once() s.mockApprovalService.EXPECT().DeleteApprover(mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() @@ -2907,7 +2908,7 @@ func (s *ServiceTestSuite) TestGetAppealsTotalCount() { s.Run("should return error if got error from repository", func() { expectedError := errors.New("repository error") s.mockRepository.EXPECT(). - GetAppealsTotalCount(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetAppealsTotalCount(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(0, expectedError).Once() actualCount, actualError := s.service.GetAppealsTotalCount(context.Background(), &domain.ListAppealsFilter{}) @@ -2919,7 +2920,7 @@ func (s *ServiceTestSuite) TestGetAppealsTotalCount() { s.Run("should return appeals count from repository", func() { expectedCount := int64(1) s.mockRepository.EXPECT(). - GetAppealsTotalCount(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetAppealsTotalCount(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedCount, nil).Once() actualCount, actualError := s.service.GetAppealsTotalCount(context.Background(), &domain.ListAppealsFilter{}) diff --git a/core/approval/service_test.go b/core/approval/service_test.go index 23a46799a..f4ae97c46 100644 --- a/core/approval/service_test.go +++ b/core/approval/service_test.go @@ -39,7 +39,7 @@ func (s *ServiceTestSuite) TestListApprovals() { s.Run("should return error if got error from repository", func() { expectedError := errors.New("repository error") s.mockRepository.EXPECT(). - ListApprovals(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + ListApprovals(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(nil, expectedError).Once() actualApprovals, actualError := s.service.ListApprovals(context.Background(), &domain.ListApprovalsFilter{}) @@ -55,7 +55,7 @@ func (s *ServiceTestSuite) TestListApprovals() { }, } s.mockRepository.EXPECT(). - ListApprovals(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + ListApprovals(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedApprovals, nil).Once() actualApprovals, actualError := s.service.ListApprovals(context.Background(), &domain.ListApprovalsFilter{}) @@ -69,7 +69,7 @@ func (s *ServiceTestSuite) TestGetApprovalsTotalCount() { s.Run("should return error if got error from repository", func() { expectedError := errors.New("repository error") s.mockRepository.EXPECT(). - GetApprovalsTotalCount(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetApprovalsTotalCount(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(0, expectedError).Once() actualCount, actualError := s.service.GetApprovalsTotalCount(context.Background(), &domain.ListApprovalsFilter{}) @@ -81,7 +81,7 @@ func (s *ServiceTestSuite) TestGetApprovalsTotalCount() { s.Run("should return approvals count from repository", func() { expectedCount := int64(1) s.mockRepository.EXPECT(). - GetApprovalsTotalCount(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetApprovalsTotalCount(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedCount, nil).Once() actualCount, actualError := s.service.GetApprovalsTotalCount(context.Background(), &domain.ListApprovalsFilter{}) @@ -95,7 +95,7 @@ func (s *ServiceTestSuite) TestBulkInsert() { s.Run("should return error if got error from repository", func() { expectedError := errors.New("repository error") s.mockRepository.EXPECT(). - BulkInsert(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + BulkInsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedError).Once() actualError := s.service.BulkInsert(context.Background(), []*domain.Approval{}) @@ -110,7 +110,7 @@ func (s *ServiceTestSuite) TestAddApprover() { ApprovalID: uuid.New().String(), Email: "user@example.com", } - s.mockRepository.EXPECT().AddApprover(mock.AnythingOfType("*context.emptyCtx"), expectedApprover).Return(nil) + s.mockRepository.EXPECT().AddApprover(mock.Anything, expectedApprover).Return(nil) err := s.service.AddApprover(context.Background(), expectedApprover.ApprovalID, expectedApprover.Email) @@ -120,7 +120,7 @@ func (s *ServiceTestSuite) TestAddApprover() { s.Run("should return error if repository returns an error", func() { expectedError := errors.New("unexpected error") - s.mockRepository.EXPECT().AddApprover(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(expectedError) + s.mockRepository.EXPECT().AddApprover(mock.Anything, mock.Anything).Return(expectedError) err := s.service.AddApprover(context.Background(), "", "") @@ -134,7 +134,7 @@ func (s *ServiceTestSuite) TestDeleteApprover() { approvalID := uuid.New().String() approverEmail := "user@example.com" - s.mockRepository.EXPECT().DeleteApprover(mock.AnythingOfType("*context.emptyCtx"), approvalID, approverEmail).Return(nil) + s.mockRepository.EXPECT().DeleteApprover(mock.MatchedBy(func(ctx context.Context) bool { return true }), approvalID, approverEmail).Return(nil) err := s.service.DeleteApprover(context.Background(), approvalID, approverEmail) @@ -144,7 +144,7 @@ func (s *ServiceTestSuite) TestDeleteApprover() { s.Run("should return error if repository returns an error", func() { expectedError := errors.New("unexpected error") - s.mockRepository.EXPECT().DeleteApprover(mock.AnythingOfType("*context.emptyCtx"), mock.Anything, mock.Anything).Return(expectedError) + s.mockRepository.EXPECT().DeleteApprover(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything, mock.Anything).Return(expectedError) err := s.service.DeleteApprover(context.Background(), "", "") diff --git a/core/grant/mocks/notifier.go b/core/grant/mocks/notifier.go index 39d658ad2..bfd899e31 100644 --- a/core/grant/mocks/notifier.go +++ b/core/grant/mocks/notifier.go @@ -1,8 +1,10 @@ -// Code generated by mockery v2.33.0. DO NOT EDIT. +// Code generated by mockery v2.32.0. DO NOT EDIT. package mocks import ( + context "context" + domain "github.com/goto/guardian/domain" mock "github.com/stretchr/testify/mock" @@ -21,13 +23,13 @@ func (_m *Notifier) EXPECT() *Notifier_Expecter { return &Notifier_Expecter{mock: &_m.Mock} } -// Notify provides a mock function with given fields: _a0 -func (_m *Notifier) Notify(_a0 []domain.Notification) []error { - ret := _m.Called(_a0) +// Notify provides a mock function with given fields: _a0, _a1 +func (_m *Notifier) Notify(_a0 context.Context, _a1 []domain.Notification) []error { + ret := _m.Called(_a0, _a1) var r0 []error - if rf, ok := ret.Get(0).(func([]domain.Notification) []error); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, []domain.Notification) []error); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]error) @@ -43,14 +45,15 @@ type Notifier_Notify_Call struct { } // Notify is a helper method to define mock.On call -// - _a0 []domain.Notification -func (_e *Notifier_Expecter) Notify(_a0 interface{}) *Notifier_Notify_Call { - return &Notifier_Notify_Call{Call: _e.mock.On("Notify", _a0)} +// - _a0 context.Context +// - _a1 []domain.Notification +func (_e *Notifier_Expecter) Notify(_a0 interface{}, _a1 interface{}) *Notifier_Notify_Call { + return &Notifier_Notify_Call{Call: _e.mock.On("Notify", _a0, _a1)} } -func (_c *Notifier_Notify_Call) Run(run func(_a0 []domain.Notification)) *Notifier_Notify_Call { +func (_c *Notifier_Notify_Call) Run(run func(_a0 context.Context, _a1 []domain.Notification)) *Notifier_Notify_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]domain.Notification)) + run(args[0].(context.Context), args[1].([]domain.Notification)) }) return _c } @@ -60,7 +63,7 @@ func (_c *Notifier_Notify_Call) Return(_a0 []error) *Notifier_Notify_Call { return _c } -func (_c *Notifier_Notify_Call) RunAndReturn(run func([]domain.Notification) []error) *Notifier_Notify_Call { +func (_c *Notifier_Notify_Call) RunAndReturn(run func(context.Context, []domain.Notification) []error) *Notifier_Notify_Call { _c.Call.Return(run) return _c } diff --git a/core/grant/service.go b/core/grant/service.go index 289eb8abd..d57c1933e 100644 --- a/core/grant/service.go +++ b/core/grant/service.go @@ -9,10 +9,10 @@ import ( "github.com/go-playground/validator/v10" "github.com/goto/guardian/domain" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/pkg/slices" "github.com/goto/guardian/plugins/notifiers" "github.com/goto/guardian/utils" - "github.com/goto/salt/log" ) const ( @@ -129,13 +129,14 @@ func (s *Service) Update(ctx context.Context, payload *domain.Grant) error { grantDetails.Owner = updatedGrant.Owner grantDetails.UpdatedAt = updatedGrant.UpdatedAt *payload = *grantDetails + s.logger.Info(ctx, "grant updated", "grant_id", grantDetails.ID, "updatedGrant", updatedGrant) if err := s.auditLogger.Log(ctx, AuditKeyUpdate, map[string]interface{}{ "grant_id": grantDetails.ID, "payload": updatedGrant, "updated_grant": payload, }); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } if previousOwner != updatedGrant.Owner { @@ -165,9 +166,9 @@ func (s *Service) Update(ctx context.Context, payload *domain.Grant) error { Message: message, }) } - if errs := s.notifier.Notify(notifications); errs != nil { + if errs := s.notifier.Notify(ctx, notifications); errs != nil { for _, err1 := range errs { - s.logger.Error("failed to send notifications", "error", err1.Error()) + s.logger.Error(ctx, "failed to send notifications", "error", err1.Error()) } } } @@ -217,7 +218,7 @@ func (s *Service) Revoke(ctx context.Context, id, actor, reason string, opts ... } if !options.skipNotification { - if errs := s.notifier.Notify([]domain.Notification{{ + if errs := s.notifier.Notify(ctx, []domain.Notification{{ User: grant.CreatedBy, Labels: map[string]string{ "appeal_id": grant.AppealID, @@ -235,16 +236,18 @@ func (s *Service) Revoke(ctx context.Context, id, actor, reason string, opts ... }, }}); errs != nil { for _, err1 := range errs { - s.logger.Error("failed to send notifications", "error", err1.Error()) + s.logger.Error(ctx, "failed to send notifications", "error", err1.Error()) } } } + s.logger.Info(ctx, "grant revoked", "grant_id", id) + if err := s.auditLogger.Log(ctx, AuditKeyRevoke, map[string]interface{}{ "grant_id": id, "reason": reason, }); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } return grant, nil @@ -317,9 +320,9 @@ func (s *Service) BulkRevoke(ctx context.Context, filter domain.RevokeGrantsFilt } result = append(result, grant) if len(result) == totalRequests { - s.logger.Info("successful grant revocation", "count", len(successRevoke), "ids", successRevoke) + s.logger.Info(ctx, "successful grant revocation", "count", len(successRevoke), "ids", successRevoke) if len(failedRevoke) > 0 { - s.logger.Info("failed grant revocation", "count", len(failedRevoke), "ids", failedRevoke) + s.logger.Info(ctx, "failed grant revocation", "count", len(failedRevoke), "ids", failedRevoke) } return result, nil } @@ -334,23 +337,23 @@ func (s *Service) expiredInActiveUserAccess(ctx context.Context, timeLimiter cha revokedGrant := &domain.Grant{} *revokedGrant = *grant if err := revokedGrant.Revoke(actor, reason); err != nil { - s.logger.Error("failed to revoke grant", "id", grant.ID, "error", err) + s.logger.Error(ctx, "failed to revoke grant", "id", grant.ID, "error", err) return } if err := s.providerService.RevokeAccess(ctx, *grant); err != nil { done <- grant - s.logger.Error("failed to revoke grant in provider", "id", grant.ID, "error", err) + s.logger.Error(ctx, "failed to revoke grant in provider", "id", grant.ID, "error", err) return } revokedGrant.Status = domain.GrantStatusInactive if err := s.repo.Update(ctx, revokedGrant); err != nil { done <- grant - s.logger.Error("failed to update access-revoke status", "id", grant.ID, "error", err) + s.logger.Error(ctx, "failed to update access-revoke status", "id", grant.ID, "error", err) return } else { done <- revokedGrant - s.logger.Info("grant revoked", "id", grant.ID) + s.logger.Info(ctx, "grant revoked", "id", grant.ID) } } } @@ -497,7 +500,7 @@ func (s *Service) DormancyCheck(ctx context.Context, criteria domain.DormancyChe return fmt.Errorf("getting provider details: %w", err) } - s.logger.Info("getting active grants", "provider_urn", provider.URN) + s.logger.Info(ctx, "getting active grants", "provider_urn", provider.URN) grants, err := s.List(ctx, domain.ListGrantsFilter{ Statuses: []string{string(domain.GrantStatusActive)}, // TODO: evaluate later to use status_in_provider ProviderTypes: []string{provider.Type}, @@ -508,11 +511,11 @@ func (s *Service) DormancyCheck(ctx context.Context, criteria domain.DormancyChe return fmt.Errorf("listing active grants: %w", err) } if len(grants) == 0 { - s.logger.Info("no active grants found", "provider_urn", provider.URN) + s.logger.Info(ctx, "no active grants found", "provider_urn", provider.URN) return nil } grantIDs := getGrantIDs(grants) - s.logger.Info(fmt.Sprintf("found %d active grants", len(grants)), "grant_ids", grantIDs, "provider_urn", provider.URN) + s.logger.Info(ctx, fmt.Sprintf("found %d active grants", len(grants)), "grant_ids", grantIDs, "provider_urn", provider.URN) var accountIDs []string for _, g := range grants { @@ -520,7 +523,7 @@ func (s *Service) DormancyCheck(ctx context.Context, criteria domain.DormancyChe } accountIDs = slices.UniqueStringSlice(accountIDs) - s.logger.Info("getting activities", "provider_urn", provider.URN) + s.logger.Info(ctx, "getting activities", "provider_urn", provider.URN) activities, err := s.providerService.ListActivities(ctx, *provider, domain.ListActivitiesFilter{ AccountIDs: accountIDs, TimestampGte: &startDate, @@ -528,7 +531,7 @@ func (s *Service) DormancyCheck(ctx context.Context, criteria domain.DormancyChe if err != nil { return fmt.Errorf("listing activities for provider %q: %w", provider.URN, err) } - s.logger.Info(fmt.Sprintf("found %d activities", len(activities)), "provider_urn", provider.URN) + s.logger.Info(ctx, fmt.Sprintf("found %d activities", len(activities)), "provider_urn", provider.URN) grantsPointer := make([]*domain.Grant, len(grants)) for i, g := range grants { @@ -539,7 +542,7 @@ func (s *Service) DormancyCheck(ctx context.Context, criteria domain.DormancyChe return fmt.Errorf("correlating grant activities: %w", err) } - s.logger.Info("checking grants dormancy...", "provider_urn", provider.URN) + s.logger.Info(ctx, "checking grants dormancy...", "provider_urn", provider.URN) var dormantGrants []*domain.Grant var dormantGrantsIDs []string var dormantGrantsByOwner = map[string][]*domain.Grant{} @@ -556,10 +559,10 @@ func (s *Service) DormancyCheck(ctx context.Context, criteria domain.DormancyChe dormantGrantsByOwner[g.Owner] = append(dormantGrantsByOwner[g.Owner], g) } } - s.logger.Info(fmt.Sprintf("found %d dormant grants", len(dormantGrants)), "grant_ids", dormantGrantsIDs, "provider_urn", provider.URN) + s.logger.Info(ctx, fmt.Sprintf("found %d dormant grants", len(dormantGrants)), "grant_ids", dormantGrantsIDs, "provider_urn", provider.URN) if criteria.DryRun { - s.logger.Info("dry run mode, skipping updating grants expiration date", "provider_urn", provider.URN) + s.logger.Info(ctx, "dry run mode, skipping updating grants expiration date", "provider_urn", provider.URN) return nil } @@ -576,7 +579,7 @@ prepare_notifications: for _, g := range grants { grantMap, err := utils.StructToMap(g) if err != nil { - s.logger.Error("failed to convert grant to map", "error", err) + s.logger.Error(ctx, "failed to convert grant to map", "error", err) continue prepare_notifications } grantsMap = append(grantsMap, grantMap) @@ -600,9 +603,9 @@ prepare_notifications: }) } - if errs := s.notifier.Notify(notifications); errs != nil { + if errs := s.notifier.Notify(ctx, notifications); errs != nil { for _, err1 := range errs { - s.logger.Error("failed to send notifications", "error", err1.Error(), "provider_urn", provider.URN) + s.logger.Error(ctx, "failed to send notifications", "error", err1.Error(), "provider_urn", provider.URN) } } diff --git a/core/grant/service_test.go b/core/grant/service_test.go index aacac2c82..9f52aa550 100644 --- a/core/grant/service_test.go +++ b/core/grant/service_test.go @@ -14,7 +14,7 @@ import ( "github.com/goto/guardian/core/grant" "github.com/goto/guardian/core/grant/mocks" "github.com/goto/guardian/domain" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -57,7 +57,7 @@ func (s *ServiceTestSuite) TestList() { filter := domain.ListGrantsFilter{} expectedGrants := []domain.Grant{} s.mockRepository.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), filter). + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), filter). Return(expectedGrants, nil).Once() grants, err := s.service.List(context.Background(), filter) @@ -72,7 +72,7 @@ func (s *ServiceTestSuite) TestList() { expectedError := errors.New("unexpected error") s.mockRepository.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.ListGrantsFilter")). + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.ListGrantsFilter")). Return(nil, expectedError).Once() grants, err := s.service.List(context.Background(), domain.ListGrantsFilter{}) @@ -90,7 +90,7 @@ func (s *ServiceTestSuite) TestGetByID() { id := uuid.New().String() expectedGrant := &domain.Grant{} s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), id). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), id). Return(expectedGrant, nil). Once() @@ -118,7 +118,7 @@ func (s *ServiceTestSuite) TestGetByID() { expectedError := errors.New("unexpected error") s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string")). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string")). Return(nil, expectedError).Once() grant, err := s.service.GetByID(context.Background(), "test-id") @@ -163,15 +163,15 @@ func (s *ServiceTestSuite) TestUpdate() { expectedUpdatedGrant.UpdatedAt = time.Now() s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), id). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), id). Return(existingGrant, nil).Once() s.mockRepository.EXPECT(). - Update(mock.AnythingOfType("*context.emptyCtx"), expectedUpdateParam). + Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedUpdateParam). Return(nil).Run(func(_a0 context.Context, g *domain.Grant) { g.UpdatedAt = time.Now() }).Once() s.mockAuditLogger.EXPECT(). - Log(mock.AnythingOfType("*context.emptyCtx"), grant.AuditKeyUpdate, mock.AnythingOfType("map[string]interface {}")).Return(nil).Once() + Log(mock.MatchedBy(func(ctx context.Context) bool { return true }), grant.AuditKeyUpdate, mock.AnythingOfType("map[string]interface {}")).Return(nil).Once() notificationMessage := domain.NotificationMessage{ Type: domain.NotificationTypeGrantOwnerChanged, Variables: map[string]interface{}{ @@ -196,7 +196,7 @@ func (s *ServiceTestSuite) TestUpdate() { Message: notificationMessage, }} s.mockNotifier.EXPECT(). - Notify(expectedNotifications).Return(nil).Once() + Notify(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedNotifications).Return(nil).Once() actualError := s.service.Update(context.Background(), updatePayload) s.NoError(actualError) @@ -221,7 +221,7 @@ func (s *ServiceTestSuite) TestUpdate() { } s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), id). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), id). Return(existingGrant, nil).Once() actualError := s.service.Update(context.Background(), updatePayload) @@ -246,10 +246,10 @@ func (s *ServiceTestSuite) TestRevoke() { s.setup() s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), id). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), id). Return(expectedGrantDetails, nil).Once() s.mockRepository.EXPECT(). - Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Grant")). + Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Grant")). Run(func(_a0 context.Context, _a1 *domain.Grant) { s.Equal(id, _a1.ID) s.Equal(actor, _a1.RevokedBy) @@ -258,7 +258,7 @@ func (s *ServiceTestSuite) TestRevoke() { }). Return(nil).Once() s.mockProviderService.EXPECT(). - RevokeAccess(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.Grant")). + RevokeAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.Grant")). Run(func(_a0 context.Context, _a1 domain.Grant) { s.Equal(id, _a1.ID) s.Equal(expectedGrantDetails.AccountID, _a1.AccountID) @@ -268,7 +268,7 @@ func (s *ServiceTestSuite) TestRevoke() { Return(nil).Once() s.mockNotifier.EXPECT(). - Notify([]domain.Notification{{ + Notify(mock.MatchedBy(func(ctx context.Context) bool { return true }), []domain.Notification{{ User: expectedGrantDetails.CreatedBy, Labels: map[string]string{ "appeal_id": expectedGrantDetails.AppealID, @@ -287,7 +287,7 @@ func (s *ServiceTestSuite) TestRevoke() { }}). Return(nil).Once() s.mockAuditLogger.EXPECT(). - Log(mock.AnythingOfType("*context.emptyCtx"), grant.AuditKeyRevoke, map[string]interface{}{ + Log(mock.MatchedBy(func(ctx context.Context) bool { return true }), grant.AuditKeyRevoke, map[string]interface{}{ "grant_id": id, "reason": reason, }). @@ -308,10 +308,10 @@ func (s *ServiceTestSuite) TestRevoke() { s.setup() s.mockRepository.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), id). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), id). Return(expectedGrantDetails, nil).Once() s.mockRepository.EXPECT(). - Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Grant")). + Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Grant")). Run(func(_a0 context.Context, _a1 *domain.Grant) { s.Equal(id, _a1.ID) s.Equal(actor, _a1.RevokedBy) @@ -321,7 +321,7 @@ func (s *ServiceTestSuite) TestRevoke() { Return(nil).Once() s.mockAuditLogger.EXPECT(). - Log(mock.AnythingOfType("*context.emptyCtx"), grant.AuditKeyRevoke, map[string]interface{}{ + Log(mock.MatchedBy(func(ctx context.Context) bool { return true }), grant.AuditKeyRevoke, map[string]interface{}{ "grant_id": id, "reason": reason, }). @@ -377,12 +377,12 @@ func (s *ServiceTestSuite) TestBulkRevoke() { } s.mockRepository.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), expectedListGrantsFilter). + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedListGrantsFilter). Return(expectedGrants, nil).Once() for _, g := range expectedGrants { grant := g s.mockProviderService.EXPECT(). - RevokeAccess(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.Grant")). + RevokeAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.Grant")). Run(func(_a0 context.Context, _a1 domain.Grant) { s.Equal(grant.ID, _a1.ID) s.Equal(grant.AccountID, _a1.AccountID) @@ -392,7 +392,7 @@ func (s *ServiceTestSuite) TestBulkRevoke() { Return(nil).Once() s.mockRepository.EXPECT(). - Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Grant")). + Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("*domain.Grant")). Run(func(_a0 context.Context, _a1 *domain.Grant) { s.Equal(grant.ID, _a1.ID) s.Equal(actor, _a1.RevokedBy) @@ -848,17 +848,17 @@ func (s *ServiceTestSuite) TestImportFromProvider() { s.setup() s.mockProviderService.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), "test-provider-id"). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), "test-provider-id"). Return(&tc.provider, nil).Once() expectedListResourcesFilter := domain.ListResourcesFilter{ ProviderType: "test-provider-type", ProviderURN: "test-provider-urn", } s.mockResourceService.EXPECT(). - Find(mock.AnythingOfType("*context.emptyCtx"), expectedListResourcesFilter). + Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedListResourcesFilter). Return(dummyResources, nil).Once() s.mockProviderService.EXPECT(). - ListAccess(mock.AnythingOfType("*context.emptyCtx"), tc.provider, dummyResources). + ListAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), tc.provider, dummyResources). Return(tc.importedGrants, nil).Once() expectedListGrantsFilter := domain.ListGrantsFilter{ ProviderTypes: []string{"test-provider-type"}, @@ -866,15 +866,15 @@ func (s *ServiceTestSuite) TestImportFromProvider() { Statuses: []string{string(domain.GrantStatusActive)}, } s.mockRepository.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), expectedListGrantsFilter). + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedListGrantsFilter). Return(tc.existingGrants, nil).Once() s.mockRepository.EXPECT(). - BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("[]*domain.Grant")). + BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("[]*domain.Grant")). Return(nil).Once() s.mockRepository.EXPECT(). - BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), tc.expectedDeactivatedGrants). + BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), tc.expectedDeactivatedGrants). Return(nil).Once() newGrants, err := s.service.ImportFromProvider(context.Background(), grant.ImportFromProviderCriteria{ @@ -943,7 +943,7 @@ func (s *ServiceTestSuite) TestDormancyCheck() { } s.mockProviderService.EXPECT(). - GetByID(mock.AnythingOfType("*context.emptyCtx"), dummyProvider.ID). + GetByID(mock.MatchedBy(func(ctx context.Context) bool { return true }), dummyProvider.ID). Return(dummyProvider, nil).Once() expectedListGrantsFilter := domain.ListGrantsFilter{ Statuses: []string{string(domain.GrantStatusActive)}, @@ -952,7 +952,7 @@ func (s *ServiceTestSuite) TestDormancyCheck() { CreatedAtLte: timeNow.Add(-dormancyCheckCriteria.Period), } s.mockRepository.EXPECT(). - List(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.ListGrantsFilter")). + List(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.ListGrantsFilter")). Run(func(_a0 context.Context, f domain.ListGrantsFilter) { s.Empty(cmp.Diff(expectedListGrantsFilter, f, cmpopts.EquateApproxTime(time.Second))) }). @@ -963,7 +963,7 @@ func (s *ServiceTestSuite) TestDormancyCheck() { TimestampGte: ×tampGte, } s.mockProviderService.EXPECT(). - ListActivities(mock.AnythingOfType("*context.emptyCtx"), *dummyProvider, mock.AnythingOfType("domain.ListActivitiesFilter")). + ListActivities(mock.MatchedBy(func(ctx context.Context) bool { return true }), *dummyProvider, mock.AnythingOfType("domain.ListActivitiesFilter")). Run(func(_a0 context.Context, _a1 domain.Provider, f domain.ListActivitiesFilter) { s.Empty(cmp.Diff(expectedListActivitiesFilter, f, cmpopts.EquateApproxTime(time.Second), @@ -972,10 +972,10 @@ func (s *ServiceTestSuite) TestDormancyCheck() { }). Return(dummyActivities, nil).Once() s.mockProviderService.EXPECT(). - CorrelateGrantActivities(mock.AnythingOfType("*context.emptyCtx"), *dummyProvider, mock.AnythingOfType("[]*domain.Grant"), dummyActivities). + CorrelateGrantActivities(mock.MatchedBy(func(ctx context.Context) bool { return true }), *dummyProvider, mock.AnythingOfType("[]*domain.Grant"), dummyActivities). Return(nil).Once() s.mockRepository.EXPECT(). - BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("[]*domain.Grant")). + BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("[]*domain.Grant")). Run(func(_a0 context.Context, grants []*domain.Grant) { s.Empty(cmp.Diff(expectedUpdatedGrants, grants, cmpopts.EquateApproxTime(time.Second), @@ -983,7 +983,7 @@ func (s *ServiceTestSuite) TestDormancyCheck() { }). Return(nil).Once() - s.mockNotifier.EXPECT().Notify(mock.Anything).Return(nil).Once() // TODO + s.mockNotifier.EXPECT().Notify(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil).Once() // TODO err := s.service.DormancyCheck(context.Background(), dormancyCheckCriteria) s.NoError(err) @@ -994,7 +994,7 @@ func (s *ServiceTestSuite) TestGetGrantsTotalCount() { s.Run("should return error if got error from repository", func() { expectedError := errors.New("repository error") s.mockRepository.EXPECT(). - GetGrantsTotalCount(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetGrantsTotalCount(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(0, expectedError).Once() actualCount, actualError := s.service.GetGrantsTotalCount(context.Background(), domain.ListGrantsFilter{}) @@ -1006,7 +1006,7 @@ func (s *ServiceTestSuite) TestGetGrantsTotalCount() { s.Run("should return Grants count from repository", func() { expectedCount := int64(1) s.mockRepository.EXPECT(). - GetGrantsTotalCount(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetGrantsTotalCount(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedCount, nil).Once() actualCount, actualError := s.service.GetGrantsTotalCount(context.Background(), domain.ListGrantsFilter{}) @@ -1024,7 +1024,7 @@ func (s *ServiceTestSuite) TestListUserRoles() { "role-2", } s.mockRepository.EXPECT(). - ListUserRoles(mock.AnythingOfType("*context.emptyCtx"), "user@example.com"). + ListUserRoles(mock.MatchedBy(func(ctx context.Context) bool { return true }), "user@example.com"). Return(expectedOutput, nil).Once() roles, err := s.service.ListUserRoles(context.Background(), "user@example.com") @@ -1045,7 +1045,7 @@ func (s *ServiceTestSuite) TestListUserRoles() { s.setup() expectedError := errors.New("repository error") s.mockRepository.EXPECT(). - ListUserRoles(mock.AnythingOfType("*context.emptyCtx"), "user"). + ListUserRoles(mock.MatchedBy(func(ctx context.Context) bool { return true }), "user"). Return(nil, expectedError).Once() roles, actualError := s.service.ListUserRoles(context.Background(), "user") diff --git a/core/policy/service.go b/core/policy/service.go index fa45a96f5..360e132c2 100644 --- a/core/policy/service.go +++ b/core/policy/service.go @@ -11,8 +11,8 @@ import ( "github.com/go-playground/validator/v10" "github.com/goto/guardian/domain" "github.com/goto/guardian/pkg/evaluator" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/utils" - "github.com/goto/salt/log" ) const ( @@ -111,7 +111,7 @@ func (s *Service) Create(ctx context.Context, p *domain.Policy) error { } if err := s.auditLogger.Log(ctx, AuditKeyPolicyCreate, p); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } } @@ -197,7 +197,7 @@ func (s *Service) Update(ctx context.Context, p *domain.Policy) error { } if err := s.auditLogger.Log(ctx, AuditKeyPolicyUpdate, p); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } } diff --git a/core/policy/service_test.go b/core/policy/service_test.go index cb92b8ea8..3c72dcc26 100644 --- a/core/policy/service_test.go +++ b/core/policy/service_test.go @@ -341,7 +341,7 @@ func (s *ServiceTestSuite) TestCreate() { s.Run("should return error if got error from the policy repository", func() { expectedError := errors.New("error from repository") - s.mockPolicyRepository.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(expectedError).Once() + s.mockPolicyRepository.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(expectedError).Once() s.mockCrypto.EXPECT().Encrypt("test-password").Return("test-password", nil).Once() s.mockCrypto.EXPECT().Decrypt("test-password").Return("test-password", nil).Once() @@ -357,7 +357,7 @@ func (s *ServiceTestSuite) TestCreate() { } expectedVersion := uint(1) - s.mockPolicyRepository.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), p).Return(nil).Once() + s.mockPolicyRepository.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), p).Return(nil).Once() s.mockCrypto.EXPECT().Encrypt("test-password").Return("test-password", nil).Once() s.mockCrypto.EXPECT().Decrypt("test-password").Return("test-password", nil).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, policy.AuditKeyPolicyCreate, mock.Anything).Return(nil).Once() @@ -371,7 +371,7 @@ func (s *ServiceTestSuite) TestCreate() { }) s.Run("should pass the model from the param", func() { - s.mockPolicyRepository.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), validPolicy).Return(nil).Once() + s.mockPolicyRepository.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), validPolicy).Return(nil).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, policy.AuditKeyPolicyCreate, mock.Anything).Return(nil).Once() s.mockCrypto.EXPECT().Encrypt("test-password").Return("test-password", nil).Once() s.mockCrypto.EXPECT().Decrypt("test-password").Return("test-password", nil).Once() @@ -624,7 +624,7 @@ func (s *ServiceTestSuite) TestPolicyRequirements() { Once() } } - s.mockPolicyRepository.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), p).Return(nil).Once() + s.mockPolicyRepository.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), p).Return(nil).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, policy.AuditKeyPolicyCreate, mock.Anything).Return(nil).Once() actualError := s.service.Create(context.Background(), p) @@ -637,7 +637,7 @@ func (s *ServiceTestSuite) TestPolicyRequirements() { func (s *ServiceTestSuite) TestFind() { s.Run("should return nil and error if got error from repository", func() { expectedError := errors.New("error from repository") - s.mockPolicyRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(nil, expectedError).Once() + s.mockPolicyRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(nil, expectedError).Once() actualResult, actualError := s.service.Find(context.Background()) @@ -661,7 +661,7 @@ func (s *ServiceTestSuite) TestFind() { }, }, } - s.mockPolicyRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(expectedResult, nil).Once() + s.mockPolicyRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(expectedResult, nil).Once() s.mockCrypto.EXPECT().Decrypt("test-password").Return("test-password", nil).Once() actualResult, actualError := s.service.Find(context.Background()) @@ -675,7 +675,7 @@ func (s *ServiceTestSuite) TestFind() { func (s *ServiceTestSuite) TestGetOne() { s.Run("should return nil and error if got error from repository", func() { expectedError := errors.New("error from repository") - s.mockPolicyRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything, mock.Anything).Return(nil, expectedError).Once() + s.mockPolicyRepository.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything, mock.Anything).Return(nil, expectedError).Once() actualResult, actualError := s.service.GetOne(context.Background(), "", 0) @@ -697,7 +697,7 @@ func (s *ServiceTestSuite) TestGetOne() { }, }, } - s.mockPolicyRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything, mock.Anything).Return(expectedResult, nil).Once() + s.mockPolicyRepository.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything, mock.Anything).Return(expectedResult, nil).Once() s.mockCrypto.EXPECT().Decrypt("test-password").Return("test-password", nil).Once() actualResult, actualError := s.service.GetOne(context.Background(), "", 0) @@ -749,8 +749,8 @@ func (s *ServiceTestSuite) TestUpdate() { Version: 5, } expectedNewVersion := uint(6) - s.mockPolicyRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), p.ID, uint(0)).Return(expectedLatestPolicy, nil).Once() - s.mockPolicyRepository.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), p).Return(nil) + s.mockPolicyRepository.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), p.ID, uint(0)).Return(expectedLatestPolicy, nil).Once() + s.mockPolicyRepository.EXPECT().Create(mock.MatchedBy(func(ctx context.Context) bool { return true }), p).Return(nil) s.mockCrypto.EXPECT().Encrypt("test-password").Return("test-password", nil).Once() s.mockCrypto.EXPECT().Decrypt("test-password").Return("test-password", nil).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, policy.AuditKeyPolicyUpdate, mock.Anything).Return(nil).Once() diff --git a/core/provider/mocks/client.go b/core/provider/mocks/client.go index afbc59985..a02cd69cc 100644 --- a/core/provider/mocks/client.go +++ b/core/provider/mocks/client.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.10.0. DO NOT EDIT. +// Code generated by mockery v2.32.0. DO NOT EDIT. package mocks @@ -42,7 +42,7 @@ type Client_CreateConfig_Call struct { } // CreateConfig is a helper method to define mock.On call -// - _a0 *domain.ProviderConfig +// - _a0 *domain.ProviderConfig func (_e *Client_Expecter) CreateConfig(_a0 interface{}) *Client_CreateConfig_Call { return &Client_CreateConfig_Call{Call: _e.mock.On("CreateConfig", _a0)} } @@ -59,6 +59,11 @@ func (_c *Client_CreateConfig_Call) Return(_a0 error) *Client_CreateConfig_Call return _c } +func (_c *Client_CreateConfig_Call) RunAndReturn(run func(*domain.ProviderConfig) error) *Client_CreateConfig_Call { + _c.Call.Return(run) + return _c +} + // GetAccountTypes provides a mock function with given fields: func (_m *Client) GetAccountTypes() []string { ret := _m.Called() @@ -97,11 +102,20 @@ func (_c *Client_GetAccountTypes_Call) Return(_a0 []string) *Client_GetAccountTy return _c } +func (_c *Client_GetAccountTypes_Call) RunAndReturn(run func() []string) *Client_GetAccountTypes_Call { + _c.Call.Return(run) + return _c +} + // GetPermissions provides a mock function with given fields: p, resourceType, role func (_m *Client) GetPermissions(p *domain.ProviderConfig, resourceType string, role string) ([]interface{}, error) { ret := _m.Called(p, resourceType, role) var r0 []interface{} + var r1 error + if rf, ok := ret.Get(0).(func(*domain.ProviderConfig, string, string) ([]interface{}, error)); ok { + return rf(p, resourceType, role) + } if rf, ok := ret.Get(0).(func(*domain.ProviderConfig, string, string) []interface{}); ok { r0 = rf(p, resourceType, role) } else { @@ -110,7 +124,6 @@ func (_m *Client) GetPermissions(p *domain.ProviderConfig, resourceType string, } } - var r1 error if rf, ok := ret.Get(1).(func(*domain.ProviderConfig, string, string) error); ok { r1 = rf(p, resourceType, role) } else { @@ -126,9 +139,9 @@ type Client_GetPermissions_Call struct { } // GetPermissions is a helper method to define mock.On call -// - p *domain.ProviderConfig -// - resourceType string -// - role string +// - p *domain.ProviderConfig +// - resourceType string +// - role string func (_e *Client_Expecter) GetPermissions(p interface{}, resourceType interface{}, role interface{}) *Client_GetPermissions_Call { return &Client_GetPermissions_Call{Call: _e.mock.On("GetPermissions", p, resourceType, role)} } @@ -145,22 +158,30 @@ func (_c *Client_GetPermissions_Call) Return(_a0 []interface{}, _a1 error) *Clie return _c } -// GetResources provides a mock function with given fields: pc -func (_m *Client) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { - ret := _m.Called(pc) +func (_c *Client_GetPermissions_Call) RunAndReturn(run func(*domain.ProviderConfig, string, string) ([]interface{}, error)) *Client_GetPermissions_Call { + _c.Call.Return(run) + return _c +} + +// GetResources provides a mock function with given fields: ctx, pc +func (_m *Client) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { + ret := _m.Called(ctx, pc) var r0 []*domain.Resource - if rf, ok := ret.Get(0).(func(*domain.ProviderConfig) []*domain.Resource); ok { - r0 = rf(pc) + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *domain.ProviderConfig) ([]*domain.Resource, error)); ok { + return rf(ctx, pc) + } + if rf, ok := ret.Get(0).(func(context.Context, *domain.ProviderConfig) []*domain.Resource); ok { + r0 = rf(ctx, pc) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*domain.Resource) } } - var r1 error - if rf, ok := ret.Get(1).(func(*domain.ProviderConfig) error); ok { - r1 = rf(pc) + if rf, ok := ret.Get(1).(func(context.Context, *domain.ProviderConfig) error); ok { + r1 = rf(ctx, pc) } else { r1 = ret.Error(1) } @@ -174,14 +195,15 @@ type Client_GetResources_Call struct { } // GetResources is a helper method to define mock.On call -// - pc *domain.ProviderConfig -func (_e *Client_Expecter) GetResources(pc interface{}) *Client_GetResources_Call { - return &Client_GetResources_Call{Call: _e.mock.On("GetResources", pc)} +// - ctx context.Context +// - pc *domain.ProviderConfig +func (_e *Client_Expecter) GetResources(ctx interface{}, pc interface{}) *Client_GetResources_Call { + return &Client_GetResources_Call{Call: _e.mock.On("GetResources", ctx, pc)} } -func (_c *Client_GetResources_Call) Run(run func(pc *domain.ProviderConfig)) *Client_GetResources_Call { +func (_c *Client_GetResources_Call) Run(run func(ctx context.Context, pc *domain.ProviderConfig)) *Client_GetResources_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*domain.ProviderConfig)) + run(args[0].(context.Context), args[1].(*domain.ProviderConfig)) }) return _c } @@ -191,11 +213,20 @@ func (_c *Client_GetResources_Call) Return(_a0 []*domain.Resource, _a1 error) *C return _c } +func (_c *Client_GetResources_Call) RunAndReturn(run func(context.Context, *domain.ProviderConfig) ([]*domain.Resource, error)) *Client_GetResources_Call { + _c.Call.Return(run) + return _c +} + // GetRoles provides a mock function with given fields: pc, resourceType func (_m *Client) GetRoles(pc *domain.ProviderConfig, resourceType string) ([]*domain.Role, error) { ret := _m.Called(pc, resourceType) var r0 []*domain.Role + var r1 error + if rf, ok := ret.Get(0).(func(*domain.ProviderConfig, string) ([]*domain.Role, error)); ok { + return rf(pc, resourceType) + } if rf, ok := ret.Get(0).(func(*domain.ProviderConfig, string) []*domain.Role); ok { r0 = rf(pc, resourceType) } else { @@ -204,7 +235,6 @@ func (_m *Client) GetRoles(pc *domain.ProviderConfig, resourceType string) ([]*d } } - var r1 error if rf, ok := ret.Get(1).(func(*domain.ProviderConfig, string) error); ok { r1 = rf(pc, resourceType) } else { @@ -220,8 +250,8 @@ type Client_GetRoles_Call struct { } // GetRoles is a helper method to define mock.On call -// - pc *domain.ProviderConfig -// - resourceType string +// - pc *domain.ProviderConfig +// - resourceType string func (_e *Client_Expecter) GetRoles(pc interface{}, resourceType interface{}) *Client_GetRoles_Call { return &Client_GetRoles_Call{Call: _e.mock.On("GetRoles", pc, resourceType)} } @@ -238,6 +268,11 @@ func (_c *Client_GetRoles_Call) Return(_a0 []*domain.Role, _a1 error) *Client_Ge return _c } +func (_c *Client_GetRoles_Call) RunAndReturn(run func(*domain.ProviderConfig, string) ([]*domain.Role, error)) *Client_GetRoles_Call { + _c.Call.Return(run) + return _c +} + // GetType provides a mock function with given fields: func (_m *Client) GetType() string { ret := _m.Called() @@ -274,13 +309,18 @@ func (_c *Client_GetType_Call) Return(_a0 string) *Client_GetType_Call { return _c } -// GrantAccess provides a mock function with given fields: _a0, _a1 -func (_m *Client) GrantAccess(_a0 *domain.ProviderConfig, _a1 domain.Grant) error { - ret := _m.Called(_a0, _a1) +func (_c *Client_GetType_Call) RunAndReturn(run func() string) *Client_GetType_Call { + _c.Call.Return(run) + return _c +} + +// GrantAccess provides a mock function with given fields: _a0, _a1, _a2 +func (_m *Client) GrantAccess(_a0 context.Context, _a1 *domain.ProviderConfig, _a2 domain.Grant) error { + ret := _m.Called(_a0, _a1, _a2) var r0 error - if rf, ok := ret.Get(0).(func(*domain.ProviderConfig, domain.Grant) error); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(context.Context, *domain.ProviderConfig, domain.Grant) error); ok { + r0 = rf(_a0, _a1, _a2) } else { r0 = ret.Error(0) } @@ -294,15 +334,16 @@ type Client_GrantAccess_Call struct { } // GrantAccess is a helper method to define mock.On call -// - _a0 *domain.ProviderConfig -// - _a1 domain.Grant -func (_e *Client_Expecter) GrantAccess(_a0 interface{}, _a1 interface{}) *Client_GrantAccess_Call { - return &Client_GrantAccess_Call{Call: _e.mock.On("GrantAccess", _a0, _a1)} +// - _a0 context.Context +// - _a1 *domain.ProviderConfig +// - _a2 domain.Grant +func (_e *Client_Expecter) GrantAccess(_a0 interface{}, _a1 interface{}, _a2 interface{}) *Client_GrantAccess_Call { + return &Client_GrantAccess_Call{Call: _e.mock.On("GrantAccess", _a0, _a1, _a2)} } -func (_c *Client_GrantAccess_Call) Run(run func(_a0 *domain.ProviderConfig, _a1 domain.Grant)) *Client_GrantAccess_Call { +func (_c *Client_GrantAccess_Call) Run(run func(_a0 context.Context, _a1 *domain.ProviderConfig, _a2 domain.Grant)) *Client_GrantAccess_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*domain.ProviderConfig), args[1].(domain.Grant)) + run(args[0].(context.Context), args[1].(*domain.ProviderConfig), args[2].(domain.Grant)) }) return _c } @@ -312,11 +353,20 @@ func (_c *Client_GrantAccess_Call) Return(_a0 error) *Client_GrantAccess_Call { return _c } +func (_c *Client_GrantAccess_Call) RunAndReturn(run func(context.Context, *domain.ProviderConfig, domain.Grant) error) *Client_GrantAccess_Call { + _c.Call.Return(run) + return _c +} + // ListAccess provides a mock function with given fields: _a0, _a1, _a2 func (_m *Client) ListAccess(_a0 context.Context, _a1 domain.ProviderConfig, _a2 []*domain.Resource) (domain.MapResourceAccess, error) { ret := _m.Called(_a0, _a1, _a2) var r0 domain.MapResourceAccess + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, domain.ProviderConfig, []*domain.Resource) (domain.MapResourceAccess, error)); ok { + return rf(_a0, _a1, _a2) + } if rf, ok := ret.Get(0).(func(context.Context, domain.ProviderConfig, []*domain.Resource) domain.MapResourceAccess); ok { r0 = rf(_a0, _a1, _a2) } else { @@ -325,7 +375,6 @@ func (_m *Client) ListAccess(_a0 context.Context, _a1 domain.ProviderConfig, _a2 } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, domain.ProviderConfig, []*domain.Resource) error); ok { r1 = rf(_a0, _a1, _a2) } else { @@ -341,9 +390,9 @@ type Client_ListAccess_Call struct { } // ListAccess is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 domain.ProviderConfig -// - _a2 []*domain.Resource +// - _a0 context.Context +// - _a1 domain.ProviderConfig +// - _a2 []*domain.Resource func (_e *Client_Expecter) ListAccess(_a0 interface{}, _a1 interface{}, _a2 interface{}) *Client_ListAccess_Call { return &Client_ListAccess_Call{Call: _e.mock.On("ListAccess", _a0, _a1, _a2)} } @@ -360,13 +409,18 @@ func (_c *Client_ListAccess_Call) Return(_a0 domain.MapResourceAccess, _a1 error return _c } -// RevokeAccess provides a mock function with given fields: _a0, _a1 -func (_m *Client) RevokeAccess(_a0 *domain.ProviderConfig, _a1 domain.Grant) error { - ret := _m.Called(_a0, _a1) +func (_c *Client_ListAccess_Call) RunAndReturn(run func(context.Context, domain.ProviderConfig, []*domain.Resource) (domain.MapResourceAccess, error)) *Client_ListAccess_Call { + _c.Call.Return(run) + return _c +} + +// RevokeAccess provides a mock function with given fields: _a0, _a1, _a2 +func (_m *Client) RevokeAccess(_a0 context.Context, _a1 *domain.ProviderConfig, _a2 domain.Grant) error { + ret := _m.Called(_a0, _a1, _a2) var r0 error - if rf, ok := ret.Get(0).(func(*domain.ProviderConfig, domain.Grant) error); ok { - r0 = rf(_a0, _a1) + if rf, ok := ret.Get(0).(func(context.Context, *domain.ProviderConfig, domain.Grant) error); ok { + r0 = rf(_a0, _a1, _a2) } else { r0 = ret.Error(0) } @@ -380,15 +434,16 @@ type Client_RevokeAccess_Call struct { } // RevokeAccess is a helper method to define mock.On call -// - _a0 *domain.ProviderConfig -// - _a1 domain.Grant -func (_e *Client_Expecter) RevokeAccess(_a0 interface{}, _a1 interface{}) *Client_RevokeAccess_Call { - return &Client_RevokeAccess_Call{Call: _e.mock.On("RevokeAccess", _a0, _a1)} +// - _a0 context.Context +// - _a1 *domain.ProviderConfig +// - _a2 domain.Grant +func (_e *Client_Expecter) RevokeAccess(_a0 interface{}, _a1 interface{}, _a2 interface{}) *Client_RevokeAccess_Call { + return &Client_RevokeAccess_Call{Call: _e.mock.On("RevokeAccess", _a0, _a1, _a2)} } -func (_c *Client_RevokeAccess_Call) Run(run func(_a0 *domain.ProviderConfig, _a1 domain.Grant)) *Client_RevokeAccess_Call { +func (_c *Client_RevokeAccess_Call) Run(run func(_a0 context.Context, _a1 *domain.ProviderConfig, _a2 domain.Grant)) *Client_RevokeAccess_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*domain.ProviderConfig), args[1].(domain.Grant)) + run(args[0].(context.Context), args[1].(*domain.ProviderConfig), args[2].(domain.Grant)) }) return _c } @@ -397,3 +452,22 @@ func (_c *Client_RevokeAccess_Call) Return(_a0 error) *Client_RevokeAccess_Call _c.Call.Return(_a0) return _c } + +func (_c *Client_RevokeAccess_Call) RunAndReturn(run func(context.Context, *domain.ProviderConfig, domain.Grant) error) *Client_RevokeAccess_Call { + _c.Call.Return(run) + return _c +} + +// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewClient(t interface { + mock.TestingT + Cleanup(func()) +}) *Client { + mock := &Client{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/provider/service.go b/core/provider/service.go index f99f3fc14..341763430 100644 --- a/core/provider/service.go +++ b/core/provider/service.go @@ -11,10 +11,10 @@ import ( "github.com/go-playground/validator/v10" "github.com/goto/guardian/domain" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/plugins/providers" "github.com/goto/guardian/utils" "github.com/goto/salt/audit" - "github.com/goto/salt/log" ) const ( @@ -114,11 +114,13 @@ func (s *Service) Create(ctx context.Context, p *domain.Provider) error { accountTypes := c.GetAccountTypes() if err := s.validateAccountTypes(p.Config, accountTypes); err != nil { + s.logger.Error(ctx, "failed to validate account types", "type", p.Type, "provider_urn", p.URN, "error", err) return err } if p.Config.Appeal != nil { if err := s.validateAppealConfig(p.Config.Appeal); err != nil { + s.logger.Error(ctx, "failed to validate appeal config", "type", p.Type, "provider_urn", p.URN, "error", err) return err } } @@ -126,6 +128,7 @@ func (s *Service) Create(ctx context.Context, p *domain.Provider) error { if err := c.CreateConfig(p.Config); err != nil { return err } + s.logger.Debug(ctx, "provider config created", "provider_urn", p.URN) dryRun := isDryRun(ctx) @@ -135,25 +138,25 @@ func (s *Service) Create(ctx context.Context, p *domain.Provider) error { } if err := s.auditLogger.Log(ctx, AuditKeyCreate, p); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } + } else { + s.logger.Info(ctx, "dry run enabled, skipping provider creation", "provider_urn", p.URN) } go func() { - s.logger.Info("provider create fetching resources", "provider_urn", p.URN) + s.logger.Info(ctx, "provider create fetching resources", "provider_urn", p.URN) ctx := audit.WithActor(context.Background(), domain.SystemActorName) resources, err := s.getResources(ctx, p) if err != nil { - s.logger.Error("failed to fetch resources", "error", err) + s.logger.Error(ctx, "failed to fetch resources", "error", err) } + s.logger.Debug(ctx, "provider create fetched resources", "provider_urn", p.URN, "count", len(resources)) if !dryRun { if err := s.resourceService.BulkUpsert(ctx, resources); err != nil { - s.logger.Error("failed to insert resources to db", "error", err) + s.logger.Error(ctx, "failed to insert resources to db", "error", err) } else { - s.logger.Info("resources added", - "provider_urn", p.URN, - "count", len(resources), - ) + s.logger.Info(ctx, "resources added", "provider_urn", p.URN, "count", len(resources)) } } }() @@ -192,11 +195,13 @@ func (s *Service) Update(ctx context.Context, p *domain.Provider) error { accountTypes := c.GetAccountTypes() if err := s.validateAccountTypes(p.Config, accountTypes); err != nil { + s.logger.Error(ctx, "failed to validate account types", "type", p.Type, "provider_urn", p.URN, "error", err) return err } if p.Config.Appeal != nil { if err := s.validateAppealConfig(p.Config.Appeal); err != nil { + s.logger.Error(ctx, "failed to validate appeal config", "type", p.Type, "provider_urn", p.URN, "error", err) return err } } @@ -204,6 +209,7 @@ func (s *Service) Update(ctx context.Context, p *domain.Provider) error { if err := c.CreateConfig(p.Config); err != nil { return err } + s.logger.Debug(ctx, "provider config created", "provider_urn", p.URN) dryRun := isDryRun(ctx) @@ -213,23 +219,26 @@ func (s *Service) Update(ctx context.Context, p *domain.Provider) error { } if err := s.auditLogger.Log(ctx, AuditKeyUpdate, p); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } + } else { + s.logger.Info(ctx, "dry run enabled, skipping provider update", "provider_urn", p.URN) } go func() { - s.logger.Info("provider update fetching resources", "provider_urn", p.URN) + s.logger.Info(ctx, "provider update fetching resources", "provider_urn", p.URN) ctx := audit.WithActor(context.Background(), domain.SystemActorName) resources, err := s.getResources(ctx, p) if err != nil { - s.logger.Error("failed to fetch resources", "error", err) + s.logger.Error(ctx, "failed to fetch resources", "error", err) } + s.logger.Debug(ctx, "provider create fetched resources", "provider_urn", p.URN, "count", len(resources)) if !dryRun { if err := s.resourceService.BulkUpsert(ctx, resources); err != nil { - s.logger.Error("failed to insert resources to db", "error", err) + s.logger.Error(ctx, "failed to insert resources to db", "error", err) } else { - s.logger.Info("resources added", "provider_urn", p.URN, "count", len(resources)) + s.logger.Info(ctx, "resources added", "provider_urn", p.URN, "count", len(resources)) } } }() @@ -246,19 +255,16 @@ func (s *Service) FetchResources(ctx context.Context) error { failedProviders := make([]string, 0) for _, p := range providers { - s.logger.Info("fetching resources", "provider_urn", p.URN) + s.logger.Info(ctx, "fetching resources", "provider_urn", p.URN) resources, err := s.getResources(ctx, p) if err != nil { - s.logger.Error("failed to get resources", "error", err) + s.logger.Error(ctx, "failed to get resources", "error", err) continue } - s.logger.Info("resources added", - "provider_urn", p.URN, - "count", len(flattenResources(resources)), - ) + s.logger.Info(ctx, "resources added", "provider_urn", p.URN, "count", len(flattenResources(resources))) if err := s.resourceService.BulkUpsert(ctx, resources); err != nil { failedProviders = append(failedProviders, p.URN) - s.logger.Error("failed to add resources", "provider_urn", p.URN) + s.logger.Error(ctx, "failed to add resources", "provider_urn", p.URN) } } @@ -400,7 +406,7 @@ func (s *Service) GrantAccess(ctx context.Context, a domain.Grant) error { return err } - return c.GrantAccess(p.Config, a) + return c.GrantAccess(ctx, p.Config, a) } func (s *Service) RevokeAccess(ctx context.Context, a domain.Grant) error { @@ -418,7 +424,7 @@ func (s *Service) RevokeAccess(ctx context.Context, a domain.Grant) error { return err } - return c.RevokeAccess(p.Config, a) + return c.RevokeAccess(ctx, p.Config, a) } func (s *Service) Delete(ctx context.Context, id string) error { @@ -427,6 +433,7 @@ func (s *Service) Delete(ctx context.Context, id string) error { return fmt.Errorf("getting provider details: %w", err) } + s.logger.Info(ctx, "retrieving related resources", "provider", id) resources, err := s.resourceService.Find(ctx, domain.ListResourcesFilter{ ProviderType: p.Type, ProviderURN: p.URN, @@ -438,6 +445,8 @@ func (s *Service) Delete(ctx context.Context, id string) error { for _, r := range resources { resourceIds = append(resourceIds, r.ID) } + s.logger.Info(ctx, "deleting resources", "provider", id, "count", len(resourceIds)) + // TODO: execute in transaction if err := s.resourceService.BatchDelete(ctx, resourceIds); err != nil { return fmt.Errorf("batch deleting resources: %w", err) @@ -446,9 +455,10 @@ func (s *Service) Delete(ctx context.Context, id string) error { if err := s.repository.Delete(ctx, id); err != nil { return err } + s.logger.Info(ctx, "provider deleted", "provider", id) if err := s.auditLogger.Log(ctx, AuditKeyDelete, p); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } return nil @@ -544,7 +554,7 @@ func (s *Service) getResources(ctx context.Context, p *domain.Provider) ([]*doma } } - newProviderResources, err := c.GetResources(p.Config) + newProviderResources, err := c.GetResources(ctx, p.Config) if err != nil { return nil, fmt.Errorf("error fetching resources for %v: %w", p.ID, err) } @@ -580,7 +590,6 @@ func (s *Service) getResources(ctx context.Context, p *domain.Provider) ([]*doma r.Details = existingDetails } } - existingProviderResources[er.ID] = true break } diff --git a/core/provider/service_test.go b/core/provider/service_test.go index 6be24ba5b..3404070e1 100644 --- a/core/provider/service_test.go +++ b/core/provider/service_test.go @@ -13,7 +13,7 @@ import ( providermocks "github.com/goto/guardian/core/provider/mocks" "github.com/goto/guardian/core/resource" "github.com/goto/guardian/domain" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -33,7 +33,7 @@ type ServiceTestSuite struct { } func (s *ServiceTestSuite) SetupTest() { - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) validator := validator.New() s.mockProviderRepository = new(providermocks.Repository) s.mockResourceService = new(providermocks.ResourceService) @@ -58,6 +58,8 @@ func (s *ServiceTestSuite) TestCreate() { Config: config, } + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) + s.Run("should return error if unable to retrieve provider", func() { expectedError := provider.ErrInvalidProviderType @@ -114,7 +116,7 @@ func (s *ServiceTestSuite) TestCreate() { expectedError := errors.New("error from repository") s.mockProvider.On("GetAccountTypes").Return([]string{"user"}).Once() s.mockProvider.On("CreateConfig", mock.Anything).Return(nil).Once() - s.mockProviderRepository.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(expectedError).Once() + s.mockProviderRepository.EXPECT().Create(mockCtx, mock.Anything).Return(expectedError).Once() actualError := s.service.Create(context.Background(), p) @@ -124,7 +126,7 @@ func (s *ServiceTestSuite) TestCreate() { s.Run("should pass the model from the param and trigger fetch resources on success", func() { s.mockProvider.On("GetAccountTypes").Return([]string{"user"}).Once() s.mockProvider.On("CreateConfig", mock.Anything).Return(nil).Once() - s.mockProviderRepository.EXPECT().Create(mock.AnythingOfType("*context.emptyCtx"), p).Return(nil).Once() + s.mockProviderRepository.EXPECT().Create(mockCtx, p).Return(nil).Once() s.mockAuditLogger.On("Log", mock.Anything, provider.AuditKeyCreate, mock.Anything).Return(nil).Once() expectedResources := []*domain.Resource{} @@ -132,7 +134,7 @@ func (s *ServiceTestSuite) TestCreate() { ProviderType: p.Type, ProviderURN: p.URN, }).Return([]*domain.Resource{}, nil).Once() - s.mockProvider.On("GetResources", p.Config).Return(expectedResources, nil).Once() + s.mockProvider.On("GetResources", mockCtx, p.Config).Return(expectedResources, nil).Once() s.mockResourceService.On("BulkUpsert", mock.Anything, expectedResources).Return(nil).Once() actualError := s.service.Create(context.Background(), p) @@ -152,7 +154,7 @@ func (s *ServiceTestSuite) TestCreate() { ProviderType: p.Type, ProviderURN: p.URN, }).Return([]*domain.Resource{}, nil).Once() - s.mockProvider.On("GetResources", p.Config).Return(expectedResources, nil).Once() + s.mockProvider.On("GetResources", mockCtx, p.Config).Return(expectedResources, nil).Once() ctx := provider.WithDryRun(context.Background()) @@ -167,9 +169,10 @@ func (s *ServiceTestSuite) TestCreate() { } func (s *ServiceTestSuite) TestFind() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should return nil and error if got error from repository", func() { expectedError := errors.New("error from repository") - s.mockProviderRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(nil, expectedError).Once() + s.mockProviderRepository.EXPECT().Find(mockCtx).Return(nil, expectedError).Once() actualResult, actualError := s.service.Find(context.Background()) @@ -179,7 +182,7 @@ func (s *ServiceTestSuite) TestFind() { s.Run("should return list of records on success", func() { expectedResult := []*domain.Provider{} - s.mockProviderRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(expectedResult, nil).Once() + s.mockProviderRepository.EXPECT().Find(mockCtx).Return(expectedResult, nil).Once() actualResult, actualError := s.service.Find(context.Background()) @@ -190,6 +193,7 @@ func (s *ServiceTestSuite) TestFind() { } func (s *ServiceTestSuite) TestUpdateValidation() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("validation", func() { s.Run("should return error if got error on account types validation", func() { p := &domain.Provider{ @@ -199,10 +203,10 @@ func (s *ServiceTestSuite) TestUpdateValidation() { }, } - s.mockProviderRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.mockProviderRepository.EXPECT().GetByID(mockCtx, mock.Anything). Return(&domain.Provider{}, nil). Once() - s.mockProviderRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything, mock.Anything). + s.mockProviderRepository.EXPECT().GetOne(mockCtx, mock.Anything, mock.Anything). Return(&domain.Provider{}, nil) s.mockProvider.On("GetAccountTypes").Return([]string{"non-user-only"}).Once() actualError := s.service.Update(context.Background(), p) @@ -220,7 +224,7 @@ func (s *ServiceTestSuite) TestUpdateValidation() { }, } - s.mockProviderRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.mockProviderRepository.EXPECT().GetByID(mockCtx, mock.Anything). Return(&domain.Provider{}, nil). Once() s.mockProvider.On("GetAccountTypes").Return([]string{"user"}).Once() @@ -232,6 +236,7 @@ func (s *ServiceTestSuite) TestUpdateValidation() { } func (s *ServiceTestSuite) TestUpdate() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should update record on success", func() { testCases := []struct { updatePayload *domain.Provider @@ -277,7 +282,7 @@ func (s *ServiceTestSuite) TestUpdate() { for _, tc := range testCases { s.mockProvider.On("GetAccountTypes").Return([]string{"user"}).Once() s.mockProvider.On("CreateConfig", mock.Anything).Return(nil).Once() - s.mockProviderRepository.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), tc.expectedNewProvider).Return(nil) + s.mockProviderRepository.EXPECT().Update(mock.MatchedBy(func(ctx context.Context) bool { return true }), tc.expectedNewProvider).Return(nil) s.mockAuditLogger.On("Log", mock.Anything, provider.AuditKeyUpdate, mock.Anything).Return(nil).Once() expectedResources := []*domain.Resource{} @@ -285,7 +290,7 @@ func (s *ServiceTestSuite) TestUpdate() { ProviderType: tc.updatePayload.Type, ProviderURN: tc.updatePayload.URN, }).Return([]*domain.Resource{}, nil).Once() - s.mockProvider.On("GetResources", tc.updatePayload.Config).Return(expectedResources, nil).Once() + s.mockProvider.On("GetResources", mockCtx, tc.updatePayload.Config).Return(expectedResources, nil).Once() s.mockResourceService.On("BulkUpsert", mock.Anything, expectedResources).Return(nil).Once() actualError := s.service.Update(context.Background(), tc.updatePayload) @@ -323,7 +328,7 @@ func (s *ServiceTestSuite) TestUpdate() { ProviderType: p.Type, ProviderURN: p.URN, }).Return([]*domain.Resource{}, nil).Once() - s.mockProvider.On("GetResources", p.Config).Return(expectedResources, nil).Once() + s.mockProvider.On("GetResources", mockCtx, p.Config).Return(expectedResources, nil).Once() actualError := s.service.Update(ctx, p) @@ -335,9 +340,10 @@ func (s *ServiceTestSuite) TestUpdate() { } func (s *ServiceTestSuite) TestFetchResources() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should return error if got any from provider repository", func() { expectedError := errors.New("any error") - s.mockProviderRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(nil, expectedError).Once() + s.mockProviderRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(nil, expectedError).Once() actualError := s.service.FetchResources(context.Background()) @@ -354,9 +360,9 @@ func (s *ServiceTestSuite) TestFetchResources() { } s.Run("should return error if got any from resource service", func() { - s.mockProviderRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(providers, nil).Once() + s.mockProviderRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(providers, nil).Once() for _, p := range providers { - s.mockProvider.On("GetResources", p.Config).Return([]*domain.Resource{}, nil).Once() + s.mockProvider.On("GetResources", mockCtx, p.Config).Return([]*domain.Resource{}, nil).Once() } expectedError := errors.New("failed to add resources providers - [mock_provider]") s.mockResourceService.On("BulkUpsert", mock.Anything, mock.Anything).Return(expectedError).Once() @@ -431,8 +437,8 @@ func (s *ServiceTestSuite) TestFetchResources() { } expectedProvider := providers[0] - s.mockProviderRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return([]*domain.Provider{expectedProvider}, nil).Once() - s.mockProvider.EXPECT().GetResources(expectedProvider.Config).Return(newResources, nil).Once() + s.mockProviderRepository.EXPECT().Find(mockCtx).Return([]*domain.Provider{expectedProvider}, nil).Once() + s.mockProvider.EXPECT().GetResources(mockCtx, expectedProvider.Config).Return(newResources, nil).Once() s.mockResourceService.EXPECT().BulkUpsert(mock.Anything, mock.AnythingOfType("[]*domain.Resource")). Run(func(_a0 context.Context, resources []*domain.Resource) { s.Empty(cmp.Diff(expectedResources, resources, cmpopts.IgnoreFields(domain.Resource{}, "ID", "CreatedAt", "UpdatedAt"))) @@ -454,7 +460,7 @@ func (s *ServiceTestSuite) TestFetchResources() { }}, }, } - s.mockProviderRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(providersWithResourceFilter, nil).Once() + s.mockProviderRepository.EXPECT().Find(mockCtx).Return(providersWithResourceFilter, nil).Once() expectedResources := []*domain.Resource{} for _, p := range providersWithResourceFilter { resources := []*domain.Resource{ @@ -470,7 +476,7 @@ func (s *ServiceTestSuite) TestFetchResources() { URN: "resource2", }, } - s.mockProvider.On("GetResources", p.Config).Return(resources, nil).Once() + s.mockProvider.On("GetResources", mockCtx, p.Config).Return(resources, nil).Once() expectedResources = append(expectedResources, resources[1]) } s.mockResourceService.On("BulkUpsert", mock.Anything, expectedResources).Return(nil) @@ -491,7 +497,7 @@ func (s *ServiceTestSuite) TestFetchResources() { }}, }, } - s.mockProviderRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx")).Return(providersWithResourceFilter, nil).Once() + s.mockProviderRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(providersWithResourceFilter, nil).Once() expectedResources := []*domain.Resource{} for _, p := range providersWithResourceFilter { resources := []*domain.Resource{ @@ -508,7 +514,7 @@ func (s *ServiceTestSuite) TestFetchResources() { Details: map[string]interface{}{"category": "transaction"}, }, } - s.mockProvider.On("GetResources", p.Config).Return(resources, nil).Once() + s.mockProvider.On("GetResources", mockCtx, p.Config).Return(resources, nil).Once() expectedResources = append(expectedResources, resources[1]) } s.mockResourceService.On("BulkUpsert", mock.Anything, expectedResources).Return(nil) @@ -520,6 +526,7 @@ func (s *ServiceTestSuite) TestFetchResources() { } func (s *ServiceTestSuite) TestGrantAccess() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should return error if got error on appeal param validation", func() { testCases := []struct { appealParam domain.Grant @@ -556,7 +563,7 @@ func (s *ServiceTestSuite) TestGrantAccess() { s.Run("should return error if got any from provider repository", func() { expectedError := errors.New("any error") - s.mockProviderRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything, mock.Anything). + s.mockProviderRepository.EXPECT().GetOne(mockCtx, mock.Anything, mock.Anything). Return(nil, expectedError). Once() @@ -566,7 +573,7 @@ func (s *ServiceTestSuite) TestGrantAccess() { }) s.Run("should return error if provider not found", func() { - s.mockProviderRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything, mock.Anything). + s.mockProviderRepository.EXPECT().GetOne(mockCtx, mock.Anything, mock.Anything). Return(nil, provider.ErrRecordNotFound). Once() expectedError := provider.ErrRecordNotFound @@ -581,11 +588,11 @@ func (s *ServiceTestSuite) TestGrantAccess() { Config: &domain.ProviderConfig{}, } s.mockProviderRepository. - EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), validAppeal.Resource.ProviderType, validAppeal.Resource.ProviderURN). + EXPECT().GetOne(mockCtx, validAppeal.Resource.ProviderType, validAppeal.Resource.ProviderURN). Return(provider, nil). Once() expectedError := errors.New("any error") - s.mockProvider.On("GrantAccess", mock.Anything, mock.Anything). + s.mockProvider.On("GrantAccess", mockCtx, mock.Anything, mock.Anything). Return(expectedError). Once() @@ -599,11 +606,11 @@ func (s *ServiceTestSuite) TestGrantAccess() { Config: &domain.ProviderConfig{}, } s.mockProviderRepository. - EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), validAppeal.Resource.ProviderType, validAppeal.Resource.ProviderURN). + EXPECT().GetOne(mockCtx, validAppeal.Resource.ProviderType, validAppeal.Resource.ProviderURN). Return(provider, nil). Once() s.mockProvider. - On("GrantAccess", provider.Config, validAppeal). + On("GrantAccess", mockCtx, provider.Config, validAppeal). Return(nil). Once() @@ -614,6 +621,7 @@ func (s *ServiceTestSuite) TestGrantAccess() { } func (s *ServiceTestSuite) TestRevokeAccess() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should return error if got error on appeal param validation", func() { testCases := []struct { appealParam domain.Grant @@ -650,7 +658,7 @@ func (s *ServiceTestSuite) TestRevokeAccess() { s.Run("should return error if got any from provider repository", func() { expectedError := errors.New("any error") - s.mockProviderRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything, mock.Anything). + s.mockProviderRepository.EXPECT().GetOne(mockCtx, mock.Anything, mock.Anything). Return(nil, expectedError). Once() @@ -660,7 +668,7 @@ func (s *ServiceTestSuite) TestRevokeAccess() { }) s.Run("should return error if provider not found", func() { - s.mockProviderRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything, mock.Anything). + s.mockProviderRepository.EXPECT().GetOne(mockCtx, mock.Anything, mock.Anything). Return(nil, provider.ErrRecordNotFound). Once() expectedError := provider.ErrRecordNotFound @@ -675,11 +683,11 @@ func (s *ServiceTestSuite) TestRevokeAccess() { Config: &domain.ProviderConfig{}, } s.mockProviderRepository. - EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), validAppeal.Resource.ProviderType, validAppeal.Resource.ProviderURN). + EXPECT().GetOne(mockCtx, validAppeal.Resource.ProviderType, validAppeal.Resource.ProviderURN). Return(provider, nil). Once() expectedError := errors.New("any error") - s.mockProvider.On("RevokeAccess", mock.Anything, mock.Anything). + s.mockProvider.On("RevokeAccess", mockCtx, mock.Anything, mock.Anything). Return(expectedError). Once() @@ -693,11 +701,11 @@ func (s *ServiceTestSuite) TestRevokeAccess() { Config: &domain.ProviderConfig{}, } s.mockProviderRepository. - EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), validAppeal.Resource.ProviderType, validAppeal.Resource.ProviderURN). + EXPECT().GetOne(mockCtx, validAppeal.Resource.ProviderType, validAppeal.Resource.ProviderURN). Return(provider, nil). Once() s.mockProvider. - On("RevokeAccess", provider.Config, validAppeal). + On("RevokeAccess", mockCtx, provider.Config, validAppeal). Return(nil). Once() @@ -708,9 +716,10 @@ func (s *ServiceTestSuite) TestRevokeAccess() { } func (s *ServiceTestSuite) TestDelete() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should return error if provider repository returns error", func() { expectedError := errors.New("random error") - s.mockProviderRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(nil, expectedError).Once() + s.mockProviderRepository.EXPECT().GetByID(mockCtx, mock.Anything).Return(nil, expectedError).Once() err := s.service.Delete(context.Background(), "test-provider") @@ -718,7 +727,7 @@ func (s *ServiceTestSuite) TestDelete() { }) s.Run("should return error if resourceService.Find returns error", func() { - s.mockProviderRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(&domain.Provider{}, nil).Once() + s.mockProviderRepository.EXPECT().GetByID(mockCtx, mock.Anything).Return(&domain.Provider{}, nil).Once() expectedError := errors.New("random error") s.mockResourceService.On("Find", mock.Anything, mock.Anything).Return(nil, expectedError).Once() @@ -728,10 +737,10 @@ func (s *ServiceTestSuite) TestDelete() { }) s.Run("should return error if resourceService.BatchDelete returns error", func() { - s.mockProviderRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(&domain.Provider{}, nil).Once() + s.mockProviderRepository.EXPECT().GetByID(mockCtx, mock.Anything).Return(&domain.Provider{}, nil).Once() s.mockResourceService.On("Find", mock.Anything, mock.Anything).Return([]*domain.Resource{}, nil).Once() expectedError := errors.New("random error") - s.mockResourceService.On("BatchDelete", mock.Anything, mock.Anything).Return(expectedError).Once() + s.mockResourceService.On("BatchDelete", mockCtx, mock.Anything, mock.Anything).Return(expectedError).Once() err := s.service.Delete(context.Background(), "test-provider") @@ -739,11 +748,11 @@ func (s *ServiceTestSuite) TestDelete() { }) s.Run("should return error if providerRepository.Delete returns error", func() { - s.mockProviderRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(&domain.Provider{}, nil).Once() + s.mockProviderRepository.EXPECT().GetByID(mockCtx, mock.Anything).Return(&domain.Provider{}, nil).Once() s.mockResourceService.On("Find", mock.Anything, mock.Anything).Return([]*domain.Resource{}, nil).Once() - s.mockResourceService.On("BatchDelete", mock.Anything, mock.Anything).Return(nil).Once() + s.mockResourceService.On("BatchDelete", mockCtx, mock.Anything, mock.Anything).Return(nil).Once() expectedError := errors.New("random error") - s.mockProviderRepository.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(expectedError).Once() + s.mockProviderRepository.EXPECT().Delete(mockCtx, mock.Anything).Return(expectedError).Once() err := s.service.Delete(context.Background(), "test-provider") @@ -758,13 +767,13 @@ func (s *ServiceTestSuite) TestDelete() { } dummyResources := []*domain.Resource{{ID: "a"}, {ID: "b"}} - s.mockProviderRepository.EXPECT().GetByID(mock.AnythingOfType("*context.emptyCtx"), testID).Return(dummyProvider, nil).Once() + s.mockProviderRepository.EXPECT().GetByID(mockCtx, testID).Return(dummyProvider, nil).Once() s.mockResourceService.On("Find", mock.Anything, domain.ListResourcesFilter{ ProviderType: dummyProvider.Type, ProviderURN: dummyProvider.URN, }).Return(dummyResources, nil).Once() - s.mockResourceService.On("BatchDelete", mock.Anything, []string{"a", "b"}).Return(nil).Once() - s.mockProviderRepository.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), testID).Return(nil).Once() + s.mockResourceService.On("BatchDelete", mockCtx, []string{"a", "b"}).Return(nil) + s.mockProviderRepository.EXPECT().Delete(mockCtx, testID).Return(nil).Once() s.mockAuditLogger.On("Log", mock.Anything, provider.AuditKeyDelete, dummyProvider).Return(nil).Once() err := s.service.Delete(context.Background(), "test-provider") @@ -1199,7 +1208,7 @@ func (s *ServiceTestSuite) TestListAccess() { }, } s.mockProvider.EXPECT(). - ListAccess(mock.AnythingOfType("*context.emptyCtx"), *p.Config, resources). + ListAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), *p.Config, resources). Return(returnedAccess, nil).Once() actualAccess, err := s.service.ListAccess(context.Background(), *p, resources) diff --git a/core/resource/service.go b/core/resource/service.go index ac080a01d..914938607 100644 --- a/core/resource/service.go +++ b/core/resource/service.go @@ -4,7 +4,7 @@ import ( "context" "github.com/goto/guardian/domain" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/imdario/mergo" ) @@ -79,7 +79,7 @@ func (s *Service) BulkUpsert(ctx context.Context, resources []*domain.Resource) } if err := s.auditLogger.Log(ctx, AuditKeyResoruceBulkUpsert, resources); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } return nil @@ -99,6 +99,7 @@ func (s *Service) Update(ctx context.Context, r *domain.Resource) error { if err := mergo.Merge(r, existingResource); err != nil { return err } + s.logger.Debug(ctx, "merged existing resource with updated resource", "resource", r.ID) res := &domain.Resource{ ID: r.ID, @@ -106,13 +107,15 @@ func (s *Service) Update(ctx context.Context, r *domain.Resource) error { Labels: r.Labels, } if err := s.repo.Update(ctx, res); err != nil { + s.logger.Error(ctx, "failed to update resource", "resource", r.ID, "error", err) return err } + s.logger.Info(ctx, "resource updated", "resource", r.ID) r.UpdatedAt = res.UpdatedAt if err := s.auditLogger.Log(ctx, AuditKeyResourceUpdate, r); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } return nil @@ -147,11 +150,13 @@ func (s *Service) Get(ctx context.Context, ri *domain.ResourceIdentifier) (*doma func (s *Service) Delete(ctx context.Context, id string) error { if err := s.repo.Delete(ctx, id); err != nil { + s.logger.Error(ctx, "failed to delete resource", "resource", id, "error", err) return err } + s.logger.Info(ctx, "resource deleted", "resource", id) if err := s.auditLogger.Log(ctx, AuditKeyResourceDelete, map[string]interface{}{"id": id}); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } return nil @@ -159,11 +164,13 @@ func (s *Service) Delete(ctx context.Context, id string) error { func (s *Service) BatchDelete(ctx context.Context, ids []string) error { if err := s.repo.BatchDelete(ctx, ids); err != nil { + s.logger.Error(ctx, "failed to delete resources", "resources", len(ids), "error", err) return err } + s.logger.Info(ctx, "resources deleted", "resources", len(ids)) if err := s.auditLogger.Log(ctx, AuditKeyResourceBatchDelete, map[string]interface{}{"ids": ids}); err != nil { - s.logger.Error("failed to record audit log", "error", err) + s.logger.Error(ctx, "failed to record audit log", "error", err) } return nil diff --git a/core/resource/service_test.go b/core/resource/service_test.go index 7feadea26..434f51325 100644 --- a/core/resource/service_test.go +++ b/core/resource/service_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "github.com/goto/guardian/pkg/log" + "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/goto/guardian/core/resource" @@ -19,6 +21,7 @@ type ServiceTestSuite struct { mockRepository *mocks.Repository mockAuditLogger *mocks.AuditLogger service *resource.Service + logger log.Logger authenticatedUserEmail string } @@ -26,9 +29,11 @@ type ServiceTestSuite struct { func (s *ServiceTestSuite) SetupTest() { s.mockRepository = new(mocks.Repository) s.mockAuditLogger = new(mocks.AuditLogger) + s.logger = log.NewCtxLogger("info", []string{"test"}) s.service = resource.NewService(resource.ServiceDeps{ Repository: s.mockRepository, AuditLogger: s.mockAuditLogger, + Logger: s.logger, }) s.authenticatedUserEmail = "user@example.com" } @@ -36,7 +41,7 @@ func (s *ServiceTestSuite) SetupTest() { func (s *ServiceTestSuite) TestFind() { s.Run("should return nil and error if got error from repository", func() { expectedError := errors.New("error from repository") - s.mockRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(nil, expectedError).Once() + s.mockRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(nil, expectedError).Once() actualResult, actualError := s.service.Find(context.Background(), domain.ListResourcesFilter{}) @@ -47,7 +52,7 @@ func (s *ServiceTestSuite) TestFind() { s.Run("should return list of records on success", func() { expectedFilters := domain.ListResourcesFilter{} expectedResult := []*domain.Resource{} - s.mockRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx"), expectedFilters).Return(expectedResult, nil).Once() + s.mockRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedFilters).Return(expectedResult, nil).Once() actualResult, actualError := s.service.Find(context.Background(), expectedFilters) @@ -61,7 +66,7 @@ func (s *ServiceTestSuite) TestGetResourcesTotalCount() { s.Run("should return error if got error from repository", func() { expectedError := errors.New("repository error") s.mockRepository.EXPECT(). - GetResourcesTotalCount(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetResourcesTotalCount(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(0, expectedError).Once() actualCount, actualError := s.service.GetResourcesTotalCount(context.Background(), domain.ListResourcesFilter{}) @@ -73,7 +78,7 @@ func (s *ServiceTestSuite) TestGetResourcesTotalCount() { s.Run("should return Resources count from repository", func() { expectedCount := int64(1) s.mockRepository.EXPECT(). - GetResourcesTotalCount(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + GetResourcesTotalCount(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything). Return(expectedCount, nil).Once() actualCount, actualError := s.service.GetResourcesTotalCount(context.Background(), domain.ListResourcesFilter{}) @@ -86,7 +91,7 @@ func (s *ServiceTestSuite) TestGetResourcesTotalCount() { func (s *ServiceTestSuite) TestBulkUpsert() { s.Run("should return error if got error from repository", func() { expectedError := errors.New("error from repository") - s.mockRepository.EXPECT().BulkUpsert(mock.AnythingOfType("*context.emptyCtx"), mock.Anything).Return(expectedError).Once() + s.mockRepository.EXPECT().BulkUpsert(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything).Return(expectedError).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, resource.AuditKeyResoruceBulkUpsert, mock.Anything).Return(nil) actualError := s.service.BulkUpsert(context.Background(), []*domain.Resource{}) @@ -96,6 +101,7 @@ func (s *ServiceTestSuite) TestBulkUpsert() { } func (s *ServiceTestSuite) TestUpdate() { + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should return error if got error getting existing record", func() { testCases := []struct { expectedExistingResource *domain.Resource @@ -120,7 +126,7 @@ func (s *ServiceTestSuite) TestUpdate() { } expectedError := tc.expectedError s.mockRepository.EXPECT(). - GetOne(mock.AnythingOfType("*context.emptyCtx"), expectedResource.ID). + GetOne(mockCtx, expectedResource.ID). Return(tc.expectedExistingResource, tc.expectedRepositoryError).Once() actualError := s.service.Update(context.Background(), expectedResource) @@ -131,9 +137,9 @@ func (s *ServiceTestSuite) TestUpdate() { s.Run("should return error if got error from repository", func() { expectedError := errors.New("error from repository") - s.mockRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.mockRepository.EXPECT().GetOne(mockCtx, mock.Anything). Return(&domain.Resource{}, nil).Once() - s.mockRepository.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.Anything). + s.mockRepository.EXPECT().Update(mockCtx, mock.Anything). Return(expectedError).Once() actualError := s.service.Update(context.Background(), &domain.Resource{}) @@ -233,8 +239,8 @@ func (s *ServiceTestSuite) TestUpdate() { for _, tc := range testCases { s.Run(tc.name, func() { - s.mockRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), tc.resourceUpdatePayload.ID).Return(tc.existingResource, nil).Once() - s.mockRepository.EXPECT().Update(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("*domain.Resource")). + s.mockRepository.EXPECT().GetOne(mockCtx, tc.resourceUpdatePayload.ID).Return(tc.existingResource, nil).Once() + s.mockRepository.EXPECT().Update(mockCtx, mock.AnythingOfType("*domain.Resource")). Run(func(_a0 context.Context, updateResourcePayload *domain.Resource) { s.Empty(cmp.Diff(tc.expectedUpdatedValues, updateResourcePayload, cmpopts.IgnoreFields(domain.Resource{}, "UpdatedAt", "CreatedAt"))) }).Return(nil).Once() @@ -255,7 +261,7 @@ func (s *ServiceTestSuite) TestGet() { expectedResource := &domain.Resource{ ID: "1", } - s.mockRepository.EXPECT().GetOne(mock.AnythingOfType("*context.emptyCtx"), expectedResource.ID). + s.mockRepository.EXPECT().GetOne(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedResource.ID). Return(expectedResource, nil).Once() actualResource, actualError := s.service.Get(context.Background(), &domain.ResourceIdentifier{ID: expectedResource.ID}) @@ -272,7 +278,7 @@ func (s *ServiceTestSuite) TestGet() { Type: "test-type", URN: "test-urn", } - s.mockRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx"), domain.ListResourcesFilter{ + s.mockRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), domain.ListResourcesFilter{ ProviderType: "test-provider", ProviderURN: "test-provider-urn", ResourceType: "test-type", @@ -294,7 +300,7 @@ func (s *ServiceTestSuite) TestGet() { s.Run("should return not found if resource not found", func() { expectedError := resource.ErrRecordNotFound - s.mockRepository.EXPECT().Find(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("domain.ListResourcesFilter")). + s.mockRepository.EXPECT().Find(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("domain.ListResourcesFilter")). Return([]*domain.Resource{}, nil).Once() actualResource, actualError := s.service.Get(context.Background(), &domain.ResourceIdentifier{ @@ -313,7 +319,7 @@ func (s *ServiceTestSuite) TestDelete() { s.Run("should delete resource", func() { expectedResourceID := "test-resource-id" - s.mockRepository.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), expectedResourceID). + s.mockRepository.EXPECT().Delete(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedResourceID). Return(nil).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, resource.AuditKeyResourceDelete, mock.Anything).Return(nil) @@ -326,7 +332,7 @@ func (s *ServiceTestSuite) TestDelete() { expectedResourceID := "test-resource-id" expectedError := errors.New("test-error") - s.mockRepository.EXPECT().Delete(mock.AnythingOfType("*context.emptyCtx"), expectedResourceID). + s.mockRepository.EXPECT().Delete(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedResourceID). Return(expectedError).Once() actualError := s.service.Delete(context.Background(), expectedResourceID) @@ -339,7 +345,7 @@ func (s *ServiceTestSuite) TestBatchDelete() { s.Run("should delete resources", func() { expectedResourceIDs := []string{"test-resource-id"} - s.mockRepository.EXPECT().BatchDelete(mock.AnythingOfType("*context.emptyCtx"), expectedResourceIDs). + s.mockRepository.EXPECT().BatchDelete(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedResourceIDs). Return(nil).Once() s.mockAuditLogger.EXPECT().Log(mock.Anything, resource.AuditKeyResourceBatchDelete, mock.Anything).Return(nil) @@ -352,7 +358,7 @@ func (s *ServiceTestSuite) TestBatchDelete() { expectedResourceIDs := []string{"test-resource-id"} expectedError := errors.New("test-error") - s.mockRepository.EXPECT().BatchDelete(mock.AnythingOfType("*context.emptyCtx"), expectedResourceIDs). + s.mockRepository.EXPECT().BatchDelete(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedResourceIDs). Return(expectedError).Once() actualError := s.service.BatchDelete(context.Background(), expectedResourceIDs) diff --git a/domain/audit.go b/domain/audit.go new file mode 100644 index 000000000..823d09fbf --- /dev/null +++ b/domain/audit.go @@ -0,0 +1,5 @@ +package domain + +const ( + TraceIDKey = "trace_id" +) diff --git a/internal/server/config.yaml b/internal/server/config.yaml index 8d451438e..134bf86c8 100644 --- a/internal/server/config.yaml +++ b/internal/server/config.yaml @@ -44,18 +44,19 @@ JOBS: USER_CRITERIA: '$user.is_active == true' REASSIGN_OWNERSHIP_TO: '$user.manager_email' TELEMETRY: - ENABLED: true - SERVICE_NAME: "guardian" - # Example for new relic - EXPORTER: otlp - OTLP: - HEADERS: - api-key: - ENDPOINT: "otlp.nr-data.net:4317" + ENABLED: true + SERVICE_NAME: "guardian" + # Example for new relic + EXPORTER: otlp + OTLP: + HEADERS: + api-key: + ENDPOINT: "otlp.nr-data.net:4317" AUTH: PROVIDER: default # can be "default" or "oidc" DEFAULT: - HEADER_KEY: X-Auth-Email # AUTHENTICATED_USER_HEADER_KEY takes priority for backward-compatibility + # AUTHENTICATED_USER_HEADER_KEY takes priority for backward-compatibility + HEADER_KEY: X-Auth-Email OIDC: AUDIENCE: "some-kind-of-audience.com" - ELIGIBLE_EMAIL_DOMAINS: "emaildomain1.com,emaildomain2.com" \ No newline at end of file + ELIGIBLE_EMAIL_DOMAINS: "emaildomain1.com,emaildomain2.com" diff --git a/internal/server/server.go b/internal/server/server.go index e09b940ac..54648a879 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -8,6 +8,8 @@ import ( "strings" "time" + "github.com/goto/guardian/domain" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -17,10 +19,10 @@ import ( "github.com/goto/guardian/internal/store/postgres" "github.com/goto/guardian/pkg/auth" "github.com/goto/guardian/pkg/crypto" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/pkg/tracing" "github.com/goto/guardian/plugins/notifiers" audit_repos "github.com/goto/salt/audit/repositories" - "github.com/goto/salt/log" "github.com/goto/salt/mux" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" @@ -45,7 +47,7 @@ const ( // RunServer runs the application server func RunServer(config *Config) error { - logger := log.NewLogrus(log.LogrusWithLevel(config.LogLevel)) + logger := log.NewCtxLogger(config.LogLevel, []string{domain.TraceIDKey}) crypto := crypto.NewAES(config.EncryptionSecretKeyKey) validator := validator.New() notifier, err := notifiers.NewClient(&config.Notifier, logger) @@ -86,7 +88,7 @@ func RunServer(config *Config) error { grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( grpc_recovery.UnaryServerInterceptor( grpc_recovery.WithRecoveryHandler(func(p interface{}) (err error) { - logger.Error(string(debug.Stack())) + logger.Error(context.Background(), string(debug.Stack())) return status.Errorf(codes.Internal, "Internal error, please check log") }), ), @@ -113,6 +115,7 @@ func RunServer(config *Config) error { services.GrantService, protoAdapter, authUserContextKey[config.Auth.Provider], + logger, )) // init http proxy @@ -163,7 +166,7 @@ func RunServer(config *Config) error { }) baseMux.Handle("/api/", http.StripPrefix("/api", gwmux)) - logger.Info(fmt.Sprintf("server running on %s", address)) + logger.Info(runtimeCtx, fmt.Sprintf("server running on %s", address)) return mux.Serve(runtimeCtx, address, mux.WithHTTP(baseMux), diff --git a/internal/server/services.go b/internal/server/services.go index 06e274d47..222c348dd 100644 --- a/internal/server/services.go +++ b/internal/server/services.go @@ -18,6 +18,7 @@ import ( "github.com/goto/guardian/domain" "github.com/goto/guardian/internal/store/postgres" "github.com/goto/guardian/pkg/auth" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/plugins/identities" "github.com/goto/guardian/plugins/notifiers" "github.com/goto/guardian/plugins/providers/bigquery" @@ -30,7 +31,6 @@ import ( "github.com/goto/guardian/plugins/providers/tableau" "github.com/goto/salt/audit" audit_repos "github.com/goto/salt/audit/repositories" - "github.com/goto/salt/log" "google.golang.org/grpc/metadata" ) @@ -87,7 +87,7 @@ func InitServices(deps ServiceDeps) (*Services, error) { if traceID == "" { traceID = uuid.New().String() } - md["trace_id"] = traceID + md[domain.TraceIDKey] = traceID return md }), diff --git a/internal/store/postgres/activity_repository_test.go b/internal/store/postgres/activity_repository_test.go index 9eff7362e..9dd3d9c3d 100644 --- a/internal/store/postgres/activity_repository_test.go +++ b/internal/store/postgres/activity_repository_test.go @@ -11,7 +11,7 @@ import ( "github.com/goto/guardian/core/activity" "github.com/goto/guardian/domain" "github.com/goto/guardian/internal/store/postgres" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/stretchr/testify/suite" ) @@ -36,7 +36,7 @@ func TestActivityRepository(t *testing.T) { } func (s *ActivityRepositoryTestSuite) SetupSuite() { - logger := log.NewLogrus(log.LogrusWithLevel("debug")) + logger := log.NewCtxLogger("info", []string{"test"}) store, pool, resource, err := newTestStore(logger) if err != nil { s.T().Fatal(err) diff --git a/internal/store/postgres/appeal_repository_test.go b/internal/store/postgres/appeal_repository_test.go index 2e4f4c1f1..fc11c02e3 100644 --- a/internal/store/postgres/appeal_repository_test.go +++ b/internal/store/postgres/appeal_repository_test.go @@ -11,7 +11,7 @@ import ( "github.com/goto/guardian/core/appeal" "github.com/goto/guardian/domain" "github.com/goto/guardian/internal/store/postgres" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/suite" ) @@ -30,7 +30,7 @@ type AppealRepositoryTestSuite struct { func (s *AppealRepositoryTestSuite) SetupSuite() { var err error - logger := log.NewLogrus(log.LogrusWithLevel("debug")) + logger := log.NewCtxLogger("debug", []string{"test"}) s.store, s.pool, s.resource, err = newTestStore(logger) if err != nil { s.T().Fatal(err) diff --git a/internal/store/postgres/approval_repository_test.go b/internal/store/postgres/approval_repository_test.go index 38dc84d99..59ef82d38 100644 --- a/internal/store/postgres/approval_repository_test.go +++ b/internal/store/postgres/approval_repository_test.go @@ -7,7 +7,7 @@ import ( "github.com/goto/guardian/domain" "github.com/goto/guardian/internal/store/postgres" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/suite" ) @@ -36,7 +36,7 @@ func TestApprovalRepository(t *testing.T) { func (s *ApprovalRepositoryTestSuite) SetupSuite() { var err error - logger := log.NewLogrus(log.LogrusWithLevel("debug")) + logger := log.NewCtxLogger("debug", []string{"test"}) s.store, s.pool, s.resource, err = newTestStore(logger) if err != nil { s.T().Fatal(err) diff --git a/internal/store/postgres/grant_repository_test.go b/internal/store/postgres/grant_repository_test.go index 592ae8089..937cce904 100644 --- a/internal/store/postgres/grant_repository_test.go +++ b/internal/store/postgres/grant_repository_test.go @@ -12,7 +12,7 @@ import ( "github.com/goto/guardian/core/grant" "github.com/goto/guardian/domain" "github.com/goto/guardian/internal/store/postgres" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/suite" ) @@ -40,7 +40,7 @@ func TestGrantRepository(t *testing.T) { func (s *GrantRepositoryTestSuite) SetupSuite() { var err error - logger := log.NewLogrus(log.LogrusWithLevel("debug")) + logger := log.NewCtxLogger("debug", []string{"test"}) s.store, s.pool, s.resource, err = newTestStore(logger) if err != nil { s.T().Fatal(err) diff --git a/internal/store/postgres/policy_repository_test.go b/internal/store/postgres/policy_repository_test.go index c4c114655..57e31d2df 100644 --- a/internal/store/postgres/policy_repository_test.go +++ b/internal/store/postgres/policy_repository_test.go @@ -11,7 +11,7 @@ import ( "github.com/goto/guardian/core/policy" "github.com/goto/guardian/domain" "github.com/goto/guardian/internal/store/postgres" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/suite" ) @@ -27,7 +27,7 @@ type PolicyRepositoryTestSuite struct { func (s *PolicyRepositoryTestSuite) SetupSuite() { var err error - logger := log.NewLogrus(log.LogrusWithLevel("debug")) + logger := log.NewCtxLogger("debug", []string{"test"}) s.store, s.pool, s.resource, err = newTestStore(logger) if err != nil { s.T().Fatal(err) diff --git a/internal/store/postgres/provider_repository_test.go b/internal/store/postgres/provider_repository_test.go index 5cef713c1..127fd9478 100644 --- a/internal/store/postgres/provider_repository_test.go +++ b/internal/store/postgres/provider_repository_test.go @@ -12,7 +12,7 @@ import ( "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" "github.com/goto/guardian/internal/store/postgres" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/suite" ) @@ -30,7 +30,7 @@ type ProviderRepositoryTestSuite struct { func (s *ProviderRepositoryTestSuite) SetupSuite() { var err error - logger := log.NewLogrus(log.LogrusWithLevel("debug")) + logger := log.NewCtxLogger("debug", []string{"test"}) s.store, s.pool, s.resource, err = newTestStore(logger) if err != nil { s.T().Fatal(err) diff --git a/internal/store/postgres/resource_repository.go b/internal/store/postgres/resource_repository.go index 990f3ec6b..f1640a68e 100644 --- a/internal/store/postgres/resource_repository.go +++ b/internal/store/postgres/resource_repository.go @@ -21,6 +21,17 @@ func NewResourceRepository(db *gorm.DB) *ResourceRepository { return &ResourceRepository{db} } +/* +only one active provider for a given resource & resource type. +eg: tables for gojek-integration should be onboarded through one-project, intersection with another provider for same project & resource must not be present +index will have `global_urn` and a check on deleted_at is null. because there can be deleted providers(which is fine) + + + + + +*/ + // Find records based on filters func (r *ResourceRepository) Find(ctx context.Context, filter domain.ListResourcesFilter) ([]*domain.Resource, error) { if err := utils.ValidateStruct(filter); err != nil { diff --git a/internal/store/postgres/resource_repository_test.go b/internal/store/postgres/resource_repository_test.go index c0229553f..b2ee48e49 100644 --- a/internal/store/postgres/resource_repository_test.go +++ b/internal/store/postgres/resource_repository_test.go @@ -11,7 +11,7 @@ import ( "github.com/goto/guardian/core/resource" "github.com/goto/guardian/domain" "github.com/goto/guardian/internal/store/postgres" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/ory/dockertest/v3" "github.com/stretchr/testify/suite" ) @@ -28,7 +28,7 @@ type ResourceRepositoryTestSuite struct { func (s *ResourceRepositoryTestSuite) SetupSuite() { var err error - logger := log.NewLogrus(log.LogrusWithLevel("debug")) + logger := log.NewCtxLogger("debug", []string{"test"}) s.store, s.pool, s.resource, err = newTestStore(logger) if err != nil { s.T().Fatal(err) diff --git a/internal/store/postgres/store_test.go b/internal/store/postgres/store_test.go index 968fbd234..83d6f3cce 100644 --- a/internal/store/postgres/store_test.go +++ b/internal/store/postgres/store_test.go @@ -1,12 +1,13 @@ package postgres_test import ( + "context" "fmt" "time" "github.com/goto/guardian/internal/store" "github.com/goto/guardian/internal/store/postgres" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" ) @@ -22,6 +23,7 @@ var ( ) func newTestStore(logger log.Logger) (*postgres.Store, *dockertest.Pool, *dockertest.Resource, error) { + ctx := context.Background() opts := &dockertest.RunOptions{ Repository: "postgres", Tag: "13", @@ -62,17 +64,17 @@ func newTestStore(logger log.Logger) (*postgres.Store, *dockertest.Pool, *docker Stream: true, }) if err != nil { - logger.Fatal("could not connect to postgres container log output", "error", err) + logger.Fatal(ctx, "could not connect to postgres container log output", "error", err) } defer func() { err = logWaiter.Close() if err != nil { - logger.Fatal("could not close container log", "error", err) + logger.Fatal(ctx, "could not close container log", "error", err) } err = logWaiter.Wait() if err != nil { - logger.Fatal("could not wait for container log to close", "error", err) + logger.Fatal(ctx, "could not wait for container log to close", "error", err) } }() } @@ -100,7 +102,7 @@ func newTestStore(logger log.Logger) (*postgres.Store, *dockertest.Pool, *docker err = setup(st) if err != nil { - logger.Fatal("failed to setup and migrate DB", "error", err) + logger.Fatal(ctx, "failed to setup and migrate DB", "error", err) } return st, pool, resource, nil } diff --git a/jobs/fetch_resources.go b/jobs/fetch_resources.go index bb5b1057d..a2ac9d0fa 100644 --- a/jobs/fetch_resources.go +++ b/jobs/fetch_resources.go @@ -9,6 +9,6 @@ import ( func (h *handler) FetchResources(ctx context.Context, cfg Config) error { ctx = audit.WithActor(ctx, domain.SystemActorName) - h.logger.Info("running fetch resources job") + h.logger.Info(ctx, "running fetch resources job") return h.providerService.FetchResources(ctx) } diff --git a/jobs/grant_dormancy_check.go b/jobs/grant_dormancy_check.go index e87b77777..1b688d360 100644 --- a/jobs/grant_dormancy_check.go +++ b/jobs/grant_dormancy_check.go @@ -35,14 +35,14 @@ func (h *handler) GrantDormancyCheck(ctx context.Context, c Config) error { } for _, p := range providers { - h.logger.Info(fmt.Sprintf("checking dormancy for grants under provider: %q", p.URN)) + h.logger.Info(ctx, fmt.Sprintf("checking dormancy for grants under provider: %q", p.URN)) if err := h.grantService.DormancyCheck(ctx, domain.DormancyCheckCriteria{ ProviderID: p.ID, Period: period, RetainDuration: retainGrantFor, DryRun: cfg.DryRun, }); err != nil { - h.logger.Error(fmt.Sprintf("failed to check dormancy for provider %q", p.URN), "error", err) + h.logger.Error(ctx, fmt.Sprintf("failed to check dormancy for provider %q", p.URN), "error", err) } } diff --git a/jobs/grant_expiration_reminder.go b/jobs/grant_expiration_reminder.go index 771f58611..d723997f4 100644 --- a/jobs/grant_expiration_reminder.go +++ b/jobs/grant_expiration_reminder.go @@ -9,11 +9,11 @@ import ( ) func (h *handler) GrantExpirationReminder(ctx context.Context, cfg Config) error { - h.logger.Info("running grant expiration reminder job") + h.logger.Info(ctx, "running grant expiration reminder job") daysBeforeExpired := []int{7, 3, 1} for _, d := range daysBeforeExpired { - h.logger.Info("retrieving active grants", "expiration_window_in_days", d) + h.logger.Info(ctx, "retrieving active grants", "expiration_window_in_days", d) now := time.Now().AddDate(0, 0, d) year, month, day := now.Date() @@ -26,12 +26,13 @@ func (h *handler) GrantExpirationReminder(ctx context.Context, cfg Config) error } grants, err := h.grantService.List(ctx, filters) if err != nil { - h.logger.Error("failed to retrieve active grants", + h.logger.Error(ctx, "failed to retrieve active grants", "expiration_window_in_days", d, "error", err, ) continue } + h.logger.Info(ctx, "retrieved active grants", "count", len(grants), "expiration_window_in_days", d) // TODO: group notifications by username var notifications []domain.Notification @@ -56,9 +57,9 @@ func (h *handler) GrantExpirationReminder(ctx context.Context, cfg Config) error }) } - if errs := h.notifier.Notify(notifications); errs != nil { + if errs := h.notifier.Notify(ctx, notifications); errs != nil { for _, err1 := range errs { - h.logger.Error("failed to send notifications", "error", err1) + h.logger.Error(ctx, "failed to send notifications", "error", err1) } } } diff --git a/jobs/grant_expiration_revoke.go b/jobs/grant_expiration_revoke.go index 735172209..db050caa5 100644 --- a/jobs/grant_expiration_revoke.go +++ b/jobs/grant_expiration_revoke.go @@ -9,7 +9,7 @@ import ( ) func (h *handler) RevokeExpiredGrants(ctx context.Context, cfg Config) error { - h.logger.Info("running revoke expired grants job") + h.logger.Info(ctx, "running revoke expired grants job") falseBool := false filters := domain.ListGrantsFilter{ @@ -18,20 +18,21 @@ func (h *handler) RevokeExpiredGrants(ctx context.Context, cfg Config) error { IsPermanent: &falseBool, } - h.logger.Info("retrieving active grant...") + h.logger.Info(ctx, "retrieving active grants...") grants, err := h.grantService.List(ctx, filters) if err != nil { return err } + h.logger.Info(ctx, "retrieved active grants", "count", len(grants)) successRevoke := []string{} failedRevoke := []map[string]interface{}{} for _, g := range grants { - h.logger.Info("revoking grant", "id", g.ID) + h.logger.Info(ctx, "revoking grant", "id", g.ID) ctx = audit.WithActor(ctx, domain.SystemActorName) if _, err := h.grantService.Revoke(ctx, g.ID, domain.SystemActorName, "Automatically revoked"); err != nil { - h.logger.Error("failed to revoke grant", + h.logger.Error(ctx, "failed to revoke grant", "id", g.ID, "error", err, ) @@ -41,7 +42,7 @@ func (h *handler) RevokeExpiredGrants(ctx context.Context, cfg Config) error { "error": err.Error(), }) } else { - h.logger.Info("grant revoked", "id", g.ID) + h.logger.Info(ctx, "grant revoked", "id", g.ID) successRevoke = append(successRevoke, g.ID) } } @@ -50,9 +51,9 @@ func (h *handler) RevokeExpiredGrants(ctx context.Context, cfg Config) error { return err } - h.logger.Info("successful grant revocation", "count", len(successRevoke), "ids", successRevoke) + h.logger.Info(ctx, "successful grant revocation", "count", len(successRevoke), "ids", successRevoke) if len(failedRevoke) > 0 { - h.logger.Info("failed grant revocation", "count", len(failedRevoke), "ids", failedRevoke) + h.logger.Info(ctx, "failed grant revocation", "count", len(failedRevoke), "ids", failedRevoke) } return nil diff --git a/jobs/handler.go b/jobs/handler.go index 24016d33b..7cced2d5b 100644 --- a/jobs/handler.go +++ b/jobs/handler.go @@ -6,8 +6,8 @@ import ( "github.com/go-playground/validator/v10" "github.com/goto/guardian/core/grant" "github.com/goto/guardian/domain" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/plugins/notifiers" - "github.com/goto/salt/log" ) //go:generate mockery --name=grantService --exported --with-expecter diff --git a/jobs/revoke_grants_by_user_criteria.go b/jobs/revoke_grants_by_user_criteria.go index 32031e62e..7dff4ea07 100644 --- a/jobs/revoke_grants_by_user_criteria.go +++ b/jobs/revoke_grants_by_user_criteria.go @@ -17,8 +17,8 @@ type RevokeGrantsByUserCriteriaConfig struct { } func (h *handler) RevokeGrantsByUserCriteria(ctx context.Context, c Config) error { - h.logger.Info(fmt.Sprintf("starting %q job", TypeRevokeGrantsByUserCriteria)) - defer h.logger.Info(fmt.Sprintf("finished %q job", TypeRevokeGrantsByUserCriteria)) + h.logger.Info(ctx, fmt.Sprintf("starting %q job", TypeRevokeGrantsByUserCriteria)) + defer h.logger.Info(ctx, fmt.Sprintf("finished %q job", TypeRevokeGrantsByUserCriteria)) var cfg RevokeGrantsByUserCriteriaConfig if err := c.Decode(&cfg); err != nil { @@ -35,7 +35,7 @@ func (h *handler) RevokeGrantsByUserCriteria(ctx context.Context, c Config) erro return fmt.Errorf("initializing IAM client: %w", err) } - h.logger.Info("getting active grants") + h.logger.Info(ctx, "getting active grants") activeGrants, err := h.grantService.List(ctx, domain.ListGrantsFilter{ Statuses: []string{string(domain.GrantStatusActive)}, }) @@ -43,11 +43,11 @@ func (h *handler) RevokeGrantsByUserCriteria(ctx context.Context, c Config) erro return fmt.Errorf("listing active grants: %w", err) } if len(activeGrants) == 0 { - h.logger.Info("no active grants found") + h.logger.Info(ctx, "no active grants found") return nil } grantIDs := getGrantIDs(activeGrants) - h.logger.Info(fmt.Sprintf("found %d active grants", len(activeGrants)), "grant_ids", grantIDs) + h.logger.Info(ctx, fmt.Sprintf("found %d active grants", len(activeGrants)), "grant_ids", grantIDs) grantsForUser := map[string][]*domain.Grant{} // map[account_id][]grant grantsOwnedByUser := map[string][]*domain.Grant{} // map[owner][]grant @@ -62,48 +62,48 @@ func (h *handler) RevokeGrantsByUserCriteria(ctx context.Context, c Config) erro grantsOwnedByUser[g.Owner] = append(grantsOwnedByUser[g.AccountID], &g) } } - h.logger.Info(fmt.Sprintf("found %d unique users", len(uniqueUserEmails)), "emails", uniqueUserEmails) + h.logger.Info(ctx, fmt.Sprintf("found %d unique users", len(uniqueUserEmails)), "emails", uniqueUserEmails) counter := 0 for email := range uniqueUserEmails { counter++ fmt.Println("") - h.logger.Info(fmt.Sprintf("processing user %d/%d", counter, len(uniqueUserEmails)), "email", email) + h.logger.Info(ctx, fmt.Sprintf("processing user %d/%d", counter, len(uniqueUserEmails)), "email", email) - h.logger.Info("fetching user details", "email", email) + h.logger.Info(ctx, "fetching user details", "email", email) userDetails, err := fetchUserDetails(iamClient, email) if err != nil { - h.logger.Error("failed to fetch user details", "email", email, "error", err) + h.logger.Error(ctx, "failed to fetch user details", "email", email, "error", err) continue } - h.logger.Info("checking criteria against user", "email", email, "criteria", cfg.UserCriteria.String()) + h.logger.Info(ctx, "checking criteria against user", "email", email, "criteria", cfg.UserCriteria.String()) if criteriaSatisfied, err := evaluateCriteria(cfg.UserCriteria, userDetails); err != nil { - h.logger.Error("failed to check criteria", "email", email, "error", err) + h.logger.Error(ctx, "failed to check criteria", "email", email, "error", err) } else if !criteriaSatisfied { - h.logger.Info("criteria not satisfied", "email", email) + h.logger.Info(ctx, "criteria not satisfied", "email", email) continue } - h.logger.Info("evaluating new owner", "email", email, "expression", cfg.ReassignOwnershipTo.String()) + h.logger.Info(ctx, "evaluating new owner", "email", email, "expression", cfg.ReassignOwnershipTo.String()) newOwner, err := h.evaluateNewOwner(cfg.ReassignOwnershipTo, userDetails) if err != nil { - h.logger.Error("evaluating new owner", "email", email, "error", err) + h.logger.Error(ctx, "evaluating new owner", "email", email, "error", err) continue } - h.logger.Info(fmt.Sprintf("evaluated new owner: %q", newOwner), "email", email) + h.logger.Info(ctx, fmt.Sprintf("evaluated new owner: %q", newOwner), "email", email) if !cfg.DryRun { // revoking grants with account_id == email - h.logger.Info("revoking user active grants", "email", email) + h.logger.Info(ctx, "revoking user active grants", "email", email) if revokedGrants, err := h.revokeUserGrants(ctx, email); err != nil { - h.logger.Error("failed to reovke grants", "email", email, "error", err) + h.logger.Error(ctx, "failed to reovke grants", "email", email, "error", err) } else { revokedGrantIDs := []string{} for _, g := range revokedGrants { revokedGrantIDs = append(revokedGrantIDs, g.ID) } - h.logger.Info("grant revocation successful", "count", len(revokedGrantIDs), "grant_ids", revokedGrantIDs) + h.logger.Info(ctx, "grant revocation successful", "count", len(revokedGrantIDs), "grant_ids", revokedGrantIDs) } // reassigning grants owned by the user to the new owner @@ -113,14 +113,14 @@ func (h *handler) RevokeGrantsByUserCriteria(ctx context.Context, c Config) erro for _, g := range successfulGrants { successfulGrantIDs = append(successfulGrantIDs, g.ID) } - h.logger.Info("grant ownership reassignment successful", "count", len(successfulGrantIDs), "grant_ids", successfulGrantIDs) + h.logger.Info(ctx, "grant ownership reassignment successful", "count", len(successfulGrantIDs), "grant_ids", successfulGrantIDs) } if len(failedGrants) > 0 { failedGrantIDs := []string{} for _, g := range failedGrants { failedGrantIDs = append(failedGrantIDs, g.ID) } - h.logger.Error("grant ownership reassignment failed", "count", len(failedGrantIDs), "grant_ids", failedGrantIDs) + h.logger.Error(ctx, "grant ownership reassignment failed", "count", len(failedGrantIDs), "grant_ids", failedGrantIDs) } } } @@ -159,7 +159,7 @@ func (h *handler) revokeUserGrants(ctx context.Context, email string) ([]*domain revokeGrantsFilter := domain.RevokeGrantsFilter{ AccountIDs: []string{email}, } - h.logger.Info("revoking grants", "account_id", email) + h.logger.Info(ctx, "revoking grants", "account_id", email) revokedGrants, err := h.grantService.BulkRevoke(ctx, revokeGrantsFilter, domain.SystemActorName, "Revoked due to user deactivated") if err != nil { return nil, fmt.Errorf("revoking grants for %q: %w", email, err) @@ -195,7 +195,7 @@ func (h *handler) reassignGrantsOwnership(ctx context.Context, ownedGrants []*do g.Owner = newOwner if err := h.grantService.Update(ctx, g); err != nil { failedGrants = append(failedGrants, g) - h.logger.Error("updating grant owner", "grant_id", g.ID, "existing_owner", g.Owner, "new_owner", newOwner, "error", err) + h.logger.Error(ctx, "updating grant owner", "grant_id", g.ID, "existing_owner", g.Owner, "new_owner", newOwner, "error", err) continue } successfulGrants = append(successfulGrants, g) diff --git a/mocks/MetabaseClient.go b/mocks/MetabaseClient.go index 6ead7bee4..86df7b94b 100644 --- a/mocks/MetabaseClient.go +++ b/mocks/MetabaseClient.go @@ -1,8 +1,10 @@ -// Code generated by mockery v2.10.0. DO NOT EDIT. +// Code generated by mockery v2.32.0. DO NOT EDIT. package mocks import ( + context "context" + metabase "github.com/goto/guardian/plugins/providers/metabase" mock "github.com/stretchr/testify/mock" ) @@ -12,22 +14,25 @@ type MetabaseClient struct { mock.Mock } -// GetCollections provides a mock function with given fields: -func (_m *MetabaseClient) GetCollections() ([]*metabase.Collection, error) { - ret := _m.Called() +// GetCollections provides a mock function with given fields: ctx +func (_m *MetabaseClient) GetCollections(ctx context.Context) ([]*metabase.Collection, error) { + ret := _m.Called(ctx) var r0 []*metabase.Collection - if rf, ok := ret.Get(0).(func() []*metabase.Collection); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*metabase.Collection, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*metabase.Collection); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*metabase.Collection) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,22 +40,25 @@ func (_m *MetabaseClient) GetCollections() ([]*metabase.Collection, error) { return r0, r1 } -// GetDatabases provides a mock function with given fields: -func (_m *MetabaseClient) GetDatabases() ([]*metabase.Database, error) { - ret := _m.Called() +// GetDatabases provides a mock function with given fields: ctx +func (_m *MetabaseClient) GetDatabases(ctx context.Context) ([]*metabase.Database, error) { + ret := _m.Called(ctx) var r0 []*metabase.Database - if rf, ok := ret.Get(0).(func() []*metabase.Database); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*metabase.Database, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*metabase.Database); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*metabase.Database) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -58,40 +66,43 @@ func (_m *MetabaseClient) GetDatabases() ([]*metabase.Database, error) { return r0, r1 } -// GetGroups provides a mock function with given fields: -func (_m *MetabaseClient) GetGroups() ([]*metabase.Group, metabase.ResourceGroupDetails, metabase.ResourceGroupDetails, error) { - ret := _m.Called() +// GetGroups provides a mock function with given fields: ctx +func (_m *MetabaseClient) GetGroups(ctx context.Context) ([]*metabase.Group, metabase.ResourceGroupDetails, metabase.ResourceGroupDetails, error) { + ret := _m.Called(ctx) var r0 []*metabase.Group - if rf, ok := ret.Get(0).(func() []*metabase.Group); ok { - r0 = rf() + var r1 metabase.ResourceGroupDetails + var r2 metabase.ResourceGroupDetails + var r3 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*metabase.Group, metabase.ResourceGroupDetails, metabase.ResourceGroupDetails, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*metabase.Group); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*metabase.Group) } } - var r1 metabase.ResourceGroupDetails - if rf, ok := ret.Get(1).(func() metabase.ResourceGroupDetails); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) metabase.ResourceGroupDetails); ok { + r1 = rf(ctx) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(metabase.ResourceGroupDetails) } } - var r2 metabase.ResourceGroupDetails - if rf, ok := ret.Get(2).(func() metabase.ResourceGroupDetails); ok { - r2 = rf() + if rf, ok := ret.Get(2).(func(context.Context) metabase.ResourceGroupDetails); ok { + r2 = rf(ctx) } else { if ret.Get(2) != nil { r2 = ret.Get(2).(metabase.ResourceGroupDetails) } } - var r3 error - if rf, ok := ret.Get(3).(func() error); ok { - r3 = rf() + if rf, ok := ret.Get(3).(func(context.Context) error); ok { + r3 = rf(ctx) } else { r3 = ret.Error(3) } @@ -99,13 +110,13 @@ func (_m *MetabaseClient) GetGroups() ([]*metabase.Group, metabase.ResourceGroup return r0, r1, r2, r3 } -// GrantCollectionAccess provides a mock function with given fields: resource, user, role -func (_m *MetabaseClient) GrantCollectionAccess(resource *metabase.Collection, user string, role string) error { - ret := _m.Called(resource, user, role) +// GrantCollectionAccess provides a mock function with given fields: ctx, resource, user, role +func (_m *MetabaseClient) GrantCollectionAccess(ctx context.Context, resource *metabase.Collection, user string, role string) error { + ret := _m.Called(ctx, resource, user, role) var r0 error - if rf, ok := ret.Get(0).(func(*metabase.Collection, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *metabase.Collection, string, string) error); ok { + r0 = rf(ctx, resource, user, role) } else { r0 = ret.Error(0) } @@ -113,13 +124,13 @@ func (_m *MetabaseClient) GrantCollectionAccess(resource *metabase.Collection, u return r0 } -// GrantDatabaseAccess provides a mock function with given fields: resource, user, role, groups -func (_m *MetabaseClient) GrantDatabaseAccess(resource *metabase.Database, user string, role string, groups map[string]*metabase.Group) error { - ret := _m.Called(resource, user, role, groups) +// GrantDatabaseAccess provides a mock function with given fields: ctx, resource, user, role, groups +func (_m *MetabaseClient) GrantDatabaseAccess(ctx context.Context, resource *metabase.Database, user string, role string, groups map[string]*metabase.Group) error { + ret := _m.Called(ctx, resource, user, role, groups) var r0 error - if rf, ok := ret.Get(0).(func(*metabase.Database, string, string, map[string]*metabase.Group) error); ok { - r0 = rf(resource, user, role, groups) + if rf, ok := ret.Get(0).(func(context.Context, *metabase.Database, string, string, map[string]*metabase.Group) error); ok { + r0 = rf(ctx, resource, user, role, groups) } else { r0 = ret.Error(0) } @@ -127,13 +138,13 @@ func (_m *MetabaseClient) GrantDatabaseAccess(resource *metabase.Database, user return r0 } -// GrantGroupAccess provides a mock function with given fields: groupID, email -func (_m *MetabaseClient) GrantGroupAccess(groupID int, email string) error { - ret := _m.Called(groupID, email) +// GrantGroupAccess provides a mock function with given fields: ctx, groupID, email +func (_m *MetabaseClient) GrantGroupAccess(ctx context.Context, groupID int, email string) error { + ret := _m.Called(ctx, groupID, email) var r0 error - if rf, ok := ret.Get(0).(func(int, string) error); ok { - r0 = rf(groupID, email) + if rf, ok := ret.Get(0).(func(context.Context, int, string) error); ok { + r0 = rf(ctx, groupID, email) } else { r0 = ret.Error(0) } @@ -141,13 +152,13 @@ func (_m *MetabaseClient) GrantGroupAccess(groupID int, email string) error { return r0 } -// GrantTableAccess provides a mock function with given fields: resource, user, role, groups -func (_m *MetabaseClient) GrantTableAccess(resource *metabase.Table, user string, role string, groups map[string]*metabase.Group) error { - ret := _m.Called(resource, user, role, groups) +// GrantTableAccess provides a mock function with given fields: ctx, resource, user, role, groups +func (_m *MetabaseClient) GrantTableAccess(ctx context.Context, resource *metabase.Table, user string, role string, groups map[string]*metabase.Group) error { + ret := _m.Called(ctx, resource, user, role, groups) var r0 error - if rf, ok := ret.Get(0).(func(*metabase.Table, string, string, map[string]*metabase.Group) error); ok { - r0 = rf(resource, user, role, groups) + if rf, ok := ret.Get(0).(func(context.Context, *metabase.Table, string, string, map[string]*metabase.Group) error); ok { + r0 = rf(ctx, resource, user, role, groups) } else { r0 = ret.Error(0) } @@ -155,13 +166,13 @@ func (_m *MetabaseClient) GrantTableAccess(resource *metabase.Table, user string return r0 } -// RevokeCollectionAccess provides a mock function with given fields: resource, user, role -func (_m *MetabaseClient) RevokeCollectionAccess(resource *metabase.Collection, user string, role string) error { - ret := _m.Called(resource, user, role) +// RevokeCollectionAccess provides a mock function with given fields: ctx, resource, user, role +func (_m *MetabaseClient) RevokeCollectionAccess(ctx context.Context, resource *metabase.Collection, user string, role string) error { + ret := _m.Called(ctx, resource, user, role) var r0 error - if rf, ok := ret.Get(0).(func(*metabase.Collection, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *metabase.Collection, string, string) error); ok { + r0 = rf(ctx, resource, user, role) } else { r0 = ret.Error(0) } @@ -169,13 +180,13 @@ func (_m *MetabaseClient) RevokeCollectionAccess(resource *metabase.Collection, return r0 } -// RevokeDatabaseAccess provides a mock function with given fields: resource, user, role -func (_m *MetabaseClient) RevokeDatabaseAccess(resource *metabase.Database, user string, role string) error { - ret := _m.Called(resource, user, role) +// RevokeDatabaseAccess provides a mock function with given fields: ctx, resource, user, role +func (_m *MetabaseClient) RevokeDatabaseAccess(ctx context.Context, resource *metabase.Database, user string, role string) error { + ret := _m.Called(ctx, resource, user, role) var r0 error - if rf, ok := ret.Get(0).(func(*metabase.Database, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *metabase.Database, string, string) error); ok { + r0 = rf(ctx, resource, user, role) } else { r0 = ret.Error(0) } @@ -183,13 +194,13 @@ func (_m *MetabaseClient) RevokeDatabaseAccess(resource *metabase.Database, user return r0 } -// RevokeGroupAccess provides a mock function with given fields: groupID, email -func (_m *MetabaseClient) RevokeGroupAccess(groupID int, email string) error { - ret := _m.Called(groupID, email) +// RevokeGroupAccess provides a mock function with given fields: ctx, groupID, email +func (_m *MetabaseClient) RevokeGroupAccess(ctx context.Context, groupID int, email string) error { + ret := _m.Called(ctx, groupID, email) var r0 error - if rf, ok := ret.Get(0).(func(int, string) error); ok { - r0 = rf(groupID, email) + if rf, ok := ret.Get(0).(func(context.Context, int, string) error); ok { + r0 = rf(ctx, groupID, email) } else { r0 = ret.Error(0) } @@ -197,16 +208,30 @@ func (_m *MetabaseClient) RevokeGroupAccess(groupID int, email string) error { return r0 } -// RevokeTableAccess provides a mock function with given fields: resource, user, role -func (_m *MetabaseClient) RevokeTableAccess(resource *metabase.Table, user string, role string) error { - ret := _m.Called(resource, user, role) +// RevokeTableAccess provides a mock function with given fields: ctx, resource, user, role +func (_m *MetabaseClient) RevokeTableAccess(ctx context.Context, resource *metabase.Table, user string, role string) error { + ret := _m.Called(ctx, resource, user, role) var r0 error - if rf, ok := ret.Get(0).(func(*metabase.Table, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *metabase.Table, string, string) error); ok { + r0 = rf(ctx, resource, user, role) } else { r0 = ret.Error(0) } return r0 } + +// NewMetabaseClient creates a new instance of MetabaseClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMetabaseClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MetabaseClient { + mock := &MetabaseClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mocks/ShieldClient.go b/mocks/ShieldClient.go index 466cf0c93..8740071f8 100644 --- a/mocks/ShieldClient.go +++ b/mocks/ShieldClient.go @@ -1,8 +1,10 @@ -// Code generated by mockery v2.10.0. DO NOT EDIT. +// Code generated by mockery v2.32.0. DO NOT EDIT. package mocks import ( + context "context" + shield "github.com/goto/guardian/plugins/providers/shield" mock "github.com/stretchr/testify/mock" ) @@ -12,22 +14,25 @@ type ShieldClient struct { mock.Mock } -// GetTeam provides a mock function with given fields: -func (_m *ShieldClient) GetTeams() ([]*shield.Team, error) { - ret := _m.Called() +// GetOrganizations provides a mock function with given fields: ctx +func (_m *ShieldClient) GetOrganizations(ctx context.Context) ([]*shield.Organization, error) { + ret := _m.Called(ctx) - var r0 []*shield.Team - if rf, ok := ret.Get(0).(func() []*shield.Team); ok { - r0 = rf() + var r0 []*shield.Organization + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*shield.Organization, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*shield.Organization); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*shield.Team) + r0 = ret.Get(0).([]*shield.Organization) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,22 +40,25 @@ func (_m *ShieldClient) GetTeams() ([]*shield.Team, error) { return r0, r1 } -// GetProjects provides a mock function with given fields: -func (_m *ShieldClient) GetProjects() ([]*shield.Project, error) { - ret := _m.Called() +// GetProjects provides a mock function with given fields: ctx +func (_m *ShieldClient) GetProjects(ctx context.Context) ([]*shield.Project, error) { + ret := _m.Called(ctx) var r0 []*shield.Project - if rf, ok := ret.Get(0).(func() []*shield.Project); ok { - r0 = rf() + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*shield.Project, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*shield.Project); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*shield.Project) } } - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -58,22 +66,51 @@ func (_m *ShieldClient) GetProjects() ([]*shield.Project, error) { return r0, r1 } -// GetOrganizations provides a mock function with given fields: -func (_m *ShieldClient) GetOrganizations() ([]*shield.Organization, error) { - ret := _m.Called() +// GetSelfUser provides a mock function with given fields: ctx, email +func (_m *ShieldClient) GetSelfUser(ctx context.Context, email string) (*shield.User, error) { + ret := _m.Called(ctx, email) - var r0 []*shield.Organization - if rf, ok := ret.Get(0).(func() []*shield.Organization); ok { - r0 = rf() + var r0 *shield.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*shield.User, error)); ok { + return rf(ctx, email) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *shield.User); ok { + r0 = rf(ctx, email) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*shield.Organization) + r0 = ret.Get(0).(*shield.User) } } + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, email) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTeams provides a mock function with given fields: ctx +func (_m *ShieldClient) GetTeams(ctx context.Context) ([]*shield.Team, error) { + ret := _m.Called(ctx) + + var r0 []*shield.Team var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]*shield.Team, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*shield.Team); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*shield.Team) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -81,13 +118,13 @@ func (_m *ShieldClient) GetOrganizations() ([]*shield.Organization, error) { return r0, r1 } -// GrantTeamAccess provides a mock function with given fields: resource, user, role -func (_m *ShieldClient) GrantTeamAccess(resource *shield.Team, user string, role string) error { - ret := _m.Called(resource, user, role) +// GrantOrganizationAccess provides a mock function with given fields: ctx, organization, userId, role +func (_m *ShieldClient) GrantOrganizationAccess(ctx context.Context, organization *shield.Organization, userId string, role string) error { + ret := _m.Called(ctx, organization, userId, role) var r0 error - if rf, ok := ret.Get(0).(func(*shield.Team, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *shield.Organization, string, string) error); ok { + r0 = rf(ctx, organization, userId, role) } else { r0 = ret.Error(0) } @@ -95,13 +132,13 @@ func (_m *ShieldClient) GrantTeamAccess(resource *shield.Team, user string, role return r0 } -// GrantProjectAccess provides a mock function with given fields: resource, user, role, groups -func (_m *ShieldClient) GrantProjectAccess(resource *shield.Project, user string, role string) error { - ret := _m.Called(resource, user, role) +// GrantProjectAccess provides a mock function with given fields: ctx, project, userId, role +func (_m *ShieldClient) GrantProjectAccess(ctx context.Context, project *shield.Project, userId string, role string) error { + ret := _m.Called(ctx, project, userId, role) var r0 error - if rf, ok := ret.Get(0).(func(*shield.Project, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *shield.Project, string, string) error); ok { + r0 = rf(ctx, project, userId, role) } else { r0 = ret.Error(0) } @@ -109,13 +146,13 @@ func (_m *ShieldClient) GrantProjectAccess(resource *shield.Project, user string return r0 } -// GrantOrganizationAccess provides a mock function with given fields: groupID, email -func (_m *ShieldClient) GrantOrganizationAccess(resource *shield.Organization, user string, role string) error { - ret := _m.Called(resource, user, role) +// GrantTeamAccess provides a mock function with given fields: ctx, team, userId, role +func (_m *ShieldClient) GrantTeamAccess(ctx context.Context, team *shield.Team, userId string, role string) error { + ret := _m.Called(ctx, team, userId, role) var r0 error - if rf, ok := ret.Get(0).(func(*shield.Organization, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *shield.Team, string, string) error); ok { + r0 = rf(ctx, team, userId, role) } else { r0 = ret.Error(0) } @@ -123,13 +160,13 @@ func (_m *ShieldClient) GrantOrganizationAccess(resource *shield.Organization, u return r0 } -// RevokeTeamAccess provides a mock function with given fields: resource, user, role -func (_m *ShieldClient) RevokeTeamAccess(resource *shield.Team, user string, role string) error { - ret := _m.Called(resource, user, role) +// RevokeOrganizationAccess provides a mock function with given fields: ctx, organization, userId, role +func (_m *ShieldClient) RevokeOrganizationAccess(ctx context.Context, organization *shield.Organization, userId string, role string) error { + ret := _m.Called(ctx, organization, userId, role) var r0 error - if rf, ok := ret.Get(0).(func(*shield.Team, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *shield.Organization, string, string) error); ok { + r0 = rf(ctx, organization, userId, role) } else { r0 = ret.Error(0) } @@ -137,13 +174,13 @@ func (_m *ShieldClient) RevokeTeamAccess(resource *shield.Team, user string, rol return r0 } -// RevokeProjectAccess provides a mock function with given fields: resource, user, role -func (_m *ShieldClient) RevokeProjectAccess(resource *shield.Project, user string, role string) error { - ret := _m.Called(resource, user, role) +// RevokeProjectAccess provides a mock function with given fields: ctx, project, userId, role +func (_m *ShieldClient) RevokeProjectAccess(ctx context.Context, project *shield.Project, userId string, role string) error { + ret := _m.Called(ctx, project, userId, role) var r0 error - if rf, ok := ret.Get(0).(func(*shield.Project, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *shield.Project, string, string) error); ok { + r0 = rf(ctx, project, userId, role) } else { r0 = ret.Error(0) } @@ -151,13 +188,13 @@ func (_m *ShieldClient) RevokeProjectAccess(resource *shield.Project, user strin return r0 } -// RevokeOrganizationAccess provides a mock function with given fields: resource, user, role -func (_m *ShieldClient) RevokeOrganizationAccess(resource *shield.Organization, user string, role string) error { - ret := _m.Called(resource, user, role) +// RevokeTeamAccess provides a mock function with given fields: ctx, team, userId, role +func (_m *ShieldClient) RevokeTeamAccess(ctx context.Context, team *shield.Team, userId string, role string) error { + ret := _m.Called(ctx, team, userId, role) var r0 error - if rf, ok := ret.Get(0).(func(*shield.Organization, string, string) error); ok { - r0 = rf(resource, user, role) + if rf, ok := ret.Get(0).(func(context.Context, *shield.Team, string, string) error); ok { + r0 = rf(ctx, team, userId, role) } else { r0 = ret.Error(0) } @@ -165,23 +202,16 @@ func (_m *ShieldClient) RevokeOrganizationAccess(resource *shield.Organization, return r0 } -// GetSelfUser a mock function with given fields: -func (_m *ShieldClient) GetSelfUser(email string) (*shield.User, error) { - ret := _m.Called(email) - - var r0 *shield.User - if rf, ok := ret.Get(0).(func(string) *shield.User); ok { - r0 = rf(email) - } else { - r0 = ret.Get(0).(*shield.User) - } +// NewShieldClient creates a new instance of ShieldClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewShieldClient(t interface { + mock.TestingT + Cleanup(func()) +}) *ShieldClient { + mock := &ShieldClient{} + mock.Mock.Test(t) - var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(email) - } else { - r1 = ret.Error(1) - } + t.Cleanup(func() { mock.AssertExpectations(t) }) - return r0, r1 + return mock } diff --git a/pkg/log/logger.go b/pkg/log/logger.go new file mode 100644 index 000000000..a83511c08 --- /dev/null +++ b/pkg/log/logger.go @@ -0,0 +1,96 @@ +package log + +import ( + "context" + "io" + + saltLog "github.com/goto/salt/log" +) + +type Logger interface { + + // Debug level message with alternating key/value pairs + // key should be string, value could be anything printable + Debug(ctx context.Context, msg string, args ...interface{}) + + // Info level message with alternating key/value pairs + // key should be string, value could be anything printable + Info(ctx context.Context, msg string, args ...interface{}) + + // Warn level message with alternating key/value pairs + // key should be string, value could be anything printable + Warn(ctx context.Context, msg string, args ...interface{}) + + // Error level message with alternating key/value pairs + // key should be string, value could be anything printable + Error(ctx context.Context, msg string, args ...interface{}) + + // Fatal level message with alternating key/value pairs + // key should be string, value could be anything printable + Fatal(ctx context.Context, msg string, args ...interface{}) + + // Level returns priority level for which this logger will filter logs + Level() string + + // Writer used to print logs + Writer() io.Writer +} + +type CtxLogger struct { + log saltLog.Logger + keys []string +} + +// NewCtxLoggerWithSaltLogger returns a logger that will add context value to the log message, wrapped with saltLog.Logger +func NewCtxLoggerWithSaltLogger(log saltLog.Logger, ctxKeys []string) *CtxLogger { + return &CtxLogger{log: log, keys: ctxKeys} +} + +// NewCtxLogger returns a logger that will add context value to the log message +func NewCtxLogger(logLevel string, ctxKeys []string) *CtxLogger { + saltLogger := saltLog.NewLogrus(saltLog.LogrusWithLevel(logLevel)) + return NewCtxLoggerWithSaltLogger(saltLogger, ctxKeys) +} + +func (l *CtxLogger) Debug(ctx context.Context, msg string, args ...interface{}) { + l.log.Debug(msg, l.addCtxToArgs(ctx, args)...) +} + +func (l *CtxLogger) Info(ctx context.Context, msg string, args ...interface{}) { + l.log.Info(msg, l.addCtxToArgs(ctx, args)...) +} + +func (l *CtxLogger) Warn(ctx context.Context, msg string, args ...interface{}) { + l.log.Warn(msg, l.addCtxToArgs(ctx, args)...) +} + +func (l *CtxLogger) Error(ctx context.Context, msg string, args ...interface{}) { + l.log.Error(msg, l.addCtxToArgs(ctx, args)...) +} + +func (l *CtxLogger) Fatal(ctx context.Context, msg string, args ...interface{}) { + l.log.Fatal(msg, l.addCtxToArgs(ctx, args)...) +} + +func (l *CtxLogger) Level() string { + return l.log.Level() +} + +func (l *CtxLogger) Writer() io.Writer { + return l.log.Writer() +} + +// addCtxToArgs adds context value to the existing args slice as key/value pair +func (l *CtxLogger) addCtxToArgs(ctx context.Context, args []interface{}) []interface{} { + if ctx == nil { + return args + } + + for _, key := range l.keys { + if val, ok := ctx.Value(key).(string); ok { + args = append(args, key, val) + } + } + + return args +} diff --git a/pkg/log/logger_test.go b/pkg/log/logger_test.go new file mode 100644 index 000000000..01a961934 --- /dev/null +++ b/pkg/log/logger_test.go @@ -0,0 +1,78 @@ +package log + +import ( + "context" + "testing" + + "github.com/goto/guardian/pkg/log/mocks" +) + +func TestLogger(t *testing.T) { + saltLogger := new(mocks.SaltLogger) + l := NewCtxLoggerWithSaltLogger(saltLogger, []string{"ctx-key-1"}) + + t.Run("empty context", func(t *testing.T) { + t.Run("Debug", func(t *testing.T) { + saltLogger.EXPECT().Debug("this is a test debug message", []interface{}{"keys", "test-value"}...).Once() + l.Debug(nil, "this is a test debug message", "keys", "test-value") + saltLogger.AssertExpectations(t) + }) + + t.Run("Info", func(t *testing.T) { + saltLogger.EXPECT().Info("this is a test info message", []interface{}{"keys", "test-value"}...).Once() + l.Info(nil, "this is a test info message", "keys", "test-value") + saltLogger.AssertExpectations(t) + }) + + t.Run("Warn", func(t *testing.T) { + saltLogger.EXPECT().Warn("this is a test warn message", []interface{}{"keys", "test-value"}...).Once() + l.Warn(nil, "this is a test warn message", "keys", "test-value") + saltLogger.AssertExpectations(t) + }) + + t.Run("Error", func(t *testing.T) { + saltLogger.EXPECT().Error("this is a test error message", []interface{}{"keys", "test-value"}...).Once() + l.Error(nil, "this is a test error message", "keys", "test-value") + saltLogger.AssertExpectations(t) + }) + + t.Run("Fatal", func(t *testing.T) { + saltLogger.EXPECT().Fatal("this is a test fatal message", []interface{}{"keys", "test-value"}...).Once() + l.Fatal(nil, "this is a test fatal message", "keys", "test-value") + saltLogger.AssertExpectations(t) + }) + }) + + t.Run("context with keys", func(t *testing.T) { + ctx := context.WithValue(context.Background(), "ctx-key-1", "ctx-value-1") + t.Run("Debug", func(t *testing.T) { + saltLogger.EXPECT().Debug("this is a test debug message", []interface{}{"key1", "test-value1", "ctx-key-1", "ctx-value-1"}...).Once() + l.Debug(ctx, "this is a test debug message", "key1", "test-value1") + saltLogger.AssertExpectations(t) + }) + + t.Run("Info", func(t *testing.T) { + saltLogger.EXPECT().Info("this is a test info message", []interface{}{"key1", "test-value1", "ctx-key-1", "ctx-value-1"}...).Once() + l.Info(ctx, "this is a test info message", "key1", "test-value1") + saltLogger.AssertExpectations(t) + }) + + t.Run("Warn", func(t *testing.T) { + saltLogger.EXPECT().Warn("this is a test warn message", []interface{}{"key1", "test-value1", "ctx-key-1", "ctx-value-1"}...).Once() + l.Warn(ctx, "this is a test warn message", "key1", "test-value1") + saltLogger.AssertExpectations(t) + }) + + t.Run("Error", func(t *testing.T) { + saltLogger.EXPECT().Error("this is a test error message", []interface{}{"key1", "test-value1", "ctx-key-1", "ctx-value-1"}...).Once() + l.Error(ctx, "this is a test error message", "key1", "test-value1") + saltLogger.AssertExpectations(t) + }) + + t.Run("Fatal", func(t *testing.T) { + saltLogger.EXPECT().Fatal("this is a test fatal message", []interface{}{"key1", "test-value1", "ctx-key-1", "ctx-value-1"}...).Once() + l.Fatal(ctx, "this is a test fatal message", "key1", "test-value1") + saltLogger.AssertExpectations(t) + }) + }) +} diff --git a/pkg/log/mocks/saltLogger.go b/pkg/log/mocks/saltLogger.go new file mode 100644 index 000000000..e0073e298 --- /dev/null +++ b/pkg/log/mocks/saltLogger.go @@ -0,0 +1,340 @@ +// Code generated by mockery v2.32.0. DO NOT EDIT. + +package mocks + +import ( + io "io" + + mock "github.com/stretchr/testify/mock" +) + +// SaltLogger is an autogenerated mock type for the saltLogger type +type SaltLogger struct { + mock.Mock +} + +type SaltLogger_Expecter struct { + mock *mock.Mock +} + +func (_m *SaltLogger) EXPECT() *SaltLogger_Expecter { + return &SaltLogger_Expecter{mock: &_m.Mock} +} + +// Debug provides a mock function with given fields: msg, args +func (_m *SaltLogger) Debug(msg string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// SaltLogger_Debug_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Debug' +type SaltLogger_Debug_Call struct { + *mock.Call +} + +// Debug is a helper method to define mock.On call +// - msg string +// - args ...interface{} +func (_e *SaltLogger_Expecter) Debug(msg interface{}, args ...interface{}) *SaltLogger_Debug_Call { + return &SaltLogger_Debug_Call{Call: _e.mock.On("Debug", + append([]interface{}{msg}, args...)...)} +} + +func (_c *SaltLogger_Debug_Call) Run(run func(msg string, args ...interface{})) *SaltLogger_Debug_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *SaltLogger_Debug_Call) Return() *SaltLogger_Debug_Call { + _c.Call.Return() + return _c +} + +func (_c *SaltLogger_Debug_Call) RunAndReturn(run func(string, ...interface{})) *SaltLogger_Debug_Call { + _c.Call.Return(run) + return _c +} + +// Error provides a mock function with given fields: msg, args +func (_m *SaltLogger) Error(msg string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// SaltLogger_Error_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Error' +type SaltLogger_Error_Call struct { + *mock.Call +} + +// Error is a helper method to define mock.On call +// - msg string +// - args ...interface{} +func (_e *SaltLogger_Expecter) Error(msg interface{}, args ...interface{}) *SaltLogger_Error_Call { + return &SaltLogger_Error_Call{Call: _e.mock.On("Error", + append([]interface{}{msg}, args...)...)} +} + +func (_c *SaltLogger_Error_Call) Run(run func(msg string, args ...interface{})) *SaltLogger_Error_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *SaltLogger_Error_Call) Return() *SaltLogger_Error_Call { + _c.Call.Return() + return _c +} + +func (_c *SaltLogger_Error_Call) RunAndReturn(run func(string, ...interface{})) *SaltLogger_Error_Call { + _c.Call.Return(run) + return _c +} + +// Fatal provides a mock function with given fields: msg, args +func (_m *SaltLogger) Fatal(msg string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// SaltLogger_Fatal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Fatal' +type SaltLogger_Fatal_Call struct { + *mock.Call +} + +// Fatal is a helper method to define mock.On call +// - msg string +// - args ...interface{} +func (_e *SaltLogger_Expecter) Fatal(msg interface{}, args ...interface{}) *SaltLogger_Fatal_Call { + return &SaltLogger_Fatal_Call{Call: _e.mock.On("Fatal", + append([]interface{}{msg}, args...)...)} +} + +func (_c *SaltLogger_Fatal_Call) Run(run func(msg string, args ...interface{})) *SaltLogger_Fatal_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *SaltLogger_Fatal_Call) Return() *SaltLogger_Fatal_Call { + _c.Call.Return() + return _c +} + +func (_c *SaltLogger_Fatal_Call) RunAndReturn(run func(string, ...interface{})) *SaltLogger_Fatal_Call { + _c.Call.Return(run) + return _c +} + +// Info provides a mock function with given fields: msg, args +func (_m *SaltLogger) Info(msg string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// SaltLogger_Info_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Info' +type SaltLogger_Info_Call struct { + *mock.Call +} + +// Info is a helper method to define mock.On call +// - msg string +// - args ...interface{} +func (_e *SaltLogger_Expecter) Info(msg interface{}, args ...interface{}) *SaltLogger_Info_Call { + return &SaltLogger_Info_Call{Call: _e.mock.On("Info", + append([]interface{}{msg}, args...)...)} +} + +func (_c *SaltLogger_Info_Call) Run(run func(msg string, args ...interface{})) *SaltLogger_Info_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *SaltLogger_Info_Call) Return() *SaltLogger_Info_Call { + _c.Call.Return() + return _c +} + +func (_c *SaltLogger_Info_Call) RunAndReturn(run func(string, ...interface{})) *SaltLogger_Info_Call { + _c.Call.Return(run) + return _c +} + +// Level provides a mock function with given fields: +func (_m *SaltLogger) Level() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// SaltLogger_Level_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Level' +type SaltLogger_Level_Call struct { + *mock.Call +} + +// Level is a helper method to define mock.On call +func (_e *SaltLogger_Expecter) Level() *SaltLogger_Level_Call { + return &SaltLogger_Level_Call{Call: _e.mock.On("Level")} +} + +func (_c *SaltLogger_Level_Call) Run(run func()) *SaltLogger_Level_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *SaltLogger_Level_Call) Return(_a0 string) *SaltLogger_Level_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SaltLogger_Level_Call) RunAndReturn(run func() string) *SaltLogger_Level_Call { + _c.Call.Return(run) + return _c +} + +// Warn provides a mock function with given fields: msg, args +func (_m *SaltLogger) Warn(msg string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// SaltLogger_Warn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Warn' +type SaltLogger_Warn_Call struct { + *mock.Call +} + +// Warn is a helper method to define mock.On call +// - msg string +// - args ...interface{} +func (_e *SaltLogger_Expecter) Warn(msg interface{}, args ...interface{}) *SaltLogger_Warn_Call { + return &SaltLogger_Warn_Call{Call: _e.mock.On("Warn", + append([]interface{}{msg}, args...)...)} +} + +func (_c *SaltLogger_Warn_Call) Run(run func(msg string, args ...interface{})) *SaltLogger_Warn_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *SaltLogger_Warn_Call) Return() *SaltLogger_Warn_Call { + _c.Call.Return() + return _c +} + +func (_c *SaltLogger_Warn_Call) RunAndReturn(run func(string, ...interface{})) *SaltLogger_Warn_Call { + _c.Call.Return(run) + return _c +} + +// Writer provides a mock function with given fields: +func (_m *SaltLogger) Writer() io.Writer { + ret := _m.Called() + + var r0 io.Writer + if rf, ok := ret.Get(0).(func() io.Writer); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.Writer) + } + } + + return r0 +} + +// SaltLogger_Writer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Writer' +type SaltLogger_Writer_Call struct { + *mock.Call +} + +// Writer is a helper method to define mock.On call +func (_e *SaltLogger_Expecter) Writer() *SaltLogger_Writer_Call { + return &SaltLogger_Writer_Call{Call: _e.mock.On("Writer")} +} + +func (_c *SaltLogger_Writer_Call) Run(run func()) *SaltLogger_Writer_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *SaltLogger_Writer_Call) Return(_a0 io.Writer) *SaltLogger_Writer_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SaltLogger_Writer_Call) RunAndReturn(run func() io.Writer) *SaltLogger_Writer_Call { + _c.Call.Return(run) + return _c +} + +// NewSaltLogger creates a new instance of SaltLogger. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewSaltLogger(t interface { + mock.TestingT + Cleanup(func()) +}) *SaltLogger { + mock := &SaltLogger{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/log/noop.go b/pkg/log/noop.go new file mode 100644 index 000000000..820c0b225 --- /dev/null +++ b/pkg/log/noop.go @@ -0,0 +1,29 @@ +package log + +import ( + "context" + "io" + "io/ioutil" +) + +type Noop struct{} +type Option func(interface{}) + +func (n *Noop) Debug(ctx context.Context, msg string, args ...interface{}) {} +func (n *Noop) Info(ctx context.Context, msg string, args ...interface{}) {} +func (n *Noop) Warn(ctx context.Context, msg string, args ...interface{}) {} +func (n *Noop) Error(ctx context.Context, msg string, args ...interface{}) {} +func (n *Noop) Fatal(ctx context.Context, msg string, args ...interface{}) {} + +func (n *Noop) Level() string { + return "unsupported" +} +func (n *Noop) Writer() io.Writer { + return ioutil.Discard +} + +// NewNoop returns a no operation logger, useful in tests +// to avoid printing logs to stdout. +func NewNoop(opts ...Option) *Noop { + return &Noop{} +} diff --git a/plugins/notifiers/client.go b/plugins/notifiers/client.go index 914dac659..deb4b27ab 100644 --- a/plugins/notifiers/client.go +++ b/plugins/notifiers/client.go @@ -1,12 +1,13 @@ package notifiers import ( + "context" "errors" "fmt" "net/http" "time" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/mitchellh/mapstructure" "github.com/goto/guardian/domain" @@ -14,7 +15,7 @@ import ( ) type Client interface { - Notify([]domain.Notification) []error + Notify(context.Context, []domain.Notification) []error } const ( @@ -39,7 +40,7 @@ type Config struct { Messages domain.NotificationMessages } -func NewClient(config *Config, logger *log.Logrus) (Client, error) { +func NewClient(config *Config, logger log.Logger) (Client, error) { if config.Provider == ProviderTypeSlack { slackConfig, err := NewSlackConfig(config) if err != nil { @@ -47,6 +48,7 @@ func NewClient(config *Config, logger *log.Logrus) (Client, error) { } httpClient := &http.Client{Timeout: 10 * time.Second} + return slack.NewNotifier(slackConfig, httpClient, logger), nil } diff --git a/plugins/notifiers/slack/client.go b/plugins/notifiers/slack/client.go index a60057e4d..37bb4686b 100644 --- a/plugins/notifiers/slack/client.go +++ b/plugins/notifiers/slack/client.go @@ -2,6 +2,7 @@ package slack import ( "bytes" + "context" "embed" "encoding/json" "errors" @@ -12,7 +13,7 @@ import ( "strings" "github.com/goto/guardian/pkg/evaluator" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/utils" @@ -53,7 +54,7 @@ type Notifier struct { Messages domain.NotificationMessages httpClient utils.HTTPClient defaultMessageFiles embed.FS - logger *log.Logrus + logger log.Logger } type slackIDCacheItem struct { @@ -69,7 +70,7 @@ type Config struct { //go:embed templates/* var defaultTemplates embed.FS -func NewNotifier(config *Config, httpClient utils.HTTPClient, logger *log.Logrus) *Notifier { +func NewNotifier(config *Config, httpClient utils.HTTPClient, logger log.Logger) *Notifier { return &Notifier{ workspaces: config.Workspaces, slackIDCache: map[string]*slackIDCacheItem{}, @@ -80,7 +81,7 @@ func NewNotifier(config *Config, httpClient utils.HTTPClient, logger *log.Logrus } } -func (n *Notifier) Notify(items []domain.Notification) []error { +func (n *Notifier) Notify(ctx context.Context, items []domain.Notification) []error { errs := make([]error, 0) for _, item := range items { var slackWorkspace *SlackWorkspace @@ -116,7 +117,7 @@ func (n *Notifier) Notify(items []domain.Notification) []error { continue } - n.logger.Debug(fmt.Sprintf("%v | sending slack notification to user:%s in workspace:%s", labelSlice, item.User, slackWorkspace.WorkspaceName)) + n.logger.Debug(ctx, fmt.Sprintf("%v | sending slack notification to user:%s in workspace:%s", labelSlice, item.User, slackWorkspace.WorkspaceName)) msg, err := ParseMessage(item.Message, n.Messages, n.defaultMessageFiles) if err != nil { diff --git a/plugins/notifiers/slack/client_test.go b/plugins/notifiers/slack/client_test.go index e4e608326..e4a26b2ee 100644 --- a/plugins/notifiers/slack/client_test.go +++ b/plugins/notifiers/slack/client_test.go @@ -2,6 +2,7 @@ package slack_test import ( "bytes" + "context" "embed" "errors" "fmt" @@ -77,7 +78,7 @@ func (s *ClientTestSuite) TestNotify() { }, }, } - actualErrs := s.notifier.Notify(notifications) + actualErrs := s.notifier.Notify(context.Background(), notifications) s.Equal(expectedErrs, actualErrs) }) diff --git a/plugins/providers/bigquery/config.go b/plugins/providers/bigquery/config.go index 431a0a814..ed9c46eb3 100644 --- a/plugins/providers/bigquery/config.go +++ b/plugins/providers/bigquery/config.go @@ -101,13 +101,13 @@ func NewConfig(pc *domain.ProviderConfig, crypto domain.Crypto) *Config { } // ParseAndValidate validates bigquery config within provider config and make the interface{} config value castable into the expected bigquery config value -func (c *Config) ParseAndValidate() error { - return c.parseAndValidate() +func (c *Config) ParseAndValidate(ctx context.Context) error { + return c.parseAndValidate(ctx) } // EncryptCredentials encrypts the bigquery credentials config -func (c *Config) EncryptCredentials() error { - if err := c.parseAndValidate(); err != nil { +func (c *Config) EncryptCredentials(ctx context.Context) error { + if err := c.parseAndValidate(ctx); err != nil { return err } @@ -124,7 +124,7 @@ func (c *Config) EncryptCredentials() error { return nil } -func (c *Config) parseAndValidate() error { +func (c *Config) parseAndValidate(ctx context.Context) error { if c.valid { return nil } @@ -154,7 +154,7 @@ func (c *Config) parseAndValidate() error { for _, resource := range c.ProviderConfig.Resources { for _, role := range resource.Roles { for i, permission := range role.Permissions { - if permissionConfig, err := c.validatePermission(permission, resource.Type, client); err != nil { + if permissionConfig, err := c.validatePermission(ctx, permission, resource.Type, client); err != nil { permissionValidationErrors = append(permissionValidationErrors, err) } else { role.Permissions[i] = permissionConfig @@ -195,13 +195,12 @@ func (c *Config) validateCredentials(value interface{}) (*Credentials, error) { return &credentials, nil } -func (c *Config) validatePermission(value interface{}, resourceType string, client *bigQueryClient) (*Permission, error) { +func (c *Config) validatePermission(ctx context.Context, value interface{}, resourceType string, client *bigQueryClient) (*Permission, error) { permision, ok := value.(string) if !ok { return nil, ErrInvalidPermissionConfig } - ctx := context.TODO() if resourceType == ResourceTypeDataset { if !utils.ContainsString([]string{DatasetRoleReader, DatasetRoleWriter, DatasetRoleOwner}, permision) { grantableRoles, err := c.getGrantableRolesForDataset(ctx, client) diff --git a/plugins/providers/bigquery/config_test.go b/plugins/providers/bigquery/config_test.go index 840cfbd8f..a24f12fa9 100644 --- a/plugins/providers/bigquery/config_test.go +++ b/plugins/providers/bigquery/config_test.go @@ -1,6 +1,7 @@ package bigquery_test import ( + "context" "encoding/base64" "errors" "testing" @@ -163,7 +164,7 @@ func TestValidate(t *testing.T) { } mockCrypto.On("Encrypt", mock.Anything).Return("", nil).Once() - err := bigquery.NewConfig(pc, mockCrypto).ParseAndValidate() + err := bigquery.NewConfig(pc, mockCrypto).ParseAndValidate(context.Background()) assert.Error(t, err) }) } diff --git a/plugins/providers/bigquery/provider.go b/plugins/providers/bigquery/provider.go index 56a8a8bcc..d1ffbd4d2 100644 --- a/plugins/providers/bigquery/provider.go +++ b/plugins/providers/bigquery/provider.go @@ -14,9 +14,9 @@ import ( "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/pkg/slices" "github.com/goto/guardian/utils" - "github.com/goto/salt/log" "github.com/mitchellh/mapstructure" "github.com/patrickmn/go-cache" "golang.org/x/sync/errgroup" @@ -105,15 +105,16 @@ func (p *Provider) GetType() string { func (p *Provider) CreateConfig(pc *domain.ProviderConfig) error { c := NewConfig(pc, p.encryptor) - if err := c.ParseAndValidate(); err != nil { + ctx := context.TODO() + if err := c.ParseAndValidate(ctx); err != nil { return err } - return c.EncryptCredentials() + return c.EncryptCredentials(ctx) } // GetResources returns BigQuery dataset and table resources -func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { +func (p *Provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return nil, err @@ -127,7 +128,7 @@ func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, resourceTypes := pc.GetResourceTypes() resources := []*domain.Resource{} - eg, ctx := errgroup.WithContext(context.TODO()) + eg, ctx := errgroup.WithContext(ctx) eg.SetLimit(20) var mu sync.Mutex @@ -153,7 +154,7 @@ func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, if datasetFilter != "" { v, err := evaluator.Expression(datasetFilter).EvaluateWithStruct(dataset) if err != nil { - p.logger.Error(fmt.Sprintf("evaluating filter expression %q for dataset %q: %v", datasetFilter, dataset.URN, err)) + p.logger.Error(ctx, fmt.Sprintf("evaluating filter expression %q for dataset %q: %v", datasetFilter, dataset.URN, err)) } if !reflect.ValueOf(v).IsZero() { resources = append(resources, dataset) @@ -193,7 +194,7 @@ func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, return resources, nil } -func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *Provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { if err := validateProviderConfigAndAppealParams(pc, a); err != nil { return err } @@ -208,7 +209,6 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error } permissions := getPermissions(a) - ctx := context.TODO() if a.Resource.Type == ResourceTypeDataset { d := new(Dataset) if err := d.FromDomain(a.Resource); err != nil { @@ -246,7 +246,7 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return ErrInvalidResourceType } -func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *Provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { if err := validateProviderConfigAndAppealParams(pc, a); err != nil { return err } @@ -261,7 +261,6 @@ func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error } permissions := getPermissions(a) - ctx := context.TODO() if a.Resource.Type == ResourceTypeDataset { d := new(Dataset) if err := d.FromDomain(a.Resource); err != nil { @@ -624,11 +623,11 @@ func (p *Provider) getGcloudRoles(ctx context.Context, pd domain.Provider) (map[ func (p *Provider) getGcloudPermissions(ctx context.Context, pd domain.Provider, gcloudRole string) ([]string, error) { roleID := translateDatasetRoleToBigQueryRole(gcloudRole) if permissions, exists := p.rolesCache.Get(roleID); exists { - p.logger.Debug("getting permissions from cache", "role", roleID) + p.logger.Debug(ctx, "getting permissions from cache", "role", roleID) return permissions.([]string), nil } - p.logger.Debug("getting permissions from gcloud", "role", roleID) + p.logger.Debug(ctx, "getting permissions from gcloud", "role", roleID) creds, err := ParseCredentials(pd.Config.Credentials, p.encryptor) if err != nil { return nil, fmt.Errorf("parsing credentials: %w", err) diff --git a/plugins/providers/bigquery/provider_test.go b/plugins/providers/bigquery/provider_test.go index eedb17b1a..571c8f29b 100644 --- a/plugins/providers/bigquery/provider_test.go +++ b/plugins/providers/bigquery/provider_test.go @@ -13,9 +13,9 @@ import ( "github.com/google/go-cmp/cmp" "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/plugins/providers/bigquery" "github.com/goto/guardian/plugins/providers/bigquery/mocks" - "github.com/goto/salt/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -272,6 +272,7 @@ func TestCreateConfig(t *testing.T) { } func TestGetResources(t *testing.T) { + ctx := context.Background() t.Run("should error when credentials are invalid", func(t *testing.T) { encryptor := new(mocks.Encryptor) l := log.NewNoop() @@ -282,7 +283,7 @@ func TestGetResources(t *testing.T) { Credentials: "invalid-creds", } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -360,7 +361,7 @@ func TestGetResources(t *testing.T) { }, }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -422,7 +423,7 @@ func TestGetResources(t *testing.T) { }, } expectedResources := append(resources, children...) - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(context.Background(), pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -431,6 +432,7 @@ func TestGetResources(t *testing.T) { } func TestGrantAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if Provider Config or Appeal doesn't have required parameters", func(t *testing.T) { testCases := []struct { name string @@ -494,7 +496,7 @@ func TestGrantAccess(t *testing.T) { pc := tc.providerConfig a := tc.grant - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, tc.expectedError.Error()) } }) @@ -523,7 +525,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Error(t, actualError) }) @@ -575,7 +577,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"VIEWER"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Equal(t, expectedError, actualError) }) @@ -627,7 +629,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"VIEWER"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -683,7 +685,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"VIEWER"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -691,6 +693,7 @@ func TestGrantAccess(t *testing.T) { } func TestRevokeAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if Provider Config or Appeal doesn't have required parameters", func(t *testing.T) { testCases := []struct { providerConfig *domain.ProviderConfig @@ -753,7 +756,7 @@ func TestRevokeAccess(t *testing.T) { pc := tc.providerConfig a := tc.grant - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, tc.expectedError.Error()) } }) @@ -782,7 +785,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, g) + actualError := p.RevokeAccess(ctx, pc, g) assert.Error(t, actualError) }) @@ -834,7 +837,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"VIEWER"}, } - actualError := p.RevokeAccess(pc, g) + actualError := p.RevokeAccess(ctx, pc, g) assert.Equal(t, expectedError, actualError) }) @@ -886,7 +889,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"VIEWER"}, } - actualError := p.RevokeAccess(pc, g) + actualError := p.RevokeAccess(ctx, pc, g) assert.Nil(t, actualError) }) @@ -941,7 +944,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"VIEWER"}, } - actualError := p.RevokeAccess(pc, g) + actualError := p.RevokeAccess(ctx, pc, g) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -1108,7 +1111,7 @@ func (s *BigQueryProviderTestSuite) TestListAccess() { expectedResourcesAccess := domain.MapResourceAccess{} expectedResources := []*domain.Resource{} s.mockBigQueryClient.EXPECT(). - ListAccess(mock.AnythingOfType("*context.emptyCtx"), expectedResources). + ListAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedResources). Return(expectedResourcesAccess, nil).Once() ctx := context.Background() @@ -1162,12 +1165,13 @@ func (s *BigQueryProviderTestSuite) TestGetActivities_Success() { `resource.type="bigquery_dataset"`, fmt.Sprintf(`protoPayload.methodName=("%s")`, strings.Join(bigquery.BigQueryAuditMetadataMethods, `" OR "`)), }, ` AND `) + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.mockCloudLoggingClient.EXPECT(). - ListLogEntries(mock.AnythingOfType("*context.emptyCtx"), expectedListLogEntriesFilter, 0).Return(expectedBigQueryActivities, nil).Once() + ListLogEntries(mockCtx, expectedListLogEntriesFilter, 0).Return(expectedBigQueryActivities, nil).Once() s.mockBigQueryClient.EXPECT(). - GetRolePermissions(mock.AnythingOfType("*context.emptyCtx"), "roles/bigquery.dataViewer").Return([]string{"bigquery.datasets.get"}, nil).Once() + GetRolePermissions(mockCtx, "roles/bigquery.dataViewer").Return([]string{"bigquery.datasets.get"}, nil).Once() s.mockBigQueryClient.EXPECT(). - GetRolePermissions(mock.AnythingOfType("*context.emptyCtx"), "roles/bigquery.dataEditor").Return([]string{"bigquery.datasets.get"}, nil).Once() + GetRolePermissions(mockCtx, "roles/bigquery.dataEditor").Return([]string{"bigquery.datasets.get"}, nil).Once() expectedActivities := []*domain.Activity{ { @@ -1252,7 +1256,7 @@ func (s *BigQueryProviderTestSuite) TestGetActivities_Success() { s.Run("should return error if there is an error on listing log entries", func() { expectedError := errors.New("error") s.mockCloudLoggingClient.EXPECT(). - ListLogEntries(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string"), 0).Return(nil, expectedError).Once() + ListLogEntries(mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.AnythingOfType("string"), 0).Return(nil, expectedError).Once() _, err := s.provider.GetActivities(context.Background(), *s.validProvider, domain.ListActivitiesFilter{}) @@ -1263,16 +1267,16 @@ func (s *BigQueryProviderTestSuite) TestGetActivities_Success() { func (s *BigQueryProviderTestSuite) TestListActivities() { timeNow := time.Now() - + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) s.Run("should return list of activity on success", func() { expectedLogBucket := &logging.LogBucket{ RetentionDays: 30, } s.mockCloudLoggingClient.EXPECT(). - GetLogBucket(mock.AnythingOfType("*context.emptyCtx"), fmt.Sprintf("projects/%s/locations/global/buckets/_Default", s.dummyProjectID)). + GetLogBucket(mockCtx, fmt.Sprintf("projects/%s/locations/global/buckets/_Default", s.dummyProjectID)). Return(expectedLogBucket, nil).Once() s.mockBigQueryClient.EXPECT(). - CheckGrantedPermission(mock.AnythingOfType("*context.emptyCtx"), []string{bigquery.PrivateLogViewerPermission}). + CheckGrantedPermission(mockCtx, []string{bigquery.PrivateLogViewerPermission}). Return([]string{bigquery.PrivateLogViewerPermission}, nil).Once() expectedBqActivities := []*bigquery.Activity{ { @@ -1289,7 +1293,7 @@ func (s *BigQueryProviderTestSuite) TestListActivities() { }, } s.mockCloudLoggingClient.EXPECT(). - ListLogEntries(mock.AnythingOfType("*context.emptyCtx"), mock.AnythingOfType("string"), 0). + ListLogEntries(mockCtx, mock.AnythingOfType("string"), 0). Return(expectedBqActivities, nil).Once() expectedActivities := []*domain.Activity{ @@ -1343,7 +1347,7 @@ func (s *BigQueryProviderTestSuite) TestListActivities() { s.Run("should return error if specified time range is more than the log bucket's retention period", func() { s.mockCloudLoggingClient.EXPECT(). - GetLogBucket(mock.AnythingOfType("*context.emptyCtx"), fmt.Sprintf("projects/%s/locations/global/buckets/_Default", s.dummyProjectID)). + GetLogBucket(mockCtx, fmt.Sprintf("projects/%s/locations/global/buckets/_Default", s.dummyProjectID)). Return(&logging.LogBucket{ RetentionDays: 30, }, nil).Once() @@ -1359,12 +1363,12 @@ func (s *BigQueryProviderTestSuite) TestListActivities() { s.Run("should return error if credentials doesn't have bigquery.privateLogViewer permission", func() { s.mockCloudLoggingClient.EXPECT(). - GetLogBucket(mock.AnythingOfType("*context.emptyCtx"), fmt.Sprintf("projects/%s/locations/global/buckets/_Default", s.dummyProjectID)). + GetLogBucket(mockCtx, fmt.Sprintf("projects/%s/locations/global/buckets/_Default", s.dummyProjectID)). Return(&logging.LogBucket{ RetentionDays: 30, }, nil).Once() s.mockBigQueryClient.EXPECT(). - CheckGrantedPermission(mock.AnythingOfType("*context.emptyCtx"), []string{bigquery.PrivateLogViewerPermission}). + CheckGrantedPermission(mockCtx, []string{bigquery.PrivateLogViewerPermission}). Return([]string{}, nil).Once() _, err := s.provider.ListActivities(context.Background(), *s.validProvider, domain.ListActivitiesFilter{}) @@ -1397,7 +1401,7 @@ func (s *BigQueryProviderTestSuite) TestCorrelateGrantActivities() { expectedUniqueRoles := []string{"role-1", "role-2", "role-3", "role-4"} s.mockBigQueryClient.EXPECT(). - ListRolePermissions(mock.AnythingOfType("*context.emptyCtx"), expectedUniqueRoles). + ListRolePermissions(mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedUniqueRoles). Return(dummyRolePermissions, nil).Once() expectedAssociatedGrants := map[string][]string{ "g1": {"a1", "a3", "a4"}, diff --git a/plugins/providers/client.go b/plugins/providers/client.go index d76e75d55..ad75082fd 100644 --- a/plugins/providers/client.go +++ b/plugins/providers/client.go @@ -9,9 +9,9 @@ import ( type Client interface { GetType() string CreateConfig(*domain.ProviderConfig) error - GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) - GrantAccess(*domain.ProviderConfig, domain.Grant) error - RevokeAccess(*domain.ProviderConfig, domain.Grant) error + GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) + GrantAccess(context.Context, *domain.ProviderConfig, domain.Grant) error + RevokeAccess(context.Context, *domain.ProviderConfig, domain.Grant) error GetRoles(pc *domain.ProviderConfig, resourceType string) ([]*domain.Role, error) GetAccountTypes() []string ListAccess(context.Context, domain.ProviderConfig, []*domain.Resource) (domain.MapResourceAccess, error) diff --git a/plugins/providers/dataplex/config_test.go b/plugins/providers/dataplex/config_test.go index 080e940e0..099db765b 100644 --- a/plugins/providers/dataplex/config_test.go +++ b/plugins/providers/dataplex/config_test.go @@ -1,6 +1,7 @@ package dataplex_test import ( + "context" "encoding/base64" "errors" "testing" @@ -165,7 +166,7 @@ func TestValidate(t *testing.T) { } mockCrypto.On("Encrypt", mock.Anything).Return("", nil).Once() - err := bigquery.NewConfig(pc, mockCrypto).ParseAndValidate() + err := bigquery.NewConfig(pc, mockCrypto).ParseAndValidate(context.Background()) assert.Error(t, err) }) } diff --git a/plugins/providers/dataplex/provider.go b/plugins/providers/dataplex/provider.go index 29ec97005..01c7263e5 100644 --- a/plugins/providers/dataplex/provider.go +++ b/plugins/providers/dataplex/provider.go @@ -54,7 +54,7 @@ func (p *Provider) CreateConfig(pc *domain.ProviderConfig) error { return c.EncryptCredentials() } -func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { +func (p *Provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return nil, err @@ -82,7 +82,7 @@ func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, return resources, nil } -func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *Provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { if err := validateProviderConfigAndAppealParams(pc, a); err != nil { return err } @@ -98,7 +98,6 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error permissions := getPermissions(a) - ctx := context.Background() if a.Resource.Type == ResourceTypeTag { policy := new(Policy) policy.FromDomain(a.Resource) @@ -116,7 +115,7 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return ErrInvalidResourceType } -func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *Provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { if err := validateProviderConfigAndAppealParams(pc, a); err != nil { return err } @@ -131,7 +130,6 @@ func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error } permissions := getPermissions(a) - ctx := context.Background() if a.Resource.Type == ResourceTypeTag { policy := new(Policy) diff --git a/plugins/providers/dataplex/provider_test.go b/plugins/providers/dataplex/provider_test.go index 603f2475c..68ccbe603 100644 --- a/plugins/providers/dataplex/provider_test.go +++ b/plugins/providers/dataplex/provider_test.go @@ -236,6 +236,7 @@ func TestCreateConfig(t *testing.T) { } func TestGetResources(t *testing.T) { + ctx := context.Background() t.Run("should error when credentials are invalid", func(t *testing.T) { encryptor := new(mocks.Encryptor) p := dataplex.NewProvider("", encryptor) @@ -245,7 +246,7 @@ func TestGetResources(t *testing.T) { Credentials: "invalid-creds", } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -303,7 +304,7 @@ func TestGetResources(t *testing.T) { }, }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -312,6 +313,8 @@ func TestGetResources(t *testing.T) { } func TestGrantAccess(t *testing.T) { + ctx := context.Background() + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) t.Run("should return error if Provider Config or Appeal doesn't have required parameters", func(t *testing.T) { testCases := []struct { name string @@ -375,7 +378,7 @@ func TestGrantAccess(t *testing.T) { pc := tc.providerConfig a := tc.grant - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, tc.expectedError.Error()) } }) @@ -404,7 +407,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Error(t, actualError) }) @@ -423,7 +426,7 @@ func TestGrantAccess(t *testing.T) { ResourceName: "projects/resource-name/locations/us", } policy := &dataplex.Policy{} - client.On("GrantPolicyAccess", mock.AnythingOfType("*context.emptyCtx"), policy, "user:test@email.com", "roles/datacatalog.categoryFineGrainedReader").Return(expectedError).Once() + client.On("GrantPolicyAccess", mockCtx, policy, "user:test@email.com", "roles/datacatalog.categoryFineGrainedReader").Return(expectedError).Once() pc := &domain.ProviderConfig{ Type: "dataplex", @@ -456,7 +459,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"roles/datacatalog.categoryFineGrainedReader"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Equal(t, expectedError, actualError) }) @@ -480,7 +483,7 @@ func TestGrantAccess(t *testing.T) { DisplayName: "", Description: "", } - client.On("GrantPolicyAccess", mock.AnythingOfType("*context.emptyCtx"), policy, "user:test@email.com", "roles/datacatalog.categoryFineGrainedReader").Return(dataplex.ErrPermissionAlreadyExists).Once() + client.On("GrantPolicyAccess", mockCtx, policy, "user:test@email.com", "roles/datacatalog.categoryFineGrainedReader").Return(dataplex.ErrPermissionAlreadyExists).Once() pc := &domain.ProviderConfig{ Type: "dataplex", @@ -515,7 +518,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"roles/datacatalog.categoryFineGrainedReader"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -523,6 +526,7 @@ func TestGrantAccess(t *testing.T) { } func TestRevokeAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if Provider Config or Appeal doesn't have required parameters", func(t *testing.T) { testCases := []struct { providerConfig *domain.ProviderConfig @@ -585,7 +589,7 @@ func TestRevokeAccess(t *testing.T) { pc := tc.providerConfig a := tc.grant - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, tc.expectedError.Error()) } }) @@ -614,7 +618,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, g) + actualError := p.RevokeAccess(ctx, pc, g) assert.Error(t, actualError) }) @@ -635,7 +639,7 @@ func TestRevokeAccess(t *testing.T) { } policy := &dataplex.Policy{} - client.On("RevokePolicyAccess", mock.AnythingOfType("*context.emptyCtx"), policy, "user:test@email.com", "roles/datacatalog.categoryFineGrainedReader").Return(expectedError).Once() + client.On("RevokePolicyAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), policy, "user:test@email.com", "roles/datacatalog.categoryFineGrainedReader").Return(expectedError).Once() pc := &domain.ProviderConfig{ Type: "dataplex", @@ -668,7 +672,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"roles/datacatalog.categoryFineGrainedReader"}, } - actualError := p.RevokeAccess(pc, g) + actualError := p.RevokeAccess(ctx, pc, g) assert.Equal(t, expectedError, actualError) }) @@ -688,7 +692,7 @@ func TestRevokeAccess(t *testing.T) { } policy := &dataplex.Policy{} - client.On("RevokePolicyAccess", mock.AnythingOfType("*context.emptyCtx"), policy, "user:test@email.com", "roles/datacatalog.categoryFineGrainedReader").Return(dataplex.ErrPermissionNotFound).Once() + client.On("RevokePolicyAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), policy, "user:test@email.com", "roles/datacatalog.categoryFineGrainedReader").Return(dataplex.ErrPermissionNotFound).Once() pc := &domain.ProviderConfig{ Type: "dataplex", @@ -721,7 +725,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"roles/datacatalog.categoryFineGrainedReader"}, } - actualError := p.RevokeAccess(pc, g) + actualError := p.RevokeAccess(ctx, pc, g) assert.Nil(t, actualError) }) diff --git a/plugins/providers/gcloudiam/client.go b/plugins/providers/gcloudiam/client.go index fff88dc5c..7e9e0cec5 100644 --- a/plugins/providers/gcloudiam/client.go +++ b/plugins/providers/gcloudiam/client.go @@ -95,8 +95,7 @@ func (c *iamClient) GetGrantableRoles(ctx context.Context, resourceType string) return roles, nil } -func (c *iamClient) GrantAccess(accountType, accountID, role string) error { - ctx := context.TODO() +func (c *iamClient) GrantAccess(ctx context.Context, accountType, accountID, role string) error { policy, err := c.getIamPolicy(ctx) if err != nil { return err @@ -124,8 +123,7 @@ func (c *iamClient) GrantAccess(accountType, accountID, role string) error { return err } -func (c *iamClient) RevokeAccess(accountType, accountID, role string) error { - ctx := context.TODO() +func (c *iamClient) RevokeAccess(ctx context.Context, accountType, accountID, role string) error { policy, err := c.getIamPolicy(ctx) if err != nil { return err diff --git a/plugins/providers/gcloudiam/mocks/GcloudIamClient.go b/plugins/providers/gcloudiam/mocks/GcloudIamClient.go index 8d687aea3..fd55967f7 100644 --- a/plugins/providers/gcloudiam/mocks/GcloudIamClient.go +++ b/plugins/providers/gcloudiam/mocks/GcloudIamClient.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. +// Code generated by mockery v2.32.0. DO NOT EDIT. package mocks @@ -80,13 +80,13 @@ func (_c *GcloudIamClient_GetGrantableRoles_Call) RunAndReturn(run func(context. return _c } -// GrantAccess provides a mock function with given fields: accountType, accountID, role -func (_m *GcloudIamClient) GrantAccess(accountType string, accountID string, role string) error { - ret := _m.Called(accountType, accountID, role) +// GrantAccess provides a mock function with given fields: ctx, accountType, accountID, role +func (_m *GcloudIamClient) GrantAccess(ctx context.Context, accountType string, accountID string, role string) error { + ret := _m.Called(ctx, accountType, accountID, role) var r0 error - if rf, ok := ret.Get(0).(func(string, string, string) error); ok { - r0 = rf(accountType, accountID, role) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, accountType, accountID, role) } else { r0 = ret.Error(0) } @@ -100,16 +100,17 @@ type GcloudIamClient_GrantAccess_Call struct { } // GrantAccess is a helper method to define mock.On call +// - ctx context.Context // - accountType string // - accountID string // - role string -func (_e *GcloudIamClient_Expecter) GrantAccess(accountType interface{}, accountID interface{}, role interface{}) *GcloudIamClient_GrantAccess_Call { - return &GcloudIamClient_GrantAccess_Call{Call: _e.mock.On("GrantAccess", accountType, accountID, role)} +func (_e *GcloudIamClient_Expecter) GrantAccess(ctx interface{}, accountType interface{}, accountID interface{}, role interface{}) *GcloudIamClient_GrantAccess_Call { + return &GcloudIamClient_GrantAccess_Call{Call: _e.mock.On("GrantAccess", ctx, accountType, accountID, role)} } -func (_c *GcloudIamClient_GrantAccess_Call) Run(run func(accountType string, accountID string, role string)) *GcloudIamClient_GrantAccess_Call { +func (_c *GcloudIamClient_GrantAccess_Call) Run(run func(ctx context.Context, accountType string, accountID string, role string)) *GcloudIamClient_GrantAccess_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -119,7 +120,7 @@ func (_c *GcloudIamClient_GrantAccess_Call) Return(_a0 error) *GcloudIamClient_G return _c } -func (_c *GcloudIamClient_GrantAccess_Call) RunAndReturn(run func(string, string, string) error) *GcloudIamClient_GrantAccess_Call { +func (_c *GcloudIamClient_GrantAccess_Call) RunAndReturn(run func(context.Context, string, string, string) error) *GcloudIamClient_GrantAccess_Call { _c.Call.Return(run) return _c } @@ -279,13 +280,13 @@ func (_c *GcloudIamClient_ListServiceAccounts_Call) RunAndReturn(run func(contex return _c } -// RevokeAccess provides a mock function with given fields: accountType, accountID, role -func (_m *GcloudIamClient) RevokeAccess(accountType string, accountID string, role string) error { - ret := _m.Called(accountType, accountID, role) +// RevokeAccess provides a mock function with given fields: ctx, accountType, accountID, role +func (_m *GcloudIamClient) RevokeAccess(ctx context.Context, accountType string, accountID string, role string) error { + ret := _m.Called(ctx, accountType, accountID, role) var r0 error - if rf, ok := ret.Get(0).(func(string, string, string) error); ok { - r0 = rf(accountType, accountID, role) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, accountType, accountID, role) } else { r0 = ret.Error(0) } @@ -299,16 +300,17 @@ type GcloudIamClient_RevokeAccess_Call struct { } // RevokeAccess is a helper method to define mock.On call +// - ctx context.Context // - accountType string // - accountID string // - role string -func (_e *GcloudIamClient_Expecter) RevokeAccess(accountType interface{}, accountID interface{}, role interface{}) *GcloudIamClient_RevokeAccess_Call { - return &GcloudIamClient_RevokeAccess_Call{Call: _e.mock.On("RevokeAccess", accountType, accountID, role)} +func (_e *GcloudIamClient_Expecter) RevokeAccess(ctx interface{}, accountType interface{}, accountID interface{}, role interface{}) *GcloudIamClient_RevokeAccess_Call { + return &GcloudIamClient_RevokeAccess_Call{Call: _e.mock.On("RevokeAccess", ctx, accountType, accountID, role)} } -func (_c *GcloudIamClient_RevokeAccess_Call) Run(run func(accountType string, accountID string, role string)) *GcloudIamClient_RevokeAccess_Call { +func (_c *GcloudIamClient_RevokeAccess_Call) Run(run func(ctx context.Context, accountType string, accountID string, role string)) *GcloudIamClient_RevokeAccess_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -318,7 +320,7 @@ func (_c *GcloudIamClient_RevokeAccess_Call) Return(_a0 error) *GcloudIamClient_ return _c } -func (_c *GcloudIamClient_RevokeAccess_Call) RunAndReturn(run func(string, string, string) error) *GcloudIamClient_RevokeAccess_Call { +func (_c *GcloudIamClient_RevokeAccess_Call) RunAndReturn(run func(context.Context, string, string, string) error) *GcloudIamClient_RevokeAccess_Call { _c.Call.Return(run) return _c } @@ -369,13 +371,12 @@ func (_c *GcloudIamClient_RevokeServiceAccountAccess_Call) RunAndReturn(run func return _c } -type mockConstructorTestingTNewGcloudIamClient interface { +// NewGcloudIamClient creates a new instance of GcloudIamClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewGcloudIamClient(t interface { mock.TestingT Cleanup(func()) -} - -// NewGcloudIamClient creates a new instance of GcloudIamClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewGcloudIamClient(t mockConstructorTestingTNewGcloudIamClient) *GcloudIamClient { +}) *GcloudIamClient { mock := &GcloudIamClient{} mock.Mock.Test(t) diff --git a/plugins/providers/gcloudiam/provider.go b/plugins/providers/gcloudiam/provider.go index 81148c77e..9d03672fe 100644 --- a/plugins/providers/gcloudiam/provider.go +++ b/plugins/providers/gcloudiam/provider.go @@ -6,7 +6,7 @@ import ( "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/mitchellh/mapstructure" "golang.org/x/net/context" "google.golang.org/api/iam/v1" @@ -15,8 +15,8 @@ import ( //go:generate mockery --name=GcloudIamClient --exported --with-expecter type GcloudIamClient interface { GetGrantableRoles(ctx context.Context, resourceType string) ([]*iam.Role, error) - GrantAccess(accountType, accountID, role string) error - RevokeAccess(accountType, accountID, role string) error + GrantAccess(ctx context.Context, accountType, accountID, role string) error + RevokeAccess(ctx context.Context, accountType, accountID, role string) error ListAccess(ctx context.Context, resources []*domain.Resource) (domain.MapResourceAccess, error) ListServiceAccounts(context.Context) ([]*iam.ServiceAccount, error) GrantServiceAccountAccess(ctx context.Context, sa, accountType, accountID, roles string) error @@ -72,7 +72,7 @@ func (p *Provider) CreateConfig(pc *domain.ProviderConfig) error { return c.EncryptCredentials() } -func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { +func (p *Provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { resources := []*domain.Resource{} for _, rc := range pc.Resources { @@ -96,7 +96,7 @@ func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, return nil, fmt.Errorf("initializing iam client: %w", err) } - serviceAccounts, err := client.ListServiceAccounts(context.TODO()) + serviceAccounts, err := client.ListServiceAccounts(ctx) if err != nil { return nil, fmt.Errorf("listing service accounts: %w", err) } @@ -121,7 +121,7 @@ func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, return resources, nil } -func (p *Provider) GrantAccess(pc *domain.ProviderConfig, g domain.Grant) error { +func (p *Provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, g domain.Grant) error { // TODO: validate provider config and appeal var creds Credentials @@ -137,7 +137,7 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, g domain.Grant) error switch g.Resource.Type { case ResourceTypeProject, ResourceTypeOrganization: for _, p := range g.Permissions { - if err := client.GrantAccess(g.AccountType, g.AccountID, p); err != nil { + if err := client.GrantAccess(ctx, g.AccountType, g.AccountID, p); err != nil { if !errors.Is(err, ErrPermissionAlreadyExists) { return err } @@ -147,7 +147,7 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, g domain.Grant) error case ResourceTypeServiceAccount: for _, p := range g.Permissions { - if err := client.GrantServiceAccountAccess(context.TODO(), g.Resource.URN, g.AccountType, g.AccountID, p); err != nil { + if err := client.GrantServiceAccountAccess(ctx, g.Resource.URN, g.AccountType, g.AccountID, p); err != nil { if !errors.Is(err, ErrPermissionAlreadyExists) { return err } @@ -160,7 +160,7 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, g domain.Grant) error } } -func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, g domain.Grant) error { +func (p *Provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, g domain.Grant) error { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return err @@ -174,7 +174,7 @@ func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, g domain.Grant) error switch g.Resource.Type { case ResourceTypeProject, ResourceTypeOrganization: for _, p := range g.Permissions { - if err := client.RevokeAccess(g.AccountType, g.AccountID, p); err != nil { + if err := client.RevokeAccess(ctx, g.AccountType, g.AccountID, p); err != nil { if !errors.Is(err, ErrPermissionNotFound) { return err } @@ -184,7 +184,7 @@ func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, g domain.Grant) error case ResourceTypeServiceAccount: for _, p := range g.Permissions { - if err := client.RevokeServiceAccountAccess(context.TODO(), g.Resource.URN, g.AccountType, g.AccountID, p); err != nil { + if err := client.RevokeServiceAccountAccess(ctx, g.Resource.URN, g.AccountType, g.AccountID, p); err != nil { if !errors.Is(err, ErrPermissionNotFound) { return err } diff --git a/plugins/providers/gcloudiam/provider_test.go b/plugins/providers/gcloudiam/provider_test.go index d6a1ba6e9..d3ba03657 100644 --- a/plugins/providers/gcloudiam/provider_test.go +++ b/plugins/providers/gcloudiam/provider_test.go @@ -1,11 +1,12 @@ package gcloudiam_test import ( + "context" "encoding/base64" "errors" "testing" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/domain" "github.com/goto/guardian/plugins/providers/gcloudiam" @@ -202,7 +203,7 @@ func TestCreateConfig(t *testing.T) { }, } client.EXPECT(). - GetGrantableRoles(mock.AnythingOfType("*context.emptyCtx"), gcloudiam.ResourceTypeProject). + GetGrantableRoles(mock.MatchedBy(func(ctx context.Context) bool { return true }), gcloudiam.ResourceTypeProject). Return(gCloudRolesList, nil).Once() pc := &domain.ProviderConfig{ @@ -248,7 +249,7 @@ func TestCreateConfig(t *testing.T) { }, } client.EXPECT(). - GetGrantableRoles(mock.AnythingOfType("*context.emptyCtx"), gcloudiam.ResourceTypeProject). + GetGrantableRoles(mock.MatchedBy(func(ctx context.Context) bool { return true }), gcloudiam.ResourceTypeProject). Return(gCloudRolesList, nil).Once() crypto.On("Encrypt", `{"type":"service_account"}`).Return(`{"type":"service_account"}`, nil) @@ -294,6 +295,8 @@ func TestGetType(t *testing.T) { } func TestGetResources(t *testing.T) { + ctx := context.Background() + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) t.Run("should error when credentials are invalid", func(t *testing.T) { crypto := new(mocks.Encryptor) l := log.NewNoop() @@ -307,7 +310,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -343,10 +346,10 @@ func TestGetResources(t *testing.T) { }, } client.EXPECT(). - GetGrantableRoles(mock.AnythingOfType("*context.emptyCtx"), gcloudiam.ResourceTypeProject). + GetGrantableRoles(mockCtx, gcloudiam.ResourceTypeProject). Return(projectRoles, nil).Once() client.EXPECT(). - GetGrantableRoles(mock.AnythingOfType("*context.emptyCtx"), gcloudiam.ResourceTypeServiceAccount). + GetGrantableRoles(mockCtx, gcloudiam.ResourceTypeServiceAccount). Return(saRoles, nil).Once() expectedServiceAccounts := []*iam.ServiceAccount{ @@ -356,7 +359,7 @@ func TestGetResources(t *testing.T) { }, } client.EXPECT(). - ListServiceAccounts(mock.AnythingOfType("*context.emptyCtx")). + ListServiceAccounts(mockCtx). Return(expectedServiceAccounts, nil).Once() pc := &domain.ProviderConfig{ @@ -411,7 +414,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) }) @@ -425,7 +428,7 @@ func TestGetResources(t *testing.T) { }, } client.EXPECT(). - GetGrantableRoles(mock.AnythingOfType("*context.emptyCtx"), gcloudiam.ResourceTypeOrganization). + GetGrantableRoles(mockCtx, gcloudiam.ResourceTypeOrganization). Return(gCloudRolesList, nil).Once() pc := &domain.ProviderConfig{ Type: domain.ProviderTypeGCloudIAM, @@ -457,7 +460,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -474,7 +477,7 @@ func TestGetResources(t *testing.T) { {Type: "invalid-resource-type"}, }, } - _, err := p.GetResources(pc) + _, err := p.GetResources(ctx, pc) assert.ErrorIs(t, err, gcloudiam.ErrInvalidResourceType) }) @@ -494,14 +497,14 @@ func TestGetResources(t *testing.T) { }, } - _, actualError := p.GetResources(pc) + _, actualError := p.GetResources(ctx, pc) assert.Error(t, actualError) }) t.Run("should return error if client returns an error", func(t *testing.T) { expectedError := errors.New("client error") - client.On("ListServiceAccounts", mock.AnythingOfType("*context.emptyCtx")).Return(nil, expectedError).Once() + client.On("ListServiceAccounts", mockCtx).Return(nil, expectedError).Once() pc := &domain.ProviderConfig{ Type: domain.ProviderTypeGCloudIAM, @@ -516,7 +519,7 @@ func TestGetResources(t *testing.T) { }, } - _, actualError := p.GetResources(pc) + _, actualError := p.GetResources(ctx, pc) assert.ErrorIs(t, actualError, expectedError) }) @@ -524,6 +527,7 @@ func TestGetResources(t *testing.T) { } func TestGrantAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Encryptor) l := log.NewNoop() @@ -544,7 +548,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -575,13 +579,14 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) t.Run("should return error if there is an error in granting the access", func(t *testing.T) { expectedError := errors.New("client error in granting access") + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) testCases := []struct { name string resourceType string @@ -594,7 +599,7 @@ func TestGrantAccess(t *testing.T) { expectedError: expectedError, setExpectationFunc: func(c *mocks.GcloudIamClient) { c.EXPECT(). - GrantAccess(mock.Anything, mock.Anything, mock.Anything). + GrantAccess(mockCtx, mock.Anything, mock.Anything, mock.Anything). Return(expectedError).Once() }, }, @@ -604,7 +609,7 @@ func TestGrantAccess(t *testing.T) { expectedError: expectedError, setExpectationFunc: func(c *mocks.GcloudIamClient) { c.EXPECT(). - GrantAccess(mock.Anything, mock.Anything, mock.Anything). + GrantAccess(mockCtx, mock.Anything, mock.Anything, mock.Anything). Return(expectedError).Once() }, }, @@ -614,7 +619,7 @@ func TestGrantAccess(t *testing.T) { expectedError: expectedError, setExpectationFunc: func(c *mocks.GcloudIamClient) { c.EXPECT(). - GrantServiceAccountAccess(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + GrantServiceAccountAccess(mockCtx, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(expectedError).Once() }, }, @@ -663,7 +668,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"permission-1"}, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, tc.expectedError.Error()) }) @@ -683,7 +688,7 @@ func TestGrantAccess(t *testing.T) { p.Clients = map[string]gcloudiam.GcloudIamClient{ providerURN: client, } - client.On("GrantAccess", expectedAccountType, expectedAccountID, expectedPermission).Return(nil).Once() + client.On("GrantAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedAccountType, expectedAccountID, expectedPermission).Return(nil).Once() pc := &domain.ProviderConfig{ Resources: []*domain.ResourceConfig{ @@ -718,7 +723,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"roles/bigquery.admin"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Nil(t, actualError) }) @@ -759,15 +764,16 @@ func TestGrantAccess(t *testing.T) { } client.EXPECT(). - GrantServiceAccountAccess(mock.AnythingOfType("*context.emptyCtx"), g.Resource.URN, g.AccountType, g.AccountID, g.Permissions[0]). + GrantServiceAccountAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), g.Resource.URN, g.AccountType, g.AccountID, g.Permissions[0]). Return(nil).Once() - err := p.GrantAccess(pc, g) + err := p.GrantAccess(ctx, pc, g) assert.NoError(t, err) }) } func TestRevokeAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if resource type is unknown", func(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Encryptor) @@ -794,13 +800,14 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) t.Run("should return error if there is an error in revoking the access", func(t *testing.T) { expectedError := errors.New("client error in revoking access") + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) testCases := []struct { name string resourceType string @@ -813,7 +820,7 @@ func TestRevokeAccess(t *testing.T) { expectedError: expectedError, setExpectationFunc: func(c *mocks.GcloudIamClient) { c.EXPECT(). - RevokeAccess(mock.Anything, mock.Anything, mock.Anything). + RevokeAccess(mockCtx, mock.Anything, mock.Anything, mock.Anything). Return(expectedError).Once() }, }, @@ -823,7 +830,7 @@ func TestRevokeAccess(t *testing.T) { expectedError: expectedError, setExpectationFunc: func(c *mocks.GcloudIamClient) { c.EXPECT(). - RevokeAccess(mock.Anything, mock.Anything, mock.Anything). + RevokeAccess(mockCtx, mock.Anything, mock.Anything, mock.Anything). Return(expectedError).Once() }, }, @@ -833,7 +840,7 @@ func TestRevokeAccess(t *testing.T) { expectedError: expectedError, setExpectationFunc: func(c *mocks.GcloudIamClient) { c.EXPECT(). - RevokeServiceAccountAccess(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + RevokeServiceAccountAccess(mockCtx, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(expectedError).Once() }, }, @@ -882,7 +889,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"permission-1"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, tc.expectedError.Error()) }) @@ -936,7 +943,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -977,10 +984,10 @@ func TestRevokeAccess(t *testing.T) { } client.EXPECT(). - RevokeServiceAccountAccess(mock.AnythingOfType("*context.emptyCtx"), g.Resource.URN, g.AccountType, g.AccountID, g.Permissions[0]). + RevokeServiceAccountAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), g.Resource.URN, g.AccountType, g.AccountID, g.Permissions[0]). Return(nil).Once() - err := p.RevokeAccess(pc, g) + err := p.RevokeAccess(ctx, pc, g) assert.NoError(t, err) }) } diff --git a/plugins/providers/gcs/client.go b/plugins/providers/gcs/client.go index 98b1a1716..04ba3571f 100644 --- a/plugins/providers/gcs/client.go +++ b/plugins/providers/gcs/client.go @@ -19,8 +19,8 @@ type gcsClient struct { projectID string } -func newGCSClient(projectID string, credentialsJSON []byte) (*gcsClient, error) { - client, err := storage.NewClient(context.TODO(), option.WithCredentialsJSON(credentialsJSON)) +func newGCSClient(ctx context.Context, projectID string, credentialsJSON []byte) (*gcsClient, error) { + client, err := storage.NewClient(ctx, option.WithCredentialsJSON(credentialsJSON)) if err != nil { return nil, err } diff --git a/plugins/providers/gcs/provider.go b/plugins/providers/gcs/provider.go index 27f1e7109..1d74c9d9c 100644 --- a/plugins/providers/gcs/provider.go +++ b/plugins/providers/gcs/provider.go @@ -63,8 +63,8 @@ func (p *Provider) CreateConfig(pc *domain.ProviderConfig) error { return nil } -func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { - client, err := p.getGCSClient(*pc) +func (p *Provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { + client, err := p.getGCSClient(ctx, *pc) if err != nil { return nil, err } @@ -75,7 +75,7 @@ func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, } var resources []*domain.Resource - buckets, err := client.GetBuckets(context.TODO()) + buckets, err := client.GetBuckets(ctx) if err != nil { return nil, err } @@ -95,14 +95,14 @@ func (p *Provider) GetType() string { return p.typeName } -func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *Provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { if err := validateProviderConfigAndAppealParams(pc, a); err != nil { return fmt.Errorf("invalid provider/appeal config: %w", err) } permissions := getPermissions(a) - client, err := p.getGCSClient(*pc) + client, err := p.getGCSClient(ctx, *pc) if err != nil { return fmt.Errorf("error in getting new client: %w", err) } @@ -115,7 +115,7 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error } for _, p := range permissions { role := iam.RoleName(string(p)) - if err := client.GrantBucketAccess(context.TODO(), *b, identity, role); err != nil { + if err := client.GrantBucketAccess(ctx, *b, identity, role); err != nil { if errors.Is(err, ErrPermissionAlreadyExists) { return nil } @@ -127,14 +127,14 @@ func (p *Provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return ErrInvalidResourceType } -func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *Provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { if err := validateProviderConfigAndAppealParams(pc, a); err != nil { return fmt.Errorf("invalid provider/appeal config: %w", err) } permissions := getPermissions(a) - client, err := p.getGCSClient(*pc) + client, err := p.getGCSClient(ctx, *pc) if err != nil { return fmt.Errorf("error in getting new client: %w", err) } @@ -148,7 +148,7 @@ func (p *Provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error } for _, p := range permissions { var role iam.RoleName = iam.RoleName(string(p)) - if err := client.RevokeBucketAccess(context.TODO(), *b, identity, role); err != nil { + if err := client.RevokeBucketAccess(ctx, *b, identity, role); err != nil { if errors.Is(err, ErrPermissionAlreadyExists) { return nil } @@ -193,7 +193,7 @@ func getPermissions(a domain.Grant) []Permission { } func (p *Provider) ListAccess(ctx context.Context, pc domain.ProviderConfig, resources []*domain.Resource) (domain.MapResourceAccess, error) { - client, err := p.getGCSClient(pc) + client, err := p.getGCSClient(ctx, pc) if err != nil { return nil, err } @@ -201,7 +201,7 @@ func (p *Provider) ListAccess(ctx context.Context, pc domain.ProviderConfig, res return client.ListAccess(ctx, resources) } -func (p *Provider) getGCSClient(pc domain.ProviderConfig) (GCSClient, error) { +func (p *Provider) getGCSClient(ctx context.Context, pc domain.ProviderConfig) (GCSClient, error) { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return nil, fmt.Errorf("decoding credentials: %w", err) @@ -216,7 +216,7 @@ func (p *Provider) getGCSClient(pc domain.ProviderConfig) (GCSClient, error) { return p.Clients[projectID], nil } - client, err := newGCSClient(projectID, []byte(creds.ServiceAccountKey)) + client, err := newGCSClient(ctx, projectID, []byte(creds.ServiceAccountKey)) if err != nil { return nil, err } diff --git a/plugins/providers/gcs/provider_test.go b/plugins/providers/gcs/provider_test.go index 11ae3b741..f610eb5e6 100644 --- a/plugins/providers/gcs/provider_test.go +++ b/plugins/providers/gcs/provider_test.go @@ -152,6 +152,7 @@ func TestCreateConfig(t *testing.T) { } func TestGetResources(t *testing.T) { + ctx := context.Background() t.Run("should return error if error in decoding credentials", func(t *testing.T) { p := initProvider() @@ -169,7 +170,7 @@ func TestGetResources(t *testing.T) { }, }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -208,7 +209,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -261,7 +262,7 @@ func TestGetResources(t *testing.T) { Name: "test-bucket-name", }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -270,6 +271,7 @@ func TestGetResources(t *testing.T) { } func TestGrantAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if Provider Config or Appeal doesn't have required parameters", func(t *testing.T) { testCases := []struct { name string @@ -338,7 +340,7 @@ func TestGrantAccess(t *testing.T) { pc := tc.providerConfig g := tc.grant - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.EqualError(t, actualError, tc.expectedError.Error()) }) } @@ -367,7 +369,7 @@ func TestGrantAccess(t *testing.T) { }, Role: "test-role", } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Error(t, actualError) }) @@ -421,7 +423,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"Storage Legacy Bucket Writer"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Error(t, actualError) }) @@ -472,7 +474,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"Storage Legacy Bucket Writer"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Error(t, actualError) }) @@ -529,13 +531,14 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"Storage Legacy Bucket Writer"}, } - actualError := p.GrantAccess(pc, g) + actualError := p.GrantAccess(ctx, pc, g) assert.Nil(t, actualError) client.AssertExpectations(t) }) } func TestRevokeAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if Provider Config or Appeal doesn't have required parameters", func(t *testing.T) { testCases := []struct { name string @@ -604,7 +607,7 @@ func TestRevokeAccess(t *testing.T) { pc := tc.providerConfig a := tc.grant - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, tc.expectedError.Error()) }) } @@ -633,7 +636,7 @@ func TestRevokeAccess(t *testing.T) { }, Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -687,7 +690,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"Storage Legacy Bucket Writer"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -738,7 +741,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"Storage Legacy Bucket Writer"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -795,7 +798,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"Storage Legacy Bucket Writer"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) }) @@ -866,7 +869,7 @@ func TestListAccess(t *testing.T) { dummyResources := []*domain.Resource{} expectedResourcesAccess := domain.MapResourceAccess{} client.EXPECT(). - ListAccess(mock.AnythingOfType("*context.emptyCtx"), dummyResources). + ListAccess(mock.MatchedBy(func(ctx context.Context) bool { return true }), dummyResources). Return(expectedResourcesAccess, nil).Once() actualResourcesAccess, err := p.ListAccess(context.Background(), *dummyProviderConfig, dummyResources) diff --git a/plugins/providers/grafana/provider.go b/plugins/providers/grafana/provider.go index 61ca20f7a..e8c4d9385 100644 --- a/plugins/providers/grafana/provider.go +++ b/plugins/providers/grafana/provider.go @@ -1,6 +1,8 @@ package grafana import ( + "context" + pv "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" "github.com/mitchellh/mapstructure" @@ -37,7 +39,7 @@ func (p *provider) CreateConfig(pc *domain.ProviderConfig) error { return c.EncryptCredentials() } -func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { +func (p *provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return nil, err @@ -69,7 +71,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, return resources, nil } -func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return err @@ -98,7 +100,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return ErrInvalidResourceType } -func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return err diff --git a/plugins/providers/grafana/provider_test.go b/plugins/providers/grafana/provider_test.go index a8246e460..d8652462d 100644 --- a/plugins/providers/grafana/provider_test.go +++ b/plugins/providers/grafana/provider_test.go @@ -1,6 +1,7 @@ package grafana_test import ( + "context" "errors" "testing" @@ -185,6 +186,7 @@ func TestCreateConfig(t *testing.T) { } func TestGetResources(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) p := grafana.NewProvider("", crypto) @@ -193,7 +195,7 @@ func TestGetResources(t *testing.T) { Credentials: "invalid-creds", } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -211,7 +213,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -233,7 +235,7 @@ func TestGetResources(t *testing.T) { expectedError := errors.New("client error") client.On("GetFolders").Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -262,7 +264,7 @@ func TestGetResources(t *testing.T) { client.On("GetFolders").Return(expectedFolders, nil).Once() client.On("GetDashboards", 1).Return(nil, expectedError).Times(2) - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -305,7 +307,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -313,6 +315,7 @@ func TestGetResources(t *testing.T) { } func TestGrantAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) p := grafana.NewProvider("", crypto) @@ -338,7 +341,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -374,7 +377,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -412,7 +415,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -458,7 +461,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -510,7 +513,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -518,6 +521,7 @@ func TestGrantAccess(t *testing.T) { } func TestRevokeAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) p := grafana.NewProvider("", crypto) @@ -543,7 +547,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -579,7 +583,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -617,7 +621,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -663,7 +667,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -717,7 +721,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"view"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) diff --git a/plugins/providers/metabase/client.go b/plugins/providers/metabase/client.go index b8ee5810a..e4f4d1181 100644 --- a/plugins/providers/metabase/client.go +++ b/plugins/providers/metabase/client.go @@ -2,6 +2,7 @@ package metabase import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -13,8 +14,8 @@ import ( "strings" "sync" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/pkg/tracing" - "github.com/goto/salt/log" "github.com/mitchellh/mapstructure" @@ -45,17 +46,17 @@ const ( type ResourceGroupDetails map[string][]map[string]interface{} type MetabaseClient interface { - GetDatabases() ([]*Database, error) - GetCollections() ([]*Collection, error) - GetGroups() ([]*Group, ResourceGroupDetails, ResourceGroupDetails, error) - GrantDatabaseAccess(resource *Database, user, role string, groups map[string]*Group) error - RevokeDatabaseAccess(resource *Database, user, role string) error - GrantCollectionAccess(resource *Collection, user, role string) error - RevokeCollectionAccess(resource *Collection, user, role string) error - GrantTableAccess(resource *Table, user, role string, groups map[string]*Group) error - RevokeTableAccess(resource *Table, user, role string) error - GrantGroupAccess(groupID int, email string) error - RevokeGroupAccess(groupID int, email string) error + GetDatabases(ctx context.Context) ([]*Database, error) + GetCollections(ctx context.Context) ([]*Collection, error) + GetGroups(ctx context.Context) ([]*Group, ResourceGroupDetails, ResourceGroupDetails, error) + GrantDatabaseAccess(ctx context.Context, resource *Database, user, role string, groups map[string]*Group) error + RevokeDatabaseAccess(ctx context.Context, resource *Database, user, role string) error + GrantCollectionAccess(ctx context.Context, resource *Collection, user, role string) error + RevokeCollectionAccess(ctx context.Context, resource *Collection, user, role string) error + GrantTableAccess(ctx context.Context, resource *Table, user, role string, groups map[string]*Group) error + RevokeTableAccess(ctx context.Context, resource *Table, user, role string) error + GrantGroupAccess(ctx context.Context, groupID int, email string) error + RevokeGroupAccess(ctx context.Context, groupID int, email string) error } type ClientConfig struct { @@ -163,7 +164,7 @@ func NewClient(config *ClientConfig, logger log.Logger) (*client, error) { return c, nil } -func (c *client) GetDatabases() ([]*Database, error) { +func (c *client) GetDatabases(ctx context.Context) ([]*Database, error) { req, err := c.newRequest(http.MethodGet, databaseEndpoint, nil) if err != nil { return nil, err @@ -172,7 +173,7 @@ func (c *client) GetDatabases() ([]*Database, error) { var databases []*Database var response interface{} - _, err = c.do(req, &response) + _, err = c.do(nil, req, &response) if err != nil { return databases, err } @@ -188,11 +189,11 @@ func (c *client) GetDatabases() ([]*Database, error) { if err != nil { return databases, err } - c.logger.Info("Fetch database from request", "total", len(databases), req.URL) + c.logger.Info(ctx, "Fetch database from request", "total", len(databases), req.URL) return databases, err } -func (c *client) GetCollections() ([]*Collection, error) { +func (c *client) GetCollections(ctx context.Context) ([]*Collection, error) { req, err := c.newRequest(http.MethodGet, collectionEndpoint, nil) if err != nil { return nil, err @@ -200,10 +201,10 @@ func (c *client) GetCollections() ([]*Collection, error) { var collections []*Collection result := make([]*Collection, 0) - if _, err := c.do(req, &collections); err != nil { + if _, err := c.do(ctx, req, &collections); err != nil { return nil, err } - c.logger.Info("Fetch collections from request", "total", len(collections), req.URL) + c.logger.Info(ctx, "Fetch collections from request", "total", len(collections), req.URL) collectionIdNameMap := make(map[string]string, 0) for _, collection := range collections { @@ -231,19 +232,19 @@ func (c *client) GetCollections() ([]*Collection, error) { return result, nil } -func (c *client) GetGroups() ([]*Group, ResourceGroupDetails, ResourceGroupDetails, error) { +func (c *client) GetGroups(ctx context.Context) ([]*Group, ResourceGroupDetails, ResourceGroupDetails, error) { wg := sync.WaitGroup{} wg.Add(3) var groups []*Group var err error - go c.fetchGroups(&wg, &groups, err) + go c.fetchGroups(ctx, &wg, &groups, err) databaseResourceGroups := make(ResourceGroupDetails, 0) - go c.fetchDatabasePermissions(&wg, databaseResourceGroups, err) + go c.fetchDatabasePermissions(ctx, &wg, databaseResourceGroups, err) collectionResourceGroups := make(ResourceGroupDetails, 0) - go c.fetchCollectionPermissions(&wg, collectionResourceGroups, err) + go c.fetchCollectionPermissions(ctx, &wg, collectionResourceGroups, err) wg.Wait() @@ -258,21 +259,21 @@ func (c *client) GetGroups() ([]*Group, ResourceGroupDetails, ResourceGroupDetai return groups, databaseResourceGroups, collectionResourceGroups, err } -func (c *client) fetchGroups(wg *sync.WaitGroup, groups *[]*Group, err error) { +func (c *client) fetchGroups(ctx context.Context, wg *sync.WaitGroup, groups *[]*Group, err error) { defer wg.Done() req, err := c.newRequest(http.MethodGet, groupEndpoint, nil) if err != nil { return } - _, err = c.do(req, &groups) + _, err = c.do(ctx, req, &groups) if err != nil { return } - c.logger.Info("Fetch groups from request", "total", len(*groups), req.URL) + c.logger.Info(ctx, "Fetch groups from request", "total", len(*groups), req.URL) } -func (c *client) fetchDatabasePermissions(wg *sync.WaitGroup, resourceGroups ResourceGroupDetails, err error) { +func (c *client) fetchDatabasePermissions(ctx context.Context, wg *sync.WaitGroup, resourceGroups ResourceGroupDetails, err error) { defer wg.Done() req, err := c.newRequest(http.MethodGet, databasePermissionEndpoint, nil) @@ -281,7 +282,7 @@ func (c *client) fetchDatabasePermissions(wg *sync.WaitGroup, resourceGroups Res } graphs := make(map[string]interface{}, 0) - _, err = c.do(req, &graphs) + _, err = c.do(ctx, req, &graphs) if err != nil { return } @@ -297,7 +298,7 @@ func (c *client) fetchDatabasePermissions(wg *sync.WaitGroup, resourceGroups Res for tableId, tablePermission := range tables { perm, ok := tablePermission.(string) if !ok { - c.logger.Warn("Invalid permission type for metabase group", "dbId", dbId, "tableId", tableId, "groupId", groupId, "permission", tablePermission, "type", reflect.TypeOf(tablePermission)) + c.logger.Warn(ctx, "Invalid permission type for metabase group", "dbId", dbId, "tableId", tableId, "groupId", groupId, "permission", tablePermission, "type", reflect.TypeOf(tablePermission)) continue } addGroupToResource(resourceGroups, fmt.Sprintf("%s:%s.%s", table, dbId, tableId), groupId, []string{perm}, err) @@ -314,7 +315,7 @@ func (c *client) fetchDatabasePermissions(wg *sync.WaitGroup, resourceGroups Res } } -func (c *client) fetchCollectionPermissions(wg *sync.WaitGroup, resourceGroups ResourceGroupDetails, err error) { +func (c *client) fetchCollectionPermissions(ctx context.Context, wg *sync.WaitGroup, resourceGroups ResourceGroupDetails, err error) { defer wg.Done() req, err := c.newRequest(http.MethodGet, collectionPermissionEndpoint, nil) @@ -323,17 +324,17 @@ func (c *client) fetchCollectionPermissions(wg *sync.WaitGroup, resourceGroups R } graphs := make(map[string]interface{}, 0) - _, err = c.do(req, &graphs) + _, err = c.do(ctx, req, &graphs) if err != nil { return } - c.logger.Info(fmt.Sprintf("Fetch permissions for collections from request: %v", req.URL)) + c.logger.Info(ctx, fmt.Sprintf("Fetch permissions for collections from request: %v", req.URL)) for groupId, r := range graphs[groups].(map[string]interface{}) { for collectionId, permission := range r.(map[string]interface{}) { if permission != none { p, ok := permission.(string) if !ok { - c.logger.Warn("Invalid permission type for metabase collection", "collectionId", collectionId, "groupId", groupId, "permission", permission, "type", reflect.TypeOf(permission)) + c.logger.Warn(ctx, "Invalid permission type for metabase collection", "collectionId", collectionId, "groupId", groupId, "permission", permission, "type", reflect.TypeOf(permission)) continue } addGroupToResource(resourceGroups, fmt.Sprintf("%s:%s", collection, collectionId), groupId, []string{p}, err) @@ -375,7 +376,7 @@ func addGroupToResource(resourceGroups ResourceGroupDetails, resourceId string, } } -func (c *client) GrantDatabaseAccess(resource *Database, email, role string, groups map[string]*Group) error { +func (c *client) GrantDatabaseAccess(ctx context.Context, resource *Database, email, role string, groups map[string]*Group) error { access, err := c.getDatabaseAccess() if err != nil { return err @@ -436,7 +437,7 @@ func (c *client) GrantDatabaseAccess(resource *Database, email, role string, gro return c.addGroupMember(groupIDInt, user.ID) } -func (c *client) RevokeDatabaseAccess(resource *Database, user, role string) error { +func (c *client) RevokeDatabaseAccess(ctx context.Context, resource *Database, user, role string) error { access, err := c.getDatabaseAccess() if err != nil { return err @@ -463,7 +464,7 @@ func (c *client) RevokeDatabaseAccess(resource *Database, user, role string) err return c.removeMembership(groupIDInt, user) } -func (c *client) GrantCollectionAccess(resource *Collection, email, role string) error { +func (c *client) GrantCollectionAccess(ctx context.Context, resource *Collection, email, role string) error { access, err := c.getCollectionAccess() if err != nil { return err @@ -512,7 +513,7 @@ func (c *client) GrantCollectionAccess(resource *Collection, email, role string) return c.addGroupMember(groupIDInt, user.ID) } -func (c *client) RevokeCollectionAccess(resource *Collection, user, role string) error { +func (c *client) RevokeCollectionAccess(ctx context.Context, resource *Collection, user, role string) error { access, err := c.getCollectionAccess() if err != nil { return err @@ -532,7 +533,7 @@ func (c *client) RevokeCollectionAccess(resource *Collection, user, role string) return c.removeMembership(groupIDInt, user) } -func (c *client) GrantTableAccess(resource *Table, email, role string, groups map[string]*Group) error { +func (c *client) GrantTableAccess(ctx context.Context, resource *Table, email, role string, groups map[string]*Group) error { access, err := c.getDatabaseAccess() if err != nil { return err @@ -597,7 +598,7 @@ func (c *client) GrantTableAccess(resource *Table, email, role string, groups ma return c.addGroupMember(groupIDInt, user.ID) } -func (c *client) RevokeTableAccess(resource *Table, user, role string) error { +func (c *client) RevokeTableAccess(ctx context.Context, resource *Table, user, role string) error { access, err := c.getDatabaseAccess() if err != nil { return err @@ -632,7 +633,7 @@ func (c *client) RevokeTableAccess(resource *Table, user, role string) error { return c.removeMembership(groupIDInt, user) } -func (c *client) GrantGroupAccess(groupID int, email string) error { +func (c *client) GrantGroupAccess(ctx context.Context, groupID int, email string) error { user, err := c.getUser(email) if err != nil { return err @@ -640,7 +641,7 @@ func (c *client) GrantGroupAccess(groupID int, email string) error { for _, userGroupId := range user.GroupIds { if userGroupId == groupID { - c.logger.Warn(fmt.Sprintf("User %s is already member of group %d", email, groupID)) + c.logger.Warn(ctx, fmt.Sprintf("User %s is already member of group %d", email, groupID)) return nil } } @@ -648,7 +649,7 @@ func (c *client) GrantGroupAccess(groupID int, email string) error { return c.addGroupMember(groupID, user.ID) } -func (c *client) RevokeGroupAccess(groupID int, email string) error { +func (c *client) RevokeGroupAccess(ctx context.Context, groupID int, email string) error { return c.removeMembership(groupID, email) } @@ -680,7 +681,7 @@ func (c *client) getUser(email string) (user, error) { var users []user var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(nil, req, &response); err != nil { return user{}, err } @@ -716,7 +717,7 @@ func (c *client) getSessionToken() (string, error) { } var sessionResponse SessionResponse - if _, err := c.do(req, &sessionResponse); err != nil { + if _, err := c.do(nil, req, &sessionResponse); err != nil { return "", err } @@ -730,7 +731,7 @@ func (c *client) getCollectionAccess() (*collectionGraph, error) { } var graph collectionGraph - if _, err := c.do(req, &graph); err != nil { + if _, err := c.do(nil, req, &graph); err != nil { return nil, err } @@ -743,7 +744,7 @@ func (c *client) updateCollectionAccess(access *collectionGraph) error { return err } - if _, err := c.do(req, &access); err != nil { + if _, err := c.do(nil, req, &access); err != nil { return err } @@ -757,7 +758,7 @@ func (c *client) getDatabaseAccess() (*databaseGraph, error) { } var dbGraph databaseGraph - if _, err := c.do(req, &dbGraph); err != nil { + if _, err := c.do(nil, req, &dbGraph); err != nil { return nil, err } @@ -770,7 +771,7 @@ func (c *client) updateDatabaseAccess(dbGraph *databaseGraph) error { return err } - if _, err := c.do(req, &dbGraph); err != nil { + if _, err := c.do(nil, req, &dbGraph); err != nil { return err } @@ -784,7 +785,7 @@ func (c *client) createGroup(group *group) error { return err } - if _, err := c.do(req, group); err != nil { + if _, err := c.do(nil, req, group); err != nil { return err } @@ -800,7 +801,7 @@ func (c *client) getGroup(id int) (*group, error) { var group group - if _, err := c.do(req, &group); err != nil { + if _, err := c.do(nil, req, &group); err != nil { return nil, err } @@ -816,7 +817,7 @@ func (c *client) addGroupMember(groupID, userID int) error { return err } - if _, err := c.do(req, nil); err != nil { + if _, err := c.do(nil, req, nil); err != nil { return err } @@ -831,7 +832,7 @@ func (c *client) removeGroupMember(membershipID int) error { return err } - if _, err := c.do(req, nil); err != nil { + if _, err := c.do(nil, req, nil); err != nil { return err } @@ -912,10 +913,10 @@ func (c *client) newRequest(method, path string, body interface{}) (*http.Reques return req, nil } -func (c *client) do(req *http.Request, v interface{}) (*http.Response, error) { +func (c *client) do(ctx context.Context, req *http.Request, v interface{}) (*http.Response, error) { resp, err := c.httpClient.Do(req) if err != nil { - c.logger.Error(fmt.Sprintf("Failed to execute request %v with error %v", req.URL, err)) + c.logger.Error(ctx, fmt.Sprintf("Failed to execute request %v with error %v", req.URL, err)) return nil, err } defer resp.Body.Close() diff --git a/plugins/providers/metabase/client_test.go b/plugins/providers/metabase/client_test.go index 6c100d2a3..5a7cf1fa1 100644 --- a/plugins/providers/metabase/client_test.go +++ b/plugins/providers/metabase/client_test.go @@ -2,6 +2,7 @@ package metabase_test import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -9,7 +10,7 @@ import ( "net/http" "testing" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/mocks" "github.com/goto/guardian/plugins/providers/metabase" @@ -21,7 +22,7 @@ import ( func TestNewClient(t *testing.T) { t.Run("should return error if config is invalid", func(t *testing.T) { invalidConfig := &metabase.ClientConfig{} - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) actualClient, actualError := metabase.NewClient(invalidConfig, logger) assert.Nil(t, actualClient) @@ -34,7 +35,7 @@ func TestNewClient(t *testing.T) { Password: "test-password", Host: "invalid-url", } - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) actualClient, actualError := metabase.NewClient(invalidHostConfig, logger) assert.Nil(t, actualClient) @@ -49,7 +50,7 @@ func TestNewClient(t *testing.T) { Host: "http://localhost", HTTPClient: mockHttpClient, } - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) expectedError := errors.New("request error") mockHttpClient.On("Do", mock.Anything).Return(nil, expectedError).Once() @@ -69,7 +70,7 @@ func TestNewClient(t *testing.T) { Host: "http://localhost", HTTPClient: mockHttpClient, } - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) sessionToken := "93df71b4-6887-46bd-b4bf-7ad3b94bd6fe" responseJSON := `{"id":"` + sessionToken + `"}` @@ -137,7 +138,7 @@ func (s *ClientTestSuite) TestGetCollections() { {ID: float64(7), Name: "CabFares/DS Analysis/Summary", Slug: "summary", Location: "/2/3/"}, } - result, err1 := s.client.GetCollections() + result, err1 := s.client.GetCollections(context.Background()) var collections []metabase.Collection for _, coll := range result { collections = append(collections, *coll) @@ -157,7 +158,7 @@ func (s *ClientTestSuite) TestGetDatabases() { databaseResponse := http.Response{StatusCode: 400, Body: io.NopCloser(bytes.NewReader([]byte(nil)))} s.mockHttpClient.On("Do", testRequest).Return(&databaseResponse, nil).Once() - result, err1 := s.client.GetDatabases() + result, err1 := s.client.GetDatabases(context.Background()) s.Nil(result) s.Error(err1) }) @@ -171,7 +172,7 @@ func (s *ClientTestSuite) TestGetDatabases() { databaseResponse := http.Response{StatusCode: 500, Body: io.NopCloser(bytes.NewReader([]byte(nil)))} s.mockHttpClient.On("Do", testRequest).Return(&databaseResponse, nil).Once() - result, err1 := s.client.GetDatabases() + result, err1 := s.client.GetDatabases(context.Background()) s.Nil(result) s.Error(err1) }) @@ -200,7 +201,7 @@ func (s *ClientTestSuite) TestGetDatabases() { {ID: 1, Name: "test-Name", CacheFieldValuesSchedule: "testCache", Timezone: "test-time", AutoRunQueries: true, MetadataSyncSchedule: "test-sync", Engine: "test-engine", NativePermissions: "per"}, } - result, err1 := s.client.GetDatabases() + result, err1 := s.client.GetDatabases(context.Background()) var databases []metabase.Database for _, db := range result { databases = append(databases, *db) @@ -225,7 +226,7 @@ func (s *ClientTestSuite) TestGetDatabases() { //Tables: []metabase.Table{{ID: 2, Name: "tab1", DbId: 1, Database: &domain.Resource{ID: "5", ProviderType: "metabase", ProviderURN: "test-URN", Type: "database"} }} }, } - result, err1 := s.client.GetDatabases() + result, err1 := s.client.GetDatabases(context.Background()) var databases []metabase.Database for _, db := range result { databases = append(databases, *db) @@ -294,7 +295,7 @@ func (s *ClientTestSuite) TestGetGroups() { fetchCollectionPermissionsResponse := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(fetchCollectionPermissionsResponseJSON)))} s.mockHttpClient.On("Do", fetchCollectionPermissionstestRequest).Return(&fetchCollectionPermissionsResponse, nil).Once() - actualGroupResponse, actualDatabaseGroupResponse, _, err := s.client.GetGroups() + actualGroupResponse, actualDatabaseGroupResponse, _, err := s.client.GetGroups(context.Background()) s.Nil(err) s.Equal(expectedgroupResponse, actualGroupResponse) @@ -349,7 +350,7 @@ func (s *ClientTestSuite) TestGrantDatabaseAccess() { "gid_1": {ID: 1, Name: "db_1"}, "gid_2": {ID: 2, Name: "db_2"}, } - actualError := s.client.GrantDatabaseAccess(resource, email, role, groups) + actualError := s.client.GrantDatabaseAccess(context.Background(), resource, email, role, groups) s.Nil(actualError) s.mockHttpClient.AssertExpectations(s.T()) }) @@ -391,7 +392,7 @@ func (s *ClientTestSuite) TestGrantCollectionAccess() { Name: "test-collection", } resource := expectedCollection - actualError := s.client.GrantCollectionAccess(resource, email, role) + actualError := s.client.GrantCollectionAccess(context.Background(), resource, email, role) s.Nil(actualError) s.mockHttpClient.AssertExpectations(s.T()) }) @@ -423,7 +424,7 @@ func (s *ClientTestSuite) TestGrantCollectionAccess() { Name: "test-collection", } resource := expectedCollection - actualError := s.client.GrantCollectionAccess(resource, email, role) + actualError := s.client.GrantCollectionAccess(context.Background(), resource, email, role) s.Nil(actualError) s.mockHttpClient.AssertExpectations(s.T()) }) @@ -462,7 +463,7 @@ func (s *ClientTestSuite) TestRevokeCollectionAccess() { Name: "test-collection", } resource := expectedCollection - actualError := s.client.RevokeCollectionAccess(resource, email, role) + actualError := s.client.RevokeCollectionAccess(context.Background(), resource, email, role) s.Nil(actualError) s.mockHttpClient.AssertExpectations(s.T()) }) @@ -482,7 +483,7 @@ func (s *ClientTestSuite) TestGrantGroupAccesss() { groupID := 53 - actualError := s.client.GrantGroupAccess(groupID, email) + actualError := s.client.GrantGroupAccess(context.Background(), groupID, email) s.Nil(actualError) s.mockHttpClient.AssertExpectations(s.T()) @@ -504,7 +505,7 @@ func (s *ClientTestSuite) TestGrantGroupAccesss() { res := http.Response{StatusCode: 200, Body: io.NopCloser(nil)} s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&res, nil).Once() - actualError := s.client.GrantGroupAccess(groupID, email) + actualError := s.client.GrantGroupAccess(context.Background(), groupID, email) s.Nil(actualError) s.mockHttpClient.AssertExpectations(s.T()) @@ -554,7 +555,7 @@ func (s *ClientTestSuite) TestGrantTableAccess() { response := http.Response{StatusCode: 200, Body: io.NopCloser(nil)} // test for addGroupMember s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&response, nil).Once() - actualError := s.client.GrantTableAccess(resource, email, role, groups) + actualError := s.client.GrantTableAccess(context.Background(), resource, email, role, groups) s.Nil(actualError) }) @@ -597,7 +598,7 @@ func (s *ClientTestSuite) TestGrantTableAccess() { response := http.Response{StatusCode: 200, Body: io.NopCloser(nil)} // test for addGroupMember s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&response, nil).Once() - actualError := s.client.GrantTableAccess(resource, email, role, groups) + actualError := s.client.GrantTableAccess(context.Background(), resource, email, role, groups) s.Nil(actualError) }) @@ -630,13 +631,13 @@ func (s *ClientTestSuite) TestRevokeDatabaseAccess() { s.mockHttpClient.On("Do", req).Return(&groupResponse, nil).Once() membershipID := 500 //test removeGroupMember - revokeGroupMemeberURL := fmt.Sprintf("/api/permissions/membership/%d", membershipID) - revokeGroupMemeberRequest, err3 := s.getTestRequest(http.MethodDelete, revokeGroupMemeberURL, nil) + revokeGroupMemberURL := fmt.Sprintf("/api/permissions/membership/%d", membershipID) + revokeGroupMemeberRequest, err3 := s.getTestRequest(http.MethodDelete, revokeGroupMemberURL, nil) s.Require().NoError(err3) revokeGroupMemeberResponse := http.Response{StatusCode: 200, Body: io.NopCloser(nil)} s.mockHttpClient.On("Do", revokeGroupMemeberRequest).Return(&revokeGroupMemeberResponse, nil).Once() - actualError := s.client.RevokeDatabaseAccess(resource, email, role) + actualError := s.client.RevokeDatabaseAccess(context.Background(), resource, email, role) s.Nil(actualError) }) diff --git a/plugins/providers/metabase/provider.go b/plugins/providers/metabase/provider.go index b77962e69..36c091ab5 100644 --- a/plugins/providers/metabase/provider.go +++ b/plugins/providers/metabase/provider.go @@ -1,11 +1,12 @@ package metabase import ( + "context" "strings" pv "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/mitchellh/mapstructure" ) @@ -42,7 +43,7 @@ func (p *provider) CreateConfig(pc *domain.ProviderConfig) error { return c.EncryptCredentials() } -func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { +func (p *provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return nil, err @@ -63,7 +64,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, var databases []*Database var collections []*Collection if _, ok := resourceTypes[ResourceTypeDatabase]; ok { - databases, err = client.GetDatabases() + databases, err = client.GetDatabases(ctx) if err != nil { return nil, err } @@ -72,7 +73,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, if _, ok := resourceTypes[ResourceTypeTable]; ok { if databases == nil { - databases, err = client.GetDatabases() + databases, err = client.GetDatabases(ctx) } if err != nil { return nil, err @@ -81,14 +82,14 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, } if _, ok := resourceTypes[ResourceTypeCollection]; ok { - collections, err = client.GetCollections() + collections, err = client.GetCollections(ctx) if err != nil { return nil, err } resources = p.addCollection(pc, collections, resources) } - groups, databaseResourceGroups, collectionResourceGroups, err := client.GetGroups() + groups, databaseResourceGroups, collectionResourceGroups, err := client.GetGroups(ctx) if err != nil { return nil, err } @@ -111,7 +112,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, collectionResourceMap := make(map[string]*domain.Resource, 0) if databases == nil { - databases, err = client.GetDatabases() + databases, err = client.GetDatabases(ctx) if err != nil { return nil, err } @@ -122,7 +123,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, } if collections == nil { - collections, err = client.GetCollections() + collections, err = client.GetCollections(ctx) if err != nil { return nil, err } @@ -200,7 +201,7 @@ func (p *provider) addTables(pc *domain.ProviderConfig, databases []*Database, r return resources } -func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { // TODO: validate provider config and appeal var creds Credentials @@ -212,7 +213,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } - groups, _, _, err := client.GetGroups() + groups, _, _, err := client.GetGroups(ctx) if err != nil { return err } @@ -230,7 +231,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error } for _, p := range permissions { - if err := client.GrantDatabaseAccess(d, a.AccountID, string(p), groupMap); err != nil { + if err := client.GrantDatabaseAccess(ctx, d, a.AccountID, string(p), groupMap); err != nil { return err } } @@ -243,7 +244,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error } for _, p := range permissions { - if err := client.GrantCollectionAccess(c, a.AccountID, string(p)); err != nil { + if err := client.GrantCollectionAccess(ctx, c, a.AccountID, string(p)); err != nil { return err } } @@ -255,7 +256,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } - if err := client.GrantGroupAccess(g.ID, a.AccountID); err != nil { + if err := client.GrantGroupAccess(ctx, g.ID, a.AccountID); err != nil { return err } return nil @@ -266,7 +267,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error } for _, p := range permissions { - if err := client.GrantTableAccess(t, a.AccountID, string(p), groupMap); err != nil { + if err := client.GrantTableAccess(ctx, t, a.AccountID, string(p), groupMap); err != nil { return err } } @@ -276,7 +277,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return ErrInvalidResourceType } -func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return err @@ -294,7 +295,7 @@ func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error } for _, p := range permissions { - if err := client.RevokeDatabaseAccess(d, a.AccountID, string(p)); err != nil { + if err := client.RevokeDatabaseAccess(ctx, d, a.AccountID, string(p)); err != nil { return err } } @@ -307,7 +308,7 @@ func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error } for _, p := range permissions { - if err := client.RevokeCollectionAccess(c, a.AccountID, string(p)); err != nil { + if err := client.RevokeCollectionAccess(ctx, c, a.AccountID, string(p)); err != nil { return err } } @@ -319,7 +320,7 @@ func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } - if err := client.RevokeGroupAccess(g.ID, a.AccountID); err != nil { + if err := client.RevokeGroupAccess(ctx, g.ID, a.AccountID); err != nil { return err } @@ -331,7 +332,7 @@ func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error } for _, p := range permissions { - if err := client.RevokeTableAccess(t, a.AccountID, string(p)); err != nil { + if err := client.RevokeTableAccess(ctx, t, a.AccountID, string(p)); err != nil { return err } } diff --git a/plugins/providers/metabase/provider_test.go b/plugins/providers/metabase/provider_test.go index 87f75bf17..475ac0eb6 100644 --- a/plugins/providers/metabase/provider_test.go +++ b/plugins/providers/metabase/provider_test.go @@ -1,10 +1,11 @@ package metabase_test import ( + "context" "errors" "testing" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" @@ -17,7 +18,7 @@ import ( func TestGetType(t *testing.T) { t.Run("should return provider type name", func(t *testing.T) { expectedTypeName := domain.ProviderTypeMetabase - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) crypto := new(mocks.Crypto) p := metabase.NewProvider(expectedTypeName, crypto, logger) @@ -32,7 +33,7 @@ func TestCreateConfig(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, @@ -71,7 +72,7 @@ func TestCreateConfig(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, @@ -126,7 +127,7 @@ func TestCreateConfig(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) crypto.On("Encrypt", "test-password").Return("encrypted-test-pasword", nil) p.Clients = map[string]metabase.MetabaseClient{ @@ -236,16 +237,17 @@ func TestCreateConfig(t *testing.T) { } func TestGetResources(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) pc := &domain.ProviderConfig{ Credentials: "invalid-creds", } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -253,7 +255,7 @@ func TestGetResources(t *testing.T) { t.Run("should return error if there are any on client initialization", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) expectedError := errors.New("decrypt error") @@ -264,7 +266,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -274,7 +276,7 @@ func TestGetResources(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, @@ -290,9 +292,9 @@ func TestGetResources(t *testing.T) { }, } expectedError := errors.New("client error") - client.On("GetDatabases").Return(nil, expectedError).Once() + client.On("GetDatabases", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -302,7 +304,7 @@ func TestGetResources(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, @@ -318,9 +320,9 @@ func TestGetResources(t *testing.T) { }, } expectedError := errors.New("client error") - client.On("GetCollections").Return(nil, expectedError).Once() + client.On("GetCollections", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -330,7 +332,7 @@ func TestGetResources(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, @@ -359,13 +361,13 @@ func TestGetResources(t *testing.T) { Tables: []metabase.Table{{ID: 2, Name: "table_1", DbId: 1}}, }, } - client.On("GetDatabases").Return(expectedDatabases, nil).Once() + client.On("GetDatabases", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(expectedDatabases, nil).Once() d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group, {Name: metabase.GuardianGroupPrefix + "database_1_schema:all", DatabaseResources: d, CollectionResources: c}}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group, {Name: metabase.GuardianGroupPrefix + "database_1_schema:all", DatabaseResources: d, CollectionResources: c}}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -375,7 +377,7 @@ func TestGetResources(t *testing.T) { Name: "col_1", }, } - client.On("GetCollections").Return(expectedCollections, nil).Once() + client.On("GetCollections", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return(expectedCollections, nil).Once() expectedResources := []*domain.Resource{ { Type: metabase.ResourceTypeDatabase, @@ -426,7 +428,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -434,9 +436,10 @@ func TestGetResources(t *testing.T) { } func TestGrantAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) pc := &domain.ProviderConfig{ @@ -460,13 +463,13 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Error(t, actualError) }) t.Run("should return decrypt error if there are any on client initialization", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) expectedError := errors.New("decrypt error") crypto.On("Decrypt", "test-password").Return("", expectedError).Once() @@ -496,14 +499,14 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) t.Run("should return error if resource type in unknown", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) expectedError := errors.New("invalid resource type") crypto.On("Decrypt", "test-password").Return("", expectedError).Once() @@ -534,7 +537,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -545,16 +548,17 @@ func TestGrantAccess(t *testing.T) { expectedError := errors.New("client error") crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("GrantDatabaseAccess", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("GrantDatabaseAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(expectedError).Once() d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -587,7 +591,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -595,7 +599,7 @@ func TestGrantAccess(t *testing.T) { t.Run("should return nil error if granting access is successful", func(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) client := new(mocks.MetabaseClient) expectedDatabase := &metabase.Database{ Name: "test-database", @@ -612,7 +616,7 @@ func TestGrantAccess(t *testing.T) { d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -647,7 +651,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -659,7 +663,7 @@ func TestGrantAccess(t *testing.T) { expectedError := errors.New("client error") crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, @@ -669,7 +673,7 @@ func TestGrantAccess(t *testing.T) { d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -702,7 +706,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -717,7 +721,7 @@ func TestGrantAccess(t *testing.T) { } expectedUser := "test@email.com" expectedRole := "viewer" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ @@ -728,7 +732,7 @@ func TestGrantAccess(t *testing.T) { d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -763,7 +767,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -775,16 +779,16 @@ func TestGrantAccess(t *testing.T) { expectedError := errors.New("client error") crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("GrantGroupAccess", mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("GrantGroupAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything, mock.Anything).Return(expectedError).Once() d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -817,7 +821,7 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) client.AssertExpectations(t) @@ -826,7 +830,7 @@ func TestGrantAccess(t *testing.T) { t.Run("should return nil error if granting access is successful", func(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) client := new(mocks.MetabaseClient) expectedGroupID := 999 expectedEmail := "test@email.com" @@ -834,11 +838,11 @@ func TestGrantAccess(t *testing.T) { p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("GrantGroupAccess", expectedGroupID, expectedEmail).Return(nil).Once() + client.On("GrantGroupAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedGroupID, expectedEmail).Return(nil).Once() d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -873,7 +877,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -887,17 +891,17 @@ func TestGrantAccess(t *testing.T) { expectedUser := "test@email.com" crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("GrantTableAccess", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("GrantTableAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -933,7 +937,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -949,18 +953,18 @@ func TestGrantAccess(t *testing.T) { } expectedUser := "test@email.com" expectedRole := "viewer" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("GrantTableAccess", expectedTable, expectedUser, expectedRole, mock.Anything).Return(nil).Once() + client.On("GrantTableAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedTable, expectedUser, expectedRole, mock.Anything).Return(nil).Once() d := []*metabase.GroupResource{{Urn: "database:1", Permissions: []string{"read", "write"}}} c := []*metabase.GroupResource{{Urn: "collection:1", Permissions: []string{"read", "write"}}} group := metabase.Group{Name: "All Users", DatabaseResources: d, CollectionResources: c} - client.On("GetGroups").Return([]*metabase.Group{&group}, + client.On("GetGroups", mock.MatchedBy(func(ctx context.Context) bool { return true })).Return([]*metabase.Group{&group}, metabase.ResourceGroupDetails{"database:1": {{"urn": "group:1", "permissions": []string{"read", "write"}}}}, metabase.ResourceGroupDetails{"collection:1": {{"urn": "group:1", "permissions": []string{"write"}}}}, nil).Once() @@ -996,7 +1000,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -1005,9 +1009,10 @@ func TestGrantAccess(t *testing.T) { } func TestRevokeAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) pc := &domain.ProviderConfig{ @@ -1031,13 +1036,13 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Error(t, actualError) }) t.Run("should return error if there are any on client initialization", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) expectedError := errors.New("decrypt error") crypto.On("Decrypt", "test-password").Return("", expectedError).Once() @@ -1067,14 +1072,14 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) t.Run("should return error if resource type in unknown", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) expectedError := errors.New("invalid resource type") crypto.On("Decrypt", "test-password").Return("", expectedError).Once() @@ -1105,7 +1110,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1116,7 +1121,7 @@ func TestRevokeAccess(t *testing.T) { expectedError := errors.New("client error") crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, @@ -1152,7 +1157,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1160,7 +1165,7 @@ func TestRevokeAccess(t *testing.T) { t.Run("should return nil error if revoking database access is successful", func(t *testing.T) { providerURN := "test-provider-urn" crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) client := new(mocks.MetabaseClient) expectedDatabase := &metabase.Database{ Name: "test-database", @@ -1172,7 +1177,7 @@ func TestRevokeAccess(t *testing.T) { p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("RevokeDatabaseAccess", expectedDatabase, expectedUser, expectedRole, mock.Anything).Return(nil).Once() + client.On("RevokeDatabaseAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedDatabase, expectedUser, expectedRole, mock.Anything).Return(nil).Once() pc := &domain.ProviderConfig{ Credentials: metabase.Credentials{ @@ -1206,7 +1211,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -1219,7 +1224,7 @@ func TestRevokeAccess(t *testing.T) { expectedError := errors.New("client error") crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, @@ -1255,7 +1260,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1270,13 +1275,13 @@ func TestRevokeAccess(t *testing.T) { } expectedUser := "test@email.com" expectedRole := "viewer" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("RevokeCollectionAccess", expectedCollection, expectedUser, expectedRole, mock.Anything).Return(nil).Once() + client.On("RevokeCollectionAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedCollection, expectedUser, expectedRole, mock.Anything).Return(nil).Once() pc := &domain.ProviderConfig{ Credentials: metabase.Credentials{ @@ -1310,7 +1315,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -1323,12 +1328,12 @@ func TestRevokeAccess(t *testing.T) { expectedError := errors.New("client error") crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("RevokeGroupAccess", mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("RevokeGroupAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything, mock.Anything).Return(expectedError).Once() pc := &domain.ProviderConfig{ Credentials: metabase.Credentials{ @@ -1358,7 +1363,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1371,12 +1376,12 @@ func TestRevokeAccess(t *testing.T) { expectedUser := "test@email.com" crypto := new(mocks.Crypto) client := new(mocks.MetabaseClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("RevokeTableAccess", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("RevokeTableAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() pc := &domain.ProviderConfig{ Credentials: metabase.Credentials{ @@ -1410,7 +1415,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1426,13 +1431,13 @@ func TestRevokeAccess(t *testing.T) { } expectedUser := "test@email.com" expectedRole := "viewer" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) p.Clients = map[string]metabase.MetabaseClient{ providerURN: client, } - client.On("RevokeTableAccess", expectedTable, expectedUser, expectedRole, mock.Anything).Return(nil).Once() + client.On("RevokeTableAccess", mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedTable, expectedUser, expectedRole, mock.Anything).Return(nil).Once() pc := &domain.ProviderConfig{ Credentials: metabase.Credentials{ @@ -1466,7 +1471,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -1477,7 +1482,7 @@ func TestRevokeAccess(t *testing.T) { func TestGetAccountTypes(t *testing.T) { expectedAccountType := []string{"user"} crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("", crypto, logger) actualAccountType := p.GetAccountTypes() @@ -1488,7 +1493,7 @@ func TestGetAccountTypes(t *testing.T) { func TestGetRoles(t *testing.T) { t.Run("should return error if resource type is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("metabase", crypto, logger) validConfig := &domain.ProviderConfig{ Type: "metabase", @@ -1521,7 +1526,7 @@ func TestGetRoles(t *testing.T) { t.Run("should return roles specified in the provider config", func(t *testing.T) { crypto := new(mocks.Crypto) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := metabase.NewProvider("metabase", crypto, logger) expectedRoles := []*domain.Role{ diff --git a/plugins/providers/noop/provider.go b/plugins/providers/noop/provider.go index 2fb47071a..b15f8e75a 100644 --- a/plugins/providers/noop/provider.go +++ b/plugins/providers/noop/provider.go @@ -1,11 +1,12 @@ package noop import ( + "context" "errors" "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" ) var ( @@ -68,7 +69,7 @@ func (p *Provider) CreateConfig(cfg *domain.ProviderConfig) error { return nil } -func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { +func (p *Provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { return []*domain.Resource{ { ProviderType: domain.ProviderTypeNoOp, @@ -80,11 +81,11 @@ func (p *Provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, }, nil } -func (p *Provider) GrantAccess(*domain.ProviderConfig, domain.Grant) error { +func (p *Provider) GrantAccess(context.Context, *domain.ProviderConfig, domain.Grant) error { return nil } -func (p *Provider) RevokeAccess(*domain.ProviderConfig, domain.Grant) error { +func (p *Provider) RevokeAccess(context.Context, *domain.ProviderConfig, domain.Grant) error { return nil } diff --git a/plugins/providers/noop/provider_test.go b/plugins/providers/noop/provider_test.go index adf75164c..35cd3ec78 100644 --- a/plugins/providers/noop/provider_test.go +++ b/plugins/providers/noop/provider_test.go @@ -1,18 +1,19 @@ package noop_test import ( + "context" "testing" "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/plugins/providers/noop" - "github.com/goto/salt/log" "github.com/stretchr/testify/assert" ) func TestGetType(t *testing.T) { t.Run("should return the proper type name based on the initialization", func(t *testing.T) { - logger := log.NewLogrus() + logger := log.NewNoop() expectedTypeName := "test-type-name" p := noop.NewProvider(expectedTypeName, logger) @@ -212,7 +213,7 @@ func TestGetResources(t *testing.T) { Name: validConfig.URN, } - actualResources, actualError := p.GetResources(validConfig) + actualResources, actualError := p.GetResources(context.TODO(), validConfig) assert.NoError(t, actualError) assert.Equal(t, []*domain.Resource{expectedResource}, actualResources) @@ -223,7 +224,7 @@ func TestGrantAccess(t *testing.T) { t.Run("should return nil", func(t *testing.T) { p := initProvider() - actualError := p.GrantAccess(nil, domain.Grant{}) + actualError := p.GrantAccess(context.TODO(), nil, domain.Grant{}) assert.NoError(t, actualError) }) @@ -233,7 +234,7 @@ func TestRevokeAccess(t *testing.T) { t.Run("should return nil", func(t *testing.T) { p := initProvider() - actualError := p.RevokeAccess(nil, domain.Grant{}) + actualError := p.RevokeAccess(context.TODO(), nil, domain.Grant{}) assert.NoError(t, actualError) }) @@ -308,6 +309,6 @@ func TestGetAccountTypes(t *testing.T) { } func initProvider() *noop.Provider { - logger := log.NewLogrus() + logger := log.NewNoop() return noop.NewProvider("noop", logger) } diff --git a/plugins/providers/shield/client.go b/plugins/providers/shield/client.go index 621ed780b..b1ab09e4b 100644 --- a/plugins/providers/shield/client.go +++ b/plugins/providers/shield/client.go @@ -2,6 +2,7 @@ package shield import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -11,8 +12,8 @@ import ( "path" "github.com/go-playground/validator/v10" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/pkg/tracing" - "github.com/goto/salt/log" "github.com/mitchellh/mapstructure" ) @@ -32,16 +33,16 @@ const ( type successAccess interface{} type ShieldClient interface { - GetTeams() ([]*Team, error) - GetProjects() ([]*Project, error) - GetOrganizations() ([]*Organization, error) - GrantTeamAccess(team *Team, userId string, role string) error - RevokeTeamAccess(team *Team, userId string, role string) error - GrantProjectAccess(project *Project, userId string, role string) error - RevokeProjectAccess(project *Project, userId string, role string) error - GrantOrganizationAccess(organization *Organization, userId string, role string) error - RevokeOrganizationAccess(organization *Organization, userId string, role string) error - GetSelfUser(email string) (*User, error) + GetTeams(ctx context.Context) ([]*Team, error) + GetProjects(ctx context.Context) ([]*Project, error) + GetOrganizations(ctx context.Context) ([]*Organization, error) + GrantTeamAccess(ctx context.Context, team *Team, userId string, role string) error + RevokeTeamAccess(ctx context.Context, team *Team, userId string, role string) error + GrantProjectAccess(ctx context.Context, project *Project, userId string, role string) error + RevokeProjectAccess(ctx context.Context, project *Project, userId string, role string) error + GrantOrganizationAccess(ctx context.Context, organization *Organization, userId string, role string) error + RevokeOrganizationAccess(ctx context.Context, organization *Organization, userId string, role string) error + GetSelfUser(ctx context.Context, email string) (*User, error) } type client struct { @@ -120,7 +121,7 @@ func (c *client) newRequest(method, path string, body interface{}, authEmail str return req, nil } -func (c *client) GetAdminsOfGivenResourceType(id string, resourceTypeEndPoint string) ([]string, error) { +func (c *client) GetAdminsOfGivenResourceType(ctx context.Context, id string, resourceTypeEndPoint string) ([]string, error) { endPoint := path.Join(resourceTypeEndPoint, "/", id, "/admins") req, err := c.newRequest(http.MethodGet, endPoint, nil, "") if err != nil { @@ -129,7 +130,7 @@ func (c *client) GetAdminsOfGivenResourceType(id string, resourceTypeEndPoint st var users []*User var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return nil, err } if v, ok := response.(map[string]interface{}); ok && v[usersConst] != nil { @@ -144,7 +145,7 @@ func (c *client) GetAdminsOfGivenResourceType(id string, resourceTypeEndPoint st return userEmails, err } -func (c *client) GetTeams() ([]*Team, error) { +func (c *client) GetTeams(ctx context.Context) ([]*Team, error) { req, err := c.newRequest(http.MethodGet, groupsEndpoint, nil, "") if err != nil { return nil, err @@ -152,7 +153,7 @@ func (c *client) GetTeams() ([]*Team, error) { var teams []*Team var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return nil, err } @@ -161,19 +162,19 @@ func (c *client) GetTeams() ([]*Team, error) { } for _, team := range teams { - admins, err := c.GetAdminsOfGivenResourceType(team.ID, groupsEndpoint) + admins, err := c.GetAdminsOfGivenResourceType(ctx, team.ID, groupsEndpoint) if err != nil { return nil, err } team.Admins = admins } - c.logger.Info("Fetch teams from request", "total", len(teams), req.URL) + c.logger.Info(ctx, "Fetch teams from request", "total", len(teams), req.URL) return teams, err } -func (c *client) GetProjects() ([]*Project, error) { +func (c *client) GetProjects(ctx context.Context) ([]*Project, error) { req, err := c.newRequest(http.MethodGet, projectsEndpoint, nil, "") if err != nil { return nil, err @@ -182,7 +183,7 @@ func (c *client) GetProjects() ([]*Project, error) { var projects []*Project var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return nil, err } @@ -191,19 +192,19 @@ func (c *client) GetProjects() ([]*Project, error) { } for _, project := range projects { - admins, err := c.GetAdminsOfGivenResourceType(project.ID, projectsEndpoint) + admins, err := c.GetAdminsOfGivenResourceType(ctx, project.ID, projectsEndpoint) if err != nil { return nil, err } project.Admins = admins } - c.logger.Info("Fetch projects from request", "total", len(projects), req.URL) + c.logger.Info(ctx, "Fetch projects from request", "total", len(projects), req.URL) return projects, err } -func (c *client) GetOrganizations() ([]*Organization, error) { +func (c *client) GetOrganizations(ctx context.Context) ([]*Organization, error) { req, err := c.newRequest(http.MethodGet, organizationEndpoint, nil, "") if err != nil { return nil, err @@ -211,7 +212,7 @@ func (c *client) GetOrganizations() ([]*Organization, error) { var organizations []*Organization var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return nil, err } @@ -220,19 +221,19 @@ func (c *client) GetOrganizations() ([]*Organization, error) { } for _, org := range organizations { - admins, err := c.GetAdminsOfGivenResourceType(org.ID, organizationEndpoint) + admins, err := c.GetAdminsOfGivenResourceType(ctx, org.ID, organizationEndpoint) if err != nil { return nil, err } org.Admins = admins } - c.logger.Info("Fetch organizations from request", "total", len(organizations), req.URL) + c.logger.Info(ctx, "Fetch organizations from request", "total", len(organizations), req.URL) return organizations, err } -func (c *client) GrantTeamAccess(resource *Team, userId string, role string) error { +func (c *client) GrantTeamAccess(ctx context.Context, resource *Team, userId string, role string) error { body := make(map[string][]string) body["userIds"] = append(body["userIds"], userId) @@ -244,7 +245,7 @@ func (c *client) GrantTeamAccess(resource *Team, userId string, role string) err var users []*User var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return err } @@ -255,12 +256,12 @@ func (c *client) GrantTeamAccess(resource *Team, userId string, role string) err } } - c.logger.Info("Team access to the user,", "total users", len(users), req.URL) + c.logger.Info(ctx, "Team access to the user,", "total users", len(users), req.URL) return nil } -func (c *client) GrantProjectAccess(resource *Project, userId string, role string) error { +func (c *client) GrantProjectAccess(ctx context.Context, resource *Project, userId string, role string) error { body := make(map[string][]string) body["userIds"] = append(body["userIds"], userId) @@ -272,7 +273,7 @@ func (c *client) GrantProjectAccess(resource *Project, userId string, role strin var users []*User var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return err } @@ -283,11 +284,11 @@ func (c *client) GrantProjectAccess(resource *Project, userId string, role strin } } - c.logger.Info("Project access to the user,", "total users", len(users), req.URL) + c.logger.Info(ctx, "Project access to the user,", "total users", len(users), req.URL) return nil } -func (c *client) GrantOrganizationAccess(resource *Organization, userId string, role string) error { +func (c *client) GrantOrganizationAccess(ctx context.Context, resource *Organization, userId string, role string) error { body := make(map[string][]string) body["userIds"] = append(body["userIds"], userId) @@ -300,7 +301,7 @@ func (c *client) GrantOrganizationAccess(resource *Organization, userId string, var users []*User var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return err } @@ -311,11 +312,11 @@ func (c *client) GrantOrganizationAccess(resource *Organization, userId string, } } - c.logger.Info("Organization access to the user,", "total users", len(users), req.URL) + c.logger.Info(ctx, "Organization access to the user,", "total users", len(users), req.URL) return nil } -func (c *client) RevokeTeamAccess(resource *Team, userId string, role string) error { +func (c *client) RevokeTeamAccess(ctx context.Context, resource *Team, userId string, role string) error { endPoint := path.Join(groupsEndpoint, "/", resource.ID, "/", role, "/", userId) req, err := c.newRequest(http.MethodDelete, endPoint, "", "") if err != nil { @@ -324,7 +325,7 @@ func (c *client) RevokeTeamAccess(resource *Team, userId string, role string) er var success successAccess var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return err } @@ -335,11 +336,11 @@ func (c *client) RevokeTeamAccess(resource *Team, userId string, role string) er } } - c.logger.Info("Remove access of the user from team,", "Users", userId, req.URL) + c.logger.Info(ctx, "Remove access of the user from team,", "Users", userId, req.URL) return nil } -func (c *client) RevokeProjectAccess(resource *Project, userId string, role string) error { +func (c *client) RevokeProjectAccess(ctx context.Context, resource *Project, userId string, role string) error { endPoint := path.Join(projectsEndpoint, "/", resource.ID, "/", role, "/", userId) req, err := c.newRequest(http.MethodDelete, endPoint, "", "") if err != nil { @@ -348,7 +349,7 @@ func (c *client) RevokeProjectAccess(resource *Project, userId string, role stri var success successAccess var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return err } @@ -359,11 +360,11 @@ func (c *client) RevokeProjectAccess(resource *Project, userId string, role stri } } - c.logger.Info("Remove access of the user from project", "Users", userId, req.URL) + c.logger.Info(ctx, "Remove access of the user from project", "Users", userId, req.URL) return nil } -func (c *client) RevokeOrganizationAccess(resource *Organization, userId string, role string) error { +func (c *client) RevokeOrganizationAccess(ctx context.Context, resource *Organization, userId string, role string) error { endPoint := path.Join(organizationEndpoint, "/", resource.ID, "/", role, "/", userId) req, err := c.newRequest(http.MethodDelete, endPoint, "", "") if err != nil { @@ -372,7 +373,7 @@ func (c *client) RevokeOrganizationAccess(resource *Organization, userId string, var success successAccess var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return err } @@ -383,11 +384,11 @@ func (c *client) RevokeOrganizationAccess(resource *Organization, userId string, } } - c.logger.Info("Remove access of the user from organization", "Users", userId, req.URL) + c.logger.Info(ctx, "Remove access of the user from organization", "Users", userId, req.URL) return nil } -func (c *client) GetSelfUser(email string) (*User, error) { +func (c *client) GetSelfUser(ctx context.Context, email string) (*User, error) { req, err := c.newRequest(http.MethodGet, selfUserEndpoint, nil, email) if err != nil { return nil, err @@ -395,7 +396,7 @@ func (c *client) GetSelfUser(email string) (*User, error) { var user *User var response interface{} - if _, err := c.do(req, &response); err != nil { + if _, err := c.do(ctx, req, &response); err != nil { return nil, err } @@ -403,15 +404,15 @@ func (c *client) GetSelfUser(email string) (*User, error) { err = mapstructure.Decode(v[userConst], &user) } - c.logger.Info("Fetch user from request", "Id", user.ID, req.URL) + c.logger.Info(ctx, "Fetch user from request", "Id", user.ID, req.URL) return user, err } -func (c *client) do(req *http.Request, v interface{}) (*http.Response, error) { +func (c *client) do(ctx context.Context, req *http.Request, v interface{}) (*http.Response, error) { resp, err := c.httpClient.Do(req) if err != nil { - c.logger.Error(fmt.Sprintf("Failed to execute request %v with error %v", req.URL, err)) + c.logger.Error(ctx, fmt.Sprintf("Failed to execute request %v with error %v", req.URL, err)) return nil, err } defer resp.Body.Close() diff --git a/plugins/providers/shield/client_test.go b/plugins/providers/shield/client_test.go index 718be9d26..08073dbd6 100644 --- a/plugins/providers/shield/client_test.go +++ b/plugins/providers/shield/client_test.go @@ -2,6 +2,7 @@ package shield_test import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -10,7 +11,7 @@ import ( "github.com/stretchr/testify/mock" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/mocks" "github.com/goto/guardian/plugins/providers/shield" @@ -21,7 +22,7 @@ import ( func TestNewClient(t *testing.T) { t.Run("should return error if config is invalid", func(t *testing.T) { invalidConfig := &shield.ClientConfig{} - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) actualClient, actualError := shield.NewClient(invalidConfig, logger) assert.Nil(t, actualClient) @@ -34,7 +35,7 @@ func TestNewClient(t *testing.T) { AuthEmail: "test-email", Host: "invalid-url", } - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) actualClient, actualError := shield.NewClient(invalidHostConfig, logger) assert.Nil(t, actualClient) @@ -50,7 +51,7 @@ func TestNewClient(t *testing.T) { Host: "http://localhost", HTTPClient: mockHttpClient, } - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) _, actualError := shield.NewClient(config, logger) mockHttpClient.AssertExpectations(t) @@ -209,7 +210,7 @@ func (s *ClientTestSuite) TestGetTeams() { teamAdminResponse2 := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(teamAdminResponse)))} s.mockHttpClient.On("Do", testAdminsRequest2).Return(&teamAdminResponse2, nil).Once() - result, err1 := s.client.GetTeams() + result, err1 := s.client.GetTeams(context.Background()) var teams []shield.Team for _, team := range result { teams = append(teams, *team) @@ -278,7 +279,7 @@ func (s *ClientTestSuite) TestGetProjects() { projectAdminResponse1 := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(projectAdminResponse)))} s.mockHttpClient.On("Do", testAdminsRequest).Return(&projectAdminResponse1, nil).Once() - result, err1 := s.client.GetProjects() + result, err1 := s.client.GetProjects(context.Background()) var projects []shield.Project for _, project := range result { projects = append(projects, *project) @@ -344,7 +345,7 @@ func (s *ClientTestSuite) TestGetOrganizations() { orgAdminResponse1 := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(orgAdminResponse)))} s.mockHttpClient.On("Do", testAdminsRequest).Return(&orgAdminResponse1, nil).Once() - result, err1 := s.client.GetOrganizations() + result, err1 := s.client.GetOrganizations(context.Background()) var orgs []shield.Organization for _, org := range result { orgs = append(orgs, *org) @@ -388,7 +389,7 @@ func (s *ClientTestSuite) TestGrantTeamAccess() { responseUsers := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(responseJson)))} s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&responseUsers, nil).Once() - actualError := s.client.GrantTeamAccess(teamObj, testUserId, role) + actualError := s.client.GrantTeamAccess(context.Background(), teamObj, testUserId, role) s.Nil(actualError) }) } @@ -427,7 +428,7 @@ func (s *ClientTestSuite) TestGrantProjectAccess() { responseUsers := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(responseJson)))} s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&responseUsers, nil).Once() - actualError := s.client.GrantProjectAccess(projectObj, testUserId, role) + actualError := s.client.GrantProjectAccess(context.Background(), projectObj, testUserId, role) s.Nil(actualError) }) } @@ -465,7 +466,7 @@ func (s *ClientTestSuite) TestGrantOrganizationAccess() { responseUsers := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(responseJson)))} s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&responseUsers, nil).Once() - actualError := s.client.GrantOrganizationAccess(orgObj, testUserId, role) + actualError := s.client.GrantOrganizationAccess(context.Background(), orgObj, testUserId, role) s.Nil(actualError) }) } @@ -491,7 +492,7 @@ func (s *ClientTestSuite) TestRevokeTeamAccess() { responseUsers := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(responseJson)))} s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&responseUsers, nil).Once() - actualError := s.client.RevokeTeamAccess(teamObj, testUserId, role) + actualError := s.client.RevokeTeamAccess(context.Background(), teamObj, testUserId, role) s.Nil(actualError) }) } @@ -518,7 +519,7 @@ func (s *ClientTestSuite) TestRevokeProjectAccess() { responseUsers := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(responseJson)))} s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&responseUsers, nil).Once() - actualError := s.client.RevokeProjectAccess(projectObj, testUserId, role) + actualError := s.client.RevokeProjectAccess(context.Background(), projectObj, testUserId, role) s.Nil(actualError) }) } @@ -544,7 +545,7 @@ func (s *ClientTestSuite) TestRevokeOrganizationAccess() { responseUsers := http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader([]byte(responseJson)))} s.mockHttpClient.On("Do", mock.AnythingOfType("*http.Request")).Return(&responseUsers, nil).Once() - actualError := s.client.RevokeOrganizationAccess(orgObj, testUserId, role) + actualError := s.client.RevokeOrganizationAccess(context.Background(), orgObj, testUserId, role) s.Nil(actualError) }) } @@ -566,7 +567,7 @@ func (s *ClientTestSuite) TestGetSelfUser() { responseUser := http.Response{StatusCode: 500, Body: io.NopCloser(bytes.NewReader([]byte(responseJson)))} s.mockHttpClient.On("Do", testGetSelfRequest).Return(&responseUser, nil).Once() - _, actualError := s.client.GetSelfUser(testUserEmail) + _, actualError := s.client.GetSelfUser(context.Background(), testUserEmail) s.NotNil(actualError) }) s.Run("Should return shield user on success", func() { @@ -599,7 +600,7 @@ func (s *ClientTestSuite) TestGetSelfUser() { Email: "test_user@email.com", } - user, actualError := s.client.GetSelfUser(testUserEmail) + user, actualError := s.client.GetSelfUser(context.Background(), testUserEmail) s.EqualValues(expectedUser, user) s.Nil(actualError) }) diff --git a/plugins/providers/shield/provider.go b/plugins/providers/shield/provider.go index c2c33177b..6b0c797f9 100644 --- a/plugins/providers/shield/provider.go +++ b/plugins/providers/shield/provider.go @@ -1,9 +1,11 @@ package shield import ( + "context" + pv "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/mitchellh/mapstructure" ) @@ -39,7 +41,7 @@ func (p *provider) CreateConfig(pc *domain.ProviderConfig) error { return c.ParseAndValidate() } -func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { +func (p *provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return nil, err @@ -62,7 +64,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, var organizations []*Organization if _, ok := resourceTypes[ResourceTypeTeam]; ok { - teams, err = client.GetTeams() + teams, err = client.GetTeams(ctx) if err != nil { return nil, err } @@ -70,7 +72,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, } if _, ok := resourceTypes[ResourceTypeProject]; ok { - projects, err = client.GetProjects() + projects, err = client.GetProjects(ctx) if err != nil { return nil, err } @@ -78,7 +80,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, } if _, ok := resourceTypes[ResourceTypeOrganization]; ok { - organizations, err = client.GetOrganizations() + organizations, err = client.GetOrganizations(ctx) if err != nil { return nil, err } @@ -141,7 +143,7 @@ func (p *provider) GetRoles(pc *domain.ProviderConfig, resourceType string) ([]* return pv.GetRoles(pc, resourceType) } -func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return err @@ -154,7 +156,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error permissions := a.GetPermissions() var user *User - if user, err = client.GetSelfUser(a.AccountID); err != nil { + if user, err = client.GetSelfUser(ctx, a.AccountID); err != nil { return nil } @@ -165,7 +167,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } for _, p := range permissions { - if err := client.GrantTeamAccess(t, user.ID, p); err != nil { + if err := client.GrantTeamAccess(ctx, t, user.ID, p); err != nil { return err } } @@ -176,7 +178,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } for _, p := range permissions { - if err := client.GrantProjectAccess(pj, user.ID, p); err != nil { + if err := client.GrantProjectAccess(ctx, pj, user.ID, p); err != nil { return err } } @@ -187,7 +189,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } for _, p := range permissions { - if err := client.GrantOrganizationAccess(o, user.ID, p); err != nil { + if err := client.GrantOrganizationAccess(ctx, o, user.ID, p); err != nil { return err } } @@ -197,7 +199,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return ErrInvalidResourceType } -func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return err @@ -210,7 +212,7 @@ func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error permissions := a.GetPermissions() var user *User - if user, err = client.GetSelfUser(a.AccountID); err != nil { + if user, err = client.GetSelfUser(ctx, a.AccountID); err != nil { return nil } @@ -221,7 +223,7 @@ func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } for _, p := range permissions { - if err := client.RevokeTeamAccess(t, user.ID, p); err != nil { + if err := client.RevokeTeamAccess(ctx, t, user.ID, p); err != nil { return err } } @@ -233,7 +235,7 @@ func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } for _, p := range permissions { - if err := client.RevokeProjectAccess(pj, user.ID, p); err != nil { + if err := client.RevokeProjectAccess(ctx, pj, user.ID, p); err != nil { return err } } @@ -245,7 +247,7 @@ func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error return err } for _, p := range permissions { - if err := client.RevokeOrganizationAccess(o, user.ID, p); err != nil { + if err := client.RevokeOrganizationAccess(ctx, o, user.ID, p); err != nil { return err } } diff --git a/plugins/providers/shield/provider_test.go b/plugins/providers/shield/provider_test.go index b7278e32a..f95cf469a 100644 --- a/plugins/providers/shield/provider_test.go +++ b/plugins/providers/shield/provider_test.go @@ -1,10 +1,11 @@ package shield_test import ( + "context" "errors" "testing" - "github.com/goto/salt/log" + "github.com/goto/guardian/pkg/log" "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" @@ -17,7 +18,7 @@ import ( func TestGetType(t *testing.T) { t.Run("should return provider type name", func(t *testing.T) { expectedTypeName := domain.ProviderTypeShield - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider(expectedTypeName, logger) actualTypeName := p.GetType() @@ -30,7 +31,7 @@ func TestCreateConfig(t *testing.T) { t.Run("should return error if there resource config is invalid", func(t *testing.T) { providerURN := "test-provider-urn" client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -82,7 +83,7 @@ func TestCreateConfig(t *testing.T) { t.Run("should not return error if parse and valid of Credentials are correct", func(t *testing.T) { providerURN := "test-provider-urn" client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -172,15 +173,16 @@ func TestCreateConfig(t *testing.T) { } func TestGetResources(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) pc := &domain.ProviderConfig{ Credentials: "invalid-creds", } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -189,7 +191,7 @@ func TestGetResources(t *testing.T) { t.Run("should return error if got any on getting team resources", func(t *testing.T) { providerURN := "test-provider-urn" client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -205,9 +207,9 @@ func TestGetResources(t *testing.T) { }, } expectedError := errors.New("client error") - client.On("GetTeams").Return(nil, expectedError).Once() + client.On("GetTeams", mock.Anything).Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(context.TODO(), pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -216,7 +218,7 @@ func TestGetResources(t *testing.T) { t.Run("should return error if got any on getting project resources", func(t *testing.T) { providerURN := "test-provider-urn" client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -232,9 +234,9 @@ func TestGetResources(t *testing.T) { }, } expectedError := errors.New("client error") - client.On("GetProjects").Return(nil, expectedError).Once() + client.On("GetProjects", mock.Anything).Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -243,7 +245,7 @@ func TestGetResources(t *testing.T) { t.Run("should return error if got any on getting organization resources", func(t *testing.T) { providerURN := "test-provider-urn" client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -259,9 +261,9 @@ func TestGetResources(t *testing.T) { }, } expectedError := errors.New("client error") - client.On("GetOrganizations").Return(nil, expectedError).Once() + client.On("GetOrganizations", mock.Anything).Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -270,7 +272,7 @@ func TestGetResources(t *testing.T) { t.Run("should return list of resources and nil error on success", func(t *testing.T) { providerURN := "test-provider-urn" client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -304,7 +306,7 @@ func TestGetResources(t *testing.T) { Admins: []string{"testTeamAdmin@gmail.com"}, }, } - client.On("GetTeams").Return(expectedTeams, nil).Once() + client.On("GetTeams", mock.Anything).Return(expectedTeams, nil).Once() expectedProjects := []*shield.Project{ { @@ -314,7 +316,7 @@ func TestGetResources(t *testing.T) { Admins: []string{"testProjectAdmin@gmail.com"}, }, } - client.On("GetProjects").Return(expectedProjects, nil).Once() + client.On("GetProjects", mock.Anything).Return(expectedProjects, nil).Once() expectedOrganizations := []*shield.Organization{ { @@ -324,7 +326,7 @@ func TestGetResources(t *testing.T) { }, } - client.On("GetOrganizations").Return(expectedOrganizations, nil).Once() + client.On("GetOrganizations", mock.Anything).Return(expectedOrganizations, nil).Once() expectedResources := []*domain.Resource{ { @@ -365,7 +367,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -373,8 +375,10 @@ func TestGetResources(t *testing.T) { } func TestGrantAccess(t *testing.T) { + ctx := context.Background() + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) t.Run("should return error if credentials is invalid", func(t *testing.T) { - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) pc := &domain.ProviderConfig{ @@ -398,7 +402,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -406,7 +410,7 @@ func TestGrantAccess(t *testing.T) { providerURN := "test-provider-urn" expectedError := errors.New("invalid resource type") client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ @@ -420,7 +424,7 @@ func TestGrantAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -449,7 +453,7 @@ func TestGrantAccess(t *testing.T) { AccountID: expectedUserEmail, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -459,7 +463,7 @@ func TestGrantAccess(t *testing.T) { providerURN := "test-provider-urn" expectedError := errors.New("client error") client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -472,7 +476,7 @@ func TestGrantAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GetSelfUser", mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedUserEmail).Return(expectedUser, nil).Once() client.On("GrantTeamAccess", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() pc := &domain.ProviderConfig{ @@ -514,14 +518,14 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) t.Run("should return nil error if granting access is successful", func(t *testing.T) { providerURN := "test-provider-urn" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) client := new(mocks.ShieldClient) expectedTeam := &shield.Team{ Name: "team_1", @@ -540,7 +544,7 @@ func TestGrantAccess(t *testing.T) { p.Clients = map[string]shield.ShieldClient{ providerURN: client, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GetSelfUser", mock.MatchedBy(func(ctx context.Context) bool { return true }), expectedUserEmail).Return(expectedUser, nil).Once() client.On("GrantTeamAccess", expectedTeam, expectedUser.ID, expectedRole).Return(nil).Once() pc := &domain.ProviderConfig{ @@ -583,7 +587,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -594,7 +598,7 @@ func TestGrantAccess(t *testing.T) { providerURN := "test-provider-urn" expectedError := errors.New("client error") client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -607,8 +611,8 @@ func TestGrantAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("GrantProjectAccess", mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GrantProjectAccess", mockCtx, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -644,14 +648,14 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) t.Run("should return nil error if granting access is successful", func(t *testing.T) { providerURN := "test-provider-urn" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) client := new(mocks.ShieldClient) expectedProject := &shield.Project{ Name: "project_1", @@ -670,8 +674,8 @@ func TestGrantAccess(t *testing.T) { providerURN: client, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("GrantProjectAccess", expectedProject, expectedUser.ID, expectedRole).Return(nil).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GrantProjectAccess", mockCtx, expectedProject, expectedUser.ID, expectedRole).Return(nil).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -708,7 +712,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -719,7 +723,7 @@ func TestGrantAccess(t *testing.T) { providerURN := "test-provider-urn" expectedError := errors.New("client error") client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -732,8 +736,8 @@ func TestGrantAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("GrantOrganizationAccess", mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GrantOrganizationAccess", mockCtx, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -768,14 +772,14 @@ func TestGrantAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) t.Run("should return nil error if granting access is successful", func(t *testing.T) { providerURN := "test-provider-urn" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) client := new(mocks.ShieldClient) expectedOrganization := &shield.Organization{ Name: "org_1", @@ -794,8 +798,8 @@ func TestGrantAccess(t *testing.T) { providerURN: client, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("GrantOrganizationAccess", expectedOrganization, expectedUser.ID, expectedRole).Return(nil).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GrantOrganizationAccess", mockCtx, expectedOrganization, expectedUser.ID, expectedRole).Return(nil).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -831,7 +835,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -839,8 +843,10 @@ func TestGrantAccess(t *testing.T) { } func TestRevokeAccess(t *testing.T) { + ctx := context.Background() + mockCtx := mock.MatchedBy(func(ctx context.Context) bool { return true }) t.Run("should return error if credentials is invalid", func(t *testing.T) { - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) pc := &domain.ProviderConfig{ @@ -864,14 +870,14 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Error(t, actualError) }) t.Run("should return error if resource type in unknown", func(t *testing.T) { providerURN := "test-provider-urn" client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -885,7 +891,7 @@ func TestRevokeAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -913,7 +919,7 @@ func TestRevokeAccess(t *testing.T) { AccountID: expectedUserEmail, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -922,7 +928,7 @@ func TestRevokeAccess(t *testing.T) { providerURN := "test-provider-urn" expectedError := errors.New("client error") client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -935,8 +941,8 @@ func TestRevokeAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("RevokeTeamAccess", mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("RevokeTeamAccess", mockCtx, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -977,14 +983,14 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) t.Run("should return nil error if revoking team access is successful", func(t *testing.T) { providerURN := "test-provider-urn" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) client := new(mocks.ShieldClient) expectedTeam := &shield.Team{ Name: "team_1", @@ -1011,8 +1017,8 @@ func TestRevokeAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("RevokeTeamAccess", expectedTeam, expectedUser.ID, expectedRole).Return(nil).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("RevokeTeamAccess", mockCtx, expectedTeam, expectedUser.ID, expectedRole).Return(nil).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -1055,7 +1061,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -1067,7 +1073,7 @@ func TestRevokeAccess(t *testing.T) { providerURN := "test-provider-urn" expectedError := errors.New("client error") client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -1080,9 +1086,9 @@ func TestRevokeAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() - client.On("RevokeProjectAccess", mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("RevokeProjectAccess", mockCtx, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -1118,7 +1124,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1133,7 +1139,7 @@ func TestRevokeAccess(t *testing.T) { Admins: []string{"testAdmin@email.com"}, } expectedRole := "admins" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ @@ -1147,8 +1153,8 @@ func TestRevokeAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("RevokeProjectAccess", expectedProject, expectedUser.ID, expectedRole).Return(nil).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("RevokeProjectAccess", mockCtx, expectedProject, expectedUser.ID, expectedRole).Return(nil).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -1186,7 +1192,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -1198,7 +1204,7 @@ func TestRevokeAccess(t *testing.T) { providerURN := "test-provider-urn" expectedError := errors.New("client error") client := new(mocks.ShieldClient) - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ providerURN: client, @@ -1211,8 +1217,8 @@ func TestRevokeAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("RevokeOrganizationAccess", mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("RevokeOrganizationAccess", mockCtx, mock.Anything, mock.Anything, mock.Anything).Return(expectedError).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -1248,7 +1254,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1262,7 +1268,7 @@ func TestRevokeAccess(t *testing.T) { Admins: []string{"testAdmin@email.com"}, } expectedRole := "admins" - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) p.Clients = map[string]shield.ShieldClient{ @@ -1275,8 +1281,8 @@ func TestRevokeAccess(t *testing.T) { Email: expectedUserEmail, } - client.On("GetSelfUser", expectedUserEmail).Return(expectedUser, nil).Once() - client.On("RevokeOrganizationAccess", expectedOrganization, expectedUser.ID, expectedRole).Return(nil).Once() + client.On("GetSelfUser", mockCtx, expectedUserEmail).Return(expectedUser, nil).Once() + client.On("RevokeOrganizationAccess", mockCtx, expectedOrganization, expectedUser.ID, expectedRole).Return(nil).Once() pc := &domain.ProviderConfig{ Credentials: shield.Credentials{ @@ -1313,7 +1319,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) client.AssertExpectations(t) @@ -1323,7 +1329,7 @@ func TestRevokeAccess(t *testing.T) { func TestGetAccountTypes(t *testing.T) { expectedAccountType := []string{"user"} - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("", logger) actualAccountType := p.GetAccountTypes() @@ -1333,7 +1339,7 @@ func TestGetAccountTypes(t *testing.T) { func TestGetRoles(t *testing.T) { t.Run("should return error if resource type is invalid", func(t *testing.T) { - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("shield", logger) validConfig := &domain.ProviderConfig{ Type: "shield", @@ -1372,7 +1378,7 @@ func TestGetRoles(t *testing.T) { }) t.Run("should return roles specified in the provider config", func(t *testing.T) { - logger := log.NewLogrus(log.LogrusWithLevel("info")) + logger := log.NewCtxLogger("info", []string{"test"}) p := shield.NewProvider("shield", logger) expectedRoles := []*domain.Role{ diff --git a/plugins/providers/tableau/provider.go b/plugins/providers/tableau/provider.go index 1b6e48fce..43fbb255a 100644 --- a/plugins/providers/tableau/provider.go +++ b/plugins/providers/tableau/provider.go @@ -1,6 +1,8 @@ package tableau import ( + "context" + pv "github.com/goto/guardian/core/provider" "github.com/goto/guardian/domain" "github.com/mitchellh/mapstructure" @@ -37,7 +39,7 @@ func (p *provider) CreateConfig(pc *domain.ProviderConfig) error { return c.EncryptCredentials() } -func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, error) { +func (p *provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return nil, err @@ -119,7 +121,7 @@ func (p *provider) GetResources(pc *domain.ProviderConfig) ([]*domain.Resource, return resources, nil } -func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return err @@ -231,7 +233,7 @@ func (p *provider) GrantAccess(pc *domain.ProviderConfig, a domain.Grant) error return ErrInvalidResourceType } -func (p *provider) RevokeAccess(pc *domain.ProviderConfig, a domain.Grant) error { +func (p *provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, a domain.Grant) error { var creds Credentials if err := mapstructure.Decode(pc.Credentials, &creds); err != nil { return err diff --git a/plugins/providers/tableau/provider_test.go b/plugins/providers/tableau/provider_test.go index 9ce5ceb9d..7a6e1cdf2 100644 --- a/plugins/providers/tableau/provider_test.go +++ b/plugins/providers/tableau/provider_test.go @@ -1,6 +1,7 @@ package tableau_test import ( + "context" "errors" "testing" @@ -279,6 +280,7 @@ func TestCreateConfig(t *testing.T) { } func TestGetResources(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) p := tableau.NewProvider("", crypto) @@ -287,7 +289,7 @@ func TestGetResources(t *testing.T) { Credentials: "invalid-creds", } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.Error(t, actualError) @@ -305,7 +307,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -341,7 +343,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Error(t, actualError) assert.Nil(t, actualResources) @@ -368,7 +370,7 @@ func TestGetResources(t *testing.T) { expectedError := errors.New("client error") client.On("GetWorkbooks").Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -395,7 +397,7 @@ func TestGetResources(t *testing.T) { expectedError := errors.New("client error") client.On("GetFlows").Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -422,7 +424,7 @@ func TestGetResources(t *testing.T) { expectedError := errors.New("client error") client.On("GetDataSources").Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -449,7 +451,7 @@ func TestGetResources(t *testing.T) { expectedError := errors.New("client error") client.On("GetViews").Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -476,7 +478,7 @@ func TestGetResources(t *testing.T) { expectedError := errors.New("client error") client.On("GetMetrics").Return(nil, expectedError).Once() - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Nil(t, actualResources) assert.EqualError(t, actualError, expectedError.Error()) @@ -632,7 +634,7 @@ func TestGetResources(t *testing.T) { }, } - actualResources, actualError := p.GetResources(pc) + actualResources, actualError := p.GetResources(ctx, pc) assert.Equal(t, expectedResources, actualResources) assert.Nil(t, actualError) @@ -641,6 +643,7 @@ func TestGetResources(t *testing.T) { } func TestGrantAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) p := tableau.NewProvider("", crypto) @@ -670,7 +673,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -710,7 +713,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -752,7 +755,7 @@ func TestGrantAccess(t *testing.T) { Role: "test-role", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -851,7 +854,7 @@ func TestGrantAccess(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - actualError := p.GrantAccess(tc.pc, tc.a) + actualError := p.GrantAccess(ctx, tc.pc, tc.a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -910,7 +913,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -1011,7 +1014,7 @@ func TestGrantAccess(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - actualError := p.GrantAccess(tc.pc, tc.a) + actualError := p.GrantAccess(ctx, tc.pc, tc.a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1070,7 +1073,7 @@ func TestGrantAccess(t *testing.T) { ID: "999", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -1171,7 +1174,7 @@ func TestGrantAccess(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - actualError := p.GrantAccess(tc.pc, tc.a) + actualError := p.GrantAccess(ctx, tc.pc, tc.a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1230,7 +1233,7 @@ func TestGrantAccess(t *testing.T) { ID: "99", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -1330,7 +1333,7 @@ func TestGrantAccess(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - actualError := p.GrantAccess(tc.pc, tc.a) + actualError := p.GrantAccess(ctx, tc.pc, tc.a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1389,7 +1392,7 @@ func TestGrantAccess(t *testing.T) { ID: "99", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -1490,7 +1493,7 @@ func TestGrantAccess(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - actualError := p.GrantAccess(tc.pc, tc.a) + actualError := p.GrantAccess(ctx, tc.pc, tc.a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1549,7 +1552,7 @@ func TestGrantAccess(t *testing.T) { ID: "99", } - actualError := p.GrantAccess(pc, a) + actualError := p.GrantAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -1557,6 +1560,7 @@ func TestGrantAccess(t *testing.T) { } func TestRevokeAccess(t *testing.T) { + ctx := context.Background() t.Run("should return error if credentials is invalid", func(t *testing.T) { crypto := new(mocks.Crypto) p := tableau.NewProvider("", crypto) @@ -1586,7 +1590,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Error(t, actualError) }) @@ -1626,7 +1630,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1668,7 +1672,7 @@ func TestRevokeAccess(t *testing.T) { Role: "test-role", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1719,7 +1723,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1776,7 +1780,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -1828,7 +1832,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1886,7 +1890,7 @@ func TestRevokeAccess(t *testing.T) { ID: "999", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -1938,7 +1942,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -1996,7 +2000,7 @@ func TestRevokeAccess(t *testing.T) { ID: "99", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -2048,7 +2052,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -2106,7 +2110,7 @@ func TestRevokeAccess(t *testing.T) { ID: "99", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) }) @@ -2158,7 +2162,7 @@ func TestRevokeAccess(t *testing.T) { Permissions: []string{"test-permission-config"}, } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.EqualError(t, actualError, expectedError.Error()) }) @@ -2216,7 +2220,7 @@ func TestRevokeAccess(t *testing.T) { ID: "99", } - actualError := p.RevokeAccess(pc, a) + actualError := p.RevokeAccess(ctx, pc, a) assert.Nil(t, actualError) })