diff --git a/answer.go b/answer.go index b70099f0..496823b9 100644 --- a/answer.go +++ b/answer.go @@ -35,6 +35,8 @@ type Promise struct { // - Resolved. Fulfill or Reject has finished. state mutex.Mutex[promiseState] + + resolver Resolver[Ptr] } type promiseState struct { @@ -64,11 +66,13 @@ type clientAndPromise struct { } // NewPromise creates a new unresolved promise. The PipelineCaller will -// be used to make pipelined calls before the promise resolves. -func NewPromise(m Method, pc PipelineCaller) *Promise { +// be used to make pipelined calls before the promise resolves. If resolver +// is not nil, calls to Fulfill will be forwarded to it. +func NewPromise(m Method, pc PipelineCaller, resolver Resolver[Ptr]) *Promise { if pc == nil { panic("NewPromise(nil)") } + resolved := make(chan struct{}) p := &Promise{ method: m, @@ -77,6 +81,7 @@ func NewPromise(m Method, pc PipelineCaller) *Promise { signals: []func(){func() { close(resolved) }}, caller: pc, }), + resolver: resolver, } p.ans.f.promise = p p.ans.metadata = *NewMetadata() @@ -152,6 +157,14 @@ func (p *Promise) Resolve(r Ptr, e error) { return p.clients }) + if p.resolver != nil { + if e == nil { + p.resolver.Fulfill(r) + } else { + p.resolver.Reject(e) + } + } + // Pending resolution state: wait for clients to be fulfilled // and calls to have answers. res := resolution{p.method, r, e} diff --git a/answer_test.go b/answer_test.go index 6f3cea28..0c22da43 100644 --- a/answer_test.go +++ b/answer_test.go @@ -16,7 +16,7 @@ var dummyMethod = Method{ func TestPromiseReject(t *testing.T) { t.Run("Done", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) done := p.Answer().Done() p.Reject(errors.New("omg bbq")) select { @@ -27,7 +27,7 @@ func TestPromiseReject(t *testing.T) { } }) t.Run("Struct", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) defer p.ReleaseClients() ans := p.Answer() p.Reject(errors.New("omg bbq")) @@ -36,7 +36,7 @@ func TestPromiseReject(t *testing.T) { } }) t.Run("Client", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) defer p.ReleaseClients() pc := p.Answer().Field(1, nil).Client() p.Reject(errors.New("omg bbq")) @@ -57,7 +57,7 @@ func TestPromiseFulfill(t *testing.T) { t.Parallel() t.Run("Done", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) done := p.Answer().Done() msg, seg, _ := NewMessage(SingleSegment(nil)) defer msg.Release() @@ -72,7 +72,7 @@ func TestPromiseFulfill(t *testing.T) { } }) t.Run("Struct", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) defer p.ReleaseClients() ans := p.Answer() msg, seg, _ := NewMessage(SingleSegment(nil)) @@ -92,7 +92,7 @@ func TestPromiseFulfill(t *testing.T) { } }) t.Run("Client", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) defer p.ReleaseClients() pc := p.Answer().Field(1, nil).Client() diff --git a/answerqueue.go b/answerqueue.go index 39a48dc6..015d3851 100644 --- a/answerqueue.go +++ b/answerqueue.go @@ -282,7 +282,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea } } } - sr.p = NewPromise(m, pcall) + sr.p = NewPromise(m, pcall, nil) ans := sr.p.Answer() return ans, func() { <-ans.Done() diff --git a/capability.go b/capability.go index b8873f90..b2f6c976 100644 --- a/capability.go +++ b/capability.go @@ -603,6 +603,22 @@ func (cs ClientSnapshot) IsPromise() bool { return ret } +// IsResolved returns true if the snapshot has resolved to its final value. +// If IsPromise() returns false, then this will also return false. Otherwise, +// it returns false before resolution and true afterwards. +func (cs ClientSnapshot) IsResolved() bool { + if cs.hook == nil { + return false + } + res, ok := cs.hook.Value().resolution.Get() + if !ok { + return false + } + return mutex.With1(res, func(s *resolveState) bool { + return s.isResolved() + }) +} + // Send implements ClientHook.Send func (cs ClientSnapshot) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) { if cs.hook == nil { @@ -817,6 +833,9 @@ func SetClientLeakFunc(clientLeakFunc func(msg string)) { clientLeakFunc("leaked client created at:\n\n" + stack) }) case ClientSnapshot: + if !c.IsValid() { + return + } runtime.SetFinalizer(c.hook, func(c *rc.Ref[clientHook]) { if !c.IsValid() { return diff --git a/capability_test.go b/capability_test.go index d7f55bb8..915b2892 100644 --- a/capability_test.go +++ b/capability_test.go @@ -132,7 +132,7 @@ func TestResolve(t *testing.T) { } t.Run("Clients", func(t *testing.T) { test(t, "Waits for the full chain", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) { - r1.Fulfill(p2) + r1.Fulfill(p2.AddRef()) ctx, cancel := context.WithTimeout(context.Background(), time.Second/10) defer cancel() require.NotNil(t, p1.Resolve(ctx), "blocks on second promise") diff --git a/localpromise.go b/localpromise.go index a478c169..64bc7250 100644 --- a/localpromise.go +++ b/localpromise.go @@ -1,9 +1,5 @@ package capnp -import ( - "context" -) - // ClientHook for a promise that will be resolved to some other capability // at some point. Buffers calls in a queue until the promsie is fulfilled, // then forwards them. @@ -12,59 +8,30 @@ type localPromise struct { } // NewLocalPromise returns a client that will eventually resolve to a capability, -// supplied via the fulfiller. +// supplied via the resolver. func NewLocalPromise[C ~ClientKind]() (C, Resolver[C]) { - lp := newLocalPromise() - p, f := NewPromisedClient(lp) + aq := NewAnswerQueue(Method{}) + f := NewPromise(Method{}, aq, aq) + p := f.Answer().Client().AddRef() return C(p), localResolver[C]{ - lp: lp, - clientResolver: f, + p: f, } } -func newLocalPromise() localPromise { - return localPromise{aq: NewAnswerQueue(Method{})} -} - -func (lp localPromise) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) { - return lp.aq.PipelineSend(ctx, nil, s) -} - -func (lp localPromise) Recv(ctx context.Context, r Recv) PipelineCaller { - return lp.aq.PipelineRecv(ctx, nil, r) -} - -func (lp localPromise) Brand() Brand { - return Brand{} -} - -func (lp localPromise) Shutdown() {} - -func (lp localPromise) String() string { - return "localPromise{...}" -} - -func (lp localPromise) Fulfill(c Client) { - msg, seg := NewSingleSegmentMessage(nil) - capID := msg.CapTable().Add(c) - lp.aq.Fulfill(NewInterface(seg, capID).ToPtr()) -} - -func (lp localPromise) Reject(err error) { - lp.aq.Reject(err) -} - type localResolver[C ~ClientKind] struct { - lp localPromise - clientResolver Resolver[Client] + p *Promise } func (lf localResolver[C]) Fulfill(c C) { - lf.lp.Fulfill(Client(c)) - lf.clientResolver.Fulfill(Client(c)) + msg, seg := NewSingleSegmentMessage(nil) + capID := msg.CapTable().Add(Client(c)) + iface := NewInterface(seg, capID) + lf.p.Fulfill(iface.ToPtr()) + lf.p.ReleaseClients() + msg.Release() } func (lf localResolver[C]) Reject(err error) { - lf.lp.Reject(err) - lf.clientResolver.Reject(err) + lf.p.Reject(err) + lf.p.ReleaseClients() } diff --git a/rpc/answer.go b/rpc/answer.go index 3b92418a..b4b1eea7 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -156,7 +156,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ er func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) { if !ans.flags.Contains(resultsReady) { ans.pcall = pcall - ans.promise = capnp.NewPromise(m, pcall) + ans.promise = capnp.NewPromise(m, pcall, nil) } } diff --git a/rpc/export.go b/rpc/export.go index 6119880b..2aca3488 100644 --- a/rpc/export.go +++ b/rpc/export.go @@ -17,9 +17,8 @@ type exportID uint32 // expent is an entry in a Conn's export table. type expent struct { - snapshot capnp.ClientSnapshot - wireRefs uint32 - isPromise bool + snapshot capnp.ClientSnapshot + wireRefs uint32 // Should be called when removing this entry from the exports table: cancel context.CancelFunc @@ -74,9 +73,11 @@ func (c *lockedConn) releaseExport(id exportID, count uint32) (capnp.ClientSnaps c.lk.exports[id] = nil c.lk.exportID.remove(id) metadata := snapshot.Metadata() - syncutil.With(metadata, func() { - c.clearExportID(metadata) - }) + if metadata != nil { + syncutil.With(metadata, func() { + c.clearExportID(metadata) + }) + } return snapshot, nil case count > ent.wireRefs: return capnp.ClientSnapshot{}, rpcerr.Failed(errors.New("export ID " + str.Utod(id) + " released too many references")) @@ -203,7 +204,7 @@ func (c *lockedConn) sendSenderPromise(id exportID, d rpccp.CapDescriptor) { // Conn before trying to use it again: unlockedConn := (*Conn)(c) - waitErr := waitRef.Resolve(ctx) + waitErr := waitRef.Resolve1(ctx) unlockedConn.withLocked(func(c *lockedConn) { if len(c.lk.exports) <= int(id) || c.lk.exports[id] != ee { // Export was removed from the table at some point; @@ -366,9 +367,8 @@ func (e *embargo) Shutdown() { // senderLoopback holds the salient information for a sender-loopback // Disembargo message. type senderLoopback struct { - id embargoID - question questionID - transform []capnp.PipelineOp + id embargoID + target parsedMessageTarget } func (sl *senderLoopback) buildDisembargo(msg rpccp.Message) error { @@ -376,23 +376,30 @@ func (sl *senderLoopback) buildDisembargo(msg rpccp.Message) error { if err != nil { return rpcerr.WrapFailed("build disembargo", err) } + d.Context().SetSenderLoopback(uint32(sl.id)) tgt, err := d.NewTarget() if err != nil { return rpcerr.WrapFailed("build disembargo", err) } - pa, err := tgt.NewPromisedAnswer() - if err != nil { - return rpcerr.WrapFailed("build disembargo", err) - } - oplist, err := pa.NewTransform(int32(len(sl.transform))) - if err != nil { - return rpcerr.WrapFailed("build disembargo", err) - } + switch sl.target.which { + case rpccp.MessageTarget_Which_promisedAnswer: + pa, err := tgt.NewPromisedAnswer() + if err != nil { + return rpcerr.WrapFailed("build disembargo", err) + } + oplist, err := pa.NewTransform(int32(len(sl.target.transform))) + if err != nil { + return rpcerr.WrapFailed("build disembargo", err) + } - d.Context().SetSenderLoopback(uint32(sl.id)) - pa.SetQuestionId(uint32(sl.question)) - for i, op := range sl.transform { - oplist.At(i).SetGetPointerField(op.Field) + pa.SetQuestionId(uint32(sl.target.promisedAnswer)) + for i, op := range sl.target.transform { + oplist.At(i).SetGetPointerField(op.Field) + } + case rpccp.MessageTarget_Which_importedCap: + tgt.SetImportedCap(uint32(sl.target.importedCap)) + default: + return errors.New("unknown variant for MessageTarget: " + str.Utod(sl.target.which)) } return nil } diff --git a/rpc/import.go b/rpc/import.go index f098e008..8b8a86e3 100644 --- a/rpc/import.go +++ b/rpc/import.go @@ -45,6 +45,11 @@ type impent struct { // importClient's generation matches the entry's generation before // removing the entry from the table and sending a release message. generation uint64 + + // If resolver is non-nil, then this is a promise (received as + // CapDescriptor_Which_senderPromise), and when a resolve message + // arrives we should use this to fulfill the promise locally. + resolver capnp.Resolver[capnp.Client] } // addImport returns a client that represents the given import, @@ -52,7 +57,7 @@ type impent struct { // This is separate from the reference counting that capnp.Client does. // // The caller must be holding onto c.mu. -func (c *lockedConn) addImport(id importID) capnp.Client { +func (c *lockedConn) addImport(id importID, isPromise bool) capnp.Client { if ent := c.lk.imports[id]; ent != nil { ent.wireRefs++ client, ok := ent.wc.AddRef() @@ -67,13 +72,23 @@ func (c *lockedConn) addImport(id importID) capnp.Client { } return client } - client := capnp.NewClient(&importClient{ + hook := &importClient{ c: (*Conn)(c), id: id, - }) + } + var ( + client capnp.Client + resolver capnp.Resolver[capnp.Client] + ) + if isPromise { + client, resolver = capnp.NewPromisedClient(hook) + } else { + client = capnp.NewClient(hook) + } c.lk.imports[id] = &impent{ wc: client.WeakRef(), wireRefs: 1, + resolver: resolver, } return client } diff --git a/rpc/localpromise_test.go b/rpc/localpromise_test.go index f29af258..e67cbfaf 100644 --- a/rpc/localpromise_test.go +++ b/rpc/localpromise_test.go @@ -70,6 +70,13 @@ func TestLocalPromiseFulfill(t *testing.T) { assert.Equal(t, int64(3), res3.N()) } +func echoNum(ctx context.Context, pp testcapnp.PingPong, n int64) (testcapnp.PingPong_echoNum_Results_Future, capnp.ReleaseFunc) { + return pp.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error { + p.SetN(n) + return nil + }) +} + func TestLocalPromiseReject(t *testing.T) { t.Parallel() @@ -77,24 +84,15 @@ func TestLocalPromiseReject(t *testing.T) { p, r := capnp.NewLocalPromise[testcapnp.PingPong]() defer p.Release() - fut1, rel1 := p.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error { - p.SetN(1) - return nil - }) + fut1, rel1 := echoNum(ctx, p, 1) defer rel1() - fut2, rel2 := p.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error { - p.SetN(2) - return nil - }) + fut2, rel2 := echoNum(ctx, p, 2) defer rel2() r.Reject(errors.New("Promise rejected")) - fut3, rel3 := p.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error { - p.SetN(3) - return nil - }) + fut3, rel3 := echoNum(ctx, p, 3) defer rel3() _, err := fut1.Struct() diff --git a/rpc/question.go b/rpc/question.go index 46953dfb..847973f0 100644 --- a/rpc/question.go +++ b/rpc/question.go @@ -55,7 +55,7 @@ func (c *lockedConn) newQuestion(method capnp.Method) *question { release: func() {}, finishMsgSend: make(chan struct{}), } - q.p = capnp.NewPromise(method, q) // TODO(someday): customize error message for bootstrap + q.p = capnp.NewPromise(method, q, nil) // TODO(someday): customize error message for bootstrap c.setAnswerQuestion(q.p.Answer(), q) if int(q.id) == len(c.lk.questions) { c.lk.questions = append(c.lk.questions, q) diff --git a/rpc/rpc.go b/rpc/rpc.go index d8e4b62e..839db35a 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -458,9 +458,11 @@ func (c *lockedConn) releaseExports(dq *deferred.Queue, exports []*expent) { for _, e := range exports { if e != nil { metadata := e.snapshot.Metadata() - syncutil.With(metadata, func() { - c.clearExportID(metadata) - }) + if metadata != nil { + syncutil.With(metadata, func() { + c.clearExportID(metadata) + }) + } dq.Defer(e.snapshot.Release) } } @@ -603,7 +605,11 @@ func (c *Conn) receive(ctx context.Context) func() error { return err } - // TODO: handle resolve. + case rpccp.Message_Which_resolve: + if err := c.handleResolve(ctx, in); err != nil { + return err + } + case rpccp.Message_Which_accept, rpccp.Message_Which_provide: if c.network != nil { panic("TODO: 3PH") @@ -1221,9 +1227,12 @@ func (c *lockedConn) parseReturn(dq *deferred.Queue, ret rpccp.Return, called [] embargoCaps.add(uint(i)) disembargoes = append(disembargoes, senderLoopback{ - id: id, - question: questionID(ret.AnswerId()), - transform: xform, + id: id, + target: parsedMessageTarget{ + which: rpccp.MessageTarget_Which_promisedAnswer, + promisedAnswer: answerID(ret.AnswerId()), + transform: xform, + }, }) } return parsedReturn{ @@ -1319,20 +1328,10 @@ func (c *lockedConn) recvCap(d rpccp.CapDescriptor) (capnp.Client, error) { return capnp.Client{}, nil case rpccp.CapDescriptor_Which_senderHosted: id := importID(d.SenderHosted()) - return c.addImport(id), nil + return c.addImport(id, false), nil case rpccp.CapDescriptor_Which_senderPromise: - // We do the same thing as senderHosted, above. @kentonv suggested this on - // issue #2; this lets messages be delivered properly, although it's a bit - // of a hack, and as Kenton describes, it has some disadvantages: - // - // > * Apps sometimes want to wait for promise resolution, and to find out if - // > it resolved to an exception. You won't be able to provide that API. But, - // > usually, it isn't needed. - // > * If the promise resolves to a capability hosted on the receiver, - // > messages sent to it will uselessly round-trip over the network - // > rather than being delivered locally. id := importID(d.SenderPromise()) - return c.addImport(id), nil + return c.addImport(id, true), nil case rpccp.CapDescriptor_Which_thirdPartyHosted: if c.network == nil { // We can't do third-party handoff without a network, so instead of @@ -1346,7 +1345,7 @@ func (c *lockedConn) recvCap(d rpccp.CapDescriptor) (capnp.Client, error) { ) } id := importID(thirdPartyDesc.VineId()) - return c.addImport(id), nil + return c.addImport(id, false), nil } panic("TODO: 3PH") case rpccp.CapDescriptor_Which_receiverHosted: @@ -1585,68 +1584,42 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag e.lift() case rpccp.Disembargo_context_Which_senderLoopback: - var ( - imp *importClient - client capnp.Client - ) + snapshot, err := withLockedConn2(c, func(c *lockedConn) (_ capnp.ClientSnapshot, err error) { + switch tgt.which { + case rpccp.MessageTarget_Which_promisedAnswer: + return c.getAnswerSnapshot( + tgt.promisedAnswer, + tgt.transform, + ) + case rpccp.MessageTarget_Which_importedCap: + ent := c.findExport(tgt.importedCap) + if ent == nil { + err = rpcerr.Failed(errors.New("sender loopback: no such export: " + + str.Utod(tgt.importedCap))) + return + } + if !ent.snapshot.IsPromise() { + err = rpcerr.Failed(errors.New( + "sender loopback: target export " + + str.Utod(tgt.importedCap) + + " is not a promise")) + return + } - c.withLocked(func(c *lockedConn) { - if tgt.which != rpccp.MessageTarget_Which_promisedAnswer { + if !ent.snapshot.IsResolved() { + err = errors.New("target for receiver loopback is an unresolved promise") + return + } + snapshot := ent.snapshot.AddRef() + err = snapshot.Resolve1(context.Background()) + if err != nil { + panic("error resolving snapshot: " + err.Error()) + } + return snapshot, nil + default: err = rpcerr.Failed(errors.New("incoming disembargo: sender loopback: target is not a promised answer")) return } - - ans := c.lk.answers[tgt.promisedAnswer] - if ans == nil { - err = rpcerr.Failed(errors.New( - "incoming disembargo: unknown answer ID " + - str.Utod(tgt.promisedAnswer), - )) - return - } - if !ans.flags.Contains(returnSent) { - err = rpcerr.Failed(errors.New( - "incoming disembargo: answer ID " + - str.Utod(tgt.promisedAnswer) + " has not sent return", - )) - return - } - - if ans.err != nil { - err = rpcerr.Failed(errors.New( - "incoming disembargo: answer ID " + - str.Utod(tgt.promisedAnswer) + " returned exception", - )) - return - } - - var content capnp.Ptr - if content, err = ans.returner.results.Content(); err != nil { - err = rpcerr.Failed(errors.New( - "incoming disembargo: read answer ID " + - str.Utod(tgt.promisedAnswer) + ": " + err.Error(), - )) - return - } - - var ptr capnp.Ptr - if ptr, err = capnp.Transform(content, tgt.transform); err != nil { - err = rpcerr.Failed(errors.New( - "incoming disembargo: read answer ID " + - str.Utod(tgt.promisedAnswer) + ": " + err.Error(), - )) - return - } - - iface := ptr.Interface() - if !ans.returner.results.Message().CapTable().Contains(iface) { - err = rpcerr.Failed(errors.New( - "incoming disembargo: sender loopback requested on a capability that is not an import", - )) - return - } - - client = iface.Client() }) if err != nil { @@ -1654,39 +1627,59 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag return err } - snapshot := client.Snapshot() - defer snapshot.Release() - imp, ok := snapshot.Brand().Value.(*importClient) - if !ok || imp.c != c { - client.Release() - return rpcerr.Failed(errors.New( - "incoming disembargo: sender loopback requested on a capability that is not an import", - )) - } - // TODO(maybe): check generation? - - // Since this Cap'n Proto RPC implementation does not send imports - // unless they are fully dequeued, we can just immediately loop back. + // FIXME: we're sending the the disembargo right a way, which I(zenhack) + // *think* is fine, and definitely was before we actually did anything + // with promises. But this is contingent on making sure that all of the + // relevant ClientHook implementations queue up their call messages before + // returning from .Recv(); if this invariant holds then this is fine + // because anything ahead of it is aready on the wire. But we need to + // actually check this invariant. id := d.Context().SenderLoopback() + c.withLocked(func(c *lockedConn) { c.sendMessage(ctx, func(m rpccp.Message) error { d, err := m.NewDisembargo() if err != nil { return err } - + d.Context().SetReceiverLoopback(id) tgt, err := d.NewTarget() if err != nil { return err } - tgt.SetImportedCap(uint32(imp.id)) - d.Context().SetReceiverLoopback(id) - return nil + brand := snapshot.Brand() + if pc, ok := brand.Value.(capnp.PipelineClient); ok { + if q, ok := c.getAnswerQuestion(pc.Answer()); ok { + if q.c == (*Conn)(c) { + pa, err := tgt.NewPromisedAnswer() + if err != nil { + return err + } + pa.SetQuestionId(uint32(q.id)) + pcTrans := pc.Transform() + trans, err := pa.NewTransform(int32(len(pcTrans))) + if err != nil { + return err + } + for i, op := range pcTrans { + trans.At(i).SetGetPointerField(op.Field) + } + } + return nil + } + } + + imp, ok := brand.Value.(*importClient) + if ok && imp.c == (*Conn)(c) { + tgt.SetImportedCap(uint32(imp.id)) + return nil + } + return errors.New("target for receiver loopback does not point to the right connection") }, func(err error) { defer in.Release() - defer client.Release() + defer snapshot.Release() if err != nil { c.er.ReportError(rpcerr.Annotate(err, "incoming disembargo: send receiver loopback")) @@ -1720,6 +1713,147 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag return nil } +func (c *lockedConn) getAnswerSnapshot( + id answerID, + transform []capnp.PipelineOp, +) (_ capnp.ClientSnapshot, err error) { + ans := c.lk.answers[id] + if ans == nil { + err = rpcerr.Failed(errors.New( + "incoming disembargo: unknown answer ID " + + str.Utod(id))) + return + } + if !ans.flags.Contains(returnSent) { + err = rpcerr.Failed(errors.New( + "incoming disembargo: answer ID " + + str.Utod(id) + " has not sent return", + )) + return + } + + if ans.err != nil { + err = rpcerr.Failed(errors.New( + "incoming disembargo: answer ID " + + str.Utod(id) + " returned exception", + )) + return + } + + var content capnp.Ptr + if content, err = ans.returner.results.Content(); err != nil { + err = rpcerr.Failed(errors.New( + "incoming disembargo: read answer ID " + + str.Utod(id) + ": " + err.Error(), + )) + return + } + + var ptr capnp.Ptr + if ptr, err = capnp.Transform(content, transform); err != nil { + err = rpcerr.Failed(errors.New( + "incoming disembargo: read answer ID " + + str.Utod(id) + ": " + err.Error(), + )) + return + } + + iface := ptr.Interface() + if !ans.returner.results.Message().CapTable().Contains(iface) { + err = rpcerr.Failed(errors.New( + "incoming disembargo: sender loopback requested on a capability that is not an import", + )) + return + } + caps := ans.returner.resultsCapTable + capID := iface.Capability() + if int(capID) >= len(caps) { + return capnp.ClientSnapshot{}, nil + } + + return caps[capID].AddRef(), nil +} + +func (c *Conn) handleResolve(ctx context.Context, in transport.IncomingMessage) error { + dq := &deferred.Queue{} + defer dq.Run() + + resolve, err := in.Message().Resolve() + if err != nil { + in.Release() + c.er.ReportError(exc.WrapError("read resolve", err)) + return nil + } + + promiseID := importID(resolve.PromiseId()) + err = withLockedConn1(c, func(c *lockedConn) error { + imp, ok := c.lk.imports[promiseID] + if !ok { + return errors.New( + "incoming resolve: no such import ID: " + str.Utod(promiseID), + ) + } + if imp.resolver == nil { + return errors.New( + "incoming resolve: import ID " + + str.Utod(promiseID) + + "is not a promise", + ) + } + switch resolve.Which() { + case rpccp.Resolve_Which_cap: + desc, err := resolve.Cap() + if err != nil { + return exc.WrapError("reading cap from resolve message", err) + } + client, err := c.recvCap(desc) + if err != nil { + return err + } + if c.isLocalClient(client) { + var id embargoID + id, client = c.embargo(client) + disembargo := senderLoopback{ + id: id, + target: parsedMessageTarget{ + which: rpccp.MessageTarget_Which_importedCap, + importedCap: exportID(promiseID), + }, + } + c.sendMessage(ctx, disembargo.buildDisembargo, func(err error) { + if err != nil { + c.er.ReportError( + exc.WrapError( + "incoming resolve: send disembargo", + err, + ), + ) + } + }) + } + dq.Defer(func() { + imp.resolver.Fulfill(client) + client.Release() + }) + case rpccp.Resolve_Which_exception: + ex, err := resolve.Exception() + if err != nil { + err = exc.WrapError("reading exception from resolve message", err) + } else { + err = ex.ToError() + } + dq.Defer(func() { + imp.resolver.Reject(err) + }) + } + return nil + }) + if err != nil { + c.er.ReportError(err) + } + return err +} + func (c *Conn) handleUnknownMessageType(ctx context.Context, in transport.IncomingMessage) { err := errors.New("unknown message type " + in.Message().Which().String() + " from remote") c.er.ReportError(err) diff --git a/rpc/senderpromise_test.go b/rpc/senderpromise_test.go index 664373ad..aa4174be 100644 --- a/rpc/senderpromise_test.go +++ b/rpc/senderpromise_test.go @@ -2,6 +2,7 @@ package rpc_test import ( "context" + "fmt" "testing" "capnproto.org/go/capnp/v3" @@ -10,6 +11,7 @@ import ( "capnproto.org/go/capnp/v3/rpc/transport" rpccp "capnproto.org/go/capnp/v3/std/capnp/rpc" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSenderPromiseFulfill(t *testing.T) { @@ -214,3 +216,207 @@ type emptyShutdowner struct { func (s emptyShutdowner) Shutdown() { close(s.onShutdown) } + +// Tests fulfilling a senderPromise with something hosted on the receiver +func TestDisembargoSenderPromise(t *testing.T) { + t.Parallel() + + ctx := context.Background() + p, r := capnp.NewLocalPromise[capnp.Client]() + + left, right := transport.NewPipe(1) + p1, p2 := rpc.NewTransport(left), rpc.NewTransport(right) + + conn := rpc.NewConn(p1, &rpc.Options{ + ErrorReporter: testErrorReporter{tb: t}, + BootstrapClient: capnp.Client(p), + }) + defer finishTest(t, conn, p2) + + // Send bootstrap. + { + msg := &rpcMessage{ + Which: rpccp.Message_Which_bootstrap, + Bootstrap: &rpcBootstrap{QuestionID: 0}, + } + assert.NoError(t, sendMessage(ctx, p2, msg)) + } + // Receive return. + var theirBootstrapID uint32 + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_return, rmsg.Which) + assert.Equal(t, uint32(0), rmsg.Return.AnswerID) + assert.Equal(t, rpccp.Return_Which_results, rmsg.Return.Which) + assert.Equal(t, 1, len(rmsg.Return.Results.CapTable)) + desc := rmsg.Return.Results.CapTable[0] + assert.Equal(t, rpccp.CapDescriptor_Which_senderPromise, desc.Which) + theirBootstrapID = desc.SenderPromise + } + + // For conveience, we use the other peer's bootstrap interface as the thing + // to resolve to. + bsClient := conn.Bootstrap(ctx) + defer bsClient.Release() + + // Receive bootstrap, send return. + myBootstrapID := uint32(12) + var incomingBSQid uint32 + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_bootstrap, rmsg.Which) + incomingBSQid = rmsg.Bootstrap.QuestionID + + outMsg, err := p2.NewMessage() + assert.NoError(t, err) + iface := capnp.NewInterface(outMsg.Message().Segment(), 0) + + assert.NoError(t, sendMessage(ctx, p2, &rpcMessage{ + Which: rpccp.Message_Which_return, + Return: &rpcReturn{ + AnswerID: incomingBSQid, + Which: rpccp.Return_Which_results, + Results: &rpcPayload{ + Content: iface.ToPtr(), + CapTable: []rpcCapDescriptor{ + { + Which: rpccp.CapDescriptor_Which_senderHosted, + SenderHosted: myBootstrapID, + }, + }, + }, + }, + })) + } + // Accept return + assert.NoError(t, bsClient.Resolve(ctx)) + + // Receive Finish + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_finish, rmsg.Which) + assert.Equal(t, incomingBSQid, rmsg.Finish.QuestionID) + } + + // Resolve bootstrap + r.Fulfill(bsClient) + + // Receive resolve. + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_resolve, rmsg.Which) + assert.Equal(t, theirBootstrapID, rmsg.Resolve.PromiseID) + assert.Equal(t, rpccp.Resolve_Which_cap, rmsg.Resolve.Which) + desc := rmsg.Resolve.Cap + assert.Equal(t, rpccp.CapDescriptor_Which_receiverHosted, desc.Which) + assert.Equal(t, myBootstrapID, desc.ReceiverHosted) + } + // Send disembargo: + embargoID := uint32(7) + { + assert.NoError(t, sendMessage(ctx, p2, &rpcMessage{ + Which: rpccp.Message_Which_disembargo, + Disembargo: &rpcDisembargo{ + Context: rpcDisembargoContext{ + Which: rpccp.Disembargo_context_Which_senderLoopback, + SenderLoopback: embargoID, + }, + Target: rpcMessageTarget{ + Which: rpccp.MessageTarget_Which_importedCap, + ImportedCap: theirBootstrapID, + }, + }, + })) + } + // Receive disembargo: + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_disembargo, rmsg.Which) + d := rmsg.Disembargo + assert.Equal(t, rpccp.Disembargo_context_Which_receiverLoopback, d.Context.Which) + assert.Equal(t, embargoID, d.Context.ReceiverLoopback) + tgt := d.Target + assert.Equal(t, rpccp.MessageTarget_Which_importedCap, tgt.Which) + assert.Equal(t, myBootstrapID, tgt.ImportedCap) + } +} + +// Tests that E-order is respected when fulfilling a promise with something on +// the remote peer. +func TestPromiseOrdering(t *testing.T) { + t.Parallel() + + ctx := context.Background() + p, r := capnp.NewLocalPromise[testcapnp.PingPong]() + defer p.Release() + + left, right := transport.NewPipe(1) + p1, p2 := rpc.NewTransport(left), rpc.NewTransport(right) + + c1 := rpc.NewConn(p1, &rpc.Options{ + ErrorReporter: testErrorReporter{tb: t}, + BootstrapClient: capnp.Client(p), + }) + ord := &echoNumOrderChecker{ + t: t, + } + c2 := rpc.NewConn(p2, &rpc.Options{ + ErrorReporter: testErrorReporter{tb: t}, + BootstrapClient: capnp.Client(testcapnp.PingPong_ServerToClient(ord)), + }) + + remotePromise := testcapnp.PingPong(c2.Bootstrap(ctx)) + defer remotePromise.Release() + + // Send a whole bunch of calls to the promise: + var ( + futures []testcapnp.PingPong_echoNum_Results_Future + rels []capnp.ReleaseFunc + ) + numCalls := 1024 + for i := 0; i < numCalls; i++ { + fut, rel := echoNum(ctx, remotePromise, int64(i)) + futures = append(futures, fut) + rels = append(rels, rel) + + // At some arbitrary point in the middle, fulfill the promise + // with the other bootstrap interface: + if i == 100 { + go func() { + r.Fulfill(testcapnp.PingPong(c1.Bootstrap(ctx))) + }() + } + } + for i, fut := range futures { + // Verify that all the results are as expected. The server + // Will verify that they came in the right order. + res, err := fut.Struct() + require.NoError(t, err, fmt.Sprintf("call #%d should succeed", i)) + require.Equal(t, int64(i), res.N()) + } + for _, rel := range rels { + rel() + } + + require.NoError(t, remotePromise.Resolve(ctx)) + // Shut down the connections, and make sure we can still send + // calls. This ensures that we've successfully shortened the path to + // cut out the remote peer: + c1.Close() + c2.Close() + fut, rel := echoNum(ctx, remotePromise, int64(numCalls)) + defer rel() + res, err := fut.Struct() + require.NoError(t, err) + require.Equal(t, int64(numCalls), res.N()) +} diff --git a/std/capnp/rpc/exception.go b/std/capnp/rpc/exception.go index 0d80a52d..1272a513 100644 --- a/std/capnp/rpc/exception.go +++ b/std/capnp/rpc/exception.go @@ -8,3 +8,24 @@ func (e Exception) MarshalError(err error) error { e.SetType(Exception_Type(exc.TypeOf(err))) return e.SetReason(err.Error()) } + +// ToError converts the exception to an error. If accessing the reason field +// returns an error, the exception's type field will still be returned by +// exc.TypeOf, but the message will be replaced by something describing the +// read erorr. +func (e Exception) ToError() error { + // TODO: rework this so that exc.Type and Exception_Type + // are aliases somehow. For now we rely on the values being + // identical: + typ := exc.Type(e.Type()) + + reason, err := e.Reason() + if err != nil { + return &exc.Exception{ + Type: typ, + Prefix: "failed to read reason", + Cause: err, + } + } + return exc.New(exc.Type(e.Type()), "", reason) +}