Skip to content

Commit

Permalink
Merge pull request #87 from qmuntal/headers
Browse files Browse the repository at this point in the history
Validate generic headers
  • Loading branch information
yogeshbdeshpande authored Jul 11, 2022
2 parents ca272d4 + 99e6660 commit 33ff2e9
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 191 deletions.
6 changes: 3 additions & 3 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func ExampleSignMessage() {
// create a signature holder
sigHolder := cose.NewSignature()
sigHolder.Headers.Protected.SetAlgorithm(cose.AlgorithmES512)
sigHolder.Headers.Unprotected[cose.HeaderLabelKeyID] = 1
sigHolder.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1")

// create message to be signed
msgToSign := cose.NewSignMessage()
Expand Down Expand Up @@ -84,7 +84,7 @@ func ExampleSign1Message() {
msgToSign := cose.NewSign1Message()
msgToSign.Payload = []byte("hello world")
msgToSign.Headers.Protected.SetAlgorithm(cose.AlgorithmES512)
msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = 1
msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1")

// create a signer
privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
Expand Down Expand Up @@ -157,7 +157,7 @@ func ExampleSign1() {
cose.HeaderLabelAlgorithm: cose.AlgorithmES512,
},
Unprotected: cose.UnprotectedHeader{
cose.HeaderLabelKeyID: 1,
cose.HeaderLabelKeyID: []byte("1"),
},
}
sig, err := cose.Sign1(rand.Reader, signer, headers, []byte("hello world"), nil)
Expand Down
181 changes: 111 additions & 70 deletions headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,8 @@ func (h ProtectedHeader) MarshalCBOR() ([]byte, error) {
if len(h) == 0 {
encoded = []byte{}
} else {
err := validateHeaderLabel(h)
err := validateHeaderParameters(h, true)
if err != nil {
return nil, err
}
if err = h.ensureCritical(); err != nil {
return nil, err
}
if err = ensureHeaderIV(h); err != nil {
return nil, fmt.Errorf("protected header: %w", err)
}
encoded, err = encMode.Marshal(map[interface{}]interface{}(h))
Expand Down Expand Up @@ -85,10 +79,7 @@ func (h *ProtectedHeader) UnmarshalCBOR(data []byte) error {
return err
}
candidate := ProtectedHeader(header)
if err := candidate.ensureCritical(); err != nil {
return err
}
if err := ensureHeaderIV(candidate); err != nil {
if err := validateHeaderParameters(candidate, true); err != nil {
return fmt.Errorf("protected header: %w", err)
}

Expand Down Expand Up @@ -140,29 +131,28 @@ func (h ProtectedHeader) Critical() ([]interface{}, error) {
if !ok {
return nil, nil
}
criticalLabels, ok := value.([]interface{})
if !ok {
return nil, errors.New("invalid crit header")
}
// if present, the array MUST have at least one value in it.
if len(criticalLabels) == 0 {
return nil, errors.New("empty crit header")
err := ensureCritical(value, h)
if err != nil {
return nil, err
}
return criticalLabels, nil
return value.([]interface{}), nil
}

// ensureCritical ensures all critical headers are present in the protected bucket.
func (h ProtectedHeader) ensureCritical() error {
labels, err := h.Critical()
if err != nil {
return err
func ensureCritical(value interface{}, headers map[interface{}]interface{}) error {
labels, ok := value.([]interface{})
if !ok {
return errors.New("invalid crit header")
}
// if present, the array MUST have at least one value in it.
if len(labels) == 0 {
return errors.New("empty crit header")
}
for _, label := range labels {
_, ok := normalizeLabel(label)
if !ok {
return fmt.Errorf("critical header label: require int / tstr type, got '%T': %v", label, label)
if !canInt(label) && !canTstr(label) {
return fmt.Errorf("require int / tstr type, got '%T': %v", label, label)
}
if _, ok := h[label]; !ok {
if _, ok := headers[label]; !ok {
return fmt.Errorf("missing critical header: %v", label)
}
}
Expand All @@ -179,13 +169,7 @@ func (h UnprotectedHeader) MarshalCBOR() ([]byte, error) {
if len(h) == 0 {
return []byte{0xa0}, nil
}
if err := validateHeaderLabel(h); err != nil {
return nil, err
}
if err := ensureNoCritical(h); err != nil {
return nil, fmt.Errorf("unprotected header: %w", err)
}
if err := ensureHeaderIV(h); err != nil {
if err := validateHeaderParameters(h, false); err != nil {
return nil, fmt.Errorf("unprotected header: %w", err)
}
return encMode.Marshal(map[interface{}]interface{}(h))
Expand Down Expand Up @@ -214,10 +198,7 @@ func (h *UnprotectedHeader) UnmarshalCBOR(data []byte) error {
if err := decMode.Unmarshal(data, &header); err != nil {
return err
}
if err := ensureNoCritical(header); err != nil {
return fmt.Errorf("unprotected header: %w", err)
}
if err := ensureHeaderIV(header); err != nil {
if err := validateHeaderParameters(header, false); err != nil {
return fmt.Errorf("unprotected header: %w", err)
}
*h = header
Expand Down Expand Up @@ -397,48 +378,108 @@ func hasLabel(h map[interface{}]interface{}, label interface{}) bool {
return ok
}

// ensureHeaderIV ensures IV and Partial IV are not both present in the header.
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
func ensureHeaderIV(h map[interface{}]interface{}) error {
if hasLabel(h, HeaderLabelIV) && hasLabel(h, HeaderLabelPartialIV) {
return errors.New("IV and PartialIV parameters must not both be present")
}
return nil
}

// ensureNoCritical ensures crit parameter is not present in the header.
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
func ensureNoCritical(h map[interface{}]interface{}) error {
if hasLabel(h, HeaderLabelCritical) {
return errors.New("unexpected crit parameter found")
}
return nil
}

// validateHeaderLabel validates if all header labels are integers or strings.
//
// label = int / tstr
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4
func validateHeaderLabel(h map[interface{}]interface{}) error {
existing := make(map[interface{}]struct{})
for label := range h {
var ok bool
label, ok = normalizeLabel(label)
// validateHeaderParameters validates all headers conform to the spec.
func validateHeaderParameters(h map[interface{}]interface{}, protected bool) error {
existing := make(map[interface{}]struct{}, len(h))
for label, value := range h {
// Validate that all header labels are integers or strings.
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4
label, ok := normalizeLabel(label)
if !ok {
return errors.New("cbor: header label: require int / tstr type")
return errors.New("header label: require int / tstr type")
}

// Validate that there are no duplicated labels.
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3
if _, ok := existing[label]; ok {
return fmt.Errorf("cbor: header label: duplicated label: %v", label)
return fmt.Errorf("header label: duplicated label: %v", label)
} else {
existing[label] = struct{}{}
}

// Validate the generic parameters.
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
switch label {
case HeaderLabelAlgorithm:
_, is_alg := value.(Algorithm)
if !is_alg && !canInt(value) && !canTstr(value) {
return errors.New("header parameter: alg: require int / tstr type")
}
case HeaderLabelCritical:
if !protected {
return errors.New("header parameter: crit: not allowed")
}
if err := ensureCritical(value, h); err != nil {
return fmt.Errorf("header parameter: crit: %w", err)
}
case HeaderLabelContentType:
if !canTstr(value) && !canUint(value) {
return errors.New("header parameter: content type: require tstr / uint type")
}
case HeaderLabelKeyID:
if !canBstr(value) {
return errors.New("header parameter: kid: require bstr type")
}
case HeaderLabelIV:
if !canBstr(value) {
return errors.New("header parameter: IV: require bstr type")
}
if hasLabel(h, HeaderLabelPartialIV) {
return errors.New("header parameter: IV and PartialIV: parameters must not both be present")
}
case HeaderLabelPartialIV:
if !canBstr(value) {
return errors.New("header parameter: Partial IV: require bstr type")
}
if hasLabel(h, HeaderLabelIV) {
return errors.New("header parameter: IV and PartialIV: parameters must not both be present")
}
}
}
return nil
}

// canUint reports whether v can be used as a CBOR uint type.
func canUint(v interface{}) bool {
switch v := v.(type) {
case uint, uint8, uint16, uint32, uint64:
return true
case int:
return v >= 0
case int8:
return v >= 0
case int16:
return v >= 0
case int32:
return v >= 0
case int64:
return v >= 0
}
return false
}

// canInt reports whether v can be used as a CBOR int type.
func canInt(v interface{}) bool {
switch v.(type) {
case int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64:
return true
}
return false
}

// canTstr reports whether v can be used as a CBOR tstr type.
func canTstr(v interface{}) bool {
_, ok := v.(string)
return ok
}

// canBstr reports whether v can be used as a CBOR bstr type.
func canBstr(v interface{}) bool {
_, ok := v.([]byte)
return ok
}

// normalizeLabel tries to cast label into a int64 or a string.
// Returns (nil, false) if the label type is not valid.
func normalizeLabel(label interface{}) (interface{}, bool) {
Expand Down
Loading

0 comments on commit 33ff2e9

Please sign in to comment.