Skip to content

Commit

Permalink
Merge pull request #532 from zenhack/generic-idgen
Browse files Browse the repository at this point in the history
Make rpc.idgen a generic type.
  • Loading branch information
zenhack authored Jun 26, 2023
2 parents b6db31e + d90c545 commit d151ede
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 21 deletions.
6 changes: 3 additions & 3 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (c *lockedConn) releaseExport(id exportID, count uint32) (capnp.ClientSnaps
defer ent.cancel()
snapshot := ent.snapshot
c.lk.exports[id] = nil
c.lk.exportID.remove(uint32(id))
c.lk.exportID.remove(id)
metadata := snapshot.Metadata()
syncutil.With(metadata, func() {
c.clearExportID(metadata)
Expand Down Expand Up @@ -170,7 +170,7 @@ func (c *lockedConn) sendCap(d rpccp.CapDescriptor, snapshot capnp.ClientSnapsho
wireRefs: 1,
cancel: func() {},
}
id = exportID(c.lk.exportID.next())
id = c.lk.exportID.next()
if int64(id) == int64(len(c.lk.exports)) {
c.lk.exports = append(c.lk.exports, ee)
} else {
Expand Down Expand Up @@ -308,7 +308,7 @@ func (e embargo) String() string {
//
// The caller must be holding onto c.mu.
func (c *lockedConn) embargo(client capnp.Client) (embargoID, capnp.Client) {
id := embargoID(c.lk.embargoID.next())
id := c.lk.embargoID.next()
e := newEmbargo(client)
if int64(id) == int64(len(c.lk.embargoes)) {
c.lk.embargoes = append(c.lk.embargoes, e)
Expand Down
10 changes: 5 additions & 5 deletions rpc/idgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ package rpc
// idgen returns a sequence of monotonically increasing IDs with
// support for replacement. The zero value is a generator that
// starts at zero.
type idgen struct {
type idgen[T ~uint32] struct {
i uint32
free uintSet
}

func (gen *idgen) next() uint32 {
func (gen *idgen[T]) next() T {
if first, ok := gen.free.min(); ok {
gen.free.remove(first)
return uint32(first)
return T(first)
}
i := gen.i
if i == ^uint32(0) {
Expand All @@ -22,10 +22,10 @@ func (gen *idgen) next() uint32 {
panic("overflow ID")
}
gen.i++
return i
return T(i)
}

func (gen *idgen) remove(i uint32) {
func (gen *idgen[T]) remove(i T) {
gen.free.add(uint(i))
}

Expand Down
4 changes: 2 additions & 2 deletions rpc/idgen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

func TestIDGen(t *testing.T) {
t.Run("NoReplacement", func(t *testing.T) {
var gen idgen
var gen idgen[uint32]
for i := uint32(0); i <= 128; i++ {
got := gen.next()
if got != i {
Expand All @@ -16,7 +16,7 @@ func TestIDGen(t *testing.T) {
}
})
t.Run("Replacement", func(t *testing.T) {
var gen idgen
var gen idgen[uint32]
for i := 0; i < 64; i++ {
gen.next()
}
Expand Down
2 changes: 1 addition & 1 deletion rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer,
})
q.p.Reject(rpcerr.WrapFailed("send message", err))
syncutil.With(&ic.c.lk, func() {
ic.c.lk.questionID.remove(uint32(q.id))
ic.c.lk.questionID.remove(q.id)
})
return
}
Expand Down
4 changes: 2 additions & 2 deletions rpc/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (flags questionFlags) Contains(flag questionFlags) bool {
func (c *lockedConn) newQuestion(method capnp.Method) *question {
q := &question{
c: (*Conn)(c),
id: questionID(c.lk.questionID.next()),
id: c.lk.questionID.next(),
release: func() {},
finishMsgSend: make(chan struct{}),
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func (q *question) PipelineSend(ctx context.Context, transform []capnp.PipelineO
})
q2.p.Reject(rpcerr.WrapFailed("send message", err))
syncutil.With(&q.c.lk, func() {
q.c.lk.questionID.remove(uint32(q2.id))
q.c.lk.questionID.remove(q2.id)
})
return
}
Expand Down
16 changes: 8 additions & 8 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ type Conn struct {

// Tables
questions []*question
questionID idgen
questionID idgen[questionID]
answers map[answerID]*ansent
exports []*expent
exportID idgen
exportID idgen[exportID]
imports map[importID]*impent
embargoes []*embargo
embargoID idgen
embargoID idgen[embargoID]
}
}

Expand Down Expand Up @@ -330,7 +330,7 @@ func (c *Conn) Bootstrap(ctx context.Context) (bc capnp.Client) {
})
q.p.Reject(exc.Annotate("rpc", "bootstrap", err))
syncutil.With(&c.lk, func() {
c.lk.questionID.remove(uint32(q.id))
c.lk.questionID.remove(q.id)
})
return
}
Expand Down Expand Up @@ -1112,7 +1112,7 @@ func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) e
select {
case <-q.finishMsgSend:
if q.flags.Contains(finishSent) {
c.lk.questionID.remove(uint32(qid))
c.lk.questionID.remove(qid)
}
dq.Defer(in.Release)
default:
Expand All @@ -1123,7 +1123,7 @@ func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) e
<-q.finishMsgSend
c.withLocked(func(c *lockedConn) {
if q.flags.Contains(finishSent) {
c.lk.questionID.remove(uint32(qid))
c.lk.questionID.remove(qid)
}
})
}()
Expand Down Expand Up @@ -1183,7 +1183,7 @@ func (c *Conn) handleReturn(ctx context.Context, in transport.IncomingMessage) e
c.er.ReportError(err)
} else {
q.flags |= finishSent
c.lk.questionID.remove(uint32(qid))
c.lk.questionID.remove(qid)
}
})
})
Expand Down Expand Up @@ -1574,7 +1574,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag
if e != nil {
// TODO(soon): verify target matches the right import.
c.lk.embargoes[id] = nil
c.lk.embargoID.remove(uint32(id))
c.lk.embargoID.remove(id)
}
})
if e == nil {
Expand Down

0 comments on commit d151ede

Please sign in to comment.