Skip to content

Commit

Permalink
fix: try without start at
Browse files Browse the repository at this point in the history
  • Loading branch information
IanKrieger committed Aug 11, 2023
1 parent beb1e26 commit d6e6c05
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 257 deletions.
68 changes: 54 additions & 14 deletions kafka/signed_token_redeem_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"bytes"
"errors"
"fmt"
"github.com/brave-intl/challenge-bypass-server/utils"
"strings"
"time"

crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi"
avroSchema "github.com/brave-intl/challenge-bypass-server/avro/generated"
"github.com/brave-intl/challenge-bypass-server/btd"
cbpServer "github.com/brave-intl/challenge-bypass-server/server"
"github.com/brave-intl/challenge-bypass-server/utils"
"github.com/rs/zerolog"
"github.com/segmentio/kafka-go"
)
Expand Down Expand Up @@ -90,17 +91,41 @@ func SignedTokenRedeemHandler(
}

// Create a lookup for issuers & signing keys based on public key
signedTokens, err := utils.MarshalIssuersAndSigningKeys(issuers)
if err != nil {
handlePermanentRedemptionError(
fmt.Sprintf("request %s: %e", tokenRedeemRequestSet.Request_id, err),
err,
msg,
producer,
tokenRedeemRequestSet.Request_id,
int32(avroSchema.RedeemResultStatusError),
log,
)
signedTokens := make(map[string]SignedIssuerToken)
for _, issuer := range issuers {
if !issuer.ExpiresAt.IsZero() && issuer.ExpiresAt.Before(time.Now()) {
continue
}

for _, issuerKey := range issuer.Keys {
// Don't use keys outside their start/end dates
if issuerTimeIsNotValid(issuerKey.StartAt, issuerKey.EndAt) {
continue
}

signingKey := issuerKey.SigningKey
issuerPublicKey := signingKey.PublicKey()
marshaledPublicKey, mErr := issuerPublicKey.MarshalText()
// Unmarshalling failure is a data issue and is probably permanent.
if mErr != nil {
message := fmt.Sprintf("request %s: could not unmarshal issuer public key into text", tokenRedeemRequestSet.Request_id)
handlePermanentRedemptionError(
message,
err,
msg,
producer,
tokenRedeemRequestSet.Request_id,
int32(avroSchema.RedeemResultStatusError),
log,
)
return nil
}

signedTokens[string(marshaledPublicKey)] = SignedIssuerToken{
issuer: issuer,
signingKey: signingKey,
}
}
}

// Iterate over requests (only one at this point but the schema can support more
Expand Down Expand Up @@ -178,12 +203,12 @@ func SignedTokenRedeemHandler(
Str("publicKey", request.Public_key).
Msg("attempting token redemption verification")

issuer := signedToken.Issuer
issuer := signedToken.issuer
if err := btd.VerifyTokenRedemption(
&tokenPreimage,
&verificationSignature,
request.Binding,
[]*crypto.SigningKey{signedToken.SigningKey},
[]*crypto.SigningKey{signedToken.signingKey},
); err == nil {
verified = true
verifiedIssuer = &issuer
Expand Down Expand Up @@ -351,6 +376,21 @@ func SignedTokenRedeemHandler(
return nil
}

func issuerTimeIsNotValid(start *time.Time, end *time.Time) bool {
if start != nil && end != nil {
now := time.Now()

startIsNotZeroAndAfterNow := !start.IsZero() && start.After(now)
endIsNotZeroAndBeforeNow := !end.IsZero() && end.Before(now)

return startIsNotZeroAndAfterNow || endIsNotZeroAndBeforeNow
}

// Both times being nil is valid
bothTimesAreNil := start == nil && end == nil
return !bothTimesAreNil
}

// avroRedeemErrorResultFromError returns a ProcessingResult that is constructed from the
// provided values.
func avroRedeemErrorResultFromError(
Expand Down
184 changes: 56 additions & 128 deletions server/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,63 +249,25 @@ func incrementCounter(c prometheus.Counter) {
func (c *Server) fetchIssuer(issuerID string) (*Issuer, error) {
defer incrementCounter(fetchIssuerCounter)

var (
err error
temporary = false
)

if cached := retrieveFromCache(c.caches, "issuer", issuerID); cached != nil {
if issuer, ok := cached.(*Issuer); ok {
return issuer, nil
}
}

fetchedIssuer := issuer{}
err = c.db.Select(&fetchedIssuer, `
var fetchedIssuer issuer
err := c.db.Select(&fetchedIssuer, `
SELECT * FROM v3_issuers
WHERE issuer_id=$1
`, issuerID)

if err != nil {
if !isPostgresNotFoundError(err) {
temporary = true
}
return nil, utils.ProcessingErrorFromError(errIssuerNotFound, temporary)
}

convertedIssuer := c.convertDBIssuer(fetchedIssuer)
// get the signing keys
if convertedIssuer.Keys == nil {
convertedIssuer.Keys = []IssuerKeys{}
return nil, utils.ProcessingErrorFromError(errIssuerNotFound, !isPostgresNotFoundError(err))
}

var fetchIssuerKeys = []issuerKeys{}
err = c.db.Select(
&fetchIssuerKeys,
`SELECT *
FROM v3_issuer_keys where issuer_id=$1 and
(
(select version from v3_issuers where issuer_id=$1)<=2
or end_at > now()
)
ORDER BY end_at ASC NULLS FIRST, start_at ASC`,
convertedIssuer.ID,
)
convertedIssuer, err := c.convertIssuerAddKeys(fetchedIssuer)
if err != nil {
if !isPostgresNotFoundError(err) {
c.Logger.Error("Postgres encountered temporary error")
temporary = true
}
return nil, utils.ProcessingErrorFromError(err, temporary)
}

for _, v := range fetchIssuerKeys {
k, err := c.convertDBIssuerKeys(v)
if err != nil {
c.Logger.Error("Failed to convert issuer keys from DB")
return nil, utils.ProcessingErrorFromError(err, temporary)
}
convertedIssuer.Keys = append(convertedIssuer.Keys, *k)
return nil, err
}

if c.caches != nil {
Expand All @@ -327,7 +289,6 @@ func (c *Server) fetchIssuersByCohort(issuerType string, issuerCohort int16) ([]
}
}

temporary := false
var fetchedIssuers []issuer
err := c.db.Select(
&fetchedIssuers,
Expand All @@ -337,17 +298,14 @@ func (c *Server) fetchIssuersByCohort(issuerType string, issuerCohort int16) ([]
ORDER BY i.expires_at DESC NULLS FIRST, i.created_at DESC`, issuerType, issuerCohort)
if err != nil {
c.Logger.Error("Failed to extract issuers from DB")
if isPostgresNotFoundError(err) {
temporary = true
}
return nil, utils.ProcessingErrorFromError(err, temporary)
return nil, utils.ProcessingErrorFromError(err, isPostgresNotFoundError(err))
}

if len(fetchedIssuers) < 1 {
return nil, utils.ProcessingErrorFromError(errIssuerCohortNotFound, temporary)
return nil, utils.ProcessingErrorFromError(errIssuerCohortNotFound, false)
}

issuersWithKey, err := c.fetchIssuerKeys(fetchedIssuers, &temporary)
issuersWithKey, err := c.fetchIssuerKeys(fetchedIssuers)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -376,34 +334,13 @@ func (c *Server) fetchIssuerByType(ctx context.Context, issuerType string) (*Iss
return nil, err
}

convertedIssuer := c.convertDBIssuer(issuerV3)

if convertedIssuer.Keys == nil {
convertedIssuer.Keys = []IssuerKeys{}
}

var fetchIssuerKeys []issuerKeys
err = c.db.SelectContext(ctx, &fetchIssuerKeys, `SELECT * FROM v3_issuer_keys where issuer_id=$1 and
(
(select version from v3_issuers where issuer_id=$1)<=2
or end_at > now()
)
ORDER BY end_at ASC NULLS FIRST, start_at ASC`, issuerV3.ID)
convertedIssuer, err := c.convertIssuerAddKeys(issuerV3)
if err != nil {
return nil, err
}

for _, v := range fetchIssuerKeys {
k, err := c.convertDBIssuerKeys(v)
if err != nil {
c.Logger.Error("Failed to convert issuer keys from DB")
return nil, err
}
convertedIssuer.Keys = append(convertedIssuer.Keys, *k)
}

if c.caches != nil {
c.caches["issuer"].SetDefault(issuerType, &issuerV3)
c.caches["issuer"].SetDefault(issuerType, convertedIssuer)
}

return convertedIssuer, nil
Expand All @@ -422,7 +359,6 @@ func (c *Server) fetchIssuers(issuerType string) ([]Issuer, error) {
}
}

temporary := false
var err error
var fetchedIssuers []issuer

Expand All @@ -441,17 +377,17 @@ func (c *Server) fetchIssuers(issuerType string) ([]Issuer, error) {

if err != nil {
c.Logger.Error("Failed to extract issuers from DB")
if !isPostgresNotFoundError(err) {
temporary = true
}
return nil, utils.ProcessingErrorFromError(err, temporary)
return nil, utils.ProcessingErrorFromError(err, !isPostgresNotFoundError(err))
}

if len(fetchedIssuers) < 1 {
return nil, utils.ProcessingErrorFromError(errIssuerNotFound, temporary)
return nil, utils.ProcessingErrorFromError(errIssuerNotFound, false)
}

issuersWithKey, err := c.fetchIssuerKeys(fetchedIssuers, &temporary)
issuersWithKey, err := c.fetchIssuerKeys(fetchedIssuers)
if err != nil {
return nil, err
}

if c.caches != nil {
c.caches["issuers"].SetDefault(issuerType, issuersWithKey)
Expand All @@ -460,51 +396,53 @@ func (c *Server) fetchIssuers(issuerType string) ([]Issuer, error) {
return issuersWithKey, nil
}

func (c *Server) fetchIssuerKeys(fetchedIssuers []issuer, temp *bool) ([]Issuer, error) {
func (c *Server) fetchIssuerKeys(fetchedIssuers []issuer) ([]Issuer, error) {
var issuers []Issuer
for _, fetchedIssuer := range fetchedIssuers {
convertedIssuer := c.convertDBIssuer(fetchedIssuer)
// get the keys for the Issuer
if convertedIssuer.Keys == nil {
convertedIssuer.Keys = []IssuerKeys{}
convertedIssuer, err := c.convertIssuerAddKeys(fetchedIssuer)
if err != nil {
return nil, err
}

lteVersionTwo := "false"
if fetchedIssuer.Version <= 2 {
lteVersionTwo = "true"
}
issuers = append(issuers, *convertedIssuer)
}

var fetchIssuerKeys []issuerKeys
err := c.db.Select(
&fetchIssuerKeys,
`SELECT *
FROM v3_issuer_keys
WHERE issuer_id=$1 AND ($2 OR end_at > now())
ORDER BY end_at ASC NULLS FIRST, start_at ASC`,
convertedIssuer.ID, lteVersionTwo,
)
if err != nil {
if !isPostgresNotFoundError(err) {
c.Logger.Error("Issuer key was not found in DB")
isTemp := true
temp = &isTemp
}
return nil, utils.ProcessingErrorFromError(err, *temp)
}
return issuers, nil
}

for _, v := range fetchIssuerKeys {
k, err := c.convertDBIssuerKeys(v)
if err != nil {
c.Logger.Error("Failed to convert issuer keys from DB")
return nil, utils.ProcessingErrorFromError(err, *temp)
}
convertedIssuer.Keys = append(convertedIssuer.Keys, *k)
func (c *Server) convertIssuerAddKeys(fetchedIssuer issuer) (*Issuer, error) {
convertedIssuer := parseIssuer(fetchedIssuer)
// get the keys for the Issuer
if convertedIssuer.Keys == nil {
convertedIssuer.Keys = []IssuerKeys{}
}

// ALWAYS RETURNS THE MOST RECENT KEY LAST
var keys []issuerKeys
lteVersionTwo := fetchedIssuer.Version <= 2
err := c.db.Select(
&keys,
`SELECT * FROM v3_issuer_keys WHERE issuer_id = $1 AND $2 AND (end_at > now() or end_at is null) ORDER BY created_at ASC`,
convertedIssuer.ID, lteVersionTwo,
)
if err != nil {
isNotPostgresNotFoundErr := !isPostgresNotFoundError(err)
if isNotPostgresNotFoundErr {
c.Logger.Error("Issuer key was not found in DB")
}
return nil, utils.ProcessingErrorFromError(err, isNotPostgresNotFoundErr)
}

issuers = append(issuers, *convertedIssuer)
for _, v := range keys {
k, cErr := c.convertDBIssuerKeys(v)
if cErr != nil {
c.Logger.Error("Failed to convert issuer keys from DB")
return nil, utils.ProcessingErrorFromError(cErr, false)
}
convertedIssuer.Keys = append(convertedIssuer.Keys, *k)
}

return issuers, nil
return &convertedIssuer, nil
}

// RotateIssuers is the function that rotates
Expand Down Expand Up @@ -539,16 +477,13 @@ func (c *Server) rotateIssuers() error {

for _, v := range fetchedIssuers {
// converted
issuer := c.convertDBIssuer(v)
iss := parseIssuer(v)
// populate keys in db
if err := txPopulateIssuerKeys(c.Logger, tx, *issuer); err != nil {
if err := txPopulateIssuerKeys(c.Logger, tx, iss); err != nil {
return fmt.Errorf("failed to populate v3 issuer keys: %w", err)
}

if _, err = tx.Exec(
`UPDATE v3_issuers SET last_rotated_at = now() where issuer_id = $1`,
issuer.ID,
); err != nil {
if _, err = tx.Exec(`UPDATE v3_issuers SET last_rotated_at = now() where issuer_id = $1`, iss.ID); err != nil {
return err
}
}
Expand Down Expand Up @@ -1000,13 +935,6 @@ func (c *Server) convertDBIssuerKeys(issuerKeyToConvert issuerKeys) (*IssuerKeys
return &parsedIssuerKeys, nil
}

// convertDBIssuer takes an issuer from the database and returns a reference to that issuer
// Represented as an Issuer struct.
func (c *Server) convertDBIssuer(issuerToConvert issuer) *Issuer {
parsedIssuer := parseIssuer(issuerToConvert)
return &parsedIssuer
}

func parseIssuerKeys(issuerKeysToParse issuerKeys) (IssuerKeys, error) {
parsedIssuerKey := IssuerKeys{
ID: issuerKeysToParse.ID,
Expand Down
Loading

0 comments on commit d6e6c05

Please sign in to comment.