From a97ac98bfe61a37a140a98fa2b06990b63fabb91 Mon Sep 17 00:00:00 2001 From: Sword Date: Tue, 21 Nov 2023 05:19:24 +0000 Subject: [PATCH] Support custom idp entity id --- identity_provider.go | 21 ++++++++- identity_provider_test.go | 47 ++++++++++++++++--- ...esponse_response_with_custom_entity_id.xml | 27 +++++++++++ 3 files changed, 87 insertions(+), 8 deletions(-) create mode 100644 testdata/TestIDPMakeResponse_response_with_custom_entity_id.xml diff --git a/identity_provider.go b/identity_provider.go index abaaad68..135915e7 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -107,8 +107,11 @@ type IdentityProvider struct { AssertionMaker AssertionMaker SignatureMethod string ValidDuration *time.Duration + EntityIDConstructor EntityIDConstructor } +type EntityIDConstructor func() string + // Metadata returns the metadata structure for this identity provider. func (idp *IdentityProvider) Metadata() *EntityDescriptor { certStr := base64.StdEncoding.EncodeToString(idp.Certificate.Raw) @@ -121,7 +124,7 @@ func (idp *IdentityProvider) Metadata() *EntityDescriptor { } ed := &EntityDescriptor{ - EntityID: idp.MetadataURL.String(), + EntityID: idp.getEntityID(), ValidUntil: TimeNow().Add(validDuration), CacheDuration: validDuration, IDPSSODescriptors: []IDPSSODescriptor{ @@ -334,6 +337,20 @@ func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Re } } +// createDefaultEntityIDConstructor creates a function to return entityID from metadataURL. +func createDefaultEntityIDConstructor(metadataURL url.URL) func() string { + return func() string { + return metadataURL.String() + } +} + +func (idp *IdentityProvider) getEntityID() string { + if idp.EntityIDConstructor == nil { + return createDefaultEntityIDConstructor(idp.MetadataURL)() + } + return idp.EntityIDConstructor() +} + // IdpAuthnRequest is used by IdentityProvider to handle a single authentication request. type IdpAuthnRequest struct { IDP *IdentityProvider @@ -1019,7 +1036,7 @@ func (req *IdpAuthnRequest) MakeResponse() error { Version: "2.0", Issuer: &Issuer{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", - Value: req.IDP.MetadataURL.String(), + Value: req.IDP.getEntityID(), }, Status: Status{ StatusCode: StatusCode{ diff --git a/identity_provider_test.go b/identity_provider_test.go index 9d06a4bb..92729112 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -38,11 +38,12 @@ type IdentityProviderTest struct { SPCertificate *x509.Certificate SP ServiceProvider - Key crypto.PrivateKey - Signer crypto.Signer - Certificate *x509.Certificate - SessionProvider SessionProvider - IDP IdentityProvider + Key crypto.PrivateKey + Signer crypto.Signer + Certificate *x509.Certificate + SessionProvider SessionProvider + IDP IdentityProvider + ExpectedFilename string } func mustParseURL(s string) url.URL { @@ -98,6 +99,24 @@ var applySigner = idpTestOpts{ }, } +// applyEntityIDConstructor will set the entity ID constructor for the identity provider. +func applyEntityIDConstructor(c EntityIDConstructor) idpTestOpts { + return idpTestOpts{ + apply: func(_ *testing.T, test *IdentityProviderTest) { + test.IDP.EntityIDConstructor = c + }, + } +} + +// applyExpectedFilename will set the expected filename for the identity provider. +func applyExpectedFilename(filename string) idpTestOpts { + return idpTestOpts{ + apply: func(_ *testing.T, test *IdentityProviderTest) { + test.ExpectedFilename = filename + }, + } +} + func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProviderTest { test := IdentityProviderTest{} TimeNow = func() time.Time { @@ -139,6 +158,7 @@ func NewIdentityProviderTest(t *testing.T, opts ...idpTestOpts) *IdentityProvide }, }, } + test.ExpectedFilename = "TestIDPMakeResponse_response.xml" // apply the test options for _, opt := range opts { @@ -772,7 +792,7 @@ func testMakeResponse(t *testing.T, test *IdentityProviderTest) { doc.Indent(2) responseStr, err := doc.WriteToString() assert.Check(t, err) - golden.Assert(t, responseStr, "TestIDPMakeResponse_response.xml") + golden.Assert(t, responseStr, test.ExpectedFilename) } func TestIDPWriteResponse(t *testing.T) { @@ -1130,3 +1150,18 @@ func TestIDPHTTPCanHandleSSORequest(t *testing.T) { assert.Check(t, is.Equal(http.StatusBadRequest, w.Code)) } } + +func TestIdentityProviderCustomEntityID(t *testing.T) { + customEntityID := "https://idp.example.com/entity-id" + test := NewIdentityProviderTest( + t, + applyKey, + applyEntityIDConstructor(func() string { + return customEntityID + }), + applyExpectedFilename("TestIDPMakeResponse_response_with_custom_entity_id.xml"), + ) + + assert.Equal(t, customEntityID, test.IDP.Metadata().EntityID) + testMakeResponse(t, test) +} diff --git a/testdata/TestIDPMakeResponse_response_with_custom_entity_id.xml b/testdata/TestIDPMakeResponse_response_with_custom_entity_id.xml new file mode 100644 index 00000000..853ca0a1 --- /dev/null +++ b/testdata/TestIDPMakeResponse_response_with_custom_entity_id.xml @@ -0,0 +1,27 @@ + + https://idp.example.com/entity-id + + + + + + + + + + + 5bBiRThV9gjcTNlKa+y00Gnzkh8= + + + A9fzgSO00HntRcx32qCEVHoTR8YiisGk6tkeAbhRKzXoIOw3UE4nhoBIYPTYj5G+mMjnB/eEw84kuUSZ9mLV+EIAMQuR6ctJyO6xdxy65l+iC0IBSk65wqCb6C4IRB5OaxN/QC0yTJ8Ps2+s1WRJSLLcmQU6Xatpe25vzk+hQ+4= + + + MIIB7zCCAVgCCQDFzbKIp7b3MTANBgkqhkiG9w0BAQUFADA8MQswCQYDVQQGEwJVUzELMAkGA1UECAwCR0ExDDAKBgNVBAoMA2ZvbzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTEzMTAwMjAwMDg1MVoXDTE0MTAwMjAwMDg1MVowPDELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkdBMQwwCgYDVQQKDANmb28xEjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1PMHYmhZj308kWLhZVT4vOulqx/9ibm5B86fPWwUKKQ2i12MYtz07tzukPymisTDhQaqyJ8Kqb/6JjhmeMnEOdTvSPmHO8m1ZVveJU6NoKRn/mP/BD7FW52WhbrUXLSeHVSKfWkNk6S4hk9MV9TswTvyRIKvRsw0X/gfnqkroJcCAwEAATANBgkqhkiG9w0BAQUFAAOBgQCMMlIO+GNcGekevKgkakpMdAqJfs24maGb90DvTLbRZRD7Xvn1MnVBBS9hzlXiFLYOInXACMW5gcoRFfeTQLSouMM8o57h0uKjfTmuoWHLQLi6hnF+cvCsEFiJZ4AbF+DgmO6TarJ8O05t8zvnOwJlNCASPZRH/JmF8tX0hoHuAQ== + + + + + + + +