diff --git a/pkg/webhook/config/claimstypes.go b/pkg/webhook/config/claimstypes.go new file mode 100644 index 0000000000..26541266a0 --- /dev/null +++ b/pkg/webhook/config/claimstypes.go @@ -0,0 +1,23 @@ +package config + +type ClaimsMutatingHook struct { + // Name is the name of the webhook + Name string `json:"name"` + Type HookType `json:"type"` + AcceptedClaims []string `json:"claims"` + Config *WebhookConfig +} + +type ClaimsValidatingHook struct { + // Name is the name of the webhook + Name string `json:"name"` + // To be modified to enum? + Type HookType `json:"type"` + AcceptedClaims []string `json:"claims"` + Config *WebhookConfig `json:"config"` +} + +type TokenClaimsHooks struct { + MutatingHooks []ClaimsMutatingHook `json:"mutatingHooks"` + ValidatingHooks []ClaimsValidatingHook `json:"validatingHooks"` +} diff --git a/pkg/webhook/config/common.go b/pkg/webhook/config/common.go new file mode 100644 index 0000000000..b093b43c14 --- /dev/null +++ b/pkg/webhook/config/common.go @@ -0,0 +1,14 @@ +package config + +type WebhookConfig struct { + URL string `json:"url"` + InsecureSkipVerify bool `json:"insecureSkipVerify"` + TLSRootCAFile string `json:"tlsRootCAFile"` + ClientAuthentication *ClientAuthentication `json:"clientAuthentication"` +} + +type ClientAuthentication struct { + ClientCertificateFile string `json:"clientCertificateFile"` + ClientKeyFile string `json:"clientKeyFile"` + ClientCAFile string `json:"clientCAFile"` +} diff --git a/pkg/webhook/config/connectortypes.go b/pkg/webhook/config/connectortypes.go new file mode 100644 index 0000000000..9642d59392 --- /dev/null +++ b/pkg/webhook/config/connectortypes.go @@ -0,0 +1,24 @@ +package config + +// HookRequestScope is the context of the request +type HookRequestScope struct { + // Headers is the headers of the request + Headers []string `json:"headers"` + // Params is the params of the request + Params []string `json:"params"` +} + +type ConnectorFilterHook struct { + // Name is the name of the webhook + Name string `json:"name"` + // To be modified to enum? + Type HookType `json:"type"` + // RequestScope is the context of the request + RequestScope *HookRequestScope `json:"requestContext"` + // Config is the configuration of the webhook + Config *WebhookConfig `json:"config"` +} + +type ConnectorFilterHooks struct { + FilterHooks []*ConnectorFilterHook `json:"filterHooks"` +} diff --git a/pkg/webhook/config/consts.go b/pkg/webhook/config/consts.go new file mode 100644 index 0000000000..c641c55710 --- /dev/null +++ b/pkg/webhook/config/consts.go @@ -0,0 +1,7 @@ +package config + +type HookType string + +const ( + External HookType = "external" +) diff --git a/pkg/webhook/helpers/helpers.go b/pkg/webhook/helpers/helpers.go new file mode 100644 index 0000000000..0d339c6dd6 --- /dev/null +++ b/pkg/webhook/helpers/helpers.go @@ -0,0 +1,135 @@ +//go:generate go run -mod mod go.uber.org/mock/mockgen -destination=mock_helpers.go -package=helpers --source=helpers.go WebhookHTTPHelpers +package helpers + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + + "github.com/dexidp/dex/pkg/webhook/config" +) + +type WebhookHTTPHelpers interface { + CallWebhook(jsonData []byte) ([]byte, error) +} + +type webhookHTTPHelpersImpl struct { + transport *http.Transport + url string +} + +var _ WebhookHTTPHelpers = &webhookHTTPHelpersImpl{} + +func NewWebhookHTTPHelpers(cfg *config.WebhookConfig) (WebhookHTTPHelpers, error) { + if cfg == nil { + return nil, errors.New("webhook config is nil") + } + if cfg.URL == "" { + return nil, errors.New("webhook url is empty") + } + transport, err := createTransport(cfg) + if err != nil { + return nil, err + } + return &webhookHTTPHelpersImpl{ + transport: transport, + url: cfg.URL, + }, nil +} + +func (h *webhookHTTPHelpersImpl) CallWebhook(jsonData []byte) ([]byte, error) { + req, err := http.NewRequest("POST", h.url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Transport: h.transport} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("could not read response body: %v", err) + } + + return body, nil +} + +func createTransport(cfg *config.WebhookConfig) (*http.Transport, error) { + p, err := url.Parse(cfg.URL) + if err != nil { + return nil, fmt.Errorf("could not parse url: %v", err) + } + switch p.Scheme { + case "http": + return &http.Transport{}, nil + case "https": + return createHTTPSTransport(cfg) + default: + return nil, fmt.Errorf("unsupported scheme: %s", p.Scheme) + } +} + +func createHTTPSTransport(cfg *config.WebhookConfig) (*http.Transport, error) { + var err error + rootCertPool := x509.NewCertPool() + if cfg.TLSRootCAFile != "" { + rootCertPool, err = readCACert(cfg.TLSRootCAFile) + if err != nil { + return nil, fmt.Errorf("failed to read file %q: %w", cfg.TLSRootCAFile, err) + } + } + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCertPool, + InsecureSkipVerify: cfg.InsecureSkipVerify, + MinVersion: tls.VersionTLS13, + }, + } + + if cfg.ClientAuthentication != nil { + clientCert, err := ReadCertificate(cfg.ClientAuthentication.ClientCertificateFile, + cfg.ClientAuthentication.ClientKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to read certificate: %w", err) + } + tr.TLSClientConfig.Certificates = []tls.Certificate{*clientCert} + } + + return tr, nil +} + +func readCACert(caPath string) (*x509.CertPool, error) { + caCertPool := x509.NewCertPool() + // Load CA cert + caCert, err := os.ReadFile(caPath) + if err != nil { + return nil, err + } + caCertPool.AppendCertsFromPEM(caCert) + return caCertPool, nil +} + +func ReadCertificate(certPath, keyPath string) (*tls.Certificate, error) { + cer, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + return &cer, nil +} diff --git a/pkg/webhook/helpers/helpers_test.go b/pkg/webhook/helpers/helpers_test.go new file mode 100644 index 0000000000..b9a3ff1f83 --- /dev/null +++ b/pkg/webhook/helpers/helpers_test.go @@ -0,0 +1,200 @@ +package helpers + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/pkg/webhook/config" +) + +func TestNewWebhookHTTPHelpers_InsecureSkip(t *testing.T) { + h, err := NewWebhookHTTPHelpers(&config.WebhookConfig{ + URL: "https://test.com", + InsecureSkipVerify: true, + }) + assert.NoError(t, err) + assert.Equal(t, h.(*webhookHTTPHelpersImpl).url, "https://test.com") + assert.Equal(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.InsecureSkipVerify, true) +} + +func TestNewWebhookHTTPHelpers_TLSRootCAFile(t *testing.T) { + dir, err := os.MkdirTemp("", "prefix") + require.NoError(t, err) + defer os.RemoveAll(dir) + caCertPem, rootCAPool, _, _, err := generateCA() + require.NoError(t, err) + require.NotNil(t, caCertPem) + filePath := filepath.Join(dir, "ca.crt") + + err = os.WriteFile(filePath, caCertPem, 0o644) + require.NoError(t, err) + h, err := NewWebhookHTTPHelpers(&config.WebhookConfig{ + URL: "https://test.com", + InsecureSkipVerify: false, + TLSRootCAFile: filePath, + }) + assert.NoError(t, err) + assert.Equal(t, h.(*webhookHTTPHelpersImpl).url, "https://test.com") + assert.Equal(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.InsecureSkipVerify, false) + assert.NotNil(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.RootCAs) + assert.True(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.RootCAs.Equal(rootCAPool)) +} + +func TestNewWebhookHTTPHelpers_TLSClientCertFile(t *testing.T) { + dir, err := os.MkdirTemp("", "prefix") + require.NoError(t, err) + defer os.RemoveAll(dir) + caCertPem, rootCAPool, caTemplate, caPrivateKey, err := generateCA() + require.NoError(t, err) + require.NotNil(t, caCertPem) + filePath := filepath.Join(dir, "ca.crt") + err = os.WriteFile(filePath, caCertPem, 0o644) + require.NoError(t, err) + + certPEM, cert, keyPEM, key, err := generateTestCertificates(true, *caTemplate, caPrivateKey) + require.NoError(t, err) + certFilePath := filepath.Join(dir, "cert.pem") + keyFilePath := filepath.Join(dir, "key.pem") + err = os.WriteFile(certFilePath, certPEM, 0o644) + require.NoError(t, err) + err = os.WriteFile(keyFilePath, keyPEM, 0o644) + require.NoError(t, err) + + h, err := NewWebhookHTTPHelpers(&config.WebhookConfig{ + URL: "https://test.com", + InsecureSkipVerify: false, + TLSRootCAFile: filePath, + ClientAuthentication: &config.ClientAuthentication{ + ClientCertificateFile: certFilePath, + ClientKeyFile: keyFilePath, + ClientCAFile: filePath, + }, + }) + + assert.NoError(t, err) + assert.Equal(t, h.(*webhookHTTPHelpersImpl).url, "https://test.com") + assert.Equal(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.InsecureSkipVerify, false) + assert.NotNil(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.RootCAs) + assert.True(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.RootCAs.Equal(rootCAPool)) + assert.NotNil(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.Certificates) + assert.Equal(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.Certificates[0], *cert) + assert.NotNil(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.Certificates[0].PrivateKey) + assert.Equal(t, h.(*webhookHTTPHelpersImpl).transport.TLSClientConfig.Certificates[0].PrivateKey, key) +} + +func generateTestCertificates(clientCert bool, caTemplate x509.Certificate, + caPrivateKey *ecdsa.PrivateKey, +) ([]byte, *tls.Certificate, []byte, *ecdsa.PrivateKey, error) { + // Generate a new private key (ECDSA P-256) + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, nil, nil, err + } + + keyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + if clientCert { + keyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} + } + + // Create a template for the certificate + certTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "localhost", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 365), // Valid for one year + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + BasicConstraintsValid: true, + IsCA: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + ExtKeyUsage: keyUsage, + DNSNames: []string{"client.example.com", "alt.example.com"}, // SAN field for DNS names + + } + + // Generate the certificate signed by CA + certDERBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &caTemplate, &privateKey.PublicKey, caPrivateKey) + if err != nil { + return nil, nil, nil, nil, err + } + + // PEM encode the certificate + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDERBytes, + }) + + // PEM encode the private key + keyPEM, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return nil, nil, nil, nil, err + } + keyPEMBlock := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyPEM, + }) + + // Create a TLS certificate using the PEM-encoded certificate and private key + tlsCert, err := tls.X509KeyPair(certPEM, keyPEMBlock) + if err != nil { + return nil, nil, nil, nil, err + } + + return certPEM, &tlsCert, keyPEMBlock, privateKey, nil +} + +func generateCA() ([]byte, *x509.CertPool, *x509.Certificate, *ecdsa.PrivateKey, error) { + // Generate a new private key (ECDSA P-256) + caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, nil, nil, err + } + + // Create a template for the CA certificate + caTemplate := x509.Certificate{ + DNSNames: []string{"test.com"}, + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "My CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 365 * 10), // Valid for 10 years + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + } + + // Generate the CA certificate + caCertDERBytes, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return nil, nil, nil, nil, err + } + + // PEM encode the CA certificate + caCertPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: caCertDERBytes, + }) + + caCertPool := x509.NewCertPool() + if ok := caCertPool.AppendCertsFromPEM(caCertPEM); !ok { + panic("Failed to append CA certificate to CertPool") + } + + return caCertPEM, caCertPool, &caTemplate, caPrivateKey, nil +} diff --git a/pkg/webhook/helpers/mock_helpers.go b/pkg/webhook/helpers/mock_helpers.go new file mode 100644 index 0000000000..8a878b2e8b --- /dev/null +++ b/pkg/webhook/helpers/mock_helpers.go @@ -0,0 +1,53 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: helpers.go +// +// Generated by this command: +// +// mockgen -destination=mock_helpers.go -package=helpers --source=helpers.go WebhookHTTPHelpers +// +// Package helpers is a generated GoMock package. +package helpers + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockWebhookHTTPHelpers is a mock of WebhookHTTPHelpers interface. +type MockWebhookHTTPHelpers struct { + ctrl *gomock.Controller + recorder *MockWebhookHTTPHelpersMockRecorder +} + +// MockWebhookHTTPHelpersMockRecorder is the mock recorder for MockWebhookHTTPHelpers. +type MockWebhookHTTPHelpersMockRecorder struct { + mock *MockWebhookHTTPHelpers +} + +// NewMockWebhookHTTPHelpers creates a new mock instance. +func NewMockWebhookHTTPHelpers(ctrl *gomock.Controller) *MockWebhookHTTPHelpers { + mock := &MockWebhookHTTPHelpers{ctrl: ctrl} + mock.recorder = &MockWebhookHTTPHelpersMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWebhookHTTPHelpers) EXPECT() *MockWebhookHTTPHelpersMockRecorder { + return m.recorder +} + +// CallWebhook mocks base method. +func (m *MockWebhookHTTPHelpers) CallWebhook(jsonData []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CallWebhook", jsonData) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CallWebhook indicates an expected call of CallWebhook. +func (mr *MockWebhookHTTPHelpersMockRecorder) CallWebhook(jsonData any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallWebhook", reflect.TypeOf((*MockWebhookHTTPHelpers)(nil).CallWebhook), jsonData) +}