diff --git a/jwks.go b/jwks.go index 2d69ff1..f6d8740 100644 --- a/jwks.go +++ b/jwks.go @@ -85,6 +85,19 @@ func (j *JWKs) EndBackground() { } } +// KIDs returns the key IDs (`kid`) for all keys in the JWKs. +func (j *JWKs) KIDs() (kids []string) { + j.mux.RLock() + defer j.mux.RUnlock() + kids = make([]string, len(j.keys)) + index := 0 + for kid := range j.keys { + kids[index] = kid + index++ + } + return kids +} + // getKey gets the jsonKey from the given KID from the JWKs. It may refresh the JWKs if configured to. func (j *JWKs) getKey(kid string) (jsonKey *jsonKey, err error) { diff --git a/jwks_test.go b/jwks_test.go index 6f1d285..5954549 100644 --- a/jwks_test.go +++ b/jwks_test.go @@ -198,6 +198,56 @@ func TestJWKs(t *testing.T) { } } +// TestKIDs confirms the JWKs.KIDs returns the key IDs (`kid`) stored in the JWKs. +func TestJWKs_KIDs(t *testing.T) { + + // Create the JWKs from JSON. + jwks, err := keyfunc.NewJSON([]byte(jwksJSON)) + if err != nil { + t.Errorf("Failed to create a JWKs from JSON.\nError: %s", err.Error()) + t.FailNow() + } + + // The expected key IDs. + expectedKIDs := []string{ + "zXew0UJ1h6Q4CCcd_9wxMzvcp5cEBifH0KWrCz2Kyxc", + "ebJxnm9B3QDBljB5XJWEu72qx6BawDaMAhwz4aKPkQ0", + "TVAAet63O3xy_KK6_bxVIu7Ra3_z1wlB543Fbwi5VaU", + "arlUxX4hh56rNO-XdIPhDT7bqBMqcBwNQuP_TnZJNGs", + "tW6ae7TomE6_2jooM-sf9N_6lWg7HNtaQXrDsElBzM4", + "Lx1FmayP2YBtxaqS1SKJRJGiXRKnw2ov5WmYIMG-BLE", + "gnmAfvmlsi3kKH3VlM1AJ85P2hekQ8ON_XvJqs3xPD8", + "CGt0ZWS4Lc5faiKSdi0tU0fjCAdvGROQRGU9iR7tV0A", + "C65q0EKQyhpd1m4fr7SKO2He_nAxgCtAdws64d2BLt8", + } + + // Get all key IDs in the JWKs. + actual := jwks.KIDs() + + // Confirm the length is the same. + actualLen := len(actual) + expectedLen := len(expectedKIDs) + if actualLen != expectedLen { + t.Errorf("The number of key IDs was not as expected.\n Expected length: %d\n Actual length: %d\n Actual key IDs: %v", expectedLen, actualLen, actual) + t.FailNow() + } + + // Confirm all expected keys are present. + var found bool + for _, expectedKID := range expectedKIDs { + found = false + for _, kid := range actual { + if kid == expectedKID { + found = true + break + } + } + if !found { + t.Errorf("Failed to find expected key ID in the slice of key IDs in the JWKs.\n Missing: %s", expectedKID) + } + } +} + // TestRateLimit performs a test to confirm the rate limiter works as expected. func TestRateLimit(t *testing.T) {