diff --git a/domain/provider.go b/domain/provider.go index 51799b05f..f4ad859f8 100644 --- a/domain/provider.go +++ b/domain/provider.go @@ -17,6 +17,7 @@ const ( ProviderTypePolicyTag = "dataplex" ProviderTypeShield = "shield" ProviderTypeGitlab = "gitlab" + ProviderTypeGate = "gate" ) // Role is the configuration to define a role and mapping the permissions in the provider diff --git a/internal/server/services.go b/internal/server/services.go index 04a762fe6..04734604d 100644 --- a/internal/server/services.go +++ b/internal/server/services.go @@ -3,9 +3,6 @@ package server import ( "context" - "github.com/goto/guardian/plugins/providers/dataplex" - "github.com/goto/guardian/plugins/providers/gitlab" - "github.com/go-playground/validator/v10" "github.com/google/uuid" "github.com/goto/guardian/core" @@ -26,8 +23,11 @@ import ( "github.com/goto/guardian/plugins/identities" "github.com/goto/guardian/plugins/notifiers" "github.com/goto/guardian/plugins/providers/bigquery" + "github.com/goto/guardian/plugins/providers/dataplex" + "github.com/goto/guardian/plugins/providers/gate" "github.com/goto/guardian/plugins/providers/gcloudiam" "github.com/goto/guardian/plugins/providers/gcs" + "github.com/goto/guardian/plugins/providers/gitlab" "github.com/goto/guardian/plugins/providers/grafana" "github.com/goto/guardian/plugins/providers/metabase" "github.com/goto/guardian/plugins/providers/noop" @@ -122,6 +122,7 @@ func InitServices(deps ServiceDeps) (*Services, error) { dataplex.NewProvider(domain.ProviderTypePolicyTag, deps.Crypto), shield.NewProvider(domain.ProviderTypeShield, deps.Logger), gitlab.NewProvider(domain.ProviderTypeGitlab, deps.Crypto, deps.Logger), + gate.NewProvider(domain.ProviderTypeGate, deps.Crypto), } iamManager := identities.NewManager(deps.Crypto, deps.Validator) diff --git a/pkg/gate/client.go b/pkg/gate/client.go new file mode 100644 index 000000000..ac58d5ae5 --- /dev/null +++ b/pkg/gate/client.go @@ -0,0 +1,139 @@ +package gate + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" +) + +type Client struct { + baseURL *url.URL + options *options +} + +func NewClient(baseURL string, opts ...ClientOption) (*Client, error) { + url, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + client := &Client{ + baseURL: url, + options: &options{ + httpClient: http.DefaultClient, + }, + } + for _, o := range opts { + o(client.options) + } + return client, nil +} + +type ListGroupsRequest struct { + Page int + PerPage int +} + +func (c *Client) ListGroups(ctx context.Context, req *ListGroupsRequest) ([]*Group, *http.Response, error) { + path := "/api/v1/groups" + r, err := c.newRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return nil, nil, err + } + + q := r.URL.Query() + if req.Page != 0 { + q.Add("page", strconv.Itoa(req.Page)) + } + if req.PerPage != 0 { + q.Add("per_page", strconv.Itoa(req.PerPage)) + } + r.URL.RawQuery = q.Encode() + + res, err := c.options.httpClient.Do(r) + if err != nil { + return nil, res, err + } + + var resBody []*Group + if err := parseResponseBody(res.Body, &resBody); err != nil { + return nil, res, err + } + + return resBody, res, nil +} + +func (c *Client) AddUserToGroup(ctx context.Context, groupID, userID int) (*http.Response, error) { + path := fmt.Sprintf("/api/v1/groups/%d/users", groupID) + reqBody := map[string]any{"user_id": userID} + r, err := c.newRequest(ctx, http.MethodPost, path, reqBody) + if err != nil { + return nil, err + } + + res, err := c.options.httpClient.Do(r) + if err != nil { + return res, err + } + + return res, nil +} + +func (c *Client) RemoveUserFromGroup(ctx context.Context, groupID, userID int) (*http.Response, error) { + path := fmt.Sprintf("/api/v1/groups/%d/users/%d", groupID, userID) + r, err := c.newRequest(ctx, http.MethodDelete, path, nil) + if err != nil { + return nil, err + } + + res, err := c.options.httpClient.Do(r) + if err != nil { + return res, err + } + + return res, nil +} + +func (c *Client) newRequest(ctx context.Context, method, path string, body interface{}) (*http.Request, error) { + url, err := c.baseURL.Parse(path) + if err != nil { + return nil, err + } + + var reqBody io.ReadWriter + if body != nil { + reqBody = new(bytes.Buffer) + if err := json.NewEncoder(reqBody).Encode(body); err != nil { + return nil, err + } + } + + req, err := http.NewRequestWithContext(ctx, method, url.String(), reqBody) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/json") + + // auth + if c.options.token != "" { + if c.options.queryParamAuthKey != "" { + q := req.URL.Query() + q.Add(c.options.queryParamAuthKey, c.options.token) + req.URL.RawQuery = q.Encode() + } else { + req.Header.Add("Authorization", c.options.token) + } + } + + return req, nil +} + +func parseResponseBody(resBody io.ReadCloser, v interface{}) error { + defer resBody.Close() + return json.NewDecoder(resBody).Decode(v) +} diff --git a/pkg/gate/model.go b/pkg/gate/model.go new file mode 100644 index 000000000..609151dd0 --- /dev/null +++ b/pkg/gate/model.go @@ -0,0 +1,12 @@ +package gate + +type Group struct { + ID int `json:"id"` + Name string `json:"name"` + GID int `json:"gid"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + DeletedBy string `json:"deleted_by"` + DeletedAt string `json:"deleted_at"` + Description *string `json:"description"` +} diff --git a/pkg/gate/options.go b/pkg/gate/options.go new file mode 100644 index 000000000..6c5eb9c3c --- /dev/null +++ b/pkg/gate/options.go @@ -0,0 +1,31 @@ +package gate + +import ( + "net/http" +) + +type options struct { + httpClient *http.Client + token string + queryParamAuthKey string +} + +type ClientOption func(*options) + +func WithHTTPClient(httpClient *http.Client) ClientOption { + return func(opts *options) { + opts.httpClient = httpClient + } +} + +func WithAPIKey(token string) ClientOption { + return func(opts *options) { + opts.token = token + } +} + +func WithQueryParamAuthMethod() ClientOption { + return func(opts *options) { + opts.queryParamAuthKey = "token" + } +} diff --git a/plugins/providers/gate/mocks/crypto.go b/plugins/providers/gate/mocks/crypto.go new file mode 100644 index 000000000..b7fb0f51a --- /dev/null +++ b/plugins/providers/gate/mocks/crypto.go @@ -0,0 +1,136 @@ +// Code generated by mockery v2.33.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// Encryptor is an autogenerated mock type for the encryptor type +type Encryptor struct { + mock.Mock +} + +type Encryptor_Expecter struct { + mock *mock.Mock +} + +func (_m *Encryptor) EXPECT() *Encryptor_Expecter { + return &Encryptor_Expecter{mock: &_m.Mock} +} + +// Decrypt provides a mock function with given fields: _a0 +func (_m *Encryptor) Decrypt(_a0 string) (string, error) { + ret := _m.Called(_a0) + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Encryptor_Decrypt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Decrypt' +type Encryptor_Decrypt_Call struct { + *mock.Call +} + +// Decrypt is a helper method to define mock.On call +// - _a0 string +func (_e *Encryptor_Expecter) Decrypt(_a0 interface{}) *Encryptor_Decrypt_Call { + return &Encryptor_Decrypt_Call{Call: _e.mock.On("Decrypt", _a0)} +} + +func (_c *Encryptor_Decrypt_Call) Run(run func(_a0 string)) *Encryptor_Decrypt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Encryptor_Decrypt_Call) Return(_a0 string, _a1 error) *Encryptor_Decrypt_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Encryptor_Decrypt_Call) RunAndReturn(run func(string) (string, error)) *Encryptor_Decrypt_Call { + _c.Call.Return(run) + return _c +} + +// Encrypt provides a mock function with given fields: _a0 +func (_m *Encryptor) Encrypt(_a0 string) (string, error) { + ret := _m.Called(_a0) + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Encryptor_Encrypt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Encrypt' +type Encryptor_Encrypt_Call struct { + *mock.Call +} + +// Encrypt is a helper method to define mock.On call +// - _a0 string +func (_e *Encryptor_Expecter) Encrypt(_a0 interface{}) *Encryptor_Encrypt_Call { + return &Encryptor_Encrypt_Call{Call: _e.mock.On("Encrypt", _a0)} +} + +func (_c *Encryptor_Encrypt_Call) Run(run func(_a0 string)) *Encryptor_Encrypt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Encryptor_Encrypt_Call) Return(_a0 string, _a1 error) *Encryptor_Encrypt_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Encryptor_Encrypt_Call) RunAndReturn(run func(string) (string, error)) *Encryptor_Encrypt_Call { + _c.Call.Return(run) + return _c +} + +// NewEncryptor creates a new instance of Encryptor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewEncryptor(t interface { + mock.TestingT + Cleanup(func()) +}) *Encryptor { + mock := &Encryptor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/plugins/providers/gate/provider.go b/plugins/providers/gate/provider.go new file mode 100644 index 000000000..55637d2d1 --- /dev/null +++ b/plugins/providers/gate/provider.go @@ -0,0 +1,289 @@ +package gate + +import ( + "context" + "errors" + "fmt" + "net/http" + "slices" + "strconv" + "sync" + + pv "github.com/goto/guardian/core/provider" + "github.com/goto/guardian/domain" + "github.com/goto/guardian/pkg/gate" + "github.com/goto/guardian/utils" + "github.com/mitchellh/mapstructure" +) + +const GroupResourceType = "group" + +type credentials struct { + Host string `mapstructure:"host" yaml:"host" json:"host"` + APIKey string `mapstructure:"api_key" yaml:"api_key" json:"api_key"` +} + +func (c credentials) validate() error { + if c.Host == "" { + return errors.New("host is required") + } + if c.APIKey == "" { + return errors.New("api_key is required") + } + return nil +} + +func (c *credentials) encrypt(encryptor domain.Encryptor) error { + encryptedAPIKey, err := encryptor.Encrypt(c.APIKey) + if err != nil { + return err + } + + c.APIKey = encryptedAPIKey + return nil +} + +func (c *credentials) decrypt(decryptor domain.Decryptor) error { + decryptedAPIKey, err := decryptor.Decrypt(c.APIKey) + if err != nil { + return err + } + + c.APIKey = decryptedAPIKey + return nil +} + +type config struct { + *domain.ProviderConfig +} + +func (c *config) validate() error { + // validate credentials + if c.Credentials == nil { + return fmt.Errorf("missing credentials") + } + creds, err := c.getCredentials() + if err != nil { + return err + } + if err := creds.validate(); err != nil { + return fmt.Errorf("invalid credentials: %w", err) + } + + // validate resource config + for _, rc := range c.Resources { + if rc.Type != "group" { + return fmt.Errorf("invalid resource type: %q", rc.Type) + } + + for _, role := range rc.Roles { + for _, permission := range role.Permissions { + permissionString, ok := permission.(string) + if !ok { + return fmt.Errorf("unexpected permission type: %T, expected: string", permission) + } + if permissionString != "member" { + return fmt.Errorf("invalid permission: %q", permissionString) + } + } + } + } + + return nil +} + +func (c *config) getCredentials() (*credentials, error) { + if creds, ok := c.Credentials.(credentials); ok { // parsed + return &creds, nil + } else if mapCreds, ok := c.Credentials.(map[string]interface{}); ok { // not parsed + var creds credentials + if err := mapstructure.Decode(mapCreds, &creds); err != nil { + return nil, fmt.Errorf("unable to decode credentials: %w", err) + } + return &creds, nil + } + + return nil, fmt.Errorf("invalid credentials type: %T", c.Credentials) +} + +type provider struct { + pv.UnimplementedClient + pv.PermissionManager + + typeName string + clients map[string]*gate.Client + crypto domain.Crypto + + mutex sync.Mutex +} + +func NewProvider(typeName string, crypto domain.Crypto) *provider { + return &provider{ + typeName: typeName, + clients: map[string]*gate.Client{}, + crypto: crypto, + mutex: sync.Mutex{}, + } +} + +func (p *provider) GetType() string { + return p.typeName +} + +func (p *provider) CreateConfig(pc *domain.ProviderConfig) error { + cfg := &config{pc} + if err := cfg.validate(); err != nil { + return fmt.Errorf("invalid gate config: %w", err) + } + + // encrypt sensitive config + creds, err := cfg.getCredentials() + if err != nil { + return err + } + if err := creds.encrypt(p.crypto); err != nil { + return fmt.Errorf("unable to encrypt credentials: %w", err) + } + pc.Credentials = creds + + return nil +} + +func (p *provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) ([]*domain.Resource, error) { + client, err := p.getClient(pc) + if err != nil { + return nil, err + } + + if !slices.Contains(pc.GetResourceTypes(), GroupResourceType) { + return nil, nil + } + + resources := []*domain.Resource{} + page := 1 + for { + groups, res, err := client.ListGroups(ctx, &gate.ListGroupsRequest{Page: page}) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to list groups: %s", res.Status) + } + + if len(groups) == 0 { + break + } + page += 1 + + for _, group := range groups { + groupID := strconv.Itoa(group.ID) + resources = append(resources, &domain.Resource{ + ProviderType: pc.Type, + ProviderURN: pc.URN, + Type: GroupResourceType, + URN: groupID, + Name: group.Name, + GlobalURN: utils.GetGlobalURN(pc.Type, pc.URN, GroupResourceType, groupID), + }) + } + } + + return resources, nil +} + +func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, g domain.Grant) error { + client, err := p.getClient(pc) + if err != nil { + return err + } + + groupID, err := strconv.Atoi(g.Resource.URN) + if err != nil { + return fmt.Errorf("invalid group ID: %q: %w", g.Resource.URN, err) + } + + userID, err := strconv.Atoi(g.AccountID) + if err != nil { + return fmt.Errorf("invalid user ID: %q: %w", g.AccountID, err) + } + + switch g.Resource.Type { + case GroupResourceType: + res, err := client.AddUserToGroup(ctx, groupID, userID) + if err != nil { + return fmt.Errorf("failed to add user %q to gate group %q: %w", g.AccountID, g.Resource.URN, err) + } + if res.StatusCode != http.StatusNoContent { + return fmt.Errorf("failed to add user %q to gate group %q: %s", g.AccountID, g.Resource.URN, res.Status) + } + default: + return fmt.Errorf("unexpected resource type: %q", g.Resource.Type) + } + + return nil +} + +func (p *provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, g domain.Grant) error { + client, err := p.getClient(pc) + if err != nil { + return err + } + + groupID, err := strconv.Atoi(g.Resource.URN) + if err != nil { + return fmt.Errorf("invalid group ID: %q: %w", g.Resource.URN, err) + } + + userID, err := strconv.Atoi(g.AccountID) + if err != nil { + return fmt.Errorf("invalid user ID: %q: %w", g.AccountID, err) + } + + switch g.Resource.Type { + case GroupResourceType: + res, err := client.RemoveUserFromGroup(ctx, groupID, userID) + if err != nil { + return fmt.Errorf("failed to remove user %q from gate group %q: %w", g.AccountID, g.Resource.URN, err) + } + if res.StatusCode != http.StatusNoContent { + return fmt.Errorf("failed to remove user %q from gate group %q: %s", g.AccountID, g.Resource.URN, res.Status) + } + default: + return fmt.Errorf("unexpected resource type: %q", g.Resource.Type) + } + + return nil +} + +func (p *provider) GetRoles(pc *domain.ProviderConfig, resourceType string) ([]*domain.Role, error) { + return pv.GetRoles(pc, resourceType) +} + +func (p *provider) GetAccountTypes() []string { + return []string{"gate_user_id"} +} + +func (p *provider) getClient(pc *domain.ProviderConfig) (*gate.Client, error) { + if p.clients[pc.URN] != nil { + return p.clients[pc.URN], nil + } + + config := &config{pc} + creds, err := config.getCredentials() + if err != nil { + return nil, fmt.Errorf("failed to get credentials: %w", err) + } + if err := creds.decrypt(p.crypto); err != nil { + return nil, fmt.Errorf("failed to decrypt credentials: %w", err) + } + + client, err := gate.NewClient(creds.Host, gate.WithAPIKey(creds.APIKey), gate.WithQueryParamAuthMethod()) + if err != nil { + return nil, fmt.Errorf("failed to initialize client: %w", err) + } + + p.mutex.Lock() + p.clients[pc.URN] = client + p.mutex.Unlock() + return client, nil +} diff --git a/plugins/providers/gate/provider_test.go b/plugins/providers/gate/provider_test.go new file mode 100644 index 000000000..b75099266 --- /dev/null +++ b/plugins/providers/gate/provider_test.go @@ -0,0 +1,103 @@ +package gate_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/goto/guardian/domain" + "github.com/goto/guardian/plugins/providers/gate" + "github.com/goto/guardian/plugins/providers/gate/mocks" + "github.com/stretchr/testify/assert" +) + +func TestGetType(t *testing.T) { + providerType := "gate" + p := gate.NewProvider(providerType, nil) + + actualType := p.GetType() + assert.Equal(t, providerType, actualType) +} + +func TestGetResources(t *testing.T) { + t.Run("should return resources returned by gate APIs", func(t *testing.T) { + mockCrypto := new(mocks.Encryptor) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + page := r.URL.Query().Get("page") + var resBody string + + switch page { + case "1": + resBody = `[ + { + "id": 1, + "name": "test-group-1", + "gid": 11, + "created_at": "2024-01-01T01:01:01.000Z", + "updated_at": "2024-01-01T01:01:01.000Z", + "deleted_by": null, + "deleted_at": null, + "description": null + } + ]` + case "2": + resBody = `[ + { + "id": 2, + "name": "test-group-2", + "gid": 22, + "created_at": "2024-01-01T01:01:01.000Z", + "updated_at": "2024-01-01T01:01:01.000Z", + "deleted_by": null, + "deleted_at": null, + "description": null + } + ]` + default: + resBody = `[]` + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(resBody)) + })) + providerConfig := &domain.ProviderConfig{ + Type: "gate", + URN: "gate.example.com", + Credentials: map[string]any{ + "host": ts.URL, + "api_key": "encrypted-api-key", + }, + Resources: []*domain.ResourceConfig{ + {Type: "group"}, + }, + } + + mockCrypto.EXPECT().Decrypt("encrypted-api-key").Return("decrypted-api-key", nil) + expectedResources := []*domain.Resource{ + { + ProviderType: providerConfig.Type, + ProviderURN: providerConfig.URN, + Type: gate.GroupResourceType, + URN: "1", + Name: "test-group-1", + GlobalURN: "urn:gate:gate.example.com:group:1", + }, + { + ProviderType: providerConfig.Type, + ProviderURN: providerConfig.URN, + Type: gate.GroupResourceType, + URN: "2", + Name: "test-group-2", + GlobalURN: "urn:gate:gate.example.com:group:2", + }, + } + + p := gate.NewProvider(domain.ProviderTypeGate, mockCrypto) + actualResources, err := p.GetResources(context.Background(), providerConfig) + + assert.NoError(t, err) + assert.Equal(t, expectedResources, actualResources) + mockCrypto.AssertExpectations(t) + }) +}