diff --git a/rpc/export.go b/rpc/export.go index e8f00873..6119880b 100644 --- a/rpc/export.go +++ b/rpc/export.go @@ -72,7 +72,7 @@ func (c *lockedConn) releaseExport(id exportID, count uint32) (capnp.ClientSnaps defer ent.cancel() snapshot := ent.snapshot c.lk.exports[id] = nil - c.lk.exportID.remove(uint32(id)) + c.lk.exportID.remove(id) metadata := snapshot.Metadata() syncutil.With(metadata, func() { c.clearExportID(metadata) @@ -170,7 +170,7 @@ func (c *lockedConn) sendCap(d rpccp.CapDescriptor, snapshot capnp.ClientSnapsho wireRefs: 1, cancel: func() {}, } - id = exportID(c.lk.exportID.next()) + id = c.lk.exportID.next() if int64(id) == int64(len(c.lk.exports)) { c.lk.exports = append(c.lk.exports, ee) } else { @@ -308,7 +308,7 @@ func (e embargo) String() string { // // The caller must be holding onto c.mu. func (c *lockedConn) embargo(client capnp.Client) (embargoID, capnp.Client) { - id := embargoID(c.lk.embargoID.next()) + id := c.lk.embargoID.next() e := newEmbargo(client) if int64(id) == int64(len(c.lk.embargoes)) { c.lk.embargoes = append(c.lk.embargoes, e) diff --git a/rpc/idgen.go b/rpc/idgen.go index 56ff0bab..acdc72e4 100644 --- a/rpc/idgen.go +++ b/rpc/idgen.go @@ -3,15 +3,15 @@ package rpc // idgen returns a sequence of monotonically increasing IDs with // support for replacement. The zero value is a generator that // starts at zero. -type idgen struct { +type idgen[T ~uint32] struct { i uint32 free uintSet } -func (gen *idgen) next() uint32 { +func (gen *idgen[T]) next() T { if first, ok := gen.free.min(); ok { gen.free.remove(first) - return uint32(first) + return T(first) } i := gen.i if i == ^uint32(0) { @@ -22,10 +22,10 @@ func (gen *idgen) next() uint32 { panic("overflow ID") } gen.i++ - return i + return T(i) } -func (gen *idgen) remove(i uint32) { +func (gen *idgen[T]) remove(i T) { gen.free.add(uint(i)) } diff --git a/rpc/idgen_test.go b/rpc/idgen_test.go index 3b710b36..8bd83e37 100644 --- a/rpc/idgen_test.go +++ b/rpc/idgen_test.go @@ -7,7 +7,7 @@ import ( func TestIDGen(t *testing.T) { t.Run("NoReplacement", func(t *testing.T) { - var gen idgen + var gen idgen[uint32] for i := uint32(0); i <= 128; i++ { got := gen.next() if got != i { @@ -16,7 +16,7 @@ func TestIDGen(t *testing.T) { } }) t.Run("Replacement", func(t *testing.T) { - var gen idgen + var gen idgen[uint32] for i := 0; i < 64; i++ { gen.next() } diff --git a/rpc/import.go b/rpc/import.go index 65f327a9..f098e008 100644 --- a/rpc/import.go +++ b/rpc/import.go @@ -111,7 +111,7 @@ func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, }) q.p.Reject(rpcerr.WrapFailed("send message", err)) syncutil.With(&ic.c.lk, func() { - ic.c.lk.questionID.remove(uint32(q.id)) + ic.c.lk.questionID.remove(q.id) }) return } diff --git a/rpc/question.go b/rpc/question.go index 458ba229..46953dfb 100644 --- a/rpc/question.go +++ b/rpc/question.go @@ -51,7 +51,7 @@ func (flags questionFlags) Contains(flag questionFlags) bool { func (c *lockedConn) newQuestion(method capnp.Method) *question { q := &question{ c: (*Conn)(c), - id: questionID(c.lk.questionID.next()), + id: c.lk.questionID.next(), release: func() {}, finishMsgSend: make(chan struct{}), } @@ -156,7 +156,7 @@ func (q *question) PipelineSend(ctx context.Context, transform []capnp.PipelineO }) q2.p.Reject(rpcerr.WrapFailed("send message", err)) syncutil.With(&q.c.lk, func() { - q.c.lk.questionID.remove(uint32(q2.id)) + q.c.lk.questionID.remove(q2.id) }) return } diff --git a/rpc/rpc.go b/rpc/rpc.go index 3d1e618e..d8e4b62e 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -122,13 +122,13 @@ type Conn struct { // Tables questions []*question - questionID idgen + questionID idgen[questionID] answers map[answerID]*ansent exports []*expent - exportID idgen + exportID idgen[exportID] imports map[importID]*impent embargoes []*embargo - embargoID idgen + embargoID idgen[embargoID] } } @@ -330,7 +330,7 @@ func (c *Conn) Bootstrap(ctx context.Context) (bc capnp.Client) { }) q.p.Reject(exc.Annotate("rpc", "bootstrap", err)) syncutil.With(&c.lk, func() { - c.lk.questionID.remove(uint32(q.id)) + c.lk.questionID.remove(q.id) }) return } @@ -1112,7 +1112,7 @@ func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) e select { case <-q.finishMsgSend: if q.flags.Contains(finishSent) { - c.lk.questionID.remove(uint32(qid)) + c.lk.questionID.remove(qid) } dq.Defer(in.Release) default: @@ -1123,7 +1123,7 @@ func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) e <-q.finishMsgSend c.withLocked(func(c *lockedConn) { if q.flags.Contains(finishSent) { - c.lk.questionID.remove(uint32(qid)) + c.lk.questionID.remove(qid) } }) }() @@ -1183,7 +1183,7 @@ func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) e c.er.ReportError(err) } else { q.flags |= finishSent - c.lk.questionID.remove(uint32(qid)) + c.lk.questionID.remove(qid) } }) }) @@ -1574,7 +1574,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag if e != nil { // TODO(soon): verify target matches the right import. c.lk.embargoes[id] = nil - c.lk.embargoID.remove(uint32(id)) + c.lk.embargoID.remove(id) } }) if e == nil {