Skip to content

Commit

Permalink
Allow using the packed transport in rpc.Serve
Browse files Browse the repository at this point in the history
The Serve function in the rpc module can only use the basic streaming encoding.

This commit allows setting the transport by the user through the appropriate ServeOption.
  • Loading branch information
fpetkovski committed Aug 28, 2024
1 parent 396906c commit 55c5869
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 20 deletions.
39 changes: 36 additions & 3 deletions rpc/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,51 @@ import (
"capnproto.org/go/capnp/v3"
)

// serveOpts are options for the Cap'n Proto server.
type serveOpts struct {
newTransport NewTransportFunc
}

// defaultServeOpts returns the default server opts.
func defaultServeOpts() serveOpts {
return serveOpts{
newTransport: NewStreamTransport,
}
}

type ServeOption func(*serveOpts)

// WithBasicStreamingTransport enables the streaming transport with basic encoding.
func WithBasicStreamingTransport() ServeOption {
return func(opts *serveOpts) {
opts.newTransport = NewStreamTransport
}
}

// WithPackedStreamingTransport enables the streaming transport with packed encoding.
func WithPackedStreamingTransport() ServeOption {
return func(opts *serveOpts) {
opts.newTransport = NewPackedStreamTransport
}
}

// Serve serves a Cap'n Proto RPC to incoming connections.
//
// Serve will take ownership of bootstrapClient and release it after the listener closes.
//
// Serve exits with the listener error if the listener is closed by the owner.
func Serve(lis net.Listener, boot capnp.Client) error {
func Serve(lis net.Listener, boot capnp.Client, opts ...ServeOption) error {
if !boot.IsValid() {
err := errors.New("bootstrap client is not valid")
return err
}
// Since we took ownership of the bootstrap client, release it after we're done.
defer boot.Release()

options := defaultServeOpts()
for _, o := range opts {
o(&options)
}
for {
// Accept incoming connections
conn, err := lis.Accept()
Expand All @@ -33,7 +66,7 @@ func Serve(lis net.Listener, boot capnp.Client) error {
BootstrapClient: boot.AddRef(),
}
// For each new incoming connection, create a new RPC transport connection that will serve incoming RPC requests
transport := NewStreamTransport(conn)
transport := options.newTransport(conn)
_ = NewConn(transport, &opts)
}
}
Expand All @@ -44,7 +77,7 @@ func Serve(lis net.Listener, boot capnp.Client) error {
// and "tcp" for regular TCP IP4 or IP6 connections.
//
// ListenAndServe will take ownership of bootstrapClient and release it on exit.
func ListenAndServe(ctx context.Context, network, addr string, bootstrapClient capnp.Client) error {
func ListenAndServe(ctx context.Context, network, addr string, bootstrapClient capnp.Client, opts ...ServeOption) error {

listener, err := net.Listen(network, addr)

Expand Down
56 changes: 39 additions & 17 deletions rpc/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,43 @@ func TestServeCapability(t *testing.T) {
}

func TestListenAndServe(t *testing.T) {
var err error
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
errChannel := make(chan error)

// Provide a server that listens
srv := testcp.PingPong_ServerToClient(pingPongServer{})
bootstrapClient := capnp.Client(srv)
go func() {
t.Log("Starting ListenAndServe")
err2 := rpc.ListenAndServe(ctx, "tcp", ":0", bootstrapClient)
errChannel <- err2
}()

cancelFunc()
err = <-errChannel // Will hang if server does not return.
assert.ErrorIs(t, err, net.ErrClosed)
cases := []struct {
name string
opts []rpc.ServeOption
}{
{
name: "basic encoding transport",
opts: []rpc.ServeOption{
rpc.WithBasicStreamingTransport(),
},
},
{
name: "packed encoding transport",
opts: []rpc.ServeOption{
rpc.WithPackedStreamingTransport(),
},
},
}

for _, tcase := range cases {
t.Run(tcase.name, func(t *testing.T) {
var err error
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
errChannel := make(chan error)

// Provide a server that listens
srv := testcp.PingPong_ServerToClient(pingPongServer{})
bootstrapClient := capnp.Client(srv)
go func() {
t.Log("Starting ListenAndServe")
err2 := rpc.ListenAndServe(ctx, "tcp", ":0", bootstrapClient, tcase.opts...)
errChannel <- err2
}()

cancelFunc()
err = <-errChannel // Will hang if server does not return.
assert.ErrorIs(t, err, net.ErrClosed)
})
}
}
1 change: 1 addition & 0 deletions rpc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

type Codec = transport.Codec
type Transport = transport.Transport
type NewTransportFunc func(io.ReadWriteCloser) Transport

// NewStreamTransport is an alias for as transport.NewStream
func NewStreamTransport(rwc io.ReadWriteCloser) Transport {
Expand Down

0 comments on commit 55c5869

Please sign in to comment.