Skip to content

Commit

Permalink
Merge pull request #489 from capnproto/cleanup/message-reset
Browse files Browse the repository at this point in the history
Clean up Message.Reset
  • Loading branch information
zenhack authored Mar 28, 2023
2 parents 3926d66 + 098f255 commit 1a829fd
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 181 deletions.
11 changes: 8 additions & 3 deletions answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ func TestPromiseReject(t *testing.T) {
}

func TestPromiseFulfill(t *testing.T) {
t.Parallel()

t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
done := p.Answer().Done()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Reset(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
p.Fulfill(res.ToPtr())
select {
Expand All @@ -73,7 +76,8 @@ func TestPromiseFulfill(t *testing.T) {
defer p.ReleaseClients()
ans := p.Answer()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Reset(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
res.SetUint32(0, 0xdeadbeef)
p.Fulfill(res.ToPtr())
Expand All @@ -96,7 +100,8 @@ func TestPromiseFulfill(t *testing.T) {
c := NewClient(h)
defer c.Release()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Reset(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{PointerCount: 3})
res.SetPtr(1, NewInterface(seg, msg.AddCap(c.AddRef())).ToPtr())

Expand Down
30 changes: 14 additions & 16 deletions answerqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import (
//
// An AnswerQueue can be in one of three states:
//
// 1) Queueing. Incoming method calls will be added to the queue.
// 2) Draining, entered by calling Fulfill or Reject. Queued method
// calls will be delivered in sequence, and new incoming method calls
// will block until the AnswerQueue enters the Drained state.
// 3) Drained, entered once all queued methods have been delivered.
// Incoming methods are passthrough.
// 1. Queueing. Incoming method calls will be added to the queue.
// 2. Draining, entered by calling Fulfill or Reject. Queued method
// calls will be delivered in sequence, and new incoming method calls
// will block until the AnswerQueue enters the Drained state.
// 3. Drained, entered once all queued methods have been delivered.
// Incoming methods are passthrough.
type AnswerQueue struct {
method Method
draining chan struct{} // closed while exiting queueing state
Expand Down Expand Up @@ -154,8 +154,9 @@ func (qc queueCaller) PipelineRecv(ctx context.Context, transform []PipelineOp,
func (qc queueCaller) PipelineSend(ctx context.Context, transform []PipelineOp, s Send) (*Answer, ReleaseFunc) {
ret := new(StructReturner)
r := Recv{
Method: s.Method,
Returner: ret,
Method: s.Method,
Returner: ret,
ReleaseArgs: func() {},
}
if s.PlaceArgs != nil {
var err error
Expand All @@ -167,12 +168,9 @@ func (qc queueCaller) PipelineSend(ctx context.Context, transform []PipelineOp,
if err = s.PlaceArgs(r.Args); err != nil {
return ErrorAnswer(s.Method, err), func() {}
}
r.ReleaseArgs = func() {
r.Args.Message().Reset(nil)
}
} else {
r.ReleaseArgs = func() {}
r.ReleaseArgs = r.Args.Message().Release
}

pcall := qc.PipelineRecv(ctx, transform, r)
return ret.Answer(s.Method, pcall)
}
Expand Down Expand Up @@ -258,7 +256,7 @@ func (sr *StructReturner) ReleaseResults() {
return
}
if err != nil && msg != nil {
msg.Reset(nil)
msg.Release()
}
}

Expand All @@ -280,7 +278,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
sr.result = Struct{}
sr.mu.Unlock()
if msg != nil {
msg.Reset(nil)
msg.Release()
}
}
}
Expand All @@ -294,7 +292,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
sr.mu.Unlock()
sr.p.ReleaseClients()
if msg != nil {
msg.Reset(nil)
msg.Release()
}
}
}
15 changes: 13 additions & 2 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,16 @@ func NewMultiSegmentMessage(b [][]byte) (msg *Message, first *Segment) {
return msg, first
}

// Release is syntactic sugar for Message.Reset(nil). See
// docstring for Reset for an important warning.
func (m *Message) Release() {
m.Reset(nil)
}

// Reset the message to use a different arena, allowing it
// to be reused. This invalidates any existing pointers in
// the Message, and releases all clients in the cap table,
// so use with caution.
// the Message, releases all clients in the cap table, and
// releases the current Arena, so use with caution.
func (m *Message) Reset(arena Arena) (first *Segment, err error) {
for _, c := range m.CapTable {
c.Release()
Expand All @@ -107,10 +113,15 @@ func (m *Message) Reset(arena Arena) (first *Segment, err error) {
delete(m.segs, k)
}

if m.Arena != nil {
m.Arena.Release()
}

*m = Message{
Arena: arena,
TraverseLimit: m.TraverseLimit,
DepthLimit: m.DepthLimit,
CapTable: m.CapTable[:0],
segs: m.segs,
}

Expand Down
17 changes: 7 additions & 10 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ func (r *Request) Future() *Future {

// Release resources associated with the request. In particular:
//
// * Release the arguments if they have not yet been released.
// * If the request has been sent, wait for the result and release
// the results.
// - Release the arguments if they have not yet been released.
// - If the request has been sent, wait for the result and release
// the results.
func (r *Request) Release() {
r.releaseArgs()
rel := r.releaseResponse
Expand All @@ -91,12 +91,9 @@ func (r *Request) Release() {
}

func (r *Request) releaseArgs() {
if r.args.IsValid() {
return
if !r.args.IsValid() {
msg := r.args.Message()
r.args = Struct{}
msg.Release()
}
msg := r.args.Message()
r.args = Struct{}
arena := msg.Arena
msg.Reset(nil)
arena.Release()
}
2 changes: 1 addition & 1 deletion rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel
if err != nil {
return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err)
}
ret, err := outMsg.Message.NewReturn()
ret, err := outMsg.Message().NewReturn()
if err != nil {
outMsg.Release()
return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err)
Expand Down
36 changes: 23 additions & 13 deletions rpc/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,36 @@ func (t *measuringTransport) RecvMessage() (transport.IncomingMessage, error) {
return inMsg, err
}

size, err := capnp.Struct(inMsg.Message).Message().TotalSize()
size, err := inMsg.Message().Message().TotalSize()
if err != nil {
return inMsg, err
}

t.mu.Lock()
t.inUse += size
if t.inUse > t.maxInUse {
defer t.mu.Unlock()

if t.inUse += size; t.inUse > t.maxInUse {
t.maxInUse = t.inUse
}
t.mu.Unlock()

oldRelease := inMsg.Release
inMsg.Release = capnp.ReleaseFunc(func() {
oldRelease()
t.mu.Lock()
defer t.mu.Unlock()
t.inUse -= size
})
return inMsg, err

return releaseHook{
t: t,
IncomingMessage: inMsg,
}, nil
}

type releaseHook struct {
t *measuringTransport
size uint64
transport.IncomingMessage
}

func (rh releaseHook) Release() {
rh.IncomingMessage.Release()

rh.t.mu.Lock()
rh.t.inUse -= rh.size
rh.t.mu.Lock()
}

// Test that attaching a fixed-size FlowLimiter results in actually limiting the
Expand Down
42 changes: 21 additions & 21 deletions rpc/level0_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ func TestSendBootstrapCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -416,12 +416,12 @@ func TestSendBootstrapCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
resp, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
resp, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
resp.SetUint64(0, 0xdeadbeef)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -530,8 +530,8 @@ func TestSendBootstrapCallException(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -755,12 +755,12 @@ func TestSendBootstrapPipelineCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
resp, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
resp, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
resp.SetUint64(0, 0xdeadbeef)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -963,12 +963,12 @@ func TestRecvBootstrapCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
params.SetUint32(0, 0x2a2b)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_call,
Call: &rpcCall{
QuestionID: callQID,
Expand Down Expand Up @@ -1114,12 +1114,12 @@ func TestRecvBootstrapCallException(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
params.SetUint32(0, 0x2a2b)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_call,
Call: &rpcCall{
QuestionID: callQID,
Expand Down Expand Up @@ -1258,12 +1258,12 @@ func TestRecvBootstrapPipelineCall(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
params, err := capnp.NewStruct(outMsg.Message.Segment(), capnp.ObjectSize{DataSize: 8})
params, err := capnp.NewStruct(outMsg.Message().Segment(), capnp.ObjectSize{DataSize: 8})
if err != nil {
t.Fatal("capnp.NewStruct:", err)
}
params.SetUint32(0, 0x2a2b)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_call,
Call: &rpcCall{
QuestionID: callQID,
Expand Down Expand Up @@ -1452,8 +1452,8 @@ func TestCallOnClosedConn(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: qid,
Expand Down Expand Up @@ -1593,7 +1593,7 @@ func TestRecvCancel(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_call,
Call: &rpcCall{
QuestionID: callQID,
Expand Down Expand Up @@ -1732,8 +1732,8 @@ func TestSendCancel(t *testing.T) {
if err != nil {
t.Fatal("p2.NewMessage():", err)
}
iptr := capnp.NewInterface(outMsg.Message.Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), &rpcMessage{
iptr := capnp.NewInterface(outMsg.Message().Segment(), 0)
err = pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), &rpcMessage{
Which: rpccp.Message_Which_return,
Return: &rpcReturn{
AnswerID: bootQID,
Expand Down Expand Up @@ -2046,7 +2046,7 @@ func sendMessage(ctx context.Context, t rpc.Transport, msg *rpcMessage) error {
return fmt.Errorf("send message: %v", err)
}
defer outMsg.Release()
if err := pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message), msg); err != nil {
if err := pogs.Insert(rpccp.Message_TypeID, capnp.Struct(outMsg.Message()), msg); err != nil {
return fmt.Errorf("send message: %v", err)
}
if err := outMsg.Send(); err != nil {
Expand All @@ -2061,7 +2061,7 @@ func recvMessage(ctx context.Context, t rpc.Transport) (*rpcMessage, capnp.Relea
return nil, nil, err
}
r := new(rpcMessage)
if err := pogs.Extract(r, rpccp.Message_TypeID, capnp.Struct(inMsg.Message)); err != nil {
if err := pogs.Extract(r, rpccp.Message_TypeID, capnp.Struct(inMsg.Message())); err != nil {
inMsg.Release()
return nil, nil, fmt.Errorf("extract RPC message: %v", err)
}
Expand Down
Loading

0 comments on commit 1a829fd

Please sign in to comment.