From 4d2274ace8c05fbc269621422fdf984dee79388f Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Thu, 16 Mar 2023 18:59:43 -0400 Subject: [PATCH 01/16] WIP: handle incoming resolve messages. Still TODO: - We need to handle disembargos with target = (importedCap = ...). - Testing. --- rpc/export.go | 36 +++++++----- rpc/import.go | 21 ++++++- rpc/rpc.go | 110 +++++++++++++++++++++++++++++++------ std/capnp/rpc/exception.go | 21 +++++++ 4 files changed, 153 insertions(+), 35 deletions(-) diff --git a/rpc/export.go b/rpc/export.go index e8f00873..d340cf83 100644 --- a/rpc/export.go +++ b/rpc/export.go @@ -366,9 +366,8 @@ func (e *embargo) Shutdown() { // senderLoopback holds the salient information for a sender-loopback // Disembargo message. type senderLoopback struct { - id embargoID - question questionID - transform []capnp.PipelineOp + id embargoID + target parsedMessageTarget } func (sl *senderLoopback) buildDisembargo(msg rpccp.Message) error { @@ -376,23 +375,30 @@ func (sl *senderLoopback) buildDisembargo(msg rpccp.Message) error { if err != nil { return rpcerr.WrapFailed("build disembargo", err) } + d.Context().SetSenderLoopback(uint32(sl.id)) tgt, err := d.NewTarget() if err != nil { return rpcerr.WrapFailed("build disembargo", err) } - pa, err := tgt.NewPromisedAnswer() - if err != nil { - return rpcerr.WrapFailed("build disembargo", err) - } - oplist, err := pa.NewTransform(int32(len(sl.transform))) - if err != nil { - return rpcerr.WrapFailed("build disembargo", err) - } + switch sl.target.which { + case rpccp.MessageTarget_Which_promisedAnswer: + pa, err := tgt.NewPromisedAnswer() + if err != nil { + return rpcerr.WrapFailed("build disembargo", err) + } + oplist, err := pa.NewTransform(int32(len(sl.target.transform))) + if err != nil { + return rpcerr.WrapFailed("build disembargo", err) + } - d.Context().SetSenderLoopback(uint32(sl.id)) - pa.SetQuestionId(uint32(sl.question)) - for i, op := range sl.transform { - oplist.At(i).SetGetPointerField(op.Field) + pa.SetQuestionId(uint32(sl.target.promisedAnswer)) + for i, op := range sl.target.transform { + oplist.At(i).SetGetPointerField(op.Field) + } + case rpccp.MessageTarget_Which_importedCap: + tgt.SetImportedCap(uint32(sl.target.importedCap)) + default: + return errors.New("unknown variant for MessageTarget: " + str.Utod(sl.target.which)) } return nil } diff --git a/rpc/import.go b/rpc/import.go index 65f327a9..6a6b5bc3 100644 --- a/rpc/import.go +++ b/rpc/import.go @@ -45,6 +45,11 @@ type impent struct { // importClient's generation matches the entry's generation before // removing the entry from the table and sending a release message. generation uint64 + + // If resolver is non-nil, then this is a promise (received as + // CapDescriptor_Which_senderPromise), and when a resolve message + // arrives we should use this to fulfill the promise locally. + resolver capnp.Resolver[capnp.Client] } // addImport returns a client that represents the given import, @@ -52,7 +57,7 @@ type impent struct { // This is separate from the reference counting that capnp.Client does. // // The caller must be holding onto c.mu. -func (c *lockedConn) addImport(id importID) capnp.Client { +func (c *lockedConn) addImport(id importID, isPromise bool) capnp.Client { if ent := c.lk.imports[id]; ent != nil { ent.wireRefs++ client, ok := ent.wc.AddRef() @@ -67,13 +72,23 @@ func (c *lockedConn) addImport(id importID) capnp.Client { } return client } - client := capnp.NewClient(&importClient{ + hook := &importClient{ c: (*Conn)(c), id: id, - }) + } + var ( + client capnp.Client + resolver capnp.Resolver[capnp.Client] + ) + if isPromise { + client, resolver = capnp.NewPromisedClient(hook) + } else { + client = capnp.NewClient(hook) + } c.lk.imports[id] = &impent{ wc: client.WeakRef(), wireRefs: 1, + resolver: resolver, } return client } diff --git a/rpc/rpc.go b/rpc/rpc.go index 3d1e618e..bb63eb56 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -603,7 +603,11 @@ func (c *Conn) receive(ctx context.Context) func() error { return err } - // TODO: handle resolve. + case rpccp.Message_Which_resolve: + if err := c.handleResolve(ctx, in); err != nil { + return err + } + case rpccp.Message_Which_accept, rpccp.Message_Which_provide: if c.network != nil { panic("TODO: 3PH") @@ -1221,9 +1225,12 @@ func (c *lockedConn) parseReturn(dq *deferred.Queue, ret rpccp.Return, called [] embargoCaps.add(uint(i)) disembargoes = append(disembargoes, senderLoopback{ - id: id, - question: questionID(ret.AnswerId()), - transform: xform, + id: id, + target: parsedMessageTarget{ + which: rpccp.MessageTarget_Which_promisedAnswer, + promisedAnswer: answerID(ret.AnswerId()), + transform: xform, + }, }) } return parsedReturn{ @@ -1319,20 +1326,10 @@ func (c *lockedConn) recvCap(d rpccp.CapDescriptor) (capnp.Client, error) { return capnp.Client{}, nil case rpccp.CapDescriptor_Which_senderHosted: id := importID(d.SenderHosted()) - return c.addImport(id), nil + return c.addImport(id, false), nil case rpccp.CapDescriptor_Which_senderPromise: - // We do the same thing as senderHosted, above. @kentonv suggested this on - // issue #2; this lets messages be delivered properly, although it's a bit - // of a hack, and as Kenton describes, it has some disadvantages: - // - // > * Apps sometimes want to wait for promise resolution, and to find out if - // > it resolved to an exception. You won't be able to provide that API. But, - // > usually, it isn't needed. - // > * If the promise resolves to a capability hosted on the receiver, - // > messages sent to it will uselessly round-trip over the network - // > rather than being delivered locally. id := importID(d.SenderPromise()) - return c.addImport(id), nil + return c.addImport(id, true), nil case rpccp.CapDescriptor_Which_thirdPartyHosted: if c.network == nil { // We can't do third-party handoff without a network, so instead of @@ -1346,7 +1343,7 @@ func (c *lockedConn) recvCap(d rpccp.CapDescriptor) (capnp.Client, error) { ) } id := importID(thirdPartyDesc.VineId()) - return c.addImport(id), nil + return c.addImport(id, false), nil } panic("TODO: 3PH") case rpccp.CapDescriptor_Which_receiverHosted: @@ -1720,6 +1717,85 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag return nil } +func (c *Conn) handleResolve(ctx context.Context, in transport.IncomingMessage) error { + dq := &deferred.Queue{} + defer dq.Run() + + resolve, err := in.Message().Resolve() + if err != nil { + in.Release() + c.er.ReportError(exc.WrapError("read resolve", err)) + return nil + } + + promiseID := importID(resolve.PromiseId()) + err = withLockedConn1(c, func(c *lockedConn) error { + imp, ok := c.lk.imports[promiseID] + if !ok { + return errors.New( + "incoming resolve: no such import ID: " + str.Utod(promiseID), + ) + } + if imp.resolver == nil { + return errors.New( + "incoming resolve: import ID " + + str.Utod(promiseID) + + "is not a promise", + ) + } + switch resolve.Which() { + case rpccp.Resolve_Which_cap: + desc, err := resolve.Cap() + if err != nil { + return exc.WrapError("reading cap from resolve message", err) + } + client, err := c.recvCap(desc) + if err != nil { + return err + } + if c.isLocalClient(client) { + var id embargoID + id, client = c.embargo(client) + disembargo := senderLoopback{ + id: id, + target: parsedMessageTarget{ + which: rpccp.MessageTarget_Which_importedCap, + importedCap: exportID(promiseID), + }, + } + c.sendMessage(ctx, disembargo.buildDisembargo, func(err error) { + if err != nil { + c.er.ReportError( + exc.WrapError( + "incoming resolve: send disembargo", + err, + ), + ) + } + }) + } + dq.Defer(func() { + imp.resolver.Fulfill(client) + }) + case rpccp.Resolve_Which_exception: + ex, err := resolve.Exception() + if err != nil { + err = exc.WrapError("reading exception from resolve message", err) + } else { + err = ex.ToError() + } + dq.Defer(func() { + imp.resolver.Reject(err) + }) + } + return nil + }) + if err != nil { + c.er.ReportError(err) + } + return err +} + func (c *Conn) handleUnknownMessageType(ctx context.Context, in transport.IncomingMessage) { err := errors.New("unknown message type " + in.Message().Which().String() + " from remote") c.er.ReportError(err) diff --git a/std/capnp/rpc/exception.go b/std/capnp/rpc/exception.go index 0d80a52d..1272a513 100644 --- a/std/capnp/rpc/exception.go +++ b/std/capnp/rpc/exception.go @@ -8,3 +8,24 @@ func (e Exception) MarshalError(err error) error { e.SetType(Exception_Type(exc.TypeOf(err))) return e.SetReason(err.Error()) } + +// ToError converts the exception to an error. If accessing the reason field +// returns an error, the exception's type field will still be returned by +// exc.TypeOf, but the message will be replaced by something describing the +// read erorr. +func (e Exception) ToError() error { + // TODO: rework this so that exc.Type and Exception_Type + // are aliases somehow. For now we rely on the values being + // identical: + typ := exc.Type(e.Type()) + + reason, err := e.Reason() + if err != nil { + return &exc.Exception{ + Type: typ, + Prefix: "failed to read reason", + Cause: err, + } + } + return exc.New(exc.Type(e.Type()), "", reason) +} From a006092e8759e20aff7be2456e9daa8850dd2ffc Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Sat, 18 Mar 2023 17:04:55 -0400 Subject: [PATCH 02/16] cleanup: remove unnecessary declaration of importClient. The := below is sufficient, since this isn't actually used before that. --- rpc/rpc.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rpc/rpc.go b/rpc/rpc.go index bb63eb56..41799bdc 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1582,11 +1582,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag e.lift() case rpccp.Disembargo_context_Which_senderLoopback: - var ( - imp *importClient - client capnp.Client - ) - + var client capnp.Client c.withLocked(func(c *lockedConn) { if tgt.which != rpccp.MessageTarget_Which_promisedAnswer { err = rpcerr.Failed(errors.New("incoming disembargo: sender loopback: target is not a promised answer")) From ae482350f1ccd6ac4100253653e0b4e599f3cabe Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Sat, 17 Jun 2023 01:19:10 -0400 Subject: [PATCH 03/16] Factor some logic out of handleDisembargo. Also, along the way, use the snapshot from resultsCapTable instead of calling .Snapshot() on the client. --- rpc/rpc.go | 125 +++++++++++++++++++++++++++++------------------------ 1 file changed, 69 insertions(+), 56 deletions(-) diff --git a/rpc/rpc.go b/rpc/rpc.go index 41799bdc..83182ac1 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1582,64 +1582,19 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag e.lift() case rpccp.Disembargo_context_Which_senderLoopback: - var client capnp.Client + var ( + snapshot capnp.ClientSnapshot + err error + ) c.withLocked(func(c *lockedConn) { if tgt.which != rpccp.MessageTarget_Which_promisedAnswer { err = rpcerr.Failed(errors.New("incoming disembargo: sender loopback: target is not a promised answer")) return } - - ans := c.lk.answers[tgt.promisedAnswer] - if ans == nil { - err = rpcerr.Failed(errors.New( - "incoming disembargo: unknown answer ID " + - str.Utod(tgt.promisedAnswer), - )) - return - } - if !ans.flags.Contains(returnSent) { - err = rpcerr.Failed(errors.New( - "incoming disembargo: answer ID " + - str.Utod(tgt.promisedAnswer) + " has not sent return", - )) - return - } - - if ans.err != nil { - err = rpcerr.Failed(errors.New( - "incoming disembargo: answer ID " + - str.Utod(tgt.promisedAnswer) + " returned exception", - )) - return - } - - var content capnp.Ptr - if content, err = ans.returner.results.Content(); err != nil { - err = rpcerr.Failed(errors.New( - "incoming disembargo: read answer ID " + - str.Utod(tgt.promisedAnswer) + ": " + err.Error(), - )) - return - } - - var ptr capnp.Ptr - if ptr, err = capnp.Transform(content, tgt.transform); err != nil { - err = rpcerr.Failed(errors.New( - "incoming disembargo: read answer ID " + - str.Utod(tgt.promisedAnswer) + ": " + err.Error(), - )) - return - } - - iface := ptr.Interface() - if !ans.returner.results.Message().CapTable().Contains(iface) { - err = rpcerr.Failed(errors.New( - "incoming disembargo: sender loopback requested on a capability that is not an import", - )) - return - } - - client = iface.Client() + snapshot, err = c.getAnswerSnapshot( + tgt.promisedAnswer, + tgt.transform, + ) }) if err != nil { @@ -1647,11 +1602,9 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag return err } - snapshot := client.Snapshot() defer snapshot.Release() imp, ok := snapshot.Brand().Value.(*importClient) if !ok || imp.c != c { - client.Release() return rpcerr.Failed(errors.New( "incoming disembargo: sender loopback requested on a capability that is not an import", )) @@ -1679,7 +1632,6 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag }, func(err error) { defer in.Release() - defer client.Release() if err != nil { c.er.ReportError(rpcerr.Annotate(err, "incoming disembargo: send receiver loopback")) @@ -1713,6 +1665,67 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag return nil } +func (c *lockedConn) getAnswerSnapshot( + id answerID, + transform []capnp.PipelineOp, +) (_ capnp.ClientSnapshot, err error) { + ans := c.lk.answers[id] + if ans == nil { + err = rpcerr.Failed(errors.New( + "incoming disembargo: unknown answer ID " + + str.Utod(id))) + return + } + if !ans.flags.Contains(returnSent) { + err = rpcerr.Failed(errors.New( + "incoming disembargo: answer ID " + + str.Utod(id) + " has not sent return", + )) + return + } + + if ans.err != nil { + err = rpcerr.Failed(errors.New( + "incoming disembargo: answer ID " + + str.Utod(id) + " returned exception", + )) + return + } + + var content capnp.Ptr + if content, err = ans.returner.results.Content(); err != nil { + err = rpcerr.Failed(errors.New( + "incoming disembargo: read answer ID " + + str.Utod(id) + ": " + err.Error(), + )) + return + } + + var ptr capnp.Ptr + if ptr, err = capnp.Transform(content, transform); err != nil { + err = rpcerr.Failed(errors.New( + "incoming disembargo: read answer ID " + + str.Utod(id) + ": " + err.Error(), + )) + return + } + + iface := ptr.Interface() + if !ans.returner.results.Message().CapTable().Contains(iface) { + err = rpcerr.Failed(errors.New( + "incoming disembargo: sender loopback requested on a capability that is not an import", + )) + return + } + caps := ans.returner.resultsCapTable + capID := iface.Capability() + if int(capID) >= len(caps) { + return capnp.ClientSnapshot{}, nil + } + + return caps[capID].AddRef(), nil +} + func (c *Conn) handleResolve(ctx context.Context, in transport.IncomingMessage) error { dq := &deferred.Queue{} defer dq.Run() From e96d05466adfb1363a97239dccb904ac89cf39c3 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Sat, 17 Jun 2023 01:26:30 -0400 Subject: [PATCH 04/16] Minor cleanup --- rpc/rpc.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/rpc/rpc.go b/rpc/rpc.go index 83182ac1..d6bd7c18 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1582,19 +1582,19 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag e.lift() case rpccp.Disembargo_context_Which_senderLoopback: - var ( - snapshot capnp.ClientSnapshot - err error - ) - c.withLocked(func(c *lockedConn) { - if tgt.which != rpccp.MessageTarget_Which_promisedAnswer { - err = rpcerr.Failed(errors.New("incoming disembargo: sender loopback: target is not a promised answer")) - return + snapshot, err := withLockedConn2(c, func(c *lockedConn) (capnp.ClientSnapshot, error) { + switch tgt.which { + case rpccp.MessageTarget_Which_promisedAnswer: + return c.getAnswerSnapshot( + tgt.promisedAnswer, + tgt.transform, + ) + case rpccp.MessageTarget_Which_importedCap: + fallthrough + default: + err := rpcerr.Failed(errors.New("incoming disembargo: sender loopback: target is not a promised answer")) + return capnp.ClientSnapshot{}, err } - snapshot, err = c.getAnswerSnapshot( - tgt.promisedAnswer, - tgt.transform, - ) }) if err != nil { From 872202ca39c0872fcca8b32859d4aa9bc9c5fa75 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Sat, 17 Jun 2023 01:32:17 -0400 Subject: [PATCH 05/16] First stab at disembargos on imports. Needs testing. --- rpc/rpc.go | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/rpc/rpc.go b/rpc/rpc.go index d6bd7c18..7d299e96 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1582,7 +1582,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag e.lift() case rpccp.Disembargo_context_Which_senderLoopback: - snapshot, err := withLockedConn2(c, func(c *lockedConn) (capnp.ClientSnapshot, error) { + snapshot, err := withLockedConn2(c, func(c *lockedConn) (_ capnp.ClientSnapshot, err error) { switch tgt.which { case rpccp.MessageTarget_Which_promisedAnswer: return c.getAnswerSnapshot( @@ -1590,10 +1590,23 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag tgt.transform, ) case rpccp.MessageTarget_Which_importedCap: - fallthrough + ent := c.findExport(tgt.importedCap) + if ent == nil { + err = rpcerr.Failed(errors.New("sender loopback: no such export: " + + str.Utod(tgt.importedCap))) + return + } + if !ent.isPromise { + err = rpcerr.Failed(errors.New( + "sender loopback: target export " + + str.Utod(tgt.importedCap) + + " is not a promise")) + return + } + return ent.snapshot.AddRef(), nil default: - err := rpcerr.Failed(errors.New("incoming disembargo: sender loopback: target is not a promised answer")) - return capnp.ClientSnapshot{}, err + err = rpcerr.Failed(errors.New("incoming disembargo: sender loopback: target is not a promised answer")) + return } }) From c3a169d147a10e9eaf9cc0be4c44b3add5197a87 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Sat, 18 Mar 2023 21:20:54 -0400 Subject: [PATCH 06/16] Add a test for disembargos on senderPromises. The test is currently failing; we're getting back an abort message complaining that the export entry is not a promise -- need to investigate. Aside from that failure, this caught a bug: we need to check if ClientState.Metadata is nil, which is possible if the client itself is null, and seems to happen to the promise after it is resolved. TODO: we should de-dup the logic between releaseExport and releaseExports. This is not entirely trivial though, because the latter is executed after we've wiped the exports table. --- rpc/export.go | 8 ++- rpc/rpc.go | 8 ++- rpc/senderpromise_test.go | 134 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 6 deletions(-) diff --git a/rpc/export.go b/rpc/export.go index d340cf83..d0eecd97 100644 --- a/rpc/export.go +++ b/rpc/export.go @@ -74,9 +74,11 @@ func (c *lockedConn) releaseExport(id exportID, count uint32) (capnp.ClientSnaps c.lk.exports[id] = nil c.lk.exportID.remove(uint32(id)) metadata := snapshot.Metadata() - syncutil.With(metadata, func() { - c.clearExportID(metadata) - }) + if metadata != nil { + syncutil.With(metadata, func() { + c.clearExportID(metadata) + }) + } return snapshot, nil case count > ent.wireRefs: return capnp.ClientSnapshot{}, rpcerr.Failed(errors.New("export ID " + str.Utod(id) + " released too many references")) diff --git a/rpc/rpc.go b/rpc/rpc.go index 7d299e96..e9cd1ae2 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -458,9 +458,11 @@ func (c *lockedConn) releaseExports(dq *deferred.Queue, exports []*expent) { for _, e := range exports { if e != nil { metadata := e.snapshot.Metadata() - syncutil.With(metadata, func() { - c.clearExportID(metadata) - }) + if metadata != nil { + syncutil.With(metadata, func() { + c.clearExportID(metadata) + }) + } dq.Defer(e.snapshot.Release) } } diff --git a/rpc/senderpromise_test.go b/rpc/senderpromise_test.go index 664373ad..9be6c439 100644 --- a/rpc/senderpromise_test.go +++ b/rpc/senderpromise_test.go @@ -214,3 +214,137 @@ type emptyShutdowner struct { func (s emptyShutdowner) Shutdown() { close(s.onShutdown) } + +// Tests fulfilling a senderPromise with something hosted on the receiver +func TestDisembargoSenderPromise(t *testing.T) { + t.Parallel() + + ctx := context.Background() + p, r := capnp.NewLocalPromise[capnp.Client]() + + left, right := transport.NewPipe(1) + p1, p2 := rpc.NewTransport(left), rpc.NewTransport(right) + + conn := rpc.NewConn(p1, &rpc.Options{ + ErrorReporter: testErrorReporter{tb: t}, + BootstrapClient: capnp.Client(p), + }) + defer finishTest(t, conn, p2) + + // Send bootstrap. + { + msg := &rpcMessage{ + Which: rpccp.Message_Which_bootstrap, + Bootstrap: &rpcBootstrap{QuestionID: 0}, + } + assert.NoError(t, sendMessage(ctx, p2, msg)) + } + // Receive return. + var theirBootstrapID uint32 + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_return, rmsg.Which) + assert.Equal(t, uint32(0), rmsg.Return.AnswerID) + assert.Equal(t, rpccp.Return_Which_results, rmsg.Return.Which) + assert.Equal(t, 1, len(rmsg.Return.Results.CapTable)) + desc := rmsg.Return.Results.CapTable[0] + assert.Equal(t, rpccp.CapDescriptor_Which_senderPromise, desc.Which) + theirBootstrapID = desc.SenderPromise + } + + // For conveience, we use the other peer's bootstrap interface as the thing + // to resolve to. + bsClient := conn.Bootstrap(ctx) + defer bsClient.Release() + + // Receive bootstrap, send return. + myBootstrapID := uint32(12) + var incomingBSQid uint32 + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_bootstrap, rmsg.Which) + incomingBSQid = rmsg.Bootstrap.QuestionID + + outMsg, err := p2.NewMessage() + assert.NoError(t, err) + iface := capnp.NewInterface(outMsg.Message().Segment(), 0) + + assert.NoError(t, sendMessage(ctx, p2, &rpcMessage{ + Which: rpccp.Message_Which_return, + Return: &rpcReturn{ + AnswerID: incomingBSQid, + Which: rpccp.Return_Which_results, + Results: &rpcPayload{ + Content: iface.ToPtr(), + CapTable: []rpcCapDescriptor{ + { + Which: rpccp.CapDescriptor_Which_senderHosted, + SenderHosted: myBootstrapID, + }, + }, + }, + }, + })) + } + // Accept return + assert.NoError(t, bsClient.Resolve(ctx)) + + // Receive Finish + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_finish, rmsg.Which) + assert.Equal(t, incomingBSQid, rmsg.Finish.QuestionID) + } + + // Resolve bootstrap + r.Fulfill(bsClient) + + // Receive resolve. + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_resolve, rmsg.Which) + assert.Equal(t, theirBootstrapID, rmsg.Resolve.PromiseID) + assert.Equal(t, rpccp.Resolve_Which_cap, rmsg.Resolve.Which) + desc := rmsg.Resolve.Cap + assert.Equal(t, rpccp.CapDescriptor_Which_receiverHosted, desc.Which) + assert.Equal(t, myBootstrapID, desc.ReceiverHosted) + } + // Send disembargo: + embargoID := uint32(7) + { + assert.NoError(t, sendMessage(ctx, p2, &rpcMessage{ + Which: rpccp.Message_Which_disembargo, + Disembargo: &rpcDisembargo{ + Context: rpcDisembargoContext{ + Which: rpccp.Disembargo_context_Which_senderLoopback, + SenderLoopback: embargoID, + }, + Target: rpcMessageTarget{ + Which: rpccp.MessageTarget_Which_importedCap, + ImportedCap: theirBootstrapID, + }, + }, + })) + } + // Receive disembargo: + { + rmsg, release, err := recvMessage(ctx, p2) + assert.NoError(t, err) + defer release() + assert.Equal(t, rpccp.Message_Which_disembargo, rmsg.Which) + d := rmsg.Disembargo + assert.Equal(t, rpccp.Disembargo_context_Which_receiverLoopback, d.Context.Which) + assert.Equal(t, embargoID, d.Context.ReceiverLoopback) + tgt := d.Target + assert.Equal(t, rpccp.MessageTarget_Which_importedCap, tgt.Which) + assert.Equal(t, myBootstrapID, tgt.ImportedCap) + } +} From 78a82e2ff54d2b5db552ba9549be09ab95f36270 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Sat, 18 Mar 2023 21:51:52 -0400 Subject: [PATCH 07/16] Add another test wrt promise resolution. ...also failing (hanging) for now. --- rpc/localpromise_test.go | 22 ++++++------ rpc/rpc.go | 1 + rpc/senderpromise_test.go | 72 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/rpc/localpromise_test.go b/rpc/localpromise_test.go index f29af258..e67cbfaf 100644 --- a/rpc/localpromise_test.go +++ b/rpc/localpromise_test.go @@ -70,6 +70,13 @@ func TestLocalPromiseFulfill(t *testing.T) { assert.Equal(t, int64(3), res3.N()) } +func echoNum(ctx context.Context, pp testcapnp.PingPong, n int64) (testcapnp.PingPong_echoNum_Results_Future, capnp.ReleaseFunc) { + return pp.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error { + p.SetN(n) + return nil + }) +} + func TestLocalPromiseReject(t *testing.T) { t.Parallel() @@ -77,24 +84,15 @@ func TestLocalPromiseReject(t *testing.T) { p, r := capnp.NewLocalPromise[testcapnp.PingPong]() defer p.Release() - fut1, rel1 := p.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error { - p.SetN(1) - return nil - }) + fut1, rel1 := echoNum(ctx, p, 1) defer rel1() - fut2, rel2 := p.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error { - p.SetN(2) - return nil - }) + fut2, rel2 := echoNum(ctx, p, 2) defer rel2() r.Reject(errors.New("Promise rejected")) - fut3, rel3 := p.EchoNum(ctx, func(p testcapnp.PingPong_echoNum_Params) error { - p.SetN(3) - return nil - }) + fut3, rel3 := echoNum(ctx, p, 3) defer rel3() _, err := fut1.Struct() diff --git a/rpc/rpc.go b/rpc/rpc.go index e9cd1ae2..e2965a64 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1800,6 +1800,7 @@ func (c *Conn) handleResolve(ctx context.Context, in transport.IncomingMessage) } dq.Defer(func() { imp.resolver.Fulfill(client) + client.Release() }) case rpccp.Resolve_Which_exception: ex, err := resolve.Exception() diff --git a/rpc/senderpromise_test.go b/rpc/senderpromise_test.go index 9be6c439..1befea58 100644 --- a/rpc/senderpromise_test.go +++ b/rpc/senderpromise_test.go @@ -348,3 +348,75 @@ func TestDisembargoSenderPromise(t *testing.T) { assert.Equal(t, myBootstrapID, tgt.ImportedCap) } } + +// Tests that E-order is respected when fulfilling a promise with something on +// the remote peer. +func TestPromiseOrdering(t *testing.T) { + t.Parallel() + + ctx := context.Background() + p, r := capnp.NewLocalPromise[testcapnp.PingPong]() + defer p.Release() + + left, right := transport.NewPipe(1) + p1, p2 := rpc.NewTransport(left), rpc.NewTransport(right) + + c1 := rpc.NewConn(p1, &rpc.Options{ + ErrorReporter: testErrorReporter{tb: t}, + BootstrapClient: capnp.Client(p), + }) + ord := &echoNumOrderChecker{ + t: t, + } + c2 := rpc.NewConn(p2, &rpc.Options{ + ErrorReporter: testErrorReporter{tb: t}, + BootstrapClient: capnp.Client(testcapnp.PingPong_ServerToClient(ord)), + }) + + remotePromise := testcapnp.PingPong(c2.Bootstrap(ctx)) + defer remotePromise.Release() + + // Send a whole bunch of calls to the promise: + var ( + futures []testcapnp.PingPong_echoNum_Results_Future + rels []capnp.ReleaseFunc + ) + numCalls := 1024 + for i := 0; i < numCalls; i++ { + fut, rel := echoNum(ctx, remotePromise, int64(i)) + futures = append(futures, fut) + rels = append(rels, rel) + + // At some arbitrary point in the middle, fulfill the promise + // with the other bootstrap interface: + if i == 100 { + go func() { + bs := testcapnp.PingPong(c1.Bootstrap(ctx)) + defer bs.Release() + r.Fulfill(bs) + }() + } + } + for i, fut := range futures { + // Verify that all the results are as expected. The server + // Will verify that they came in the right order. + res, err := fut.Struct() + assert.NoError(t, err) + assert.Equal(t, int64(i), res.N()) + } + for _, rel := range rels { + rel() + } + + assert.NoError(t, remotePromise.Resolve(ctx)) + // Shut down the connections, and make sure we can still send + // calls. This ensures that we've successfully shortened the path to + // cut out the remote peer: + c1.Close() + c2.Close() + fut, rel := echoNum(ctx, remotePromise, int64(numCalls)) + defer rel() + res, err := fut.Struct() + assert.NoError(t, err) + assert.Equal(t, int64(numCalls), res.N()) +} From d42eb54105e15515c1fd093f77cab33096de68f3 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Thu, 23 Mar 2023 09:33:29 -0400 Subject: [PATCH 08/16] Correctly send receiverLoopbacks that target promisedAnswers. --- rpc/rpc.go | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/rpc/rpc.go b/rpc/rpc.go index e2965a64..4f4d3714 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1617,36 +1617,54 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag return err } - defer snapshot.Release() - imp, ok := snapshot.Brand().Value.(*importClient) - if !ok || imp.c != c { - return rpcerr.Failed(errors.New( - "incoming disembargo: sender loopback requested on a capability that is not an import", - )) - } - // TODO(maybe): check generation? - // Since this Cap'n Proto RPC implementation does not send imports // unless they are fully dequeued, we can just immediately loop back. id := d.Context().SenderLoopback() + c.withLocked(func(c *lockedConn) { c.sendMessage(ctx, func(m rpccp.Message) error { d, err := m.NewDisembargo() if err != nil { return err } - + d.Context().SetReceiverLoopback(id) tgt, err := d.NewTarget() if err != nil { return err } - tgt.SetImportedCap(uint32(imp.id)) - d.Context().SetReceiverLoopback(id) - return nil + brand := snapshot.Brand() + if pc, ok := brand.Value.(capnp.PipelineClient); ok { + if q, ok := c.getAnswerQuestion(pc.Answer()); ok { + if q.c == (*Conn)(c) { + pa, err := tgt.NewPromisedAnswer() + if err != nil { + return err + } + pa.SetQuestionId(uint32(q.id)) + pcTrans := pc.Transform() + trans, err := pa.NewTransform(int32(len(pcTrans))) + if err != nil { + return err + } + for i, op := range pcTrans { + trans.At(i).SetGetPointerField(op.Field) + } + } + return nil + } + } + + imp, ok := brand.Value.(*importClient) + if ok && imp.c == (*Conn)(c) { + tgt.SetImportedCap(uint32(imp.id)) + return nil + } + return errors.New("target for receiver loopback does not point to the right connection") }, func(err error) { defer in.Release() + defer snapshot.Release() if err != nil { c.er.ReportError(rpcerr.Annotate(err, "incoming disembargo: send receiver loopback")) From dee0565af309c4373e8962da2e8cbf95d80e354c Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Mon, 19 Jun 2023 21:45:43 -0400 Subject: [PATCH 09/16] Get rid of redundant isPromise ...which we were forgetting to set. We can query snapshot for this anyway, so just get rid of it -- single source of truth and all. --- rpc/export.go | 5 ++--- rpc/rpc.go | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/rpc/export.go b/rpc/export.go index d0eecd97..d6b8f10b 100644 --- a/rpc/export.go +++ b/rpc/export.go @@ -17,9 +17,8 @@ type exportID uint32 // expent is an entry in a Conn's export table. type expent struct { - snapshot capnp.ClientSnapshot - wireRefs uint32 - isPromise bool + snapshot capnp.ClientSnapshot + wireRefs uint32 // Should be called when removing this entry from the exports table: cancel context.CancelFunc diff --git a/rpc/rpc.go b/rpc/rpc.go index 4f4d3714..f0ac8097 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1598,7 +1598,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag str.Utod(tgt.importedCap))) return } - if !ent.isPromise { + if !ent.snapshot.IsPromise() { err = rpcerr.Failed(errors.New( "sender loopback: target export " + str.Utod(tgt.importedCap) + From eecd20f16913dd28eb3a5fac1da69ff4f7c4018a Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Mon, 19 Jun 2023 22:34:05 -0400 Subject: [PATCH 10/16] Fix TestDisembargoSenderPromise. When looking at the target of the promise, we want to look at its *resolution.* --- capability.go | 16 ++++++++++++++++ rpc/rpc.go | 21 ++++++++++++++++++--- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/capability.go b/capability.go index 5d14b163..1b2c442e 100644 --- a/capability.go +++ b/capability.go @@ -596,6 +596,22 @@ func (cs ClientSnapshot) IsPromise() bool { return ret } +// IsResolved returns true if the snapshot has resolved to its final value. +// If IsPromise() returns false, then this will also return false. Otherwise, +// it returns false before resolution and true afterwards. +func (cs ClientSnapshot) IsResolved() bool { + if cs.hook == nil { + return false + } + res, ok := cs.hook.Value().resolution.Get() + if !ok { + return false + } + return mutex.With1(res, func(s *resolveState) bool { + return s.isResolved() + }) +} + // Send implements ClientHook.Send func (cs ClientSnapshot) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) { if cs.hook == nil { diff --git a/rpc/rpc.go b/rpc/rpc.go index f0ac8097..14cbbeef 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1605,7 +1605,17 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag " is not a promise")) return } - return ent.snapshot.AddRef(), nil + + if !ent.snapshot.IsResolved() { + err = errors.New("target for receiver loopback is an unresolved promise") + return + } + snapshot := ent.snapshot.AddRef() + err = snapshot.Resolve1(context.Background()) + if err != nil { + panic("error resolving snapshot: " + err.Error()) + } + return snapshot, nil default: err = rpcerr.Failed(errors.New("incoming disembargo: sender loopback: target is not a promised answer")) return @@ -1617,8 +1627,13 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag return err } - // Since this Cap'n Proto RPC implementation does not send imports - // unless they are fully dequeued, we can just immediately loop back. + // FIXME: we're sending the the disembargo right a way, which I(zenhack) + // *think* is fine, and definitely was before we actually did anything + // with promises. But this is contingent on making sure that all of the + // relevant ClientHook implementations queue up their call messages before + // returning from .Recv(); if this invariant holds then this is fine + // because anything ahead of it is aready on the wire. But we need to + // actually check this invariant. id := d.Context().SenderLoopback() c.withLocked(func(c *lockedConn) { From a297d2ea0950f2ebf6d48ed223a2f36d7b765d9b Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Mon, 19 Jun 2023 22:52:31 -0400 Subject: [PATCH 11/16] leak detection: skip nil snapshots. Otherwise we get a runtime error sometimes. --- capability.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/capability.go b/capability.go index 1b2c442e..d44b7942 100644 --- a/capability.go +++ b/capability.go @@ -821,6 +821,9 @@ func SetClientLeakFunc(clientLeakFunc func(msg string)) { clientLeakFunc("leaked client created at:\n\n" + stack) }) case ClientSnapshot: + if !c.IsValid() { + return + } runtime.SetFinalizer(c.hook, func(c *rc.Ref[clientHook]) { if !c.IsValid() { return From 849aa0720f8711c5924453745057f6db16303b5b Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Mon, 19 Jun 2023 23:18:29 -0400 Subject: [PATCH 12/16] senderpromise: only wait for one resolution before sending resolve. --- rpc/export.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rpc/export.go b/rpc/export.go index d6b8f10b..b7d6c9cb 100644 --- a/rpc/export.go +++ b/rpc/export.go @@ -204,7 +204,7 @@ func (c *lockedConn) sendSenderPromise(id exportID, d rpccp.CapDescriptor) { // Conn before trying to use it again: unlockedConn := (*Conn)(c) - waitErr := waitRef.Resolve(ctx) + waitErr := waitRef.Resolve1(ctx) unlockedConn.withLocked(func(c *lockedConn) { if len(c.lk.exports) <= int(id) || c.lk.exports[id] != ee { // Export was removed from the table at some point; From 0d80d1fb96d1d0fb34bc9b8e8266d3302b871fda Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Mon, 19 Jun 2023 23:23:42 -0400 Subject: [PATCH 13/16] Add a bit more debugging info to test failure --- rpc/senderpromise_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rpc/senderpromise_test.go b/rpc/senderpromise_test.go index 1befea58..9e86e7e7 100644 --- a/rpc/senderpromise_test.go +++ b/rpc/senderpromise_test.go @@ -2,6 +2,7 @@ package rpc_test import ( "context" + "fmt" "testing" "capnproto.org/go/capnp/v3" @@ -401,7 +402,7 @@ func TestPromiseOrdering(t *testing.T) { // Verify that all the results are as expected. The server // Will verify that they came in the right order. res, err := fut.Struct() - assert.NoError(t, err) + assert.NoError(t, err, fmt.Sprintf("call #%d should succeed", i)) assert.Equal(t, int64(i), res.N()) } for _, rel := range rels { From f103d94417c492ef6a9da395fad1e385e9d4e39e Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Mon, 19 Jun 2023 23:25:08 -0400 Subject: [PATCH 14/16] TestPromiseOrdering: use require, not assert. ...the latter often just results in me staring at cascading errors, which is unhelpful. --- rpc/senderpromise_test.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/rpc/senderpromise_test.go b/rpc/senderpromise_test.go index 9e86e7e7..cbedabbe 100644 --- a/rpc/senderpromise_test.go +++ b/rpc/senderpromise_test.go @@ -11,6 +11,7 @@ import ( "capnproto.org/go/capnp/v3/rpc/transport" rpccp "capnproto.org/go/capnp/v3/std/capnp/rpc" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSenderPromiseFulfill(t *testing.T) { @@ -402,14 +403,14 @@ func TestPromiseOrdering(t *testing.T) { // Verify that all the results are as expected. The server // Will verify that they came in the right order. res, err := fut.Struct() - assert.NoError(t, err, fmt.Sprintf("call #%d should succeed", i)) - assert.Equal(t, int64(i), res.N()) + require.NoError(t, err, fmt.Sprintf("call #%d should succeed", i)) + require.Equal(t, int64(i), res.N()) } for _, rel := range rels { rel() } - assert.NoError(t, remotePromise.Resolve(ctx)) + require.NoError(t, remotePromise.Resolve(ctx)) // Shut down the connections, and make sure we can still send // calls. This ensures that we've successfully shortened the path to // cut out the remote peer: @@ -418,6 +419,6 @@ func TestPromiseOrdering(t *testing.T) { fut, rel := echoNum(ctx, remotePromise, int64(numCalls)) defer rel() res, err := fut.Struct() - assert.NoError(t, err) - assert.Equal(t, int64(numCalls), res.N()) + require.NoError(t, err) + require.Equal(t, int64(numCalls), res.N()) } From 51facf744f838606044ffe408e57c712a11895c9 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Wed, 29 Mar 2023 23:56:16 -0400 Subject: [PATCH 15/16] Clean up the way local promises work. Push the logic for flushing the answerqueue into the Promise type itself. This is much cleaner, and avoids some racy logic that I'm not sure was correct. --- answer.go | 17 ++++++++++++-- answer_test.go | 12 +++++----- answerqueue.go | 2 +- localpromise.go | 60 +++++++++++-------------------------------------- rpc/answer.go | 2 +- rpc/question.go | 2 +- 6 files changed, 37 insertions(+), 58 deletions(-) diff --git a/answer.go b/answer.go index b70099f0..496823b9 100644 --- a/answer.go +++ b/answer.go @@ -35,6 +35,8 @@ type Promise struct { // - Resolved. Fulfill or Reject has finished. state mutex.Mutex[promiseState] + + resolver Resolver[Ptr] } type promiseState struct { @@ -64,11 +66,13 @@ type clientAndPromise struct { } // NewPromise creates a new unresolved promise. The PipelineCaller will -// be used to make pipelined calls before the promise resolves. -func NewPromise(m Method, pc PipelineCaller) *Promise { +// be used to make pipelined calls before the promise resolves. If resolver +// is not nil, calls to Fulfill will be forwarded to it. +func NewPromise(m Method, pc PipelineCaller, resolver Resolver[Ptr]) *Promise { if pc == nil { panic("NewPromise(nil)") } + resolved := make(chan struct{}) p := &Promise{ method: m, @@ -77,6 +81,7 @@ func NewPromise(m Method, pc PipelineCaller) *Promise { signals: []func(){func() { close(resolved) }}, caller: pc, }), + resolver: resolver, } p.ans.f.promise = p p.ans.metadata = *NewMetadata() @@ -152,6 +157,14 @@ func (p *Promise) Resolve(r Ptr, e error) { return p.clients }) + if p.resolver != nil { + if e == nil { + p.resolver.Fulfill(r) + } else { + p.resolver.Reject(e) + } + } + // Pending resolution state: wait for clients to be fulfilled // and calls to have answers. res := resolution{p.method, r, e} diff --git a/answer_test.go b/answer_test.go index 6f3cea28..0c22da43 100644 --- a/answer_test.go +++ b/answer_test.go @@ -16,7 +16,7 @@ var dummyMethod = Method{ func TestPromiseReject(t *testing.T) { t.Run("Done", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) done := p.Answer().Done() p.Reject(errors.New("omg bbq")) select { @@ -27,7 +27,7 @@ func TestPromiseReject(t *testing.T) { } }) t.Run("Struct", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) defer p.ReleaseClients() ans := p.Answer() p.Reject(errors.New("omg bbq")) @@ -36,7 +36,7 @@ func TestPromiseReject(t *testing.T) { } }) t.Run("Client", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) defer p.ReleaseClients() pc := p.Answer().Field(1, nil).Client() p.Reject(errors.New("omg bbq")) @@ -57,7 +57,7 @@ func TestPromiseFulfill(t *testing.T) { t.Parallel() t.Run("Done", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) done := p.Answer().Done() msg, seg, _ := NewMessage(SingleSegment(nil)) defer msg.Release() @@ -72,7 +72,7 @@ func TestPromiseFulfill(t *testing.T) { } }) t.Run("Struct", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) defer p.ReleaseClients() ans := p.Answer() msg, seg, _ := NewMessage(SingleSegment(nil)) @@ -92,7 +92,7 @@ func TestPromiseFulfill(t *testing.T) { } }) t.Run("Client", func(t *testing.T) { - p := NewPromise(dummyMethod, dummyPipelineCaller{}) + p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil) defer p.ReleaseClients() pc := p.Answer().Field(1, nil).Client() diff --git a/answerqueue.go b/answerqueue.go index 39a48dc6..015d3851 100644 --- a/answerqueue.go +++ b/answerqueue.go @@ -282,7 +282,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea } } } - sr.p = NewPromise(m, pcall) + sr.p = NewPromise(m, pcall, nil) ans := sr.p.Answer() return ans, func() { <-ans.Done() diff --git a/localpromise.go b/localpromise.go index a478c169..fc462cf7 100644 --- a/localpromise.go +++ b/localpromise.go @@ -1,9 +1,5 @@ package capnp -import ( - "context" -) - // ClientHook for a promise that will be resolved to some other capability // at some point. Buffers calls in a queue until the promsie is fulfilled, // then forwards them. @@ -12,59 +8,29 @@ type localPromise struct { } // NewLocalPromise returns a client that will eventually resolve to a capability, -// supplied via the fulfiller. +// supplied via the resolver. func NewLocalPromise[C ~ClientKind]() (C, Resolver[C]) { - lp := newLocalPromise() - p, f := NewPromisedClient(lp) + aq := NewAnswerQueue(Method{}) + f := NewPromise(Method{}, aq, aq) + p := f.Answer().Client().AddRef() return C(p), localResolver[C]{ - lp: lp, - clientResolver: f, + p: f, } } -func newLocalPromise() localPromise { - return localPromise{aq: NewAnswerQueue(Method{})} -} - -func (lp localPromise) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) { - return lp.aq.PipelineSend(ctx, nil, s) -} - -func (lp localPromise) Recv(ctx context.Context, r Recv) PipelineCaller { - return lp.aq.PipelineRecv(ctx, nil, r) -} - -func (lp localPromise) Brand() Brand { - return Brand{} -} - -func (lp localPromise) Shutdown() {} - -func (lp localPromise) String() string { - return "localPromise{...}" -} - -func (lp localPromise) Fulfill(c Client) { - msg, seg := NewSingleSegmentMessage(nil) - capID := msg.CapTable().Add(c) - lp.aq.Fulfill(NewInterface(seg, capID).ToPtr()) -} - -func (lp localPromise) Reject(err error) { - lp.aq.Reject(err) -} - type localResolver[C ~ClientKind] struct { - lp localPromise - clientResolver Resolver[Client] + p *Promise } func (lf localResolver[C]) Fulfill(c C) { - lf.lp.Fulfill(Client(c)) - lf.clientResolver.Fulfill(Client(c)) + msg, seg := NewSingleSegmentMessage(nil) + capID := msg.CapTable().Add(Client(c)) + iface := NewInterface(seg, capID) + lf.p.Fulfill(iface.ToPtr()) + lf.p.ReleaseClients() } func (lf localResolver[C]) Reject(err error) { - lf.lp.Reject(err) - lf.clientResolver.Reject(err) + lf.p.Reject(err) + lf.p.ReleaseClients() } diff --git a/rpc/answer.go b/rpc/answer.go index 3b92418a..b4b1eea7 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -156,7 +156,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ er func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) { if !ans.flags.Contains(resultsReady) { ans.pcall = pcall - ans.promise = capnp.NewPromise(m, pcall) + ans.promise = capnp.NewPromise(m, pcall, nil) } } diff --git a/rpc/question.go b/rpc/question.go index 458ba229..d24aa2bd 100644 --- a/rpc/question.go +++ b/rpc/question.go @@ -55,7 +55,7 @@ func (c *lockedConn) newQuestion(method capnp.Method) *question { release: func() {}, finishMsgSend: make(chan struct{}), } - q.p = capnp.NewPromise(method, q) // TODO(someday): customize error message for bootstrap + q.p = capnp.NewPromise(method, q, nil) // TODO(someday): customize error message for bootstrap c.setAnswerQuestion(q.p.Answer(), q) if int(q.id) == len(c.lk.questions) { c.lk.questions = append(c.lk.questions, q) From 4f8d2d8bc59c6f91475c3143d1738c32f68f444d Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Thu, 22 Jun 2023 02:05:36 -0400 Subject: [PATCH 16/16] localPromise.Fulfill: steal argument --- capability_test.go | 2 +- localpromise.go | 1 + rpc/senderpromise_test.go | 4 +--- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/capability_test.go b/capability_test.go index d7f55bb8..915b2892 100644 --- a/capability_test.go +++ b/capability_test.go @@ -132,7 +132,7 @@ func TestResolve(t *testing.T) { } t.Run("Clients", func(t *testing.T) { test(t, "Waits for the full chain", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) { - r1.Fulfill(p2) + r1.Fulfill(p2.AddRef()) ctx, cancel := context.WithTimeout(context.Background(), time.Second/10) defer cancel() require.NotNil(t, p1.Resolve(ctx), "blocks on second promise") diff --git a/localpromise.go b/localpromise.go index fc462cf7..64bc7250 100644 --- a/localpromise.go +++ b/localpromise.go @@ -28,6 +28,7 @@ func (lf localResolver[C]) Fulfill(c C) { iface := NewInterface(seg, capID) lf.p.Fulfill(iface.ToPtr()) lf.p.ReleaseClients() + msg.Release() } func (lf localResolver[C]) Reject(err error) { diff --git a/rpc/senderpromise_test.go b/rpc/senderpromise_test.go index cbedabbe..aa4174be 100644 --- a/rpc/senderpromise_test.go +++ b/rpc/senderpromise_test.go @@ -393,9 +393,7 @@ func TestPromiseOrdering(t *testing.T) { // with the other bootstrap interface: if i == 100 { go func() { - bs := testcapnp.PingPong(c1.Bootstrap(ctx)) - defer bs.Release() - r.Fulfill(bs) + r.Fulfill(testcapnp.PingPong(c1.Bootstrap(ctx))) }() } }