Skip to content

Commit

Permalink
o/registrystate: support snapctl get --pristine (canonical#14552)
Browse files Browse the repository at this point in the history
* o/registrystate: improve get/set helper names

Signed-off-by: Miguel Pires <[email protected]>
  • Loading branch information
MiguelPires authored Oct 8, 2024
1 parent ee99b92 commit 0d30934
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 83 deletions.
4 changes: 2 additions & 2 deletions daemon/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ var (
assertstateRefreshSnapAssertions = assertstate.RefreshSnapAssertions
assertstateRestoreValidationSetsTracking = assertstate.RestoreValidationSetsTracking

registrystateGetViaView = registrystate.GetViaView
registrystateSetViaView = registrystate.SetViaView
registrystateGet = registrystate.Get
registrystateSet = registrystate.Set
)

func ensureStateSoonImpl(st *state.State) {
Expand Down
4 changes: 2 additions & 2 deletions daemon/api_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func getView(c *Command, r *http.Request, _ *auth.UserState) Response {
fields = strutil.CommaSeparatedList(fieldStr)
}

results, err := registrystateGetViaView(st, account, registryName, view, fields)
results, err := registrystateGet(st, account, registryName, view, fields)
if err != nil {
return toAPIError(err)
}
Expand All @@ -86,7 +86,7 @@ func setView(c *Command, r *http.Request, _ *auth.UserState) Response {
return BadRequest("cannot decode registry request body: %v", err)
}

err := registrystateSetViaView(st, account, registryName, view, values)
err := registrystateSet(st, account, registryName, view, values)
if err != nil {
return toAPIError(err)
}
Expand Down
34 changes: 17 additions & 17 deletions daemon/api_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (s *registrySuite) TestGetView(c *C) {
{name: "map", value: map[string]int{"foo": 123}},
} {
cmt := Commentf("%s test", t.name)
restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, acc, registry, view string, fields []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, acc, registry, view string, fields []string) (interface{}, error) {
c.Check(acc, Equals, "system", cmt)
c.Check(registry, Equals, "network", cmt)
c.Check(view, Equals, "wifi-setup", cmt)
Expand All @@ -112,7 +112,7 @@ func (s *registrySuite) TestViewGetMany(c *C) {
s.setFeatureFlag(c)

var calls int
restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) {
calls++
switch calls {
case 1:
Expand All @@ -137,7 +137,7 @@ func (s *registrySuite) TestViewGetSomeFieldNotFound(c *C) {
s.setFeatureFlag(c)

var calls int
restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, acc, registry, view string, _ []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, acc, registry, view string, _ []string) (interface{}, error) {
calls++
switch calls {
case 1:
Expand All @@ -162,7 +162,7 @@ func (s *registrySuite) TestGetViewNoFieldsFound(c *C) {
s.setFeatureFlag(c)

var calls int
restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) {
calls++
switch calls {
case 1:
Expand Down Expand Up @@ -193,7 +193,7 @@ func (s *registrySuite) TestGetViewNoFieldsFound(c *C) {
func (s *registrySuite) TestViewGetDatabagNotFound(c *C) {
s.setFeatureFlag(c)

restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) {
return nil, &registry.NotFoundError{Account: "foo", RegistryName: "network", View: "wifi-setup", Operation: "get", Requests: []string{"ssid"}, Cause: "mocked"}
})
defer restore()
Expand Down Expand Up @@ -242,7 +242,7 @@ func (s *registrySuite) testViewSetMany(c *C) {
s.setFeatureFlag(c)

var calls int
restore := daemon.MockRegistrystateSetViaView(func(st *state.State, account, registryName, viewName string, requests map[string]interface{}) error {
restore := daemon.MockRegistrystateSet(func(st *state.State, account, registryName, viewName string, requests map[string]interface{}) error {
calls++
switch calls {
case 1:
Expand Down Expand Up @@ -307,7 +307,7 @@ func (s *registrySuite) TestGetViewError(c *C) {
{name: "registry not found", err: &registry.NotFoundError{}, code: 404},
{name: "internal", err: errors.New("internal"), code: 500},
} {
restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, _ []string) (interface{}, error) {
return nil, t.err
})

Expand All @@ -324,7 +324,7 @@ func (s *registrySuite) TestGetViewMisshapenQuery(c *C) {
s.setFeatureFlag(c)

var calls int
restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) {
calls++
switch calls {
case 1:
Expand Down Expand Up @@ -361,7 +361,7 @@ func (s *registrySuite) TestSetView(c *C) {
{name: "map", value: map[string]interface{}{"foo": "bar"}},
} {
cmt := Commentf("%s test", t.name)
restore := daemon.MockRegistrystateSetViaView(func(st *state.State, acc, registryName, view string, requests map[string]interface{}) error {
restore := daemon.MockRegistrystateSet(func(st *state.State, acc, registryName, view string, requests map[string]interface{}) error {
c.Check(acc, Equals, "system", cmt)
c.Check(registryName, Equals, "network", cmt)
c.Check(view, Equals, "wifi-setup", cmt)
Expand Down Expand Up @@ -412,7 +412,7 @@ func (s *registrySuite) TestSetView(c *C) {
func (s *registrySuite) TestUnsetView(c *C) {
s.setFeatureFlag(c)

restore := daemon.MockRegistrystateSetViaView(func(_ *state.State, acc, registryName, view string, requests map[string]interface{}) error {
restore := daemon.MockRegistrystateSet(func(_ *state.State, acc, registryName, view string, requests map[string]interface{}) error {
c.Check(acc, Equals, "system")
c.Check(registryName, Equals, "network")
c.Check(view, Equals, "wifi-setup")
Expand Down Expand Up @@ -452,7 +452,7 @@ func (s *registrySuite) TestSetViewError(c *C) {
{name: "not found", err: &registry.NotFoundError{}, code: 404},
{name: "internal", err: errors.New("internal"), code: 500},
} {
restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error {
restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error {
return t.err
})
cmt := Commentf("%s test", t.name)
Expand All @@ -471,7 +471,7 @@ func (s *registrySuite) TestSetViewError(c *C) {
func (s *registrySuite) TestSetViewEmptyBody(c *C) {
s.setFeatureFlag(c)

restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error {
restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error {
err := errors.New("unexpected call to registrystate.Set")
c.Error(err)
return err
Expand Down Expand Up @@ -501,7 +501,7 @@ func (s *registrySuite) TestSetViewBadRequest(c *C) {
func (s *registrySuite) TestGetBadRequest(c *C) {
s.setFeatureFlag(c)

restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, acc, registryName, view string, fields []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, acc, registryName, view string, fields []string) (interface{}, error) {
return nil, &registry.BadRequestError{
Account: "acc",
RegistryName: "reg",
Expand All @@ -525,7 +525,7 @@ func (s *registrySuite) TestGetBadRequest(c *C) {
func (s *registrySuite) TestSetBadRequest(c *C) {
s.setFeatureFlag(c)

restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error {
restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error {
return &registry.BadRequestError{
Account: "acc",
RegistryName: "reg",
Expand All @@ -549,7 +549,7 @@ func (s *registrySuite) TestSetBadRequest(c *C) {
}

func (s *registrySuite) TestSetFailUnsetFeatureFlag(c *C) {
restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error {
restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error {
err := fmt.Errorf("unexpected call to registrystate")
c.Error(err)
return err
Expand All @@ -568,7 +568,7 @@ func (s *registrySuite) TestSetFailUnsetFeatureFlag(c *C) {
}

func (s *registrySuite) TestGetFailUnsetFeatureFlag(c *C) {
restore := daemon.MockRegistrystateSetViaView(func(*state.State, string, string, string, map[string]interface{}) error {
restore := daemon.MockRegistrystateSet(func(*state.State, string, string, string, map[string]interface{}) error {
err := fmt.Errorf("unexpected call to registrystate")
c.Error(err)
return err
Expand All @@ -588,7 +588,7 @@ func (s *registrySuite) TestGetNoFields(c *C) {
s.setFeatureFlag(c)

value := map[string]interface{}{"foo": 1, "bar": "baz", "nested": map[string]interface{}{"a": []interface{}{1, 2}}}
restore := daemon.MockRegistrystateGetViaView(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) {
restore := daemon.MockRegistrystateGet(func(_ *state.State, _, _, _ string, fields []string) (interface{}, error) {
c.Check(fields, IsNil)
return value, nil
})
Expand Down
16 changes: 8 additions & 8 deletions daemon/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,19 +379,19 @@ var (
MaxReadBuflen = maxReadBuflen
)

func MockRegistrystateGetViaView(f func(_ *state.State, _, _, _ string, _ []string) (interface{}, error)) (restore func()) {
old := registrystateGetViaView
registrystateGetViaView = f
func MockRegistrystateGet(f func(_ *state.State, _, _, _ string, _ []string) (interface{}, error)) (restore func()) {
old := registrystateGet
registrystateGet = f
return func() {
registrystateGetViaView = old
registrystateGet = old
}
}

func MockRegistrystateSetViaView(f func(_ *state.State, _, _, _ string, _ map[string]interface{}) error) (restore func()) {
old := registrystateSetViaView
registrystateSetViaView = f
func MockRegistrystateSet(f func(_ *state.State, _, _, _ string, _ map[string]interface{}) error) (restore func()) {
old := registrystateSet
registrystateSet = f
return func() {
registrystateSetViaView = old
registrystateSet = old
}
}

Expand Down
10 changes: 10 additions & 0 deletions overlord/hookstate/ctlcmd/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ import (
"github.com/snapcore/snapd/osutil/user"
"github.com/snapcore/snapd/overlord/devicestate"
"github.com/snapcore/snapd/overlord/hookstate"
"github.com/snapcore/snapd/overlord/registrystate"
"github.com/snapcore/snapd/overlord/servicestate"
"github.com/snapcore/snapd/overlord/snapstate"
"github.com/snapcore/snapd/overlord/state"
"github.com/snapcore/snapd/registry"
"github.com/snapcore/snapd/snap"
"github.com/snapcore/snapd/testutil"
)
Expand Down Expand Up @@ -177,3 +179,11 @@ func MockNewStatusDecorator(f func(ctx context.Context, isGlobal bool, uid strin
newStatusDecorator = f
return restore
}

func MockRegistrystateRegistryTransaction(f func(*hookstate.Context, *registry.Registry) (*registrystate.Transaction, error)) (restore func()) {
old := registrystateRegistryTransaction
registrystateRegistryTransaction = f
return func() {
registrystateRegistryTransaction = old
}
}
19 changes: 15 additions & 4 deletions overlord/hookstate/ctlcmd/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type getCommand struct {
ForceSlotSide bool `long:"slot" description:"return attribute values from the slot side of the connection"`
ForcePlugSide bool `long:"plug" description:"return attribute values from the plug side of the connection"`
View bool `long:"view" description:"return registry values from the view declared in the plug"`
Pristine bool `long:"pristine" description:"return registry values disregarding changes from the current transaction"`

Positional struct {
PlugOrSlotSpec string `positional-args:"true" positional-arg-name:":<plug|slot>"`
Expand Down Expand Up @@ -159,6 +160,9 @@ func (c *getCommand) Execute(args []string) error {
if c.Typed && c.Document {
return fmt.Errorf("cannot use -d and -t together")
}
if c.Pristine && !c.View {
return fmt.Errorf("cannot use --pristine without --view")
}

if strings.Contains(c.Positional.PlugOrSlotSpec, ":") {
parts := strings.SplitN(c.Positional.PlugOrSlotSpec, ":", 2)
Expand All @@ -176,7 +180,7 @@ func (c *getCommand) Execute(args []string) error {

if c.View {
requests := c.Positional.Keys
return c.getRegistryValues(context, name, requests)
return c.getRegistryValues(context, name, requests, c.Pristine)
}
return c.getInterfaceSetting(context, name)
}
Expand Down Expand Up @@ -357,7 +361,9 @@ func (c *getCommand) getInterfaceSetting(context *hookstate.Context, plugOrSlot
})
}

func (c *getCommand) getRegistryValues(ctx *hookstate.Context, plugName string, requests []string) error {
var registrystateRegistryTransaction = registrystate.RegistryTransaction

func (c *getCommand) getRegistryValues(ctx *hookstate.Context, plugName string, requests []string, pristine bool) error {
if c.ForcePlugSide || c.ForceSlotSide {
return errors.New(i18n.G("cannot use --plug or --slot with --view"))
}
Expand All @@ -369,12 +375,17 @@ func (c *getCommand) getRegistryValues(ctx *hookstate.Context, plugName string,
return fmt.Errorf("cannot get registry: %v", err)
}

tx, err := registrystate.RegistryTransaction(ctx, view.Registry())
tx, err := registrystateRegistryTransaction(ctx, view.Registry())
if err != nil {
return err
}

res, err := registrystate.GetViaViewInTx(tx, view, requests)
bag := registry.DataBag(tx)
if pristine {
bag = tx.Pristine()
}

res, err := registrystate.GetViaView(bag, view, requests)
if err != nil {
return err
}
Expand Down
50 changes: 45 additions & 5 deletions overlord/hookstate/ctlcmd/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/snapcore/snapd/overlord/ifacestate/ifacerepo"
"github.com/snapcore/snapd/overlord/registrystate"
"github.com/snapcore/snapd/overlord/state"
"github.com/snapcore/snapd/registry"
"github.com/snapcore/snapd/snap"
"github.com/snapcore/snapd/testutil"
)
Expand Down Expand Up @@ -591,7 +592,7 @@ slots:

func (s *registrySuite) TestRegistryGetSingleView(c *C) {
s.state.Lock()
err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
"ssid": "my-ssid",
})
s.state.Unlock()
Expand All @@ -605,7 +606,7 @@ func (s *registrySuite) TestRegistryGetSingleView(c *C) {

func (s *registrySuite) TestRegistryGetManyViews(c *C) {
s.state.Lock()
err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
"ssid": "my-ssid",
"password": "secret",
})
Expand All @@ -624,7 +625,7 @@ func (s *registrySuite) TestRegistryGetManyViews(c *C) {

func (s *registrySuite) TestRegistryGetNoRequest(c *C) {
s.state.Lock()
err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
"ssid": "my-ssid",
"password": "secret",
})
Expand All @@ -643,7 +644,7 @@ func (s *registrySuite) TestRegistryGetNoRequest(c *C) {

func (s *registrySuite) TestRegistryGetHappensTransactionally(c *C) {
s.state.Lock()
err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
"ssid": "my-ssid",
})
s.state.Unlock()
Expand All @@ -659,7 +660,7 @@ func (s *registrySuite) TestRegistryGetHappensTransactionally(c *C) {
c.Check(stderr, IsNil)

s.state.Lock()
err = registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
err = registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
"ssid": "other-ssid",
})
s.state.Unlock()
Expand Down Expand Up @@ -847,3 +848,42 @@ func (s *registrySuite) TestRegistryGetAndSetViewNotFound(c *C) {
c.Check(stdout, IsNil)
c.Check(stderr, IsNil)
}

func (s *registrySuite) TestRegistryGetPristine(c *C) {
s.state.Lock()
defer s.state.Unlock()

err := registrystate.Set(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{
"ssid": "foo",
})
c.Assert(err, IsNil)

task := s.state.NewTask("run-hook", "")
setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "save-view-plug"}
ctx, err := hookstate.NewContext(task, s.state, setup, s.mockHandler, "")
c.Assert(err, IsNil)

tx, err := registrystate.NewTransaction(s.state, s.devAccID, "network")
c.Assert(err, IsNil)

err = tx.Set("wifi.ssid", "bar")
c.Assert(err, IsNil)

restore := ctlcmd.MockRegistrystateRegistryTransaction(func(*hookstate.Context, *registry.Registry) (*registrystate.Transaction, error) {
return tx, nil
})
defer restore()

s.state.Unlock()
defer s.state.Lock()

stdout, stderr, err := ctlcmd.Run(ctx, []string{"get", "--view", "--pristine", ":read-wifi", "ssid"}, 0)
c.Assert(err, IsNil)
c.Check(string(stdout), Equals, "foo\n")
c.Check(stderr, IsNil)

stdout, stderr, err = ctlcmd.Run(ctx, []string{"get", "--view", ":read-wifi", "ssid"}, 0)
c.Assert(err, IsNil)
c.Check(string(stdout), Equals, "bar\n")
c.Check(stderr, IsNil)
}
2 changes: 1 addition & 1 deletion overlord/hookstate/ctlcmd/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,5 @@ func setRegistryValues(ctx *hookstate.Context, plugName string, requests map[str
// TODO: once we have hooks, check that we don't set values in the wrong hooks
// (e.g., "registry-changed" hooks can only read data)

return registrystate.SetViaViewInTx(tx, view, requests)
return registrystate.SetViaView(tx, view, requests)
}
Loading

0 comments on commit 0d30934

Please sign in to comment.