diff --git a/kafka/signed_token_redeem_handler.go b/kafka/signed_token_redeem_handler.go index 0ca744d3..94c2cfc2 100644 --- a/kafka/signed_token_redeem_handler.go +++ b/kafka/signed_token_redeem_handler.go @@ -4,13 +4,14 @@ import ( "bytes" "errors" "fmt" - "github.com/brave-intl/challenge-bypass-server/utils" "strings" + "time" crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" avroSchema "github.com/brave-intl/challenge-bypass-server/avro/generated" "github.com/brave-intl/challenge-bypass-server/btd" cbpServer "github.com/brave-intl/challenge-bypass-server/server" + "github.com/brave-intl/challenge-bypass-server/utils" "github.com/rs/zerolog" "github.com/segmentio/kafka-go" ) @@ -90,17 +91,41 @@ func SignedTokenRedeemHandler( } // Create a lookup for issuers & signing keys based on public key - signedTokens, err := utils.MarshalIssuersAndSigningKeys(issuers) - if err != nil { - handlePermanentRedemptionError( - fmt.Sprintf("request %s: %e", tokenRedeemRequestSet.Request_id, err), - err, - msg, - producer, - tokenRedeemRequestSet.Request_id, - int32(avroSchema.RedeemResultStatusError), - log, - ) + signedTokens := make(map[string]SignedIssuerToken) + for _, issuer := range issuers { + if !issuer.ExpiresAt.IsZero() && issuer.ExpiresAt.Before(time.Now()) { + continue + } + + for _, issuerKey := range issuer.Keys { + // Don't use keys outside their start/end dates + if issuerTimeIsNotValid(issuerKey.StartAt, issuerKey.EndAt) { + continue + } + + signingKey := issuerKey.SigningKey + issuerPublicKey := signingKey.PublicKey() + marshaledPublicKey, mErr := issuerPublicKey.MarshalText() + // Unmarshalling failure is a data issue and is probably permanent. + if mErr != nil { + message := fmt.Sprintf("request %s: could not unmarshal issuer public key into text", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + err, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil + } + + signedTokens[string(marshaledPublicKey)] = SignedIssuerToken{ + issuer: issuer, + signingKey: signingKey, + } + } } // Iterate over requests (only one at this point but the schema can support more @@ -178,12 +203,12 @@ func SignedTokenRedeemHandler( Str("publicKey", request.Public_key). Msg("attempting token redemption verification") - issuer := signedToken.Issuer + issuer := signedToken.issuer if err := btd.VerifyTokenRedemption( &tokenPreimage, &verificationSignature, request.Binding, - []*crypto.SigningKey{signedToken.SigningKey}, + []*crypto.SigningKey{signedToken.signingKey}, ); err == nil { verified = true verifiedIssuer = &issuer @@ -351,6 +376,21 @@ func SignedTokenRedeemHandler( return nil } +func issuerTimeIsNotValid(start *time.Time, end *time.Time) bool { + if start != nil && end != nil { + now := time.Now() + + startIsNotZeroAndAfterNow := !start.IsZero() && start.After(now) + endIsNotZeroAndBeforeNow := !end.IsZero() && end.Before(now) + + return startIsNotZeroAndAfterNow || endIsNotZeroAndBeforeNow + } + + // Both times being nil is valid + bothTimesAreNil := start == nil && end == nil + return !bothTimesAreNil +} + // avroRedeemErrorResultFromError returns a ProcessingResult that is constructed from the // provided values. func avroRedeemErrorResultFromError( diff --git a/server/db.go b/server/db.go index f4efcf38..0f7cdf38 100644 --- a/server/db.go +++ b/server/db.go @@ -249,63 +249,25 @@ func incrementCounter(c prometheus.Counter) { func (c *Server) fetchIssuer(issuerID string) (*Issuer, error) { defer incrementCounter(fetchIssuerCounter) - var ( - err error - temporary = false - ) - if cached := retrieveFromCache(c.caches, "issuer", issuerID); cached != nil { if issuer, ok := cached.(*Issuer); ok { return issuer, nil } } - fetchedIssuer := issuer{} - err = c.db.Select(&fetchedIssuer, ` + var fetchedIssuer issuer + err := c.db.Select(&fetchedIssuer, ` SELECT * FROM v3_issuers WHERE issuer_id=$1 `, issuerID) if err != nil { - if !isPostgresNotFoundError(err) { - temporary = true - } - return nil, utils.ProcessingErrorFromError(errIssuerNotFound, temporary) - } - - convertedIssuer := c.convertDBIssuer(fetchedIssuer) - // get the signing keys - if convertedIssuer.Keys == nil { - convertedIssuer.Keys = []IssuerKeys{} + return nil, utils.ProcessingErrorFromError(errIssuerNotFound, !isPostgresNotFoundError(err)) } - var fetchIssuerKeys = []issuerKeys{} - err = c.db.Select( - &fetchIssuerKeys, - `SELECT * - FROM v3_issuer_keys where issuer_id=$1 and - ( - (select version from v3_issuers where issuer_id=$1)<=2 - or end_at > now() - ) - ORDER BY end_at ASC NULLS FIRST, start_at ASC`, - convertedIssuer.ID, - ) + convertedIssuer, err := c.convertIssuerAddKeys(fetchedIssuer) if err != nil { - if !isPostgresNotFoundError(err) { - c.Logger.Error("Postgres encountered temporary error") - temporary = true - } - return nil, utils.ProcessingErrorFromError(err, temporary) - } - - for _, v := range fetchIssuerKeys { - k, err := c.convertDBIssuerKeys(v) - if err != nil { - c.Logger.Error("Failed to convert issuer keys from DB") - return nil, utils.ProcessingErrorFromError(err, temporary) - } - convertedIssuer.Keys = append(convertedIssuer.Keys, *k) + return nil, err } if c.caches != nil { @@ -327,7 +289,6 @@ func (c *Server) fetchIssuersByCohort(issuerType string, issuerCohort int16) ([] } } - temporary := false var fetchedIssuers []issuer err := c.db.Select( &fetchedIssuers, @@ -337,17 +298,14 @@ func (c *Server) fetchIssuersByCohort(issuerType string, issuerCohort int16) ([] ORDER BY i.expires_at DESC NULLS FIRST, i.created_at DESC`, issuerType, issuerCohort) if err != nil { c.Logger.Error("Failed to extract issuers from DB") - if isPostgresNotFoundError(err) { - temporary = true - } - return nil, utils.ProcessingErrorFromError(err, temporary) + return nil, utils.ProcessingErrorFromError(err, isPostgresNotFoundError(err)) } if len(fetchedIssuers) < 1 { - return nil, utils.ProcessingErrorFromError(errIssuerCohortNotFound, temporary) + return nil, utils.ProcessingErrorFromError(errIssuerCohortNotFound, false) } - issuersWithKey, err := c.fetchIssuerKeys(fetchedIssuers, &temporary) + issuersWithKey, err := c.fetchIssuerKeys(fetchedIssuers) if err != nil { return nil, err } @@ -376,34 +334,13 @@ func (c *Server) fetchIssuerByType(ctx context.Context, issuerType string) (*Iss return nil, err } - convertedIssuer := c.convertDBIssuer(issuerV3) - - if convertedIssuer.Keys == nil { - convertedIssuer.Keys = []IssuerKeys{} - } - - var fetchIssuerKeys []issuerKeys - err = c.db.SelectContext(ctx, &fetchIssuerKeys, `SELECT * FROM v3_issuer_keys where issuer_id=$1 and - ( - (select version from v3_issuers where issuer_id=$1)<=2 - or end_at > now() - ) - ORDER BY end_at ASC NULLS FIRST, start_at ASC`, issuerV3.ID) + convertedIssuer, err := c.convertIssuerAddKeys(issuerV3) if err != nil { return nil, err } - for _, v := range fetchIssuerKeys { - k, err := c.convertDBIssuerKeys(v) - if err != nil { - c.Logger.Error("Failed to convert issuer keys from DB") - return nil, err - } - convertedIssuer.Keys = append(convertedIssuer.Keys, *k) - } - if c.caches != nil { - c.caches["issuer"].SetDefault(issuerType, &issuerV3) + c.caches["issuer"].SetDefault(issuerType, convertedIssuer) } return convertedIssuer, nil @@ -422,7 +359,6 @@ func (c *Server) fetchIssuers(issuerType string) ([]Issuer, error) { } } - temporary := false var err error var fetchedIssuers []issuer @@ -441,17 +377,17 @@ func (c *Server) fetchIssuers(issuerType string) ([]Issuer, error) { if err != nil { c.Logger.Error("Failed to extract issuers from DB") - if !isPostgresNotFoundError(err) { - temporary = true - } - return nil, utils.ProcessingErrorFromError(err, temporary) + return nil, utils.ProcessingErrorFromError(err, !isPostgresNotFoundError(err)) } if len(fetchedIssuers) < 1 { - return nil, utils.ProcessingErrorFromError(errIssuerNotFound, temporary) + return nil, utils.ProcessingErrorFromError(errIssuerNotFound, false) } - issuersWithKey, err := c.fetchIssuerKeys(fetchedIssuers, &temporary) + issuersWithKey, err := c.fetchIssuerKeys(fetchedIssuers) + if err != nil { + return nil, err + } if c.caches != nil { c.caches["issuers"].SetDefault(issuerType, issuersWithKey) @@ -460,51 +396,53 @@ func (c *Server) fetchIssuers(issuerType string) ([]Issuer, error) { return issuersWithKey, nil } -func (c *Server) fetchIssuerKeys(fetchedIssuers []issuer, temp *bool) ([]Issuer, error) { +func (c *Server) fetchIssuerKeys(fetchedIssuers []issuer) ([]Issuer, error) { var issuers []Issuer for _, fetchedIssuer := range fetchedIssuers { - convertedIssuer := c.convertDBIssuer(fetchedIssuer) - // get the keys for the Issuer - if convertedIssuer.Keys == nil { - convertedIssuer.Keys = []IssuerKeys{} + convertedIssuer, err := c.convertIssuerAddKeys(fetchedIssuer) + if err != nil { + return nil, err } - lteVersionTwo := "false" - if fetchedIssuer.Version <= 2 { - lteVersionTwo = "true" - } + issuers = append(issuers, *convertedIssuer) + } - var fetchIssuerKeys []issuerKeys - err := c.db.Select( - &fetchIssuerKeys, - `SELECT * - FROM v3_issuer_keys - WHERE issuer_id=$1 AND ($2 OR end_at > now()) - ORDER BY end_at ASC NULLS FIRST, start_at ASC`, - convertedIssuer.ID, lteVersionTwo, - ) - if err != nil { - if !isPostgresNotFoundError(err) { - c.Logger.Error("Issuer key was not found in DB") - isTemp := true - temp = &isTemp - } - return nil, utils.ProcessingErrorFromError(err, *temp) - } + return issuers, nil +} - for _, v := range fetchIssuerKeys { - k, err := c.convertDBIssuerKeys(v) - if err != nil { - c.Logger.Error("Failed to convert issuer keys from DB") - return nil, utils.ProcessingErrorFromError(err, *temp) - } - convertedIssuer.Keys = append(convertedIssuer.Keys, *k) +func (c *Server) convertIssuerAddKeys(fetchedIssuer issuer) (*Issuer, error) { + convertedIssuer := parseIssuer(fetchedIssuer) + // get the keys for the Issuer + if convertedIssuer.Keys == nil { + convertedIssuer.Keys = []IssuerKeys{} + } + + // ALWAYS RETURNS THE MOST RECENT KEY LAST + var keys []issuerKeys + lteVersionTwo := fetchedIssuer.Version <= 2 + err := c.db.Select( + &keys, + `SELECT * FROM v3_issuer_keys WHERE issuer_id = $1 AND $2 AND (end_at > now() or end_at is null) ORDER BY created_at ASC`, + convertedIssuer.ID, lteVersionTwo, + ) + if err != nil { + isNotPostgresNotFoundErr := !isPostgresNotFoundError(err) + if isNotPostgresNotFoundErr { + c.Logger.Error("Issuer key was not found in DB") } + return nil, utils.ProcessingErrorFromError(err, isNotPostgresNotFoundErr) + } - issuers = append(issuers, *convertedIssuer) + for _, v := range keys { + k, cErr := c.convertDBIssuerKeys(v) + if cErr != nil { + c.Logger.Error("Failed to convert issuer keys from DB") + return nil, utils.ProcessingErrorFromError(cErr, false) + } + convertedIssuer.Keys = append(convertedIssuer.Keys, *k) } - return issuers, nil + return &convertedIssuer, nil } // RotateIssuers is the function that rotates @@ -539,16 +477,13 @@ func (c *Server) rotateIssuers() error { for _, v := range fetchedIssuers { // converted - issuer := c.convertDBIssuer(v) + iss := parseIssuer(v) // populate keys in db - if err := txPopulateIssuerKeys(c.Logger, tx, *issuer); err != nil { + if err := txPopulateIssuerKeys(c.Logger, tx, iss); err != nil { return fmt.Errorf("failed to populate v3 issuer keys: %w", err) } - if _, err = tx.Exec( - `UPDATE v3_issuers SET last_rotated_at = now() where issuer_id = $1`, - issuer.ID, - ); err != nil { + if _, err = tx.Exec(`UPDATE v3_issuers SET last_rotated_at = now() where issuer_id = $1`, iss.ID); err != nil { return err } } @@ -1000,13 +935,6 @@ func (c *Server) convertDBIssuerKeys(issuerKeyToConvert issuerKeys) (*IssuerKeys return &parsedIssuerKeys, nil } -// convertDBIssuer takes an issuer from the database and returns a reference to that issuer -// Represented as an Issuer struct. -func (c *Server) convertDBIssuer(issuerToConvert issuer) *Issuer { - parsedIssuer := parseIssuer(issuerToConvert) - return &parsedIssuer -} - func parseIssuerKeys(issuerKeysToParse issuerKeys) (IssuerKeys, error) { parsedIssuerKey := IssuerKeys{ ID: issuerKeysToParse.ID, diff --git a/server/issuers.go b/server/issuers.go index b1bb13af..c3646d52 100644 --- a/server/issuers.go +++ b/server/issuers.go @@ -3,7 +3,6 @@ package server import ( "encoding/json" "errors" - "github.com/brave-intl/challenge-bypass-server/utils" "net/http" "os" "time" @@ -123,17 +122,8 @@ func (c *Server) issuerGetHandlerV1(w http.ResponseWriter, r *http.Request) *han if appErr != nil { return appErr } - expiresAt := "" - if !issuer.ExpiresAt.IsZero() { - expiresAt = issuer.ExpiresAt.Format(time.RFC3339) - } - - var publicKey *crypto.PublicKey - for _, k := range issuer.Keys { - publicKey = k.SigningKey.PublicKey() - } - err := json.NewEncoder(w).Encode(issuerResponse{issuer.ID.String(), issuer.IssuerType, publicKey, expiresAt, issuer.IssuerCohort}) + err := json.NewEncoder(w).Encode(makeIssuerResponse(issuer)) if err != nil { c.Logger.Error("Error encoding the issuer response") panic(err) @@ -159,18 +149,7 @@ func (c *Server) issuerHandlerV3(w http.ResponseWriter, r *http.Request) *handle return appErr } - expiresAt := "" - if !issuer.ExpiresAt.IsZero() { - expiresAt = issuer.ExpiresAt.Format(time.RFC3339) - } - - // get the signing public key - var publicKey *crypto.PublicKey - for _, k := range issuer.Keys { - publicKey = k.SigningKey.PublicKey() - } - - err := json.NewEncoder(w).Encode(issuerResponse{issuer.ID.String(), issuer.IssuerType, publicKey, expiresAt, issuer.IssuerCohort}) + err := json.NewEncoder(w).Encode(makeIssuerResponse(issuer)) if err != nil { c.Logger.Error("Error encoding the issuer response") panic(err) @@ -193,18 +172,9 @@ func (c *Server) issuerHandlerV2(w http.ResponseWriter, r *http.Request) *handle if appErr != nil { return appErr } - expiresAt := "" - if !issuer.ExpiresAt.IsZero() { - expiresAt = issuer.ExpiresAt.Format(time.RFC3339) - } // get the signing public key - var publicKey *crypto.PublicKey - for _, k := range issuer.Keys { - publicKey = k.SigningKey.PublicKey() - } - - err := json.NewEncoder(w).Encode(issuerResponse{issuer.ID.String(), issuer.IssuerType, publicKey, expiresAt, issuer.IssuerCohort}) + err := json.NewEncoder(w).Encode(makeIssuerResponse(issuer)) if err != nil { c.Logger.Error("Error encoding the issuer response") panic(err) @@ -226,26 +196,12 @@ func (c *Server) issuerGetAllHandler(w http.ResponseWriter, r *http.Request) *ha } } - marshalledIssuers, err := utils.MarshalIssuersAndSigningKeys(issuers) - if err != nil { - c.Logger.Error("error retrieving issuer values") - panic(err) - } - - var respIssuers []issuerResponse - for _, value := range marshalledIssuers { - iss := value.Issuer - issuerKeySigningKey := value.SigningKey - respIssuers = append(respIssuers, issuerResponse{ - iss.ID.String(), - iss.IssuerType, - issuerKeySigningKey.PublicKey(), - iss.ExpiresAt.Format(time.RFC3339), - iss.IssuerCohort, - }) + respIssuers := make([]issuerResponse, len(issuers)) + for idx, currIssuer := range issuers { + respIssuers[idx] = makeIssuerResponse(&currIssuer) } - err = json.NewEncoder(w).Encode(respIssuers) + err := json.NewEncoder(w).Encode(respIssuers) if err != nil { c.Logger.Error("Error encoding issuer") panic(err) @@ -419,6 +375,27 @@ func (c *Server) issuerCreateHandlerV1(w http.ResponseWriter, r *http.Request) * return nil } +func makeIssuerResponse(iss *Issuer) issuerResponse { + expiresAt := "" + if !iss.ExpiresAt.IsZero() { + expiresAt = iss.ExpiresAt.Format(time.RFC3339) + } + + // Last key in array is the valid one + var publicKey *crypto.PublicKey + for _, k := range iss.Keys { + publicKey = k.SigningKey.PublicKey() + } + + return issuerResponse{ + iss.ID.String(), + iss.IssuerType, + publicKey, + expiresAt, + iss.IssuerCohort, + } +} + func (c *Server) issuerRouterV1() chi.Router { r := chi.NewRouter() if os.Getenv("ENV") == "production" { diff --git a/server/server_test.go b/server/server_test.go index 510604b1..1d0a162d 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -270,24 +270,24 @@ func (suite *ServerTestSuite) TestRotateTimeAwareIssuer() { time.Sleep(2 * time.Second) myIssuer, err := suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) fmt.Println(err) - suite.Require().Equal(len(myIssuer.Keys), 1) // should be one left + suite.Require().Equal(1, len(myIssuer.Keys)) // should be one left // rotate issuers should pick up that there are some new intervals to make up buffer and populate err = suite.srv.rotateIssuersV3() suite.Require().NoError(err) myIssuer, _ = suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) - suite.Require().Equal(len(myIssuer.Keys), 3) // should be 3 now + suite.Require().Equal(3, len(myIssuer.Keys)) // should be 3 now // rotate issuers should pick up that there are some new intervals to make up buffer and populate err = suite.srv.rotateIssuersV3() suite.Require().NoError(err) - suite.Require().Equal(len(myIssuer.Keys), 3) // should be 3 now still + suite.Require().Equal(3, len(myIssuer.Keys)) // should be 3 now still // wait a few intervals after creation and check number of signing keys left time.Sleep(2 * time.Second) myIssuer, _ = suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) - suite.Require().Equal(len(myIssuer.Keys), 1) // should be one left + suite.Require().Equal(3, len(myIssuer.Keys)) // should be one left } func (suite *ServerTestSuite) TestCreateIssuerV3() { diff --git a/utils/issuers.go b/utils/issuers.go deleted file mode 100644 index 549a01c3..00000000 --- a/utils/issuers.go +++ /dev/null @@ -1,60 +0,0 @@ -package utils - -import ( - "fmt" - crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" - cbpServer "github.com/brave-intl/challenge-bypass-server/server" - "time" -) - -type SignedIssuerToken struct { - Issuer cbpServer.Issuer - SigningKey *crypto.SigningKey -} - -func MarshalIssuersAndSigningKeys(issuers []cbpServer.Issuer) (map[string]SignedIssuerToken, error) { - // Create a lookup for issuers & signing keys based on public key - signedTokens := make(map[string]SignedIssuerToken) - for _, issuer := range issuers { - if !issuer.ExpiresAt.IsZero() && issuer.ExpiresAt.Before(time.Now()) { - continue - } - - for _, issuerKey := range issuer.Keys { - // Don't use keys outside their start/end dates - if issuerTimeIsNotValid(issuerKey.StartAt, issuerKey.EndAt) { - continue - } - - signingKey := issuerKey.SigningKey - issuerPublicKey := signingKey.PublicKey() - marshaledPublicKey, mErr := issuerPublicKey.MarshalText() - // Unmarshalling failure is a data issue and is probably permanent. - if mErr != nil { - return nil, fmt.Errorf("could not unmarshal issuer public key into text: %e", mErr) - } - - signedTokens[string(marshaledPublicKey)] = SignedIssuerToken{ - Issuer: issuer, - SigningKey: signingKey, - } - } - } - - return signedTokens, nil -} - -func issuerTimeIsNotValid(start *time.Time, end *time.Time) bool { - if start != nil && end != nil { - now := time.Now() - - startIsNotZeroAndAfterNow := !start.IsZero() && start.After(now) - endIsNotZeroAndBeforeNow := !end.IsZero() && end.Before(now) - - return startIsNotZeroAndAfterNow || endIsNotZeroAndBeforeNow - } - - // Both times being nil is valid - bothTimesAreNil := start == nil && end == nil - return !bothTimesAreNil -}