Skip to content

Commit

Permalink
o/assertstate: add function to get enforced validation sets that are …
Browse files Browse the repository at this point in the history
…associated with a model
  • Loading branch information
andrewphelpsj authored and Meulengracht committed Feb 23, 2024
1 parent 535f048 commit b79debc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 12 deletions.
58 changes: 46 additions & 12 deletions overlord/assertstate/validation_set_tracking.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,6 @@ func ValidationSets(st *state.State) (map[string]*ValidationSetTracking, error)
// added to the returned set and replaces validation sets with same account/name
// in case they were tracked already.
func TrackedEnforcedValidationSets(st *state.State, extraVss ...*asserts.ValidationSet) (*snapasserts.ValidationSets, error) {
valsets, err := ValidationSets(st)
if err != nil {
return nil, err
}

db := DB(st)
sets := snapasserts.NewValidationSets()

skip := make(map[string]bool, len(extraVss))
Expand All @@ -208,14 +202,54 @@ func TrackedEnforcedValidationSets(st *state.State, extraVss ...*asserts.Validat
skip[fmt.Sprintf("%s:%s", extraVs.AccountID(), extraVs.Name())] = true
}

skipSet := func(key string) bool {
// if extraVs matches an already enforced validation set, then skip that one, extraVs has been added
// before the loop.
return skip[key]
}

if err := trackedEnforcedValidationSets(st, skipSet, sets); err != nil {
return nil, err
}

return sets, nil
}

// TrackedEnforcedValidationSetsForModel returns a ValidationSets object for
// currently tracked validation sets that are in enforcing mode and also
// associated with the specified model.
func TrackedEnforcedValidationSetsForModel(st *state.State, model *asserts.Model) (*snapasserts.ValidationSets, error) {
modelSets := make(map[string]bool, len(model.ValidationSets()))
for _, vs := range model.ValidationSets() {
modelSets[fmt.Sprintf("%s:%s", vs.AccountID, vs.Name)] = true
}

skipSet := func(key string) bool {
return !modelSets[key]
}

sets := snapasserts.NewValidationSets()
if err := trackedEnforcedValidationSets(st, skipSet, sets); err != nil {
return nil, err
}

return sets, nil
}

func trackedEnforcedValidationSets(st *state.State, skipSet func(string) bool, sets *snapasserts.ValidationSets) error {
valsets, err := ValidationSets(st)
if err != nil {
return err
}

db := DB(st)

for _, vs := range valsets {
if vs.Mode != Enforce {
continue
}

// if extraVs matches an already enforced validation set, then skip that one, extraVs has been added
// before the loop.
if skip[fmt.Sprintf("%s:%s", vs.AccountID, vs.Name)] {
if skipSet(fmt.Sprintf("%s:%s", vs.AccountID, vs.Name)) {
continue
}

Expand All @@ -232,16 +266,16 @@ func TrackedEnforcedValidationSets(st *state.State, extraVss ...*asserts.Validat

as, err := db.Find(asserts.ValidationSetType, headers)
if err != nil {
return nil, err
return err
}

vsetAssert := as.(*asserts.ValidationSet)
if err := sets.Add(vsetAssert); err != nil {
return nil, err
return err
}
}

return sets, err
return err
}

// addCurrentTrackingToValidationSetsHistory stores the current state of validation-sets
Expand Down
55 changes: 55 additions & 0 deletions overlord/assertstate/validation_set_tracking_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
package assertstate_test

import (
"fmt"

. "gopkg.in/check.v1"

"github.com/snapcore/snapd/asserts"
"github.com/snapcore/snapd/asserts/assertstest"
"github.com/snapcore/snapd/asserts/snapasserts"
"github.com/snapcore/snapd/overlord/assertstate"
"github.com/snapcore/snapd/overlord/assertstate/assertstatetest"
"github.com/snapcore/snapd/overlord/snapstate/snapstatetest"
Expand Down Expand Up @@ -576,3 +579,55 @@ func (s *validationSetTrackingSuite) TestValidationSetSequence(c *C) {
tr.PinnedAt = 1
c.Check(tr.Sequence(), Equals, 1)
}

func (s *validationSetTrackingSuite) TestTrackedEnforcedValidationSets(c *C) {
s.st.Lock()
defer s.st.Unlock()

a := assertstest.FakeAssertion(map[string]interface{}{
"type": "model",
"authority-id": "my-brand",
"series": "16",
"brand-id": "my-brand",
"model": "my-model",
"architecture": "amd64",
"store": "my-brand-store",
"gadget": "gadget",
"kernel": "krnl",
"validation-sets": []interface{}{
map[string]interface{}{
"account-id": s.dev1acct.AccountID(),
"name": "foo",
"mode": "enforce",
"sequence": "9",
},
map[string]interface{}{
"account-id": s.dev1acct.AccountID(),
"name": "bar",
"mode": "prefer-enforce",
"sequence": "9",
},
},
})

model := a.(*asserts.Model)

for _, name := range []string{"foo", "bar", "baz"} {
assertstate.UpdateValidationSet(s.st, &assertstate.ValidationSetTracking{
AccountID: s.dev1acct.AccountID(),
Name: name,
Mode: assertstate.Enforce,
Current: 9,
})
vs := s.mockAssert(c, name, "9", "required")
c.Assert(assertstate.Add(s.st, vs), IsNil)
}

sets, err := assertstate.TrackedEnforcedValidationSetsForModel(s.st, model)
c.Assert(err, IsNil)

keys := sets.Keys()
c.Check(keys, testutil.Contains, snapasserts.ValidationSetKey(fmt.Sprintf("16/%s/%s/9", s.dev1acct.AccountID(), "foo")))
c.Check(keys, testutil.Contains, snapasserts.ValidationSetKey(fmt.Sprintf("16/%s/%s/9", s.dev1acct.AccountID(), "bar")))
c.Check(keys, Not(testutil.Contains), snapasserts.ValidationSetKey(fmt.Sprintf("16/%s/%s/9", s.dev1acct.AccountID(), "baz")))
}

0 comments on commit b79debc

Please sign in to comment.