From 047639426b4a8f2d2adc1f24fae1a0c8bafbdb20 Mon Sep 17 00:00:00 2001 From: Miguel Pires Date: Thu, 4 Jul 2024 10:46:12 +0100 Subject: [PATCH] o/h/ctlcmd: support reading registry views in snapctl (#14120) Adds registry support to `snapctl get` so snaps can access registry views, if they have a matching connected plug for the view and if the registry assertion can be found. Signed-off-by: Miguel Pires --- overlord/hookstate/ctlcmd/get.go | 83 +++- overlord/hookstate/ctlcmd/get_test.go | 389 ++++++++++++++++++ .../hookstate/ctlcmd/is_connected_test.go | 3 +- overlord/registrystate/registrystate.go | 62 ++- overlord/registrystate/registrystate_test.go | 75 ++++ registry/registry.go | 22 +- registry/registry_test.go | 4 + registry/transaction.go | 12 +- registry/transaction_test.go | 62 +-- 9 files changed, 664 insertions(+), 48 deletions(-) diff --git a/overlord/hookstate/ctlcmd/get.go b/overlord/hookstate/ctlcmd/get.go index b9c1151b880..89adc7f5086 100644 --- a/overlord/hookstate/ctlcmd/get.go +++ b/overlord/hookstate/ctlcmd/get.go @@ -21,14 +21,19 @@ package ctlcmd import ( "encoding/json" + "errors" "fmt" "strings" + "github.com/snapcore/snapd/asserts" "github.com/snapcore/snapd/i18n" "github.com/snapcore/snapd/interfaces" + "github.com/snapcore/snapd/overlord/assertstate" "github.com/snapcore/snapd/overlord/configstate" "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/hookstate" + "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" + "github.com/snapcore/snapd/overlord/registrystate" "github.com/snapcore/snapd/overlord/state" ) @@ -38,6 +43,7 @@ type getCommand struct { // these two options are mutually exclusive ForceSlotSide bool `long:"slot" description:"return attribute values from the slot side of the connection"` ForcePlugSide bool `long:"plug" description:"return attribute values from the plug side of the connection"` + View bool `long:"view" description:"return registry values from the view declared in the plug"` Positional struct { PlugOrSlotSpec string `positional-args:"true" positional-arg-name:":"` @@ -103,9 +109,15 @@ func (c *getCommand) printValues(getByKey func(string) (interface{}, bool, error } } + return c.printPatch(patch) +} + +func (c *getCommand) printPatch(patch interface{}) error { var confToPrint interface{} = patch if !c.Document && len(c.Positional.Keys) == 1 { - confToPrint = patch[c.Positional.Keys[0]] + if confMap, ok := patch.(map[string]interface{}); ok { + confToPrint = confMap[c.Positional.Keys[0]] + } } if c.Typed && confToPrint == nil { @@ -155,10 +167,14 @@ func (c *getCommand) Execute(args []string) error { if snap != "" { return fmt.Errorf(`"snapctl get %s" not supported, use "snapctl get :%s" instead`, c.Positional.PlugOrSlotSpec, parts[1]) } - if len(c.Positional.Keys) == 0 { + // registry views can be read without fields + if !c.View && len(c.Positional.Keys) == 0 { return fmt.Errorf(i18n.G("get which attribute?")) } + if c.View { + return c.getRegistryView(context, name) + } return c.getInterfaceSetting(context, name) } @@ -284,8 +300,7 @@ func (c *getCommand) getInterfaceSetting(context *hookstate.Context, plugOrSlot return fmt.Errorf(i18n.G("interface attributes can only be read during the execution of interface hooks")) } - var attrsTask *state.Task - attrsTask, err = attributesTask(context) + attrsTask, err := attributesTask(context) if err != nil { return err } @@ -338,3 +353,63 @@ func (c *getCommand) getInterfaceSetting(context *hookstate.Context, plugOrSlot return nil, false, err }) } + +func (c *getCommand) getRegistryView(ctx *hookstate.Context, plugName string) error { + if c.ForcePlugSide || c.ForceSlotSide { + return fmt.Errorf(i18n.G("cannot use --plug or --slot with --view")) + } + + ctx.Lock() + defer ctx.Unlock() + repo := ifacerepo.Get(ctx.State()) + + plug := repo.Plug(ctx.InstanceName(), plugName) + if plug == nil { + return fmt.Errorf(i18n.G("no plug :%s for snap %q"), plugName, ctx.InstanceName()) + } + + if plug.Interface != "registry" { + return fmt.Errorf(i18n.G("cannot use --view with non-registry plug :%s"), plugName) + } + + var account string + if err := plug.Attr("account", &account); err != nil { + // should not be possible at this stage + return fmt.Errorf(i18n.G("internal error: cannot find \"account\" attribute in plug :%s: %w"), plugName, err) + } + + var registryView string + if err := plug.Attr("view", ®istryView); err != nil { + // should not be possible at this stage + return fmt.Errorf(i18n.G("internal error: cannot find \"view\" attribute in plug :%s: %w"), plugName, err) + } + + parts := strings.Split(registryView, "/") + registryName, viewName := parts[0], parts[1] + + registryAssert, err := assertstate.Registry(ctx.State(), account, registryName) + if err != nil { + if errors.Is(err, &asserts.NotFoundError{}) { + return fmt.Errorf(i18n.G("cannot get %s/%s: registry not found"), account, registryView) + } + return err + } + reg := registryAssert.Registry() + + view := reg.View(viewName) + if view == nil { + return fmt.Errorf(i18n.G("cannot get %s/%s: view not found"), account, registryView) + } + + tx, err := registrystate.RegistryTransaction(ctx, reg) + if err != nil { + return err + } + + res, err := registrystate.GetViaViewInTx(tx, view, c.Positional.Keys) + if err != nil { + return err + } + + return c.printPatch(res) +} diff --git a/overlord/hookstate/ctlcmd/get_test.go b/overlord/hookstate/ctlcmd/get_test.go index f2bd08d6f7f..8db5cb6d7c7 100644 --- a/overlord/hookstate/ctlcmd/get_test.go +++ b/overlord/hookstate/ctlcmd/get_test.go @@ -20,18 +20,28 @@ package ctlcmd_test import ( + "fmt" "strings" . "gopkg.in/check.v1" + "github.com/snapcore/snapd/asserts" + "github.com/snapcore/snapd/asserts/assertstest" + "github.com/snapcore/snapd/dirs" "github.com/snapcore/snapd/interfaces" + "github.com/snapcore/snapd/interfaces/ifacetest" + "github.com/snapcore/snapd/overlord/assertstate" + "github.com/snapcore/snapd/overlord/assertstate/assertstatetest" "github.com/snapcore/snapd/overlord/configstate" "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/hookstate" "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" "github.com/snapcore/snapd/overlord/hookstate/hooktest" + "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" + "github.com/snapcore/snapd/overlord/registrystate" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/snap" + "github.com/snapcore/snapd/testutil" ) type getSuite struct { @@ -436,3 +446,382 @@ func (s *getAttrSuite) TestSlotHookTests(c *C) { } } } + +type registrySuite struct { + testutil.BaseTest + + state *state.State + signingDB *assertstest.SigningDB + devAccID string + + mockContext *hookstate.Context + mockHandler *hooktest.MockHandler +} + +var _ = Suite(®istrySuite{}) + +func (s *registrySuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) + dirs.SetRootDir(c.MkDir()) + s.AddCleanup(func() { + dirs.SetRootDir("/") + }) + + s.mockHandler = hooktest.NewMockHandler() + s.state = state.New(nil) + s.state.Lock() + task := s.state.NewTask("test-task", "my test task") + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "test-hook"} + s.state.Unlock() + + var err error + s.mockContext, err = hookstate.NewContext(task, s.state, setup, s.mockHandler, "") + c.Assert(err, IsNil) + + storeSigning := assertstest.NewStoreStack("can0nical", nil) + db, err := asserts.OpenDatabase(&asserts.DatabaseConfig{ + Backstore: asserts.NewMemoryBackstore(), + Trusted: storeSigning.Trusted, + }) + c.Assert(err, IsNil) + c.Assert(db.Add(storeSigning.StoreAccountKey("")), IsNil) + + s.state.Lock() + defer s.state.Unlock() + assertstate.ReplaceDB(s.state, db) + + // add developer1's account and account-key assertions + devAcc := assertstest.NewAccount(storeSigning, "developer1", nil, "") + c.Assert(storeSigning.Add(devAcc), IsNil) + + devPrivKey, _ := assertstest.GenerateKey(752) + devAccKey := assertstest.NewAccountKey(storeSigning, devAcc, nil, devPrivKey.PublicKey(), "") + s.devAccID = devAccKey.AccountID() + + assertstatetest.AddMany(s.state, storeSigning.StoreAccountKey(""), devAcc, devAccKey) + + s.signingDB = assertstest.NewSigningDB("developer1", devPrivKey) + c.Check(s.signingDB, NotNil) + c.Assert(storeSigning.Add(devAccKey), IsNil) + + headers := map[string]interface{}{ + "authority-id": s.devAccID, + "account-id": s.devAccID, + "name": "network", + "views": map[string]interface{}{ + "read-wifi": map[string]interface{}{ + "rules": []interface{}{ + map[string]interface{}{"request": "ssid", "storage": "wifi.ssid", "access": "read"}, + map[string]interface{}{"request": "password", "storage": "wifi.psk", "access": "read"}, + }, + }, + "write-wifi": map[string]interface{}{ + "rules": []interface{}{ + map[string]interface{}{"request": "ssid", "storage": "wifi.ssid", "access": "write"}, + map[string]interface{}{"request": "password", "storage": "wifi.psk", "access": "write"}, + }, + }, + }, + "timestamp": "2030-11-06T09:16:26Z", + } + + body := []byte(`{ + "storage": { + "schema": { + "wifi": "any" + } + } +}`) + + as, err := s.signingDB.Sign(asserts.RegistryType, headers, body, "") + c.Assert(err, IsNil) + c.Assert(assertstate.Add(s.state, as), IsNil) + + repo := interfaces.NewRepository() + ifacerepo.Replace(s.state, repo) + + regIface := &ifacetest.TestInterface{InterfaceName: "registry"} + err = repo.AddInterface(regIface) + c.Assert(err, IsNil) + + snapYaml := fmt.Sprintf(`name: test-snap +type: app +version: 1 +plugs: + read-wifi: + interface: registry + account: %[1]s + view: network/read-wifi +`, s.devAccID) + info := mockInstalledSnap(c, s.state, snapYaml, "") + + appSet, err := interfaces.NewSnapAppSet(info, nil) + c.Assert(err, IsNil) + err = repo.AddAppSet(appSet) + c.Assert(err, IsNil) + + const coreYaml = `name: core +version: 1.0 +type: os +slots: + registry-slot: + interface: registry +` + info = mockInstalledSnap(c, s.state, coreYaml, "") + + coreSet, err := interfaces.NewSnapAppSet(info, nil) + c.Assert(err, IsNil) + + err = repo.AddAppSet(coreSet) + c.Assert(err, IsNil) + + ref := &interfaces.ConnRef{ + PlugRef: interfaces.PlugRef{Snap: "test-snap", Name: "read-wifi"}, + SlotRef: interfaces.SlotRef{Snap: "core", Name: "registry-slot"}, + } + _, err = repo.Connect(ref, nil, nil, nil, nil, nil) + c.Assert(err, IsNil) +} + +func (s *registrySuite) TestRegistryGetSingleView(c *C) { + s.state.Lock() + err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + "ssid": "my-ssid", + }) + s.state.Unlock() + c.Assert(err, IsNil) + + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi", "ssid"}, 0) + c.Assert(err, IsNil) + c.Check(string(stdout), Equals, "my-ssid\n") + c.Check(stderr, IsNil) +} + +func (s *registrySuite) TestRegistryGetManyViews(c *C) { + s.state.Lock() + err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + "ssid": "my-ssid", + "password": "secret", + }) + s.state.Unlock() + c.Assert(err, IsNil) + + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi", "ssid", "password"}, 0) + c.Assert(err, IsNil) + c.Check(string(stdout), Equals, `{ + "password": "secret", + "ssid": "my-ssid" +} +`) + c.Check(stderr, IsNil) +} + +func (s *registrySuite) TestRegistryGetNoRequest(c *C) { + s.state.Lock() + err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + "ssid": "my-ssid", + "password": "secret", + }) + s.state.Unlock() + c.Assert(err, IsNil) + + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) + c.Assert(err, IsNil) + c.Check(string(stdout), Equals, `{ + "password": "secret", + "ssid": "my-ssid" +} +`) + c.Check(stderr, IsNil) +} + +func (s *registrySuite) TestRegistryGetHappensTransactionally(c *C) { + s.state.Lock() + err := registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + "ssid": "my-ssid", + }) + s.state.Unlock() + c.Assert(err, IsNil) + + // registry transaction is created when snapctl runs for the first time + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) + c.Assert(err, IsNil) + c.Check(string(stdout), Equals, `{ + "ssid": "my-ssid" +} +`) + c.Check(stderr, IsNil) + + s.state.Lock() + err = registrystate.SetViaView(s.state, s.devAccID, "network", "write-wifi", map[string]interface{}{ + "ssid": "other-ssid", + }) + s.state.Unlock() + c.Assert(err, IsNil) + + // the new write wasn't reflected because it didn't run in the same transaction + stdout, stderr, err = ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) + c.Assert(err, IsNil) + c.Check(string(stdout), Equals, `{ + "ssid": "my-ssid" +} +`) + c.Check(stderr, IsNil) + + // make a new context so we get a new transaction + s.state.Lock() + task := s.state.NewTask("test-task", "my test task") + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "test-hook"} + s.mockContext, err = hookstate.NewContext(task, s.state, setup, s.mockHandler, "") + s.state.Unlock() + c.Assert(err, IsNil) + + // now we get the new data + stdout, stderr, err = ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) + c.Assert(err, IsNil) + c.Check(string(stdout), Equals, `{ + "ssid": "other-ssid" +} +`) + c.Check(stderr, IsNil) +} + +func (s *registrySuite) TestRegistryGetInvalid(c *C) { + type testcase struct { + args []string + err string + } + + tcs := []testcase{ + { + args: []string{"--slot", ":something"}, + err: `cannot use --plug or --slot with --view`, + }, + { + args: []string{"--plug", ":something"}, + err: `cannot use --plug or --slot with --view`, + }, + { + args: []string{":non-existent"}, + err: `no plug :non-existent for snap "test-snap"`, + }, + } + + for _, tc := range tcs { + stdout, stderr, err := ctlcmd.Run(s.mockContext, append([]string{"get", "--view"}, tc.args...), 0) + c.Assert(err, ErrorMatches, tc.err) + c.Check(stdout, IsNil) + c.Check(stderr, IsNil) + } +} + +func (s *registrySuite) TestRegistryGetNonRegistryPlug(c *C) { + dirs.SetRootDir(c.MkDir()) + s.AddCleanup(func() { + dirs.SetRootDir("/") + }) + + s.state.Lock() + repo := interfaces.NewRepository() + ifacerepo.Replace(s.state, repo) + + err := repo.AddInterface(&ifacetest.TestInterface{InterfaceName: "random"}) + c.Assert(err, IsNil) + + snapYaml := `name: test-snap +type: app +version: 1 +plugs: + my-plug: + interface: random +` + info := mockInstalledSnap(c, s.state, snapYaml, "") + + appSet, err := interfaces.NewSnapAppSet(info, nil) + c.Assert(err, IsNil) + err = repo.AddAppSet(appSet) + c.Assert(err, IsNil) + + const coreYaml = `name: core +version: 1.0 +type: os +slots: + my-slot: + interface: random +` + info = mockInstalledSnap(c, s.state, coreYaml, "") + + coreSet, err := interfaces.NewSnapAppSet(info, nil) + c.Assert(err, IsNil) + + err = repo.AddAppSet(coreSet) + c.Assert(err, IsNil) + + ref := &interfaces.ConnRef{ + PlugRef: interfaces.PlugRef{Snap: "test-snap", Name: "my-plug"}, + SlotRef: interfaces.SlotRef{Snap: "core", Name: "my-slot"}, + } + _, err = repo.Connect(ref, nil, nil, nil, nil, nil) + c.Assert(err, IsNil) + s.state.Unlock() + + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"get", "--view", ":my-plug"}, 0) + c.Assert(err, ErrorMatches, "cannot use --view with non-registry plug :my-plug") + c.Check(stdout, IsNil) + c.Check(stderr, IsNil) +} + +func (s *registrySuite) TestRegistryGetAssertionNotFound(c *C) { + storeSigning := assertstest.NewStoreStack("can0nical", nil) + db, err := asserts.OpenDatabase(&asserts.DatabaseConfig{ + Backstore: asserts.NewMemoryBackstore(), + Trusted: storeSigning.Trusted, + }) + c.Assert(err, IsNil) + c.Assert(db.Add(storeSigning.StoreAccountKey("")), IsNil) + + s.state.Lock() + assertstate.ReplaceDB(s.state, db) + s.state.Unlock() + + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) + c.Assert(err, ErrorMatches, fmt.Sprintf("cannot get %s/network/read-wifi: registry not found", s.devAccID)) + c.Check(stdout, IsNil) + c.Check(stderr, IsNil) +} + +func (s *registrySuite) TestRegistryGetViewNotFound(c *C) { + headers := map[string]interface{}{ + "authority-id": s.devAccID, + "account-id": s.devAccID, + "revision": "1", + "name": "network", + "views": map[string]interface{}{ + "other": map[string]interface{}{ + "rules": []interface{}{ + map[string]interface{}{"request": "a", "storage": "a"}, + }, + }, + }, + "timestamp": "2030-11-06T09:16:26Z", + } + + body := []byte(`{ + "storage": { + "schema": { + "a": "any" + } + } +}`) + + as, err := s.signingDB.Sign(asserts.RegistryType, headers, body, "") + c.Assert(err, IsNil) + s.state.Lock() + c.Assert(assertstate.Add(s.state, as), IsNil) + s.state.Unlock() + + stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"get", "--view", ":read-wifi"}, 0) + c.Assert(err, ErrorMatches, fmt.Sprintf("cannot get %s/network/read-wifi: view not found", s.devAccID)) + c.Check(stdout, IsNil) + c.Check(stderr, IsNil) +} diff --git a/overlord/hookstate/ctlcmd/is_connected_test.go b/overlord/hookstate/ctlcmd/is_connected_test.go index c6040f36148..0bbc40a45b8 100644 --- a/overlord/hookstate/ctlcmd/is_connected_test.go +++ b/overlord/hookstate/ctlcmd/is_connected_test.go @@ -135,7 +135,7 @@ var isConnectedTests = []struct { exitCode: ctlcmd.ClassicSnapCode, }} -func mockInstalledSnap(c *C, st *state.State, snapYaml, cohortKey string) { +func mockInstalledSnap(c *C, st *state.State, snapYaml, cohortKey string) *snap.Info { info := snaptest.MockSnapCurrent(c, snapYaml, &snap.SideInfo{Revision: snap.R(1)}) snapstate.Set(st, info.InstanceName(), &snapstate.SnapState{ Active: true, @@ -150,6 +150,7 @@ func mockInstalledSnap(c *C, st *state.State, snapYaml, cohortKey string) { TrackingChannel: "stable", CohortKey: cohortKey, }) + return info } func (s *isConnectedSuite) testIsConnected(c *C, context *hookstate.Context) { diff --git a/overlord/registrystate/registrystate.go b/overlord/registrystate/registrystate.go index 2131d4c6213..ccc8dd2f37e 100644 --- a/overlord/registrystate/registrystate.go +++ b/overlord/registrystate/registrystate.go @@ -22,14 +22,17 @@ import ( "errors" "github.com/snapcore/snapd/overlord/assertstate" + "github.com/snapcore/snapd/overlord/hookstate" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/registry" ) +var assertstateRegistry = assertstate.Registry + // SetViaView finds the view identified by the account, registry and view names // and sets the request fields to their respective values. func SetViaView(st *state.State, account, registryName, viewName string, requests map[string]interface{}) error { - registryAssert, err := assertstate.Registry(st, account, registryName) + registryAssert, err := assertstateRegistry(st, account, registryName) if err != nil { return err } @@ -37,9 +40,12 @@ func SetViaView(st *state.State, account, registryName, viewName string, request view := reg.View(viewName) if view == nil { - keys := make([]string, 0, len(requests)) - for k := range requests { - keys = append(keys, k) + var keys []string + if len(requests) > 0 { + keys = make([]string, 0, len(requests)) + for k := range requests { + keys = append(keys, k) + } } return ®istry.NotFoundError{ @@ -78,7 +84,7 @@ func SetViaView(st *state.State, account, registryName, viewName string, request // returned in a map of fields to their values, unless there are no fields in // which case all views are returned. func GetViaView(st *state.State, account, registryName, viewName string, fields []string) (interface{}, error) { - registryAssert, err := assertstate.Registry(st, account, registryName) + registryAssert, err := assertstateRegistry(st, account, registryName) if err != nil { return nil, err } @@ -101,6 +107,12 @@ func GetViaView(st *state.State, account, registryName, viewName string, fields return nil, err } + return GetViaViewInTx(tx, view, fields) +} + +// GetViaViewInTx uses the view to get values for the fields from the databag +// in the transaction. +func GetViaViewInTx(tx *registry.Transaction, view *registry.View, fields []string) (interface{}, error) { if len(fields) == 0 { val, err := view.Get(tx, "") if err != nil { @@ -126,10 +138,11 @@ func GetViaView(st *state.State, account, registryName, viewName string, fields } if len(results) == 0 { + account, registryName := tx.RegistryInfo() return nil, ®istry.NotFoundError{ Account: account, RegistryName: registryName, - View: viewName, + View: view.Name, Operation: "get", Requests: fields, Cause: "matching rules don't map to any values", @@ -139,15 +152,15 @@ func GetViaView(st *state.State, account, registryName, viewName string, fields return results, nil } -// newTransaction returns a transaction configured to read and write databags -// from state as needed. +// newTransaction returns a transaction configured to read and write +// databags from state as needed. func newTransaction(st *state.State, reg *registry.Registry) (*registry.Transaction, error) { getter := bagGetter(st, reg) setter := func(bag registry.JSONDataBag) error { return updateDatabags(st, bag, reg) } - tx, err := registry.NewTransaction(getter, setter, reg.Schema) + tx, err := registry.NewTransaction(reg, getter, setter) if err != nil { return nil, err } @@ -199,3 +212,34 @@ func updateDatabags(st *state.State, databag registry.JSONDataBag, reg *registry st.Set("registry-databags", databags) return nil } + +type cachedRegistryTx struct { + account string + registry string +} + +// RegistryTransaction returns the registry.Transaction cached in the context +// or creates one and caches it, if none existed. The context must be locked by +// the caller. +func RegistryTransaction(ctx *hookstate.Context, reg *registry.Registry) (*registry.Transaction, error) { + key := cachedRegistryTx{ + account: reg.Account, + registry: reg.Name, + } + tx, ok := ctx.Cached(key).(*registry.Transaction) + if ok { + return tx, nil + } + + tx, err := newTransaction(ctx.State(), reg) + if err != nil { + return nil, err + } + + ctx.OnDone(func() error { + return tx.Commit() + }) + + ctx.Cache(key, tx) + return tx, nil +} diff --git a/overlord/registrystate/registrystate_test.go b/overlord/registrystate/registrystate_test.go index e3baaae0460..417a17280ec 100644 --- a/overlord/registrystate/registrystate_test.go +++ b/overlord/registrystate/registrystate_test.go @@ -29,9 +29,12 @@ import ( "github.com/snapcore/snapd/overlord" "github.com/snapcore/snapd/overlord/assertstate" "github.com/snapcore/snapd/overlord/assertstate/assertstatetest" + "github.com/snapcore/snapd/overlord/hookstate" + "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/registrystate" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/registry" + "github.com/snapcore/snapd/snap" ) type registryTestSuite struct { @@ -291,3 +294,75 @@ func (s *registryTestSuite) TestRegistrystateGetEntireView(c *C) { }, }) } + +func (s *registryTestSuite) TestRegistryTransaction(c *C) { + mkRegistry := func(account, name string) *registry.Registry { + reg, err := registry.New(account, name, map[string]interface{}{ + "bar": map[string]interface{}{ + "rules": []interface{}{ + map[string]interface{}{"request": "foo", "storage": "foo"}, + }, + }, + }, registry.NewJSONSchema()) + c.Assert(err, IsNil) + return reg + } + + s.state.Lock() + task := s.state.NewTask("test-task", "my test task") + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "test-hook"} + s.state.Unlock() + mockHandler := hooktest.NewMockHandler() + + type testcase struct { + acc1, acc2 string + reg1, reg2 string + equals bool + } + + tcs := []testcase{ + { + // same transaction + acc1: "acc-1", reg1: "reg-1", + acc2: "acc-1", reg2: "reg-1", + equals: true, + }, + { + // different registry name, different transaction + acc1: "acc-1", reg1: "reg-1", + acc2: "acc-1", reg2: "reg-2", + }, + { + // different account, different transaction + acc1: "acc-1", reg1: "reg-1", + acc2: "acc-2", reg2: "reg-1", + }, + { + // both different, different transaction + acc1: "acc-1", reg1: "reg-1", + acc2: "acc-2", reg2: "reg-2", + }, + } + + for _, tc := range tcs { + ctx, err := hookstate.NewContext(task, task.State(), setup, mockHandler, "") + c.Assert(err, IsNil) + ctx.Lock() + + reg1 := mkRegistry(tc.acc1, tc.reg1) + reg2 := mkRegistry(tc.acc2, tc.reg2) + + tx1, err := registrystate.RegistryTransaction(ctx, reg1) + c.Assert(err, IsNil) + + tx2, err := registrystate.RegistryTransaction(ctx, reg2) + c.Assert(err, IsNil) + + if tc.equals { + c.Assert(tx1, Equals, tx2) + } else { + c.Assert(tx1, Not(Equals), tx2) + } + ctx.Unlock() + } +} diff --git a/registry/registry.go b/registry/registry.go index 43b414a74a2..1fc5452a8fa 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -70,12 +70,16 @@ type NotFoundError struct { func (e *NotFoundError) Error() string { var reqStr string - if len(e.Requests) == 1 { - reqStr = fmt.Sprintf("%q", e.Requests[0]) - } else { - reqStr = strutil.Quoted(e.Requests) + switch len(e.Requests) { + case 0: + // leave empty, so the message reflects the request gets the whole view + case 1: + reqStr = fmt.Sprintf(" %q in", e.Requests[0]) + default: + reqStr = fmt.Sprintf(" %s in", strutil.Quoted(e.Requests)) } - return fmt.Sprintf("cannot %s %s in registry view %s/%s/%s: %s", e.Operation, reqStr, e.Account, e.RegistryName, e.View, e.Cause) + + return fmt.Sprintf("cannot %s%s registry view %s/%s/%s: %s", e.Operation, reqStr, e.Account, e.RegistryName, e.View, e.Cause) } func (e *NotFoundError) Is(err error) bool { @@ -84,12 +88,17 @@ func (e *NotFoundError) Is(err error) bool { } func notFoundErrorFrom(v *View, op, request, errMsg string) *NotFoundError { + var req []string + if request != "" { + req = []string{request} + } + return &NotFoundError{ Account: v.registry.Account, RegistryName: v.registry.Name, View: v.Name, Operation: op, - Requests: []string{request}, + Requests: req, Cause: errMsg, } } @@ -980,6 +989,7 @@ func (v *View) Get(databag DataBag, request string) (interface{}, error) { } if merged == nil { + // TODO: improve this error message return nil, notFoundErrorFrom(v, "get", request, "matching rules don't map to any values") } diff --git a/registry/registry_test.go b/registry/registry_test.go index 7c0b3a50d68..a9599cefc60 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -321,6 +321,10 @@ func (s *viewSuite) TestRegistryNotFound(c *C) { c.Assert(err, testutil.ErrorIs, ®istry.NotFoundError{}) c.Assert(err, ErrorMatches, `cannot get "top-level" in registry view acc/foo/bar: matching rules don't map to any values`) + _, err = view.Get(databag, "") + c.Assert(err, testutil.ErrorIs, ®istry.NotFoundError{}) + c.Assert(err, ErrorMatches, `cannot get registry view acc/foo/bar: matching rules don't map to any values`) + err = view.Set(databag, "nested", "thing") c.Assert(err, IsNil) diff --git a/registry/transaction.go b/registry/transaction.go index 2b2370b600d..f2d126d2214 100644 --- a/registry/transaction.go +++ b/registry/transaction.go @@ -28,7 +28,7 @@ type DatabagWrite func(JSONDataBag) error // Transaction performs read and writes to a databag in an atomic way. type Transaction struct { pristine JSONDataBag - schema Schema + registry *Registry modified JSONDataBag deltas []map[string]interface{} @@ -40,7 +40,7 @@ type Transaction struct { } // NewTransaction takes a getter and setter to read and write the databag. -func NewTransaction(readDatabag DatabagRead, writeDatabag DatabagWrite, schema Schema) (*Transaction, error) { +func NewTransaction(reg *Registry, readDatabag DatabagRead, writeDatabag DatabagWrite) (*Transaction, error) { databag, err := readDatabag() if err != nil { return nil, err @@ -48,12 +48,16 @@ func NewTransaction(readDatabag DatabagRead, writeDatabag DatabagWrite, schema S return &Transaction{ pristine: databag.Copy(), - schema: schema, + registry: reg, readDatabag: readDatabag, writeDatabag: writeDatabag, }, nil } +func (t *Transaction) RegistryInfo() (account string, registryName string) { + return t.registry.Account, t.registry.Name +} + // Set sets a value in the transaction's databag. The change isn't persisted // until Commit returns without errors. func (t *Transaction) Set(path string, value interface{}) error { @@ -123,7 +127,7 @@ func (t *Transaction) Commit() error { return err } - if err := t.schema.Validate(data); err != nil { + if err := t.registry.Schema.Validate(data); err != nil { return err } diff --git a/registry/transaction_test.go b/registry/transaction_test.go index e8693110648..eb6a0805a99 100644 --- a/registry/transaction_test.go +++ b/registry/transaction_test.go @@ -27,6 +27,18 @@ import ( type transactionTestSuite struct{} +func newRegistry(c *C, schema registry.Schema) *registry.Registry { + registry, err := registry.New("my-account", "my-reg", map[string]interface{}{ + "my-view": map[string]interface{}{ + "rules": []interface{}{ + map[string]interface{}{"request": "foo", "storage": "foo"}, + }, + }, + }, schema) + c.Assert(err, IsNil) + return registry +} + var _ = Suite(&transactionTestSuite{}) type witnessReadWriter struct { @@ -50,8 +62,8 @@ func (w *witnessReadWriter) write(bag registry.JSONDataBag) error { func (s *transactionTestSuite) TestSet(c *C) { bag := registry.NewJSONDataBag() witness := &witnessReadWriter{bag: bag} - schema := registry.NewJSONSchema() - tx, err := registry.NewTransaction(witness.read, witness.write, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + tx, err := registry.NewTransaction(reg, witness.read, witness.write) c.Assert(err, IsNil) c.Assert(witness.readCalled, Equals, 1) @@ -65,8 +77,8 @@ func (s *transactionTestSuite) TestSet(c *C) { func (s *transactionTestSuite) TestCommit(c *C) { witness := &witnessReadWriter{bag: registry.NewJSONDataBag()} - schema := registry.NewJSONSchema() - tx, err := registry.NewTransaction(witness.read, witness.write, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + tx, err := registry.NewTransaction(reg, witness.read, witness.write) c.Assert(err, IsNil) c.Assert(witness.readCalled, Equals, 1) @@ -88,8 +100,8 @@ func (s *transactionTestSuite) TestCommit(c *C) { func (s *transactionTestSuite) TestGetReadsUncommitted(c *C) { databag := registry.NewJSONDataBag() witness := &witnessReadWriter{bag: databag} - schema := registry.NewJSONSchema() - tx, err := registry.NewTransaction(witness.read, witness.write, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + tx, err := registry.NewTransaction(reg, witness.read, witness.write) c.Assert(err, IsNil) err = databag.Set("foo", "bar") @@ -125,8 +137,8 @@ func (f *failingSchema) Type() registry.SchemaType { func (s *transactionTestSuite) TestRollBackOnCommitError(c *C) { databag := registry.NewJSONDataBag() witness := &witnessReadWriter{bag: databag} - schema := &failingSchema{err: errors.New("expected error")} - tx, err := registry.NewTransaction(witness.read, witness.write, schema) + reg := newRegistry(c, &failingSchema{err: errors.New("expected error")}) + tx, err := registry.NewTransaction(reg, witness.read, witness.write) c.Assert(err, IsNil) err = tx.Set("foo", "bar") @@ -148,8 +160,8 @@ func (s *transactionTestSuite) TestRollBackOnCommitError(c *C) { func (s *transactionTestSuite) TestManyWrites(c *C) { databag := registry.NewJSONDataBag() witness := &witnessReadWriter{bag: databag} - schema := registry.NewJSONSchema() - tx, err := registry.NewTransaction(witness.read, witness.write, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + tx, err := registry.NewTransaction(reg, witness.read, witness.write) c.Assert(err, IsNil) err = tx.Set("foo", "bar") @@ -172,8 +184,8 @@ func (s *transactionTestSuite) TestManyWrites(c *C) { func (s *transactionTestSuite) TestCommittedIncludesRecentWrites(c *C) { databag := registry.NewJSONDataBag() witness := &witnessReadWriter{bag: databag} - schema := registry.NewJSONSchema() - tx, err := registry.NewTransaction(witness.read, witness.write, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + tx, err := registry.NewTransaction(reg, witness.read, witness.write) c.Assert(err, IsNil) c.Assert(witness.readCalled, Equals, 1) @@ -214,11 +226,11 @@ func (s *transactionTestSuite) TestCommittedIncludesPreviousCommit(c *C) { return nil } - schema := registry.NewJSONSchema() - txOne, err := registry.NewTransaction(readBag, writeBag, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + txOne, err := registry.NewTransaction(reg, readBag, writeBag) c.Assert(err, IsNil) - txTwo, err := registry.NewTransaction(readBag, writeBag, schema) + txTwo, err := registry.NewTransaction(reg, readBag, writeBag) c.Assert(err, IsNil) err = txOne.Set("foo", "bar") @@ -259,8 +271,8 @@ func (s *transactionTestSuite) TestTransactionBagReadError(c *C) { return nil } - schema := registry.NewJSONSchema() - txOne, err := registry.NewTransaction(readBag, writeBag, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + txOne, err := registry.NewTransaction(reg, readBag, writeBag) c.Assert(err, IsNil) readErr = errors.New("expected") @@ -269,7 +281,7 @@ func (s *transactionTestSuite) TestTransactionBagReadError(c *C) { c.Assert(err, ErrorMatches, "expected") // NewTransaction()'s databag read fails - txOne, err = registry.NewTransaction(readBag, writeBag, schema) + txOne, err = registry.NewTransaction(reg, readBag, writeBag) c.Assert(err, ErrorMatches, "expected") } @@ -282,8 +294,8 @@ func (s *transactionTestSuite) TestTransactionBagWriteError(c *C) { return writeErr } - schema := registry.NewJSONSchema() - txOne, err := registry.NewTransaction(readBag, writeBag, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + txOne, err := registry.NewTransaction(reg, readBag, writeBag) c.Assert(err, IsNil) writeErr = errors.New("expected") @@ -301,8 +313,8 @@ func (s *transactionTestSuite) TestTransactionReadsIsolated(c *C) { return nil } - schema := registry.NewJSONSchema() - tx, err := registry.NewTransaction(readBag, writeBag, schema) + reg := newRegistry(c, registry.NewJSONSchema()) + tx, err := registry.NewTransaction(reg, readBag, writeBag) c.Assert(err, IsNil) err = databag.Set("foo", "bar") @@ -315,7 +327,8 @@ func (s *transactionTestSuite) TestTransactionReadsIsolated(c *C) { func (s *transactionTestSuite) TestReadDatabagsAreCopiedForIsolation(c *C) { witness := &witnessReadWriter{bag: registry.NewJSONDataBag()} schema := &failingSchema{} - tx, err := registry.NewTransaction(witness.read, witness.write, schema) + reg := newRegistry(c, schema) + tx, err := registry.NewTransaction(reg, witness.read, witness.write) c.Assert(err, IsNil) err = tx.Set("foo", "bar") @@ -342,7 +355,8 @@ func (s *transactionTestSuite) TestReadDatabagsAreCopiedForIsolation(c *C) { func (s *transactionTestSuite) TestUnset(c *C) { witness := &witnessReadWriter{bag: registry.NewJSONDataBag()} - tx, err := registry.NewTransaction(witness.read, witness.write, registry.NewJSONSchema()) + reg := newRegistry(c, registry.NewJSONSchema()) + tx, err := registry.NewTransaction(reg, witness.read, witness.write) c.Assert(err, IsNil) err = tx.Set("foo", "bar")