diff --git a/rpc/rpc.go b/rpc/rpc.go index 72805baa..be8fc746 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -1389,20 +1389,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, d rpccp.Disembargo, release return } - client := iface.Client() - - var ok bool - syncutil.Without(&c.lk, func() { - imp, ok = client.State().Brand.Value.(*importClient) - }) - - if !ok || imp.c != c { - client.Release() - err = rpcerr.Failedf("incoming disembargo: sender loopback requested on a capability that is not an import") - return - } - - // TODO(maybe): check generation? + client = iface.Client() }) if err != nil { @@ -1410,43 +1397,52 @@ func (c *Conn) handleDisembargo(ctx context.Context, d rpccp.Disembargo, release return err } + imp, ok := client.State().Brand.Value.(*importClient) + if !ok || imp.c != c { + client.Release() + return rpcerr.Failedf("incoming disembargo: sender loopback requested on a capability that is not an import") + } + // TODO(maybe): check generation? + // Since this Cap'n Proto RPC implementation does not send imports // unless they are fully dequeued, we can just immediately loop back. id := d.Context().SenderLoopback() - c.sendMessage(ctx, func(m rpccp.Message) error { - defer release() - defer client.Release() - - d, err := m.NewDisembargo() - if err != nil { - return err - } + syncutil.With(&c.lk, func() { + c.sendMessage(ctx, func(m rpccp.Message) error { + d, err := m.NewDisembargo() + if err != nil { + return err + } - tgt, err := d.NewTarget() - if err != nil { - return err - } + tgt, err := d.NewTarget() + if err != nil { + return err + } - tgt.SetImportedCap(uint32(imp.id)) - d.Context().SetReceiverLoopback(id) - return nil + tgt.SetImportedCap(uint32(imp.id)) + d.Context().SetReceiverLoopback(id) + return nil - }, func(err error) { - c.er.ReportError(rpcerr.Annotatef(err, "incoming disembargo: send receiver loopback")) + }, func(err error) { + defer release() + defer client.Release() + c.er.ReportError(rpcerr.Annotatef(err, "incoming disembargo: send receiver loopback")) + }) }) default: c.er.ReportError(fmt.Errorf("incoming disembargo: context %v not implemented", d.Context().Which())) - c.sendMessage(ctx, func(m rpccp.Message) (err error) { - defer release() - - if m, err = m.NewUnimplemented(); err == nil { - err = m.SetDisembargo(d) - } + syncutil.With(&c.lk, func() { + c.sendMessage(ctx, func(m rpccp.Message) (err error) { + if m, err = m.NewUnimplemented(); err == nil { + err = m.SetDisembargo(d) + } - return - }, func(err error) { - c.er.ReportError(rpcerr.Annotate(err, "incoming disembargo: send unimplemented")) + return + }, func(err error) { + defer release() + c.er.ReportError(rpcerr.Annotate(err, "incoming disembargo: send unimplemented")) + }) }) }