diff --git a/README.md b/README.md index 22a71db..0332844 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,8 @@ import ( ) func main() { - // Create the health service. + // Create a health service. The health service is used to check the status + // of services running within the server. healthSvc := health.NewServer() healthSvc.SetServingStatus("example.up.Service", healthpb.HealthCheckResponse_SERVING) healthSvc.SetServingStatus("example.down.Service", healthpb.HealthCheckResponse_NOT_SERVING) @@ -68,7 +69,7 @@ func main() { // - websocket /v1/healthz -> grpc.health.v1.Health.Watch health.AddHealthz(serviceConfig) - // Mux implements http.Handler, use by itself to serve only HTTP endpoints. + // Mux impements http.Handler and serves both gRPC and HTTP connections. mux, err := larking.NewMux( larking.ServiceConfigOption(serviceConfig), ) @@ -78,8 +79,8 @@ func main() { // RegisterHealthServer registers a HealthServer to the mux. healthpb.RegisterHealthServer(mux, healthSvc) - // Server is a gRPC server that serves both gRPC and HTTP endpoints. - svr, err := larking.NewServer(mux, larking.InsecureServerOption()) + // Server creates a *http.Server. + svr, err := larking.NewServer(mux) if err != nil { log.Fatal(err) } @@ -97,7 +98,6 @@ func main() { if err := svr.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) } -} ``` Running the service we can check the health endpoints with curl: diff --git a/benchmarks/README.md b/benchmarks/README.md index 152c50d..11d153b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -33,3 +33,23 @@ Compares speed with writing the annotations binding by hand, useful for compairs ## Twirp [Twirp](https://github.com/twitchtv/twirp) is a simple RPC protocol based on HTTP and Protocol Buffers (proto). + + +## gRPC + +Compare gRPC server benchmarks with the `mux.ServeHTTP`. +We use an altered version of go-gRPC's [benchmain](https://github.com/grpc/grpc-go/blob/master/Documentation/benchmark.md) +tool to run a benchmark and compare it to gRPC internal server. + +``` +go run benchmain/main.go -benchtime=10s -workloads=all \ + -compression=gzip -maxConcurrentCalls=1 -trace=off \ + -reqSizeBytes=1,1048576 -respSizeBytes=1,1048576 -networkMode=Local \ + -cpuProfile=cpuProf -memProfile=memProf -memProfileRate=10000 -resultFile=result.bin +``` + +``` +go run google.golang.org/grpc/benchmark/benchresult grpc_result.bin result.bin +``` + +See `grpc-bench.txt` for gRPC results. diff --git a/benchmarks/bench.txt b/benchmarks/bench.txt index 8fb69c5..7c081df 100644 --- a/benchmarks/bench.txt +++ b/benchmarks/bench.txt @@ -1,33 +1,33 @@ goos: darwin goarch: arm64 pkg: larking.io/benchmarks -BenchmarkLarking/GRPC_GetBook-8 16065 73217 ns/op 25753 B/op 234 allocs/op -BenchmarkLarking/HTTP_GetBook-8 25358 47049 ns/op 9179 B/op 146 allocs/op -BenchmarkLarking/HTTP_UpdateBook-8 24842 48419 ns/op 11191 B/op 174 allocs/op -BenchmarkLarking/HTTP_DeleteBook-8 30992 38817 ns/op 7958 B/op 98 allocs/op -BenchmarkLarking/HTTP_GetBook+pb-8 30486 39433 ns/op 8302 B/op 104 allocs/op -BenchmarkLarking/HTTP_UpdateBook+pb-8 26541 44876 ns/op 9898 B/op 135 allocs/op -BenchmarkLarking/HTTP_DeleteBook+pb-8 31981 38090 ns/op 7721 B/op 95 allocs/op -BenchmarkGRPCGateway/GRPC_GetBook-8 32599 36980 ns/op 9495 B/op 178 allocs/op -BenchmarkGRPCGateway/HTTP_GetBook-8 24849 48138 ns/op 11202 B/op 179 allocs/op -BenchmarkGRPCGateway/HTTP_UpdateBook-8 22648 53060 ns/op 16661 B/op 230 allocs/op -BenchmarkGRPCGateway/HTTP_DeleteBook-8 30973 38945 ns/op 9158 B/op 119 allocs/op -BenchmarkEnvoyGRPC/GRPC_GetBook-8 8965 133720 ns/op 10959 B/op 177 allocs/op -BenchmarkEnvoyGRPC/HTTP_GetBook-8 7647 148356 ns/op 9968 B/op 163 allocs/op -BenchmarkEnvoyGRPC/HTTP_UpdateBook-8 7760 155771 ns/op 10648 B/op 166 allocs/op -BenchmarkEnvoyGRPC/HTTP_DeleteBook-8 8828 133543 ns/op 9023 B/op 126 allocs/op -BenchmarkGorillaMux/HTTP_GetBook-8 26178 45236 ns/op 9940 B/op 146 allocs/op -BenchmarkGorillaMux/HTTP_UpdateBook-8 25618 45267 ns/op 11880 B/op 170 allocs/op -BenchmarkGorillaMux/HTTP_DeleteBook-8 32311 36810 ns/op 8209 B/op 92 allocs/op -BenchmarkConnectGo/GRPC_GetBook-8 18517 64640 ns/op 13419 B/op 194 allocs/op -BenchmarkConnectGo/HTTP_GetBook-8 24274 50934 ns/op 10994 B/op 176 allocs/op -BenchmarkConnectGo/HTTP_UpdateBook-8 25250 47564 ns/op 11234 B/op 183 allocs/op -BenchmarkConnectGo/HTTP_DeleteBook-8 28860 41592 ns/op 9294 B/op 122 allocs/op -BenchmarkConnectGo/Connect_GetBook-8 15196 80045 ns/op 67857 B/op 150 allocs/op -BenchmarkConnectGo/Connect_UpdateBook-8 16398 73632 ns/op 72733 B/op 150 allocs/op -BenchmarkConnectGo/Connect_DeleteBook-8 17664 66962 ns/op 66902 B/op 140 allocs/op -BenchmarkTwirp/HTTP_GetBook-8 23032 51052 ns/op 12680 B/op 196 allocs/op -BenchmarkTwirp/HTTP_UpdateBook-8 22718 52466 ns/op 13436 B/op 214 allocs/op -BenchmarkTwirp/HTTP_DeleteBook-8 28699 41943 ns/op 10755 B/op 140 allocs/op +BenchmarkLarking/GRPC_GetBook-8 19376 60966 ns/op 13877 B/op 193 allocs/op +BenchmarkLarking/HTTP_GetBook-8 25690 46071 ns/op 9230 B/op 144 allocs/op +BenchmarkLarking/HTTP_UpdateBook-8 24384 49030 ns/op 11228 B/op 171 allocs/op +BenchmarkLarking/HTTP_DeleteBook-8 31225 38384 ns/op 8002 B/op 96 allocs/op +BenchmarkLarking/HTTP_GetBook+pb-8 30403 39225 ns/op 8351 B/op 102 allocs/op +BenchmarkLarking/HTTP_UpdateBook+pb-8 26089 45797 ns/op 9929 B/op 132 allocs/op +BenchmarkLarking/HTTP_DeleteBook+pb-8 31750 37533 ns/op 7773 B/op 93 allocs/op +BenchmarkGRPCGateway/GRPC_GetBook-8 29191 40917 ns/op 9498 B/op 178 allocs/op +BenchmarkGRPCGateway/HTTP_GetBook-8 25154 46624 ns/op 11200 B/op 179 allocs/op +BenchmarkGRPCGateway/HTTP_UpdateBook-8 23169 51382 ns/op 16404 B/op 230 allocs/op +BenchmarkGRPCGateway/HTTP_DeleteBook-8 31365 38172 ns/op 9160 B/op 119 allocs/op +BenchmarkEnvoyGRPC/GRPC_GetBook-8 9067 131554 ns/op 10960 B/op 177 allocs/op +BenchmarkEnvoyGRPC/HTTP_GetBook-8 8090 148991 ns/op 9974 B/op 163 allocs/op +BenchmarkEnvoyGRPC/HTTP_UpdateBook-8 7807 150534 ns/op 10648 B/op 166 allocs/op +BenchmarkEnvoyGRPC/HTTP_DeleteBook-8 8866 133976 ns/op 9026 B/op 126 allocs/op +BenchmarkGorillaMux/HTTP_GetBook-8 26590 45013 ns/op 9874 B/op 143 allocs/op +BenchmarkGorillaMux/HTTP_UpdateBook-8 26713 44912 ns/op 11791 B/op 166 allocs/op +BenchmarkGorillaMux/HTTP_DeleteBook-8 32792 36645 ns/op 8143 B/op 89 allocs/op +BenchmarkConnectGo/GRPC_GetBook-8 18492 64467 ns/op 13437 B/op 194 allocs/op +BenchmarkConnectGo/HTTP_GetBook-8 24248 49106 ns/op 10996 B/op 176 allocs/op +BenchmarkConnectGo/HTTP_UpdateBook-8 25184 47556 ns/op 11232 B/op 183 allocs/op +BenchmarkConnectGo/HTTP_DeleteBook-8 28917 41493 ns/op 9299 B/op 122 allocs/op +BenchmarkConnectGo/Connect_GetBook-8 15196 79021 ns/op 77348 B/op 151 allocs/op +BenchmarkConnectGo/Connect_UpdateBook-8 16684 73370 ns/op 79413 B/op 151 allocs/op +BenchmarkConnectGo/Connect_DeleteBook-8 18038 66268 ns/op 69503 B/op 140 allocs/op +BenchmarkTwirp/HTTP_GetBook-8 23119 49839 ns/op 12679 B/op 196 allocs/op +BenchmarkTwirp/HTTP_UpdateBook-8 23336 52310 ns/op 13456 B/op 214 allocs/op +BenchmarkTwirp/HTTP_DeleteBook-8 28702 41848 ns/op 10766 B/op 140 allocs/op PASS -ok larking.io/benchmarks 47.466s +ok larking.io/benchmarks 47.310s diff --git a/benchmarks/benchmain/main.go b/benchmarks/benchmain/main.go new file mode 100644 index 0000000..0b3a365 --- /dev/null +++ b/benchmarks/benchmain/main.go @@ -0,0 +1,892 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* +Package main provides benchmark with setting flags. + +An example to run some benchmarks with profiling enabled: + + go run benchmark/benchmain/main.go -benchtime=10s -workloads=all \ + -compression=gzip -maxConcurrentCalls=1 -trace=off \ + -reqSizeBytes=1,1048576 -respSizeBytes=1,1048576 -networkMode=Local \ + -cpuProfile=cpuProf -memProfile=memProf -memProfileRate=10000 -resultFile=result + + go run benchmain/main.go -benchtime=10s -workloads=all \ + -compression=gzip -maxConcurrentCalls=1 -trace=off \ + -reqSizeBytes=1,1048576 -respSizeBytes=1,1048576 -networkMode=Local \ + -cpuProfile=cpuProf -memProfile=memProf -memProfileRate=10000 -resultFile=result.bin + +As a suggestion, when creating a branch, you can run this benchmark and save the result +file "-resultFile=basePerf", and later when you at the middle of the work or finish the +work, you can get the benchmark result and compare it with the base anytime. + +Assume there are two result files names as "basePerf" and "curPerf" created by adding +-resultFile=basePerf and -resultFile=curPerf. + + To format the curPerf, run: + go run benchmark/benchresult/main.go curPerf + To observe how the performance changes based on a base result, run: + go run benchmark/benchresult/main.go basePerf curPerf +*/ +package main + +import ( + "context" + "encoding/gob" + "flag" + "fmt" + "io" + "log" + "net" + "os" + "reflect" + "runtime" + "runtime/pprof" + "strings" + "sync" + "sync/atomic" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/benchmark" + "google.golang.org/grpc/benchmark/flags" + "google.golang.org/grpc/benchmark/latency" + "google.golang.org/grpc/benchmark/stats" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/grpclog" + + //"google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/test/bufconn" + + testpb "google.golang.org/grpc/interop/grpc_testing" + + "larking.io/benchmarks/server" +) + +var ( + workloads = flags.StringWithAllowedValues("workloads", workloadsAll, + fmt.Sprintf("Workloads to execute - One of: %v", strings.Join(allWorkloads, ", ")), allWorkloads) + traceMode = flags.StringWithAllowedValues("trace", toggleModeOff, + fmt.Sprintf("Trace mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes) + preloaderMode = flags.StringWithAllowedValues("preloader", toggleModeOff, + fmt.Sprintf("Preloader mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes) + //channelzOn = flags.StringWithAllowedValues("channelz", toggleModeOff, + // fmt.Sprintf("Channelz mode - One of: %v", strings.Join(allToggleModes, ", ")), allToggleModes) + compressorMode = flags.StringWithAllowedValues("compression", compModeOff, + fmt.Sprintf("Compression mode - One of: %v", strings.Join(allCompModes, ", ")), allCompModes) + networkMode = flags.StringWithAllowedValues("networkMode", networkModeNone, + "Network mode includes LAN, WAN, Local and Longhaul", allNetworkModes) + readLatency = flags.DurationSlice("latency", defaultReadLatency, "Simulated one-way network latency - may be a comma-separated list") + readKbps = flags.IntSlice("kbps", defaultReadKbps, "Simulated network throughput (in kbps) - may be a comma-separated list") + readMTU = flags.IntSlice("mtu", defaultReadMTU, "Simulated network MTU (Maximum Transmission Unit) - may be a comma-separated list") + maxConcurrentCalls = flags.IntSlice("maxConcurrentCalls", defaultMaxConcurrentCalls, "Number of concurrent RPCs during benchmarks") + readReqSizeBytes = flags.IntSlice("reqSizeBytes", nil, "Request size in bytes - may be a comma-separated list") + readRespSizeBytes = flags.IntSlice("respSizeBytes", nil, "Response size in bytes - may be a comma-separated list") + reqPayloadCurveFiles = flags.StringSlice("reqPayloadCurveFiles", nil, "comma-separated list of CSV files describing the shape a random distribution of request payload sizes") + respPayloadCurveFiles = flags.StringSlice("respPayloadCurveFiles", nil, "comma-separated list of CSV files describing the shape a random distribution of response payload sizes") + benchTime = flag.Duration("benchtime", time.Second, "Configures the amount of time to run each benchmark") + memProfile = flag.String("memProfile", "", "Enables memory profiling output to the filename provided.") + memProfileRate = flag.Int("memProfileRate", 512*1024, "Configures the memory profiling rate. \n"+ + "memProfile should be set before setting profile rate. To include every allocated block in the profile, "+ + "set MemProfileRate to 1. To turn off profiling entirely, set MemProfileRate to 0. 512 * 1024 by default.") + cpuProfile = flag.String("cpuProfile", "", "Enables CPU profiling output to the filename provided") + benchmarkResultFile = flag.String("resultFile", "", "Save the benchmark result into a binary file") + useBufconn = flag.Bool("bufconn", false, "Use in-memory connection instead of system network I/O") + enableKeepalive = flag.Bool("enable_keepalive", false, "Enable client keepalive. \n"+ + "Keepalive.Time is set to 10s, Keepalive.Timeout is set to 1s, Keepalive.PermitWithoutStream is set to true.") + clientReadBufferSize = flags.IntSlice("clientReadBufferSize", []int{-1}, "Configures the client read buffer size in bytes. If negative, use the default - may be a a comma-separated list") + clientWriteBufferSize = flags.IntSlice("clientWriteBufferSize", []int{-1}, "Configures the client write buffer size in bytes. If negative, use the default - may be a a comma-separated list") + serverReadBufferSize = flags.IntSlice("serverReadBufferSize", []int{-1}, "Configures the server read buffer size in bytes. If negative, use the default - may be a a comma-separated list") + serverWriteBufferSize = flags.IntSlice("serverWriteBufferSize", []int{-1}, "Configures the server write buffer size in bytes. If negative, use the default - may be a a comma-separated list") + + logger = grpclog.Component("benchmark") +) + +const ( + workloadsUnary = "unary" + workloadsStreaming = "streaming" + workloadsUnconstrained = "unconstrained" + workloadsAll = "all" + // Compression modes. + compModeOff = "off" + compModeGzip = "gzip" + compModeNop = "nop" + compModeAll = "all" + // Toggle modes. + toggleModeOff = "off" + toggleModeOn = "on" + toggleModeBoth = "both" + // Network modes. + networkModeNone = "none" + networkModeLocal = "Local" + networkModeLAN = "LAN" + networkModeWAN = "WAN" + networkLongHaul = "Longhaul" + + numStatsBuckets = 10 + warmupCallCount = 10 + warmuptime = time.Second +) + +var ( + allWorkloads = []string{workloadsUnary, workloadsStreaming, workloadsUnconstrained, workloadsAll} + allCompModes = []string{compModeOff, compModeGzip, compModeNop, compModeAll} + allToggleModes = []string{toggleModeOff, toggleModeOn, toggleModeBoth} + allNetworkModes = []string{networkModeNone, networkModeLocal, networkModeLAN, networkModeWAN, networkLongHaul} + defaultReadLatency = []time.Duration{0, 40 * time.Millisecond} // if non-positive, no delay. + defaultReadKbps = []int{0, 10240} // if non-positive, infinite + defaultReadMTU = []int{0} // if non-positive, infinite + defaultMaxConcurrentCalls = []int{1, 8, 64, 512} + defaultReqSizeBytes = []int{1, 1024, 1024 * 1024} + defaultRespSizeBytes = []int{1, 1024, 1024 * 1024} + networks = map[string]latency.Network{ + networkModeLocal: latency.Local, + networkModeLAN: latency.LAN, + networkModeWAN: latency.WAN, + networkLongHaul: latency.Longhaul, + } + keepaliveTime = 10 * time.Second + keepaliveTimeout = 1 * time.Second + // This is 0.8*keepaliveTime to prevent connection issues because of server + // keepalive enforcement. + keepaliveMinTime = 8 * time.Second +) + +// runModes indicates the workloads to run. This is initialized with a call to +// `runModesFromWorkloads`, passing the workloads flag set by the user. +type runModes struct { + unary, streaming, unconstrained bool +} + +// runModesFromWorkloads determines the runModes based on the value of +// workloads flag set by the user. +func runModesFromWorkloads(workload string) runModes { + r := runModes{} + switch workload { + case workloadsUnary: + r.unary = true + case workloadsStreaming: + r.streaming = true + case workloadsUnconstrained: + r.unconstrained = true + case workloadsAll: + r.unary = true + r.streaming = true + r.unconstrained = true + default: + log.Fatalf("Unknown workloads setting: %v (want one of: %v)", + workloads, strings.Join(allWorkloads, ", ")) + } + return r +} + +type startFunc func(mode string, bf stats.Features) +type stopFunc func(count uint64) +type ucStopFunc func(req uint64, resp uint64) +type rpcCallFunc func(pos int) +type rpcSendFunc func(pos int) +type rpcRecvFunc func(pos int) +type rpcCleanupFunc func() + +func unaryBenchmark(start startFunc, stop stopFunc, bf stats.Features, s *stats.Stats) { + caller, cleanup := makeFuncUnary(bf) + defer cleanup() + runBenchmark(caller, start, stop, bf, s, workloadsUnary) +} + +func streamBenchmark(start startFunc, stop stopFunc, bf stats.Features, s *stats.Stats) { + caller, cleanup := makeFuncStream(bf) + defer cleanup() + runBenchmark(caller, start, stop, bf, s, workloadsStreaming) +} + +func unconstrainedStreamBenchmark(start startFunc, stop ucStopFunc, bf stats.Features) { + var sender rpcSendFunc + var recver rpcRecvFunc + var cleanup rpcCleanupFunc + if bf.EnablePreloader { + sender, recver, cleanup = makeFuncUnconstrainedStreamPreloaded(bf) + } else { + sender, recver, cleanup = makeFuncUnconstrainedStream(bf) + } + defer cleanup() + + var req, resp uint64 + go func() { + // Resets the counters once warmed up + <-time.NewTimer(warmuptime).C + atomic.StoreUint64(&req, 0) + atomic.StoreUint64(&resp, 0) + start(workloadsUnconstrained, bf) + }() + + bmEnd := time.Now().Add(bf.BenchTime + warmuptime) + var wg sync.WaitGroup + wg.Add(2 * bf.MaxConcurrentCalls) + for i := 0; i < bf.MaxConcurrentCalls; i++ { + go func(pos int) { + defer wg.Done() + for { + t := time.Now() + if t.After(bmEnd) { + return + } + sender(pos) + atomic.AddUint64(&req, 1) + } + }(i) + go func(pos int) { + defer wg.Done() + for { + t := time.Now() + if t.After(bmEnd) { + return + } + recver(pos) + atomic.AddUint64(&resp, 1) + } + }(i) + } + wg.Wait() + stop(req, resp) +} + +// makeClient returns a gRPC client for the grpc.testing.BenchmarkService +// service. The client is configured using the different options in the passed +// 'bf'. Also returns a cleanup function to close the client and release +// resources. +func makeClient(bf stats.Features) (testpb.BenchmarkServiceClient, func()) { + nw := &latency.Network{Kbps: bf.Kbps, Latency: bf.Latency, MTU: bf.MTU} + opts := []grpc.DialOption{} + //sopts := []grpc.ServerOption{} + //sopts := []larking.ServerOption{} + if bf.ModeCompressor == compModeNop { + //sopts = append(sopts, + // grpc.RPCCompressor(nopCompressor{}), + // grpc.RPCDecompressor(nopDecompressor{}), + //) + opts = append(opts, + grpc.WithCompressor(nopCompressor{}), + grpc.WithDecompressor(nopDecompressor{}), + ) + } + if bf.ModeCompressor == compModeGzip { + //sopts = append(sopts, + // grpc.RPCCompressor(grpc.NewGZIPCompressor()), + // grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), + //) + opts = append(opts, + grpc.WithCompressor(grpc.NewGZIPCompressor()), + grpc.WithDecompressor(grpc.NewGZIPDecompressor()), + ) + } + if bf.EnableKeepalive { + //sopts = append(sopts, + // grpc.KeepaliveParams(keepalive.ServerParameters{ + // Time: keepaliveTime, + // Timeout: keepaliveTimeout, + // }), + // grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ + // MinTime: keepaliveMinTime, + // PermitWithoutStream: true, + // }), + //) + opts = append(opts, + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: keepaliveTime, + Timeout: keepaliveTimeout, + PermitWithoutStream: true, + }), + ) + } + if bf.ClientReadBufferSize >= 0 { + opts = append(opts, grpc.WithReadBufferSize(bf.ClientReadBufferSize)) + } + if bf.ClientWriteBufferSize >= 0 { + opts = append(opts, grpc.WithWriteBufferSize(bf.ClientWriteBufferSize)) + } + //if bf.ServerReadBufferSize >= 0 { + // sopts = append(sopts, grpc.ReadBufferSize(bf.ServerReadBufferSize)) + //} + //if bf.ServerWriteBufferSize >= 0 { + // sopts = append(sopts, grpc.WriteBufferSize(bf.ServerWriteBufferSize)) + //} + + //sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(bf.MaxConcurrentCalls+1))) + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + + var lis net.Listener + if bf.UseBufConn { + bcLis := bufconn.Listen(256 * 1024) + lis = bcLis + opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) { + return nw.ContextDialer(func(context.Context, string, string) (net.Conn, error) { + return bcLis.Dial() + })(ctx, "", "") + })) + } else { + var err error + lis, err = net.Listen("tcp", "localhost:0") + if err != nil { + logger.Fatalf("Failed to listen: %v", err) + } + opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) { + return nw.ContextDialer((&net.Dialer{}).DialContext)(ctx, "tcp", lis.Addr().String()) + })) + } + lis = nw.Listener(lis) + stopper := server.StartServer(server.ServerInfo{Type: "protobuf", Listener: lis}) + target := "localhost:5051" + conn := benchmark.NewClientConn(target /* target *IS* used */, opts...) + return testpb.NewBenchmarkServiceClient(conn), func() { + conn.Close() + stopper() + } +} + +func makeFuncUnary(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) { + tc, cleanup := makeClient(bf) + return func(int) { + reqSizeBytes := bf.ReqSizeBytes + respSizeBytes := bf.RespSizeBytes + if bf.ReqPayloadCurve != nil { + reqSizeBytes = bf.ReqPayloadCurve.ChooseRandom() + } + if bf.RespPayloadCurve != nil { + respSizeBytes = bf.RespPayloadCurve.ChooseRandom() + } + unaryCaller(tc, reqSizeBytes, respSizeBytes) + }, cleanup +} + +func makeFuncStream(bf stats.Features) (rpcCallFunc, rpcCleanupFunc) { + tc, cleanup := makeClient(bf) + + streams := make([]testpb.BenchmarkService_StreamingCallClient, bf.MaxConcurrentCalls) + for i := 0; i < bf.MaxConcurrentCalls; i++ { + stream, err := tc.StreamingCall(context.Background()) + if err != nil { + logger.Fatalf("%v.StreamingCall(_) = _, %v", tc, err) + } + streams[i] = stream + } + + return func(pos int) { + reqSizeBytes := bf.ReqSizeBytes + respSizeBytes := bf.RespSizeBytes + if bf.ReqPayloadCurve != nil { + reqSizeBytes = bf.ReqPayloadCurve.ChooseRandom() + } + if bf.RespPayloadCurve != nil { + respSizeBytes = bf.RespPayloadCurve.ChooseRandom() + } + streamCaller(streams[pos], reqSizeBytes, respSizeBytes) + }, cleanup +} + +func makeFuncUnconstrainedStreamPreloaded(bf stats.Features) (rpcSendFunc, rpcRecvFunc, rpcCleanupFunc) { + streams, req, cleanup := setupUnconstrainedStream(bf) + + preparedMsg := make([]*grpc.PreparedMsg, len(streams)) + for i, stream := range streams { + preparedMsg[i] = &grpc.PreparedMsg{} + err := preparedMsg[i].Encode(stream, req) + if err != nil { + logger.Fatalf("%v.Encode(%v, %v) = %v", preparedMsg[i], req, stream, err) + } + } + + return func(pos int) { + streams[pos].SendMsg(preparedMsg[pos]) + }, func(pos int) { + streams[pos].Recv() + }, cleanup +} + +func makeFuncUnconstrainedStream(bf stats.Features) (rpcSendFunc, rpcRecvFunc, rpcCleanupFunc) { + streams, req, cleanup := setupUnconstrainedStream(bf) + + return func(pos int) { + streams[pos].Send(req) + }, func(pos int) { + streams[pos].Recv() + }, cleanup +} + +func setupUnconstrainedStream(bf stats.Features) ([]testpb.BenchmarkService_StreamingCallClient, *testpb.SimpleRequest, rpcCleanupFunc) { + tc, cleanup := makeClient(bf) + + streams := make([]testpb.BenchmarkService_StreamingCallClient, bf.MaxConcurrentCalls) + md := metadata.Pairs(benchmark.UnconstrainedStreamingHeader, "1") + ctx := metadata.NewOutgoingContext(context.Background(), md) + for i := 0; i < bf.MaxConcurrentCalls; i++ { + stream, err := tc.StreamingCall(ctx) + if err != nil { + logger.Fatalf("%v.StreamingCall(_) = _, %v", tc, err) + } + streams[i] = stream + } + + pl := benchmark.NewPayload(testpb.PayloadType_COMPRESSABLE, bf.ReqSizeBytes) + req := &testpb.SimpleRequest{ + ResponseType: pl.Type, + ResponseSize: int32(bf.RespSizeBytes), + Payload: pl, + } + + return streams, req, cleanup +} + +// Makes a UnaryCall gRPC request using the given BenchmarkServiceClient and +// request and response sizes. +func unaryCaller(client testpb.BenchmarkServiceClient, reqSize, respSize int) { + if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil { + logger.Fatalf("DoUnaryCall failed: %v", err) + } +} + +func streamCaller(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) { + if err := benchmark.DoStreamingRoundTrip(stream, reqSize, respSize); err != nil { + logger.Fatalf("DoStreamingRoundTrip failed: %v", err) + } +} + +func runBenchmark(caller rpcCallFunc, start startFunc, stop stopFunc, bf stats.Features, s *stats.Stats, mode string) { + // Warm up connection. + for i := 0; i < warmupCallCount; i++ { + caller(0) + } + + // Run benchmark. + start(mode, bf) + var wg sync.WaitGroup + wg.Add(bf.MaxConcurrentCalls) + bmEnd := time.Now().Add(bf.BenchTime) + var count uint64 + for i := 0; i < bf.MaxConcurrentCalls; i++ { + go func(pos int) { + defer wg.Done() + for { + t := time.Now() + if t.After(bmEnd) { + return + } + start := time.Now() + caller(pos) + elapse := time.Since(start) + atomic.AddUint64(&count, 1) + s.AddDuration(elapse) + } + }(i) + } + wg.Wait() + stop(count) +} + +// benchOpts represents all configurable options available while running this +// benchmark. This is built from the values passed as flags. +type benchOpts struct { + rModes runModes + benchTime time.Duration + memProfileRate int + memProfile string + cpuProfile string + networkMode string + benchmarkResultFile string + useBufconn bool + enableKeepalive bool + features *featureOpts +} + +// featureOpts represents options which can have multiple values. The user +// usually provides a comma-separated list of options for each of these +// features through command line flags. We generate all possible combinations +// for the provided values and run the benchmarks for each combination. +type featureOpts struct { + enableTrace []bool + readLatencies []time.Duration + readKbps []int + readMTU []int + maxConcurrentCalls []int + reqSizeBytes []int + respSizeBytes []int + reqPayloadCurves []*stats.PayloadCurve + respPayloadCurves []*stats.PayloadCurve + compModes []string + enableChannelz []bool + enablePreloader []bool + clientReadBufferSize []int + clientWriteBufferSize []int + serverReadBufferSize []int + serverWriteBufferSize []int +} + +// makeFeaturesNum returns a slice of ints of size 'maxFeatureIndex' where each +// element of the slice (indexed by 'featuresIndex' enum) contains the number +// of features to be exercised by the benchmark code. +// For example: Index 0 of the returned slice contains the number of values for +// enableTrace feature, while index 1 contains the number of value of +// readLatencies feature and so on. +func makeFeaturesNum(b *benchOpts) []int { + featuresNum := make([]int, stats.MaxFeatureIndex) + for i := 0; i < len(featuresNum); i++ { + switch stats.FeatureIndex(i) { + case stats.EnableTraceIndex: + featuresNum[i] = len(b.features.enableTrace) + case stats.ReadLatenciesIndex: + featuresNum[i] = len(b.features.readLatencies) + case stats.ReadKbpsIndex: + featuresNum[i] = len(b.features.readKbps) + case stats.ReadMTUIndex: + featuresNum[i] = len(b.features.readMTU) + case stats.MaxConcurrentCallsIndex: + featuresNum[i] = len(b.features.maxConcurrentCalls) + case stats.ReqSizeBytesIndex: + featuresNum[i] = len(b.features.reqSizeBytes) + case stats.RespSizeBytesIndex: + featuresNum[i] = len(b.features.respSizeBytes) + case stats.ReqPayloadCurveIndex: + featuresNum[i] = len(b.features.reqPayloadCurves) + case stats.RespPayloadCurveIndex: + featuresNum[i] = len(b.features.respPayloadCurves) + case stats.CompModesIndex: + featuresNum[i] = len(b.features.compModes) + case stats.EnableChannelzIndex: + featuresNum[i] = len(b.features.enableChannelz) + case stats.EnablePreloaderIndex: + featuresNum[i] = len(b.features.enablePreloader) + case stats.ClientReadBufferSize: + featuresNum[i] = len(b.features.clientReadBufferSize) + case stats.ClientWriteBufferSize: + featuresNum[i] = len(b.features.clientWriteBufferSize) + case stats.ServerReadBufferSize: + featuresNum[i] = len(b.features.serverReadBufferSize) + case stats.ServerWriteBufferSize: + featuresNum[i] = len(b.features.serverWriteBufferSize) + default: + log.Fatalf("Unknown feature index %v in generateFeatures. maxFeatureIndex is %v", i, stats.MaxFeatureIndex) + } + } + return featuresNum +} + +// sharedFeatures returns a bool slice which acts as a bitmask. Each item in +// the slice represents a feature, indexed by 'featureIndex' enum. The bit is +// set to 1 if the corresponding feature does not have multiple value, so is +// shared amongst all benchmarks. +func sharedFeatures(featuresNum []int) []bool { + result := make([]bool, len(featuresNum)) + for i, num := range featuresNum { + if num <= 1 { + result[i] = true + } + } + return result +} + +// generateFeatures generates all combinations of the provided feature options. +// While all the feature options are stored in the benchOpts struct, the input +// parameter 'featuresNum' is a slice indexed by 'featureIndex' enum containing +// the number of values for each feature. +// For example, let's say the user sets -workloads=all and +// -maxConcurrentCalls=1,100, this would end up with the following +// combinations: +// [workloads: unary, maxConcurrentCalls=1] +// [workloads: unary, maxConcurrentCalls=1] +// [workloads: streaming, maxConcurrentCalls=100] +// [workloads: streaming, maxConcurrentCalls=100] +// [workloads: unconstrained, maxConcurrentCalls=1] +// [workloads: unconstrained, maxConcurrentCalls=100] +func (b *benchOpts) generateFeatures(featuresNum []int) []stats.Features { + // curPos and initialPos are two slices where each value acts as an index + // into the appropriate feature slice maintained in benchOpts.features. This + // loop generates all possible combinations of features by changing one value + // at a time, and once curPos becomes equal to initialPos, we have explored + // all options. + var result []stats.Features + var curPos []int + initialPos := make([]int, stats.MaxFeatureIndex) + for !reflect.DeepEqual(initialPos, curPos) { + if curPos == nil { + curPos = make([]int, stats.MaxFeatureIndex) + } + f := stats.Features{ + // These features stay the same for each iteration. + NetworkMode: b.networkMode, + UseBufConn: b.useBufconn, + EnableKeepalive: b.enableKeepalive, + BenchTime: b.benchTime, + // These features can potentially change for each iteration. + EnableTrace: b.features.enableTrace[curPos[stats.EnableTraceIndex]], + Latency: b.features.readLatencies[curPos[stats.ReadLatenciesIndex]], + Kbps: b.features.readKbps[curPos[stats.ReadKbpsIndex]], + MTU: b.features.readMTU[curPos[stats.ReadMTUIndex]], + MaxConcurrentCalls: b.features.maxConcurrentCalls[curPos[stats.MaxConcurrentCallsIndex]], + ModeCompressor: b.features.compModes[curPos[stats.CompModesIndex]], + //EnableChannelz: b.features.enableChannelz[curPos[stats.EnableChannelzIndex]], + EnablePreloader: b.features.enablePreloader[curPos[stats.EnablePreloaderIndex]], + ClientReadBufferSize: b.features.clientReadBufferSize[curPos[stats.ClientReadBufferSize]], + ClientWriteBufferSize: b.features.clientWriteBufferSize[curPos[stats.ClientWriteBufferSize]], + ServerReadBufferSize: b.features.serverReadBufferSize[curPos[stats.ServerReadBufferSize]], + ServerWriteBufferSize: b.features.serverWriteBufferSize[curPos[stats.ServerWriteBufferSize]], + } + if len(b.features.reqPayloadCurves) == 0 { + f.ReqSizeBytes = b.features.reqSizeBytes[curPos[stats.ReqSizeBytesIndex]] + } else { + f.ReqPayloadCurve = b.features.reqPayloadCurves[curPos[stats.ReqPayloadCurveIndex]] + } + if len(b.features.respPayloadCurves) == 0 { + f.RespSizeBytes = b.features.respSizeBytes[curPos[stats.RespSizeBytesIndex]] + } else { + f.RespPayloadCurve = b.features.respPayloadCurves[curPos[stats.RespPayloadCurveIndex]] + } + result = append(result, f) + addOne(curPos, featuresNum) + } + return result +} + +// addOne mutates the input slice 'features' by changing one feature, thus +// arriving at the next combination of feature values. 'featuresMaxPosition' +// provides the numbers of allowed values for each feature, indexed by +// 'featureIndex' enum. +func addOne(features []int, featuresMaxPosition []int) { + for i := len(features) - 1; i >= 0; i-- { + if featuresMaxPosition[i] == 0 { + continue + } + features[i] = (features[i] + 1) + if features[i]/featuresMaxPosition[i] == 0 { + break + } + features[i] = features[i] % featuresMaxPosition[i] + } +} + +// processFlags reads the command line flags and builds benchOpts. Specifying +// invalid values for certain flags will cause flag.Parse() to fail, and the +// program to terminate. +// This *SHOULD* be the only place where the flags are accessed. All other +// parts of the benchmark code should rely on the returned benchOpts. +func processFlags() *benchOpts { + flag.Parse() + if flag.NArg() != 0 { + log.Fatal("Error: unparsed arguments: ", flag.Args()) + } + + opts := &benchOpts{ + rModes: runModesFromWorkloads(*workloads), + benchTime: *benchTime, + memProfileRate: *memProfileRate, + memProfile: *memProfile, + cpuProfile: *cpuProfile, + networkMode: *networkMode, + benchmarkResultFile: *benchmarkResultFile, + useBufconn: *useBufconn, + enableKeepalive: *enableKeepalive, + features: &featureOpts{ + enableTrace: setToggleMode(*traceMode), + readLatencies: append([]time.Duration(nil), *readLatency...), + readKbps: append([]int(nil), *readKbps...), + readMTU: append([]int(nil), *readMTU...), + maxConcurrentCalls: append([]int(nil), *maxConcurrentCalls...), + reqSizeBytes: append([]int(nil), *readReqSizeBytes...), + respSizeBytes: append([]int(nil), *readRespSizeBytes...), + compModes: setCompressorMode(*compressorMode), + //enableChannelz: setToggleMode(*channelzOn), + enablePreloader: setToggleMode(*preloaderMode), + clientReadBufferSize: append([]int(nil), *clientReadBufferSize...), + clientWriteBufferSize: append([]int(nil), *clientWriteBufferSize...), + serverReadBufferSize: append([]int(nil), *serverReadBufferSize...), + serverWriteBufferSize: append([]int(nil), *serverWriteBufferSize...), + }, + } + + if len(*reqPayloadCurveFiles) == 0 { + if len(opts.features.reqSizeBytes) == 0 { + opts.features.reqSizeBytes = defaultReqSizeBytes + } + } else { + if len(opts.features.reqSizeBytes) != 0 { + log.Fatalf("you may not specify -reqPayloadCurveFiles and -reqSizeBytes at the same time") + } + for _, file := range *reqPayloadCurveFiles { + pc, err := stats.NewPayloadCurve(file) + if err != nil { + log.Fatalf("cannot load payload curve file %s: %v", file, err) + } + opts.features.reqPayloadCurves = append(opts.features.reqPayloadCurves, pc) + } + opts.features.reqSizeBytes = nil + } + if len(*respPayloadCurveFiles) == 0 { + if len(opts.features.respSizeBytes) == 0 { + opts.features.respSizeBytes = defaultRespSizeBytes + } + } else { + if len(opts.features.respSizeBytes) != 0 { + log.Fatalf("you may not specify -respPayloadCurveFiles and -respSizeBytes at the same time") + } + for _, file := range *respPayloadCurveFiles { + pc, err := stats.NewPayloadCurve(file) + if err != nil { + log.Fatalf("cannot load payload curve file %s: %v", file, err) + } + opts.features.respPayloadCurves = append(opts.features.respPayloadCurves, pc) + } + opts.features.respSizeBytes = nil + } + + // Re-write latency, kpbs and mtu if network mode is set. + if network, ok := networks[opts.networkMode]; ok { + opts.features.readLatencies = []time.Duration{network.Latency} + opts.features.readKbps = []int{network.Kbps} + opts.features.readMTU = []int{network.MTU} + } + return opts +} + +func setToggleMode(val string) []bool { + switch val { + case toggleModeOn: + return []bool{true} + case toggleModeOff: + return []bool{false} + case toggleModeBoth: + return []bool{false, true} + default: + // This should never happen because a wrong value passed to this flag would + // be caught during flag.Parse(). + return []bool{} + } +} + +func setCompressorMode(val string) []string { + switch val { + case compModeNop, compModeGzip, compModeOff: + return []string{val} + case compModeAll: + return []string{compModeNop, compModeGzip, compModeOff} + default: + // This should never happen because a wrong value passed to this flag would + // be caught during flag.Parse(). + return []string{} + } +} + +func main() { + opts := processFlags() + before(opts) + + s := stats.NewStats(numStatsBuckets) + featuresNum := makeFeaturesNum(opts) + sf := sharedFeatures(featuresNum) + + var ( + start = func(mode string, bf stats.Features) { s.StartRun(mode, bf, sf) } + stop = func(count uint64) { s.EndRun(count) } + ucStop = func(req uint64, resp uint64) { s.EndUnconstrainedRun(req, resp) } + ) + + for _, bf := range opts.generateFeatures(featuresNum) { + grpc.EnableTracing = bf.EnableTrace + //if bf.EnableChannelz { + // channelz.TurnOn() + //} + if opts.rModes.unary { + unaryBenchmark(start, stop, bf, s) + } + if opts.rModes.streaming { + streamBenchmark(start, stop, bf, s) + } + if opts.rModes.unconstrained { + unconstrainedStreamBenchmark(start, ucStop, bf) + } + } + after(opts, s.GetResults()) +} + +func before(opts *benchOpts) { + if opts.memProfile != "" { + runtime.MemProfileRate = opts.memProfileRate + } + if opts.cpuProfile != "" { + f, err := os.Create(opts.cpuProfile) + if err != nil { + fmt.Fprintf(os.Stderr, "testing: %s\n", err) + return + } + if err := pprof.StartCPUProfile(f); err != nil { + fmt.Fprintf(os.Stderr, "testing: can't start cpu profile: %s\n", err) + f.Close() + return + } + } +} + +func after(opts *benchOpts, data []stats.BenchResults) { + if opts.cpuProfile != "" { + pprof.StopCPUProfile() // flushes profile to disk + } + if opts.memProfile != "" { + f, err := os.Create(opts.memProfile) + if err != nil { + fmt.Fprintf(os.Stderr, "testing: %s\n", err) + os.Exit(2) + } + runtime.GC() // materialize all statistics + if err = pprof.WriteHeapProfile(f); err != nil { + fmt.Fprintf(os.Stderr, "testing: can't write heap profile %s: %s\n", opts.memProfile, err) + os.Exit(2) + } + f.Close() + } + if opts.benchmarkResultFile != "" { + f, err := os.Create(opts.benchmarkResultFile) + if err != nil { + log.Fatalf("testing: can't write benchmark result %s: %s\n", opts.benchmarkResultFile, err) + } + dataEncoder := gob.NewEncoder(f) + dataEncoder.Encode(data) + f.Close() + } +} + +// nopCompressor is a compressor that just copies data. +type nopCompressor struct{} + +func (nopCompressor) Do(w io.Writer, p []byte) error { + n, err := w.Write(p) + if err != nil { + return err + } + if n != len(p) { + return fmt.Errorf("nopCompressor.Write: wrote %d bytes; want %d", n, len(p)) + } + return nil +} + +func (nopCompressor) Type() string { return compModeNop } + +// nopDecompressor is a decompressor that just copies data. +type nopDecompressor struct{} + +func (nopDecompressor) Do(r io.Reader) ([]byte, error) { return io.ReadAll(r) } +func (nopDecompressor) Type() string { return compModeNop } diff --git a/benchmarks/grpc-bench.txt b/benchmarks/grpc-bench.txt new file mode 100644 index 0000000..5535ae4 --- /dev/null +++ b/benchmarks/grpc-bench.txt @@ -0,0 +1,192 @@ +streaming-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1048576B-respSize_1B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 3132 3521 12.42% + SendOps 0 0 NaN% + RecvOps 0 0 NaN% + Bytes/op 9780220.69 6369158.76 -34.88% + Allocs/op 175.32 118.53 -32.51% + ReqT/op 2627312025.60 2953628876.80 12.42% + RespT/op 2505.60 2816.80 12.41% + 50th-Lat 3.231875ms 2.814041ms -12.93% + 90th-Lat 3.322042ms 3.076375ms -7.40% + 99th-Lat 3.464792ms 3.193459ms -7.83% + Avg-Lat 3.193236ms 2.840273ms -11.05% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +streaming-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1048576B-respSize_1048576B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 1628 1772 8.85% + SendOps 0 0 NaN% + RecvOps 0 0 NaN% + Bytes/op 18522984.53 13475306.02 -27.25% + Allocs/op 219.34 155.40 -29.18% + ReqT/op 1365665382.40 1486461337.60 8.85% + RespT/op 1365665382.40 1486461337.60 8.85% + 50th-Lat 6.136ms 5.607667ms -8.61% + 90th-Lat 6.221625ms 5.95075ms -4.35% + 99th-Lat 6.458209ms 6.177167ms -4.35% + Avg-Lat 6.142553ms 5.643576ms -8.12% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +streaming-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1B-respSize_1B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 188659 174004 -7.77% + SendOps 0 0 NaN% + RecvOps 0 0 NaN% + Bytes/op 3849.79 2525.08 -34.39% + Allocs/op 91.84 58.77 -35.93% + ReqT/op 150927.20 139203.20 -7.77% + RespT/op 150927.20 139203.20 -7.77% + 50th-Lat 51.459µs 55.75µs 8.34% + 90th-Lat 62.75µs 67.375µs 7.37% + 99th-Lat 73.458µs 90.083µs 22.63% + Avg-Lat 52.843µs 57.307µs 8.45% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +unconstrained-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1B-respSize_1B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 0 0 NaN% + SendOps 876580 1022098 16.60% + RecvOps 906289 473664 -47.74% + Bytes/op 4148.69 2249.00 -45.77% + Allocs/op 40.93 33.26 -17.10% + ReqT/op 701264.00 817678.40 16.60% + RespT/op 725031.20 378931.20 -47.74% + 50th-Lat 0s 0s NaN% + 90th-Lat 0s 0s NaN% + 99th-Lat 0s 0s NaN% + Avg-Lat 0s 0s NaN% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +streaming-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1B-respSize_1048576B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 3099 3324 7.26% + SendOps 0 0 NaN% + RecvOps 0 0 NaN% + Bytes/op 9951205.20 8235786.95 -17.24% + Allocs/op 179.59 134.60 -25.06% + ReqT/op 2479.20 2659.20 7.26% + RespT/op 2599629619.20 2788373299.20 7.26% + 50th-Lat 3.243083ms 2.984541ms -7.97% + 90th-Lat 3.321083ms 3.170208ms -4.54% + 99th-Lat 3.490458ms 3.324458ms -4.76% + Avg-Lat 3.227493ms 3.007915ms -6.80% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +unary-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1048576B-respSize_1B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 3181 3449 8.43% + SendOps 0 0 NaN% + RecvOps 0 0 NaN% + Bytes/op 9463415.42 6702429.34 -29.18% + Allocs/op 308.81 311.11 0.97% + ReqT/op 2668416204.80 2893230899.20 8.43% + RespT/op 2544.80 2759.20 8.45% + 50th-Lat 3.091875ms 2.883917ms -6.73% + 90th-Lat 3.354834ms 3.146041ms -6.22% + 99th-Lat 3.551917ms 3.358792ms -5.44% + Avg-Lat 3.1441ms 2.899666ms -7.77% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +unary-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1048576B-respSize_1048576B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 1620 1711 5.62% + SendOps 0 0 NaN% + RecvOps 0 0 NaN% + Bytes/op 18533065.91 15485754.44 -16.44% + Allocs/op 361.89 364.01 0.83% + ReqT/op 1358954496.00 1435290828.80 5.62% + RespT/op 1358954496.00 1435290828.80 5.62% + 50th-Lat 6.173584ms 5.827625ms -5.60% + 90th-Lat 6.260833ms 6.137ms -1.98% + 99th-Lat 6.37275ms 6.402542ms 0.47% + Avg-Lat 6.172941ms 5.846148ms -5.29% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +unconstrained-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1048576B-respSize_1048576B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 0 0 NaN% + SendOps 3936 4106 4.32% + RecvOps 3890 4107 5.58% + Bytes/op 14870093.41 12176865.11 -18.11% + Allocs/op 176.24 145.07 -17.59% + ReqT/op 3301756108.80 3444362444.80 4.32% + RespT/op 3263168512.00 3445201305.60 5.58% + 50th-Lat 0s 0s NaN% + 90th-Lat 0s 0s NaN% + 99th-Lat 0s 0s NaN% + Avg-Lat 0s 0s NaN% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +unary-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1B-respSize_1B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 145095 92960 -35.93% + SendOps 0 0 NaN% + RecvOps 0 0 NaN% + Bytes/op 15360.03 17503.52 13.95% + Allocs/op 231.95 248.11 7.33% + ReqT/op 116076.00 74368.00 -35.93% + RespT/op 116076.00 74368.00 -35.93% + 50th-Lat 67.708µs 102.666µs 51.63% + 90th-Lat 80.75µs 118.584µs 46.85% + 99th-Lat 94.583µs 204.333µs 116.04% + Avg-Lat 68.756µs 107.408µs 56.22% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +unary-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1B-respSize_1048576B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 2997 3046 1.63% + SendOps 0 0 NaN% + RecvOps 0 0 NaN% + Bytes/op 10052174.55 9730336.44 -3.20% + Allocs/op 325.92 361.19 11.05% + ReqT/op 2397.60 2436.80 1.63% + RespT/op 2514065817.60 2555169996.80 1.63% + 50th-Lat 3.309125ms 3.287167ms -0.66% + 90th-Lat 3.558791ms 3.464708ms -2.64% + 99th-Lat 3.696166ms 3.668708ms -0.74% + Avg-Lat 3.336849ms 3.28346ms -1.60% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +unconstrained-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1B-respSize_1048576B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 0 0 NaN% + SendOps 1097703 1119919 2.02% + RecvOps 3381 3480 2.93% + Bytes/op 53395.47 46662.17 -12.61% + Allocs/op 41.74 33.85 -19.16% + ReqT/op 878162.40 895935.20 2.02% + RespT/op 2836188364.80 2919235584.00 2.93% + 50th-Lat 0s 0s NaN% + 90th-Lat 0s 0s NaN% + 99th-Lat 0s 0s NaN% + Avg-Lat 0s 0s NaN% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + +unconstrained-networkMode_Local-bufConn_false-keepalive_false-benchTime_10s-trace_false-latency_0s-kbps_0-MTU_0-maxConcurrentCalls_1-reqSize_1048576B-respSize_1B-compressor_gzip-channelz_false-preloader_false-clientReadBufferSize_-1-clientWriteBufferSize_-1-serverReadBufferSize_-1-serverWriteBufferSize_-1- + Title Before After Percentage + TotalOps 0 0 NaN% + SendOps 3441 3609 4.88% + RecvOps 1231131 609195 -50.52% + Bytes/op 48065.87 58849.88 22.44% + Allocs/op 43.49 42.52 -2.30% + ReqT/op 2886520012.80 3027448627.20 4.88% + RespT/op 984904.80 487356.00 -50.52% + 50th-Lat 0s 0s NaN% + 90th-Lat 0s 0s NaN% + 99th-Lat 0s 0s NaN% + Avg-Lat 0s 0s NaN% + GoVersion go1.20.3 go1.20.3 + GrpcVersion 1.56.0-dev 1.54.0 + diff --git a/benchmarks/main_test.go b/benchmarks/main_test.go index 9b78559..bdf544a 100644 --- a/benchmarks/main_test.go +++ b/benchmarks/main_test.go @@ -161,7 +161,7 @@ func BenchmarkLarking(b *testing.B) { } librarypb.RegisterLibraryServiceServer(mux, svc) - ts, err := larking.NewServer(mux, larking.InsecureServerOption()) + ts, err := larking.NewServer(mux) if err != nil { b.Fatal(err) } @@ -192,12 +192,14 @@ func BenchmarkLarking(b *testing.B) { b.Fatal(err) } }() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() - cc, err := grpc.Dial( + cc, err := grpc.DialContext( + ctx, lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), - grpc.WithTimeout(time.Second), ) if err != nil { b.Fatal(err) @@ -274,11 +276,14 @@ func BenchmarkGRPCGateway(b *testing.B) { go func() { errs <- m.Serve() }() - cc, err := grpc.Dial( + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + cc, err := grpc.DialContext( + ctx, lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), - grpc.WithTimeout(time.Second), ) if err != nil { b.Fatal(err) @@ -351,11 +356,11 @@ func BenchmarkEnvoyGRPC(b *testing.B) { }) defer cancel() - cc, err := grpc.Dial( + cc, err := grpc.DialContext( + ctx, envoyAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), - grpc.WithTimeout(2*time.Second), ) if err != nil { b.Fatal(err) @@ -389,7 +394,7 @@ func BenchmarkGorillaMux(b *testing.B) { newIncomingContext := func(ctx context.Context, header http.Header) (context.Context, metadata.MD) { md := make(metadata.MD, len(header)) for k, vs := range header { - md["http-"+strings.ToLower(k)] = vs + md[strings.ToLower(k)] = vs } return metadata.NewIncomingContext(ctx, md), md } @@ -563,11 +568,11 @@ func BenchmarkConnectGo(b *testing.B) { go func() { errs <- hs.Serve(lis) }() defer hs.Close() - cc, err := grpc.Dial( + cc, err := grpc.DialContext( + ctx, lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), - grpc.WithTimeout(time.Second), ) if err != nil { b.Fatal(err) diff --git a/benchmarks/server/server.go b/benchmarks/server/server.go new file mode 100644 index 0000000..0207b4a --- /dev/null +++ b/benchmarks/server/server.go @@ -0,0 +1,295 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* +Package server implements the building blocks to setup end-to-end gRPC benchmarks. +*/ +package server + +import ( + "context" + "fmt" + "io" + "log" + "net" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + testpb "google.golang.org/grpc/interop/grpc_testing" + "larking.io/larking" +) + +var logger = grpclog.Component("benchmark") + +// Allows reuse of the same testpb.Payload object. +func setPayload(p *testpb.Payload, t testpb.PayloadType, size int) { + if size < 0 { + logger.Fatalf("Requested a response with invalid length %d", size) + } + body := make([]byte, size) + switch t { + case testpb.PayloadType_COMPRESSABLE: + default: + logger.Fatalf("Unsupported payload type: %d", t) + } + p.Type = t + p.Body = body +} + +// NewPayload creates a payload with the given type and size. +func NewPayload(t testpb.PayloadType, size int) *testpb.Payload { + p := new(testpb.Payload) + setPayload(p, t, size) + return p +} + +type testServer struct { + testpb.UnimplementedBenchmarkServiceServer +} + +func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{ + Payload: NewPayload(in.ResponseType, int(in.ResponseSize)), + }, nil +} + +// UnconstrainedStreamingHeader indicates to the StreamingCall handler that its +// behavior should be unconstrained (constant send/receive in parallel) instead +// of ping-pong. +const UnconstrainedStreamingHeader = "unconstrained-streaming" + +func (s *testServer) StreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error { + if md, ok := metadata.FromIncomingContext(stream.Context()); ok && len(md[UnconstrainedStreamingHeader]) != 0 { + return s.UnconstrainedStreamingCall(stream) + } + response := &testpb.SimpleResponse{ + Payload: new(testpb.Payload), + } + in := new(testpb.SimpleRequest) + for { + // use ServerStream directly to reuse the same testpb.SimpleRequest object + err := stream.(grpc.ServerStream).RecvMsg(in) + if err == io.EOF { + // read done. + return nil + } + if err != nil { + return err + } + setPayload(response.Payload, in.ResponseType, int(in.ResponseSize)) + if err := stream.Send(response); err != nil { + return err + } + } +} + +func (s *testServer) UnconstrainedStreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error { + in := new(testpb.SimpleRequest) + // Receive a message to learn response type and size. + err := stream.RecvMsg(in) + if err == io.EOF { + // read done. + return nil + } + if err != nil { + return err + } + + response := &testpb.SimpleResponse{ + Payload: new(testpb.Payload), + } + setPayload(response.Payload, in.ResponseType, int(in.ResponseSize)) + + go func() { + for { + // Using RecvMsg rather than Recv to prevent reallocation of SimpleRequest. + err := stream.RecvMsg(in) + switch status.Code(err) { + case codes.Canceled: + return + case codes.OK: + default: + log.Fatalf("server recv error: %v", err) + } + } + }() + + go func() { + for { + err := stream.Send(response) + switch status.Code(err) { + case codes.Unavailable, codes.Canceled: + return + case codes.OK: + default: + log.Fatalf("server send error: %v", err) + } + } + }() + + <-stream.Context().Done() + return stream.Context().Err() +} + +// byteBufServer is a gRPC server that sends and receives byte buffer. +// The purpose is to benchmark the gRPC performance without protobuf serialization/deserialization overhead. +type byteBufServer struct { + testpb.UnimplementedBenchmarkServiceServer + respSize int32 +} + +// UnaryCall is an empty function and is not used for benchmark. +// If bytebuf UnaryCall benchmark is needed later, the function body needs to be updated. +func (s *byteBufServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{}, nil +} + +func (s *byteBufServer) StreamingCall(stream testpb.BenchmarkService_StreamingCallServer) error { + for { + var in []byte + err := stream.(grpc.ServerStream).RecvMsg(&in) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + out := make([]byte, s.respSize) + if err := stream.(grpc.ServerStream).SendMsg(&out); err != nil { + return err + } + } +} + +// ServerInfo contains the information to create a gRPC benchmark server. +type ServerInfo struct { + // Type is the type of the server. + // It should be "protobuf" or "bytebuf". + Type string + + // Metadata is an optional configuration. + // For "protobuf", it's ignored. + // For "bytebuf", it should be an int representing response size. + Metadata interface{} + + // Listener is the network listener for the server to use + Listener net.Listener +} + +// StartServer starts a gRPC server serving a benchmark service according to info. +// It returns a function to stop the server. +func StartServer(info ServerInfo, _ ...grpc.ServerOption) func() { + m, err := larking.NewMux() + if err != nil { + logger.Fatalf("failed to StartServer, NewMux failed: %v", err) + } + + //s := grpc.NewServer(opts...) + switch info.Type { + case "protobuf": + testpb.RegisterBenchmarkServiceServer(m, &testServer{}) + case "bytebuf": + respSize, ok := info.Metadata.(int32) + if !ok { + logger.Fatalf("failed to StartServer, invalid metadata: %v, for Type: %v", info.Metadata, info.Type) + } + testpb.RegisterBenchmarkServiceServer(m, &byteBufServer{respSize: respSize}) + default: + logger.Fatalf("failed to StartServer, unknown Type: %v", info.Type) + } + + s, err := larking.NewServer(m) + if err != nil { + logger.Fatalf("failed to StartServer, NewServer failed: %v", err) + } + + go s.Serve(info.Listener) + return func() { + s.Close() + } +} + +// DoUnaryCall performs an unary RPC with given stub and request and response sizes. +func DoUnaryCall(tc testpb.BenchmarkServiceClient, reqSize, respSize int) error { + pl := NewPayload(testpb.PayloadType_COMPRESSABLE, reqSize) + req := &testpb.SimpleRequest{ + ResponseType: pl.Type, + ResponseSize: int32(respSize), + Payload: pl, + } + if _, err := tc.UnaryCall(context.Background(), req); err != nil { + return fmt.Errorf("/BenchmarkService/UnaryCall(_, _) = _, %v, want _, ", err) + } + return nil +} + +// DoStreamingRoundTrip performs a round trip for a single streaming rpc. +func DoStreamingRoundTrip(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) error { + pl := NewPayload(testpb.PayloadType_COMPRESSABLE, reqSize) + req := &testpb.SimpleRequest{ + ResponseType: pl.Type, + ResponseSize: int32(respSize), + Payload: pl, + } + if err := stream.Send(req); err != nil { + return fmt.Errorf("/BenchmarkService/StreamingCall.Send(_) = %v, want ", err) + } + if _, err := stream.Recv(); err != nil { + // EOF is a valid error here. + if err == io.EOF { + return nil + } + return fmt.Errorf("/BenchmarkService/StreamingCall.Recv(_) = %v, want ", err) + } + return nil +} + +// DoByteBufStreamingRoundTrip performs a round trip for a single streaming rpc, using a custom codec for byte buffer. +func DoByteBufStreamingRoundTrip(stream testpb.BenchmarkService_StreamingCallClient, reqSize, respSize int) error { + out := make([]byte, reqSize) + if err := stream.(grpc.ClientStream).SendMsg(&out); err != nil { + return fmt.Errorf("/BenchmarkService/StreamingCall.(ClientStream).SendMsg(_) = %v, want ", err) + } + var in []byte + if err := stream.(grpc.ClientStream).RecvMsg(&in); err != nil { + // EOF is a valid error here. + if err == io.EOF { + return nil + } + return fmt.Errorf("/BenchmarkService/StreamingCall.(ClientStream).RecvMsg(_) = %v, want ", err) + } + return nil +} + +// NewClientConn creates a gRPC client connection to addr. +func NewClientConn(addr string, opts ...grpc.DialOption) *grpc.ClientConn { + return NewClientConnWithContext(context.Background(), addr, opts...) +} + +// NewClientConnWithContext creates a gRPC client connection to addr using ctx. +func NewClientConnWithContext(ctx context.Context, addr string, opts ...grpc.DialOption) *grpc.ClientConn { + conn, err := grpc.DialContext(ctx, addr, opts...) + if err != nil { + logger.Fatalf("NewClientConn(%q) failed to create a ClientConn: %v", addr, err) + } + return conn +} diff --git a/benchmarks/twirp_test.go b/benchmarks/twirp_test.go index 53a12fd..613e22f 100644 --- a/benchmarks/twirp_test.go +++ b/benchmarks/twirp_test.go @@ -24,7 +24,6 @@ func TestTwirp(t *testing.T) { librarypb.RegisterLibraryServiceServer(mux, svc) ts, err := larking.NewServer(mux, - larking.InsecureServerOption(), larking.MuxHandleOption("/", "/twirp"), ) if err != nil { diff --git a/docs/main.go b/docs/main.go index e557cb4..3da470e 100644 --- a/docs/main.go +++ b/docs/main.go @@ -24,7 +24,7 @@ func main() { // - websocket /v1/healthz -> grpc.health.v1.Health.Watch health.AddHealthz(serviceConfig) - // Mux implements http.Handler, use by itself to serve only HTTP endpoints. + // Mux impements http.Handler and serves both gRPC and HTTP connections. mux, err := larking.NewMux( larking.ServiceConfigOption(serviceConfig), ) @@ -34,8 +34,8 @@ func main() { // RegisterHealthServer registers a HealthServer to the mux. healthpb.RegisterHealthServer(mux, healthSvc) - // Server is a gRPC server that serves both gRPC and HTTP endpoints. - svr, err := larking.NewServer(mux, larking.InsecureServerOption()) + // Server creates a *http.Server. + svr, err := larking.NewServer(mux) if err != nil { log.Fatal(err) } diff --git a/larking/compress.go b/larking/compress.go index 8a10ffa..c264ef4 100644 --- a/larking/compress.go +++ b/larking/compress.go @@ -1,6 +1,7 @@ package larking import ( + "bytes" "compress/gzip" "io" "sync" @@ -8,6 +9,12 @@ import ( "google.golang.org/grpc/encoding" ) +var bufPool = sync.Pool{ + New: func() interface{} { + return &bytes.Buffer{} + }, +} + // Compressor is used to compress and decompress messages. // Based on grpc/encoding. type Compressor interface { diff --git a/larking/grpc.go b/larking/grpc.go new file mode 100644 index 0000000..3e70dae --- /dev/null +++ b/larking/grpc.go @@ -0,0 +1,606 @@ +package larking + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "math" + "net/http" + "net/textproto" + "strconv" + "strings" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +func isStreamError(err error) bool { + switch err { + case nil, io.EOF, context.Canceled: + return false + } + return true +} + +func isReservedHeader(k string) bool { + switch k { + case "content-type", "user-agent", "grpc-message-type", "grpc-encoding", + "grpc-message", "grpc-status", "grpc-timeout", + "grpc-status-details", "te": + return true + default: + return false + } +} +func isWhitelistedHeader(k string) bool { + switch k { + case ":authority", "user-agent": + return true + default: + return false + } +} + +const binHdrSuffix = "-bin" + +func encodeBinHeader(b []byte) string { + return base64.RawStdEncoding.EncodeToString(b) +} + +func decodeBinHeader(v string) (s string, err error) { + var b []byte + if len(v)%4 == 0 { + // Input was padded, or padding was not necessary. + b, err = base64.RawStdEncoding.DecodeString(v) + } else { + b, err = base64.RawStdEncoding.DecodeString(v) + } + return string(b), err +} + +func newIncomingContext(ctx context.Context, header http.Header) (context.Context, metadata.MD) { + md := make(metadata.MD, len(header)) + for k, vs := range header { + k = strings.ToLower(k) + if isReservedHeader(k) && !isWhitelistedHeader(k) { + continue + } + if strings.HasSuffix(k, binHdrSuffix) { + dst := make([]string, len(vs)) + for i, v := range vs { + v, err := decodeBinHeader(v) + if err != nil { + continue // TODO: log error? + } + dst[i] = v + } + vs = dst + } + md[k] = vs + } + return metadata.NewIncomingContext(ctx, md), md +} + +func setOutgoingHeader(header http.Header, md metadata.MD) { + for k, vs := range md { + if isReservedHeader(k) { + continue + } + + if strings.HasSuffix(k, binHdrSuffix) { + dst := make([]string, len(vs)) + for i, v := range vs { + dst[i] = encodeBinHeader([]byte(v)) + } + vs = dst + } + header[textproto.CanonicalMIMEHeaderKey(k)] = vs + } +} + +func encodeGrpcMessage(msg string) string { + var ( + sb strings.Builder + pos int + ) + for i := 0; i < len(msg); i++ { + c := msg[i] + if c < ' ' || c > '~' || c == '%' { + if pos < i { + sb.WriteString(msg[pos:i]) + } + sb.WriteString(fmt.Sprintf("%%%02x", c)) + pos = i + 1 + } + } + if pos == 0 { + return msg + } + return sb.String() +} + +func timeoutUnit(s byte) time.Duration { + switch s { + case 'H': + return time.Hour + case 'M': + return time.Minute + case 'S': + return time.Second + case 'm': + return time.Millisecond + case 'u': + return time.Microsecond + case 'n': + return time.Nanosecond + default: + return 0 + } +} + +func decodeTimeout(s string) (time.Duration, error) { + size := len(s) + if size < 2 { + return 0, fmt.Errorf("transport: timeout string is too short: %q", s) + } + if size > 9 { + // Spec allows for 8 digits plus the unit. + return 0, fmt.Errorf("transport: timeout string is too long: %q", s) + } + d := timeoutUnit(s[size-1]) + if d == 0 { + return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s) + } + t, err := strconv.ParseInt(s[:size-1], 10, 64) + if err != nil { + return 0, err + } + const maxHours = math.MaxInt64 / int64(time.Hour) + if d == time.Hour && t > maxHours { + // This timeout would overflow math.MaxInt64; clamp it. + return time.Duration(math.MaxInt64), nil + } + return d * time.Duration(t), nil +} + +type streamGRPC struct { + opts muxOptions + ctx context.Context + done <-chan struct{} // ctx.Done() + wg sync.WaitGroup + handler *handler + codec Codec // both read and write + comp Compressor // both read and write + w io.Writer + r io.Reader // + wHeader http.Header + header metadata.MD + trailer metadata.MD + contentType string + messageEncoding string + sentHeader bool +} + +func (s *streamGRPC) isDone() error { + select { + case <-s.done: + return status.FromContextError(s.ctx.Err()).Err() + default: + return nil + } +} + +func (s *streamGRPC) SetHeader(md metadata.MD) error { + if s.sentHeader { + return fmt.Errorf("already sent headers") + } + s.header = metadata.Join(s.header, md) + return nil +} +func (s *streamGRPC) SendHeader(md metadata.MD) error { + s.wg.Add(1) + defer s.wg.Done() + + if err := s.isDone(); err != nil { + return err + } + + if s.sentHeader { + return fmt.Errorf("already sent headers") + } + s.header = metadata.Join(s.header, md) + h := s.wHeader + + h.Set("Content-Type", s.contentType) + if s.messageEncoding != "" { + h.Set("Grpc-Encoding", s.messageEncoding) + } + h.Add("Trailer", "Grpc-Status") + h.Add("Trailer", "Grpc-Message") + h.Add("Trailer", "Grpc-Status-Details-Bin") + + setOutgoingHeader(h, s.header) + + // don't write the header code, wait for the body. + s.sentHeader = true + + if sh := s.opts.statsHandler; sh != nil { + out := &stats.OutHeader{ + Header: s.header.Copy(), + } + if s.comp != nil { + out.Compression = s.comp.Name() + } + sh.HandleRPC(s.ctx, out) + } + return nil +} + +func (s *streamGRPC) SetTrailer(md metadata.MD) { + s.trailer = metadata.Join(s.trailer, md) +} + +func (s *streamGRPC) Context() context.Context { + sts := &serverTransportStream{s, s.handler.method} + return grpc.NewContextWithServerTransportStream(s.ctx, sts) +} +func (s *streamGRPC) compress(dst *bytes.Buffer, b []byte) error { + w, err := s.comp.Compress(dst) + if err != nil { + return err + } + defer w.Close() + if _, err := w.Write(b); err != nil { + return err + } + return nil +} + +func (s *streamGRPC) SendMsg(m interface{}) error { + s.wg.Add(1) + defer s.wg.Done() + + if err := s.isDone(); err != nil { + return err + } + + reply := m.(proto.Message) + if !s.sentHeader { + if err := s.SendHeader(nil); err != nil { + return err + } + } + + bp := bytesPool.Get().(*[]byte) + b := (*bp)[:0] + defer func() { + if cap(b) < s.opts.maxReceiveMessageSize { + *bp = b + bytesPool.Put(bp) + } + }() + if cap(b) < 5 { + b = make([]byte, 0, growcap(cap(b), 5)) + } + b = b[:5] // 1 byte compression flag, 4 bytes message length + + var err error + b, err = s.codec.MarshalAppend(b, reply) + if err != nil { + return err + } + + var size uint32 + size = uint32(len(b) - 5) + if int(size) > s.opts.maxReceiveMessageSize { + return fmt.Errorf("grpc: received message larger than max (%d vs. %d)", size, s.opts.maxReceiveMessageSize) + } + + b[0] = 0 // uncompressed + if s.comp != nil { + buf := bufPool.Get().(*bytes.Buffer) + buf.Reset() + if err := s.compress(buf, b[5:]); err != nil { + bufPool.Put(buf) + return err + } + bufSize := buf.Len() + if bufSize+5 > cap(b) { + b = make([]byte, 0, growcap(cap(b), bufSize+5)) + } + b = b[:bufSize+5] + b[0] = 1 // compressed + copy(b[5:], buf.Bytes()) + size = uint32(bufSize) + bufPool.Put(buf) + } + + binary.BigEndian.PutUint32(b[1:], size) + if _, err := s.w.Write(b); err != nil { + if isStreamError(err) { + msg := err.Error() + return status.Errorf(codes.Unavailable, msg) + } + return err + } + s.w.(http.Flusher).Flush() + if stats := s.opts.statsHandler; stats != nil { + // TODO: raw payload stats. + b := b[headerLen:] // shadow + stats.HandleRPC(s.ctx, outPayload(false, m, b, b, time.Now())) + } + return nil +} + +func (s *streamGRPC) decompress(dst *bytes.Buffer, b []byte) error { + src := bytes.NewReader(b) + + r, err := s.comp.Decompress(src) + if err != nil { + return err + } + if _, err := dst.ReadFrom(r); err != nil { + return err + } + return nil +} + +func (s *streamGRPC) RecvMsg(m interface{}) error { + s.wg.Add(1) + defer s.wg.Done() + + if err := s.isDone(); err != nil { + return err + } + + args := m.(proto.Message) + + bp := bytesPool.Get().(*[]byte) + b := (*bp)[:0] + defer func() { + if cap(b) < s.opts.maxReceiveMessageSize { + *bp = b + bytesPool.Put(bp) + } + }() + if cap(b) < 5 { + b = make([]byte, 0, growcap(cap(b), 5)) + } + b = b[:5] // 1 byte compression flag, 4 bytes message length + + if _, err := io.ReadFull(s.r, b); err != nil { + if isStreamError(err) { + msg := err.Error() + return status.Errorf(codes.Canceled, msg) + } + return err + } + isCompressed := b[0] == 1 + size := binary.BigEndian.Uint32(b[1:]) + if int(size) > s.opts.maxReceiveMessageSize { + return fmt.Errorf("grpc: received message larger than max (%d vs. %d)", size, s.opts.maxReceiveMessageSize) + } + + if cap(b) < int(size) { + b = make([]byte, 0, growcap(cap(b), int(size))) + } + b = b[:size] + if _, err := io.ReadFull(s.r, b); err != nil { + return err + } + + if isCompressed { + // compressed + if s.comp == nil { + return fmt.Errorf("grpc: Decompressor is not installed for grpc-encoding %q", s.messageEncoding) + } + + buf := bufPool.Get().(*bytes.Buffer) + buf.Reset() + if err := s.decompress(buf, b); err != nil { + bufPool.Put(buf) + return err + } + size = uint32(buf.Len()) + if int(size) > cap(b) { + b = make([]byte, 0, growcap(cap(b), int(size))) + } + b = b[:int(size)] + copy(b, buf.Bytes()) + bufPool.Put(buf) + } + + if err := s.codec.Unmarshal(b, args); err != nil { + return err + } + if stats := s.opts.statsHandler; stats != nil { + // TODO: raw payload stats. + b := b[headerLen:] // shadow + stats.HandleRPC(s.ctx, inPayload(false, m, b, b, time.Now())) + } + return nil +} + +func (m *Mux) grpcGetCodec(ct string) (Codec, bool) { + typ, enc, ok := strings.Cut(ct, "+") + if !ok { + enc = "proto" + } + if typ != "application/grpc" { + return nil, false + } + c, ok := m.opts.codecsByName[enc] + return c, ok +} + +func (m *Mux) grpcGetCompressor(me string) (Compressor, bool) { + if me == "" { + return nil, false + } + c, ok := m.opts.compressors[me] + return c, ok +} + +// serveGRPC serves the gRPC server. +func (m *Mux) serveGRPC(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor != 2 { + msg := "gRPC requires HTTP/2" + http.Error(w, msg, http.StatusBadRequest) + return + } + if r.Method != "POST" { + msg := fmt.Sprintf("invalid gRPC request method %q", r.Method) + http.Error(w, msg, http.StatusBadRequest) + return + } + flusher, ok := w.(http.Flusher) + if !ok { + msg := "Streaming unsupported" + http.Error(w, msg, http.StatusInternalServerError) + return + } + + contentType := r.Header.Get("Content-Type") + codec, ok := m.grpcGetCodec(contentType) + if !ok { + msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType) + http.Error(w, msg, http.StatusUnsupportedMediaType) + return + } + messageEncoding := r.Header.Get("Grpc-Encoding") + var compressor Compressor + if messageEncoding != "" { + comp, ok := m.grpcGetCompressor(messageEncoding) + if !ok { + msg := fmt.Sprintf("invalid gRPC request message-encoding %q", messageEncoding) + http.Error(w, msg, http.StatusUnsupportedMediaType) + return + } + compressor = comp + } + + ctx, md := newIncomingContext(r.Context(), r.Header) + + if v := r.Header.Get("grpc-timeout"); v != "" { + to, err := decodeTimeout(v) + if err != nil { + msg := fmt.Sprintf("malformed grpc-timeout: %v", err) + http.Error(w, msg, http.StatusBadRequest) + return + } + tctx, cancel := context.WithTimeout(ctx, to) + defer cancel() + ctx = tctx + } + + method := r.URL.Path + s := m.loadState() + hd, err := s.pickMethodHandler(method) + if err != nil { + msg := fmt.Sprintf("no handler for gRPC method %q", method) + http.Error(w, msg, http.StatusNotFound) + return + } + + // Handle stats. + beginTime := time.Now() + if sh := m.opts.statsHandler; sh != nil { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{ + FullMethodName: hd.method, + FailFast: false, // TODO + }) + + sh.HandleRPC(ctx, &stats.InHeader{ + FullMethod: method, + RemoteAddr: strAddr(r.RemoteAddr), + Compression: r.Header.Get("Content-Encoding"), + Header: metadata.MD(md).Copy(), + }) + + sh.HandleRPC(ctx, &stats.Begin{ + Client: false, + BeginTime: beginTime, + FailFast: false, // TODO + IsClientStream: hd.desc.IsStreamingClient(), + IsServerStream: hd.desc.IsStreamingServer(), + IsTransparentRetryAttempt: false, // TODO + }) + } + + ctx, cancel := context.WithCancel(ctx) + stream := &streamGRPC{ + ctx: ctx, + handler: hd, + opts: m.opts, + codec: codec, + comp: compressor, + done: ctx.Done(), + + // write + w: w, + wHeader: w.Header(), + + // read + r: r.Body, + contentType: contentType, + messageEncoding: messageEncoding, + //rHeader: r.Header, + } + // Sync stream return on stream methods. + defer func() { + cancel() + stream.wg.Wait() + }() + + herr := hd.handler(&m.opts, stream) + if !stream.sentHeader { + if err := stream.SendHeader(nil); err != nil { + return // ctx canceled + } + } + flusher.Flush() + r.Body.Close() + + // Write status. + st := status.Convert(herr) + + h := w.Header() + h.Set("Grpc-Status", strconv.FormatInt(int64(st.Code()), 10)) + if m := st.Message(); m != "" { + h.Set("Grpc-Message", encodeGrpcMessage(m)) + } + if p := st.Proto(); p != nil && len(p.Details) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + panic(err) + } + h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes)) + } + setOutgoingHeader(h, stream.trailer) + + if sh := m.opts.statsHandler; sh != nil { + endTime := time.Now() + + // Try to send Trailers, might not be respected. + setOutgoingHeader(w.Header(), stream.trailer) + sh.HandleRPC(ctx, &stats.OutTrailer{ + Trailer: stream.trailer.Copy(), + }) + + sh.HandleRPC(ctx, &stats.End{ + Client: false, + BeginTime: beginTime, + EndTime: endTime, + Error: herr, + }) + } +} diff --git a/larking/grpc_test.go b/larking/grpc_test.go new file mode 100644 index 0000000..67fab08 --- /dev/null +++ b/larking/grpc_test.go @@ -0,0 +1,284 @@ +package larking + +import ( + "context" + "fmt" + "net" + "net/http" + "os" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/encoding" + grpc_testing "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestGRPC(t *testing.T) { + // Create test server. + ts := grpc_testing.UnimplementedTestServiceServer{} + + o := new(overrides) + m, err := NewMux( + UnaryServerInterceptorOption(o.unary()), + StreamServerInterceptorOption(o.stream()), + ) + if err != nil { + t.Fatalf("failed to create mux: %v", err) + } + grpc_testing.RegisterTestServiceServer(m, ts) + + index := http.HandlerFunc(m.serveGRPC) + + h2s := &http2.Server{} + hs := &http.Server{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, // 1 MB + Handler: h2c.NewHandler(index, h2s), + } + if err := http2.ConfigureServer(hs, h2s); err != nil { + t.Fatalf("failed to configure server: %v", err) + } + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer lis.Close() + + // Start server. + go func() { + if err := hs.Serve(lis); err != nil && err != http.ErrServerClosed { + fmt.Println(err) + os.Exit(1) + } + }() + defer hs.Close() + + encoding.RegisterCompressor(&CompressorGzip{}) + + conns := []struct { + name string + opts []grpc.DialOption + }{{ + name: "insecure", + opts: []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + }, + }, { + name: "compressed", + opts: []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")), + }, + }} + + // https://github.com/grpc/grpc/blob/master/src/proto/grpc/testing/test.proto + tests := []struct { + name string + method string + desc grpc.StreamDesc + inouts []any + }{{ + name: "unary", + method: "/grpc.testing.TestService/UnaryCall", + desc: grpc.StreamDesc{}, + inouts: []any{ + in{ + msg: &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + out{ + msg: &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + }, + }, { + name: "client_streaming", + method: "/grpc.testing.TestService/StreamingInputCall", + desc: grpc.StreamDesc{ + ClientStreams: true, + }, + inouts: []any{ + in{ + msg: &grpc_testing.StreamingInputCallRequest{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + in{ + msg: &grpc_testing.StreamingInputCallRequest{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + out{ + msg: &grpc_testing.StreamingInputCallResponse{ + AggregatedPayloadSize: 2, + }, + }, + }, + }, { + name: "server_streaming", + method: "/grpc.testing.TestService/StreamingOutputCall", + desc: grpc.StreamDesc{ + ServerStreams: true, + }, + inouts: []any{ + in{ + msg: &grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + out{ + msg: &grpc_testing.StreamingOutputCallResponse{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + out{ + msg: &grpc_testing.StreamingOutputCallResponse{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + }, + }, { + name: "full_streaming", + method: "/grpc.testing.TestService/FullDuplexCall", + desc: grpc.StreamDesc{ + ClientStreams: true, + ServerStreams: true, + }, + inouts: []any{ + in{ + msg: &grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + out{ + msg: &grpc_testing.StreamingOutputCallResponse{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + in{ + msg: &grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + out{ + msg: &grpc_testing.StreamingOutputCallResponse{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + }, + }, { + name: "half_streaming", + method: "/grpc.testing.TestService/HalfDuplexCall", + desc: grpc.StreamDesc{ + ClientStreams: true, + ServerStreams: true, + }, + inouts: []any{ + in{ + msg: &grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + in{ + msg: &grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + out{ + msg: &grpc_testing.StreamingOutputCallResponse{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + out{ + msg: &grpc_testing.StreamingOutputCallResponse{ + Payload: &grpc_testing.Payload{Body: []byte{0}}, + }, + }, + }, + }, { + name: "large_client_streaming", + method: "/grpc.testing.TestService/StreamingInputCall", + desc: grpc.StreamDesc{ + ClientStreams: true, + }, + inouts: []any{ + in{ + msg: &grpc_testing.StreamingInputCallRequest{ + Payload: &grpc_testing.Payload{Body: make([]byte, 1024)}, + }, + }, + in{ + msg: &grpc_testing.StreamingInputCallRequest{ + Payload: &grpc_testing.Payload{Body: make([]byte, 1024)}, + }, + }, + out{ + msg: &grpc_testing.StreamingInputCallResponse{ + AggregatedPayloadSize: 2, + }, + }, + }, + }} + + opts := cmp.Options{protocmp.Transform()} + for _, tc := range conns { + t.Run(tc.name, func(t *testing.T) { + conn, err := grpc.Dial(lis.Addr().String(), tc.opts...) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer conn.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o.reset(t, "test", tt.inouts) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ctx = metadata.AppendToOutgoingContext(ctx, "test", tt.method) + + stream, err := conn.NewStream(ctx, &tt.desc, tt.method) + if err != nil { + t.Fatalf("failed to create stream: %v", err) + } + + for i, inout := range tt.inouts { + + switch v := inout.(type) { + case in: + t.Logf("stream.SendMsg: %d", i) + if err := stream.SendMsg(v.msg); err != nil { + t.Fatalf("failed to send msg: %v", err) + } + case out: + t.Logf("stream.RecvMsg: %d", i) + want := v.msg + got := v.msg.ProtoReflect().New().Interface() + if err := stream.RecvMsg(got); err != nil { + t.Fatalf("failed to recv msg: %v", err) + } + diff := cmp.Diff(got, want, opts...) + if diff != "" { + t.Error(diff) + } + } + } + }) + } + }) + } +} diff --git a/larking/handler.go b/larking/handler.go index f584f2e..eb3bea8 100644 --- a/larking/handler.go +++ b/larking/handler.go @@ -17,9 +17,9 @@ import ( type handlerFunc func(*muxOptions, grpc.ServerStream) error type handler struct { - descriptor protoreflect.MethodDescriptor - handler handlerFunc - method string // /Service/Method + desc protoreflect.MethodDescriptor + handler handlerFunc + method string // /Service/Method } // TODO: use grpclog? @@ -74,8 +74,8 @@ func (m *Mux) registerService(gsd *grpc.ServiceDesc, ss interface{}) error { } h := &handler{ - method: method, - descriptor: md, + method: method, + desc: md, handler: func(opts *muxOptions, stream grpc.ServerStream) error { ctx := stream.Context() @@ -101,8 +101,8 @@ func (m *Mux) registerService(gsd *grpc.ServiceDesc, ss interface{}) error { } h := &handler{ - method: method, - descriptor: md, + method: method, + desc: md, handler: func(opts *muxOptions, stream grpc.ServerStream) error { info := &grpc.StreamServerInfo{ FullMethod: method, diff --git a/larking/http.go b/larking/http.go index 9a69470..78acae6 100644 --- a/larking/http.go +++ b/larking/http.go @@ -50,7 +50,9 @@ func (s *streamHTTP) SendHeader(md metadata.MD) error { return fmt.Errorf("already sent headers") } s.header = metadata.Join(s.header, md) - setOutgoingHeader(s.wHeader, s.header) + + h := s.wHeader + setOutgoingHeader(h, s.header) // don't write the header code, wait for the body. s.sentHeader = true @@ -75,9 +77,13 @@ func (s *streamHTTP) Context() context.Context { func (s *streamHTTP) writeMsg(c Codec, b []byte, contentType string) (int, error) { count := s.sendCount if count == 0 { - s.wHeader.Set("Content-Type", contentType) - setOutgoingHeader(s.wHeader, s.header, s.trailer) - s.sentHeader = true + h := s.wHeader + h.Set("Content-Type", contentType) + if !s.sentHeader { + if err := s.SendHeader(nil); err != nil { + return count, err + } + } } s.sendCount += 1 if s.method.desc.IsStreamingServer() { diff --git a/larking/mux.go b/larking/mux.go index cd7de48..6739cc0 100644 --- a/larking/mux.go +++ b/larking/mux.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "io" + "math" "math/rand" "net/http" "sort" @@ -157,6 +158,7 @@ type muxOptions struct { unaryInterceptor grpc.UnaryServerInterceptor streamInterceptor grpc.StreamServerInterceptor codecs map[string]Codec + codecsByName map[string]Codec compressors map[string]Compressor httprules ruleSelector contentTypeOffers []string @@ -215,6 +217,12 @@ func (o *muxOptions) stream(srv interface{}, ss grpc.ServerStream, info *grpc.St // MuxOption is an option for a mux. type MuxOption func(*muxOptions) +const ( + defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 + defaultServerMaxSendMessageSize = math.MaxInt32 + defaultServerConnectionTimeout = 120 * time.Second +) + var ( defaultMuxOptions = muxOptions{ maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, @@ -297,10 +305,9 @@ func ServiceConfigOption(sc *serviceconfig.Service) MuxOption { } type Mux struct { - opts muxOptions - state atomic.Value - services map[*grpc.ServiceDesc]interface{} - mu sync.Mutex + opts muxOptions + state atomic.Value + mu sync.Mutex } func NewMux(opts ...MuxOption) (*Mux, error) { @@ -319,6 +326,10 @@ func NewMux(opts ...MuxOption) (*Mux, error) { muxOpts.codecs[k] = v } } + muxOpts.codecsByName = make(map[string]Codec) + for _, v := range muxOpts.codecs { + muxOpts.codecsByName[v.Name()] = v + } for k := range muxOpts.codecs { muxOpts.contentTypeOffers = append(muxOpts.contentTypeOffers, k) } @@ -699,9 +710,9 @@ func createConnHandler( } return &handler{ - method: method, - descriptor: md, - handler: h, + method: method, + desc: md, + handler: h, } } else { info := &grpc.UnaryServerInfo{ @@ -736,9 +747,9 @@ func createConnHandler( } return &handler{ - method: method, - descriptor: md, - handler: h, + method: method, + desc: md, + handler: h, } } } @@ -884,8 +895,8 @@ func (m *Mux) serveHTTP(w http.ResponseWriter, r *http.Request) error { Client: false, BeginTime: beginTime, FailFast: false, // TODO - IsClientStream: hd.descriptor.IsStreamingClient(), - IsServerStream: hd.descriptor.IsStreamingServer(), + IsClientStream: hd.desc.IsStreamingClient(), + IsServerStream: hd.desc.IsStreamingServer(), IsTransparentRetryAttempt: false, // TODO }) } @@ -1012,7 +1023,23 @@ func (m *Mux) serveHTTP(w http.ResponseWriter, r *http.Request) error { return nil } +// ServeHTTP implements http.Handler. +// It supports both gRPC and HTTP requests. func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor == 2 && strings.HasPrefix( + r.Header.Get("Content-Type"), "application/grpc", + ) { + m.serveGRPC(w, r) + return + } + + if strings.HasPrefix( + r.Header.Get("Content-Type"), "application/grpc-web", + ) { + m.serveGRPCWeb(w, r) + return + } + if !strings.HasPrefix(r.URL.Path, "/") { r.URL.Path = "/" + r.URL.Path } diff --git a/larking/proxy.go b/larking/proxy.go index a920766..e5f1812 100644 --- a/larking/proxy.go +++ b/larking/proxy.go @@ -7,7 +7,6 @@ package larking import ( "io" - "golang.org/x/net/context" "google.golang.org/grpc" rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/protobuf/proto" @@ -16,34 +15,6 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" ) -func isStreamError(err error) bool { - switch err { - case nil: - return false - case io.EOF: - return false - case context.Canceled: - return false - } - return true -} - -// StreamHandler returns a gRPC stream handler to proxy gRPC requests. -func (m *Mux) StreamHandler() grpc.StreamHandler { - return func(srv interface{}, stream grpc.ServerStream) error { - ctx := stream.Context() - name, _ := grpc.Method(ctx) - s := m.loadState() - - hd, err := s.pickMethodHandler(name) - if err != nil { - return err - } - - return hd.handler(&m.opts, stream) - } -} - // TODO: fetch type on a per stream basis type serverReflectionServer struct { rpb.UnimplementedServerReflectionServer @@ -61,8 +32,9 @@ func (m *Mux) RegisterReflectionServer(s *grpc.Server) { }) } -//nolint:unused // fileDescEncodingByFilename finds the file descriptor for given filename, // does marshalling on it and returns the marshalled result. +// +//nolint:unused // fileDescEncodingByFilename finds the file descriptor for given filename, func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) { fd, err := protoregistry.GlobalFiles.FindFileByPath(name) if err != nil { @@ -71,9 +43,10 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte return proto.Marshal(protodesc.ToFileDescriptorProto(fd)) } -//nolint:unused // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, // does marshalling on it and returns the marshalled result. // The given symbol can be a type, a service or a method. +// +//nolint:unused // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) { fullname := protoreflect.FullName(name) d, err := protoregistry.GlobalFiles.FindDescriptorByName(fullname) @@ -84,8 +57,9 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ( return proto.Marshal(protodesc.ToFileDescriptorProto(fd)) } -//nolint:unused // fileDescEncodingContainingExtension finds the file descriptor containing given extension, // does marshalling on it and returns the marshalled result. +// +//nolint:unused // fileDescEncodingContainingExtension finds the file descriptor containing given extension, func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) { fullname := protoreflect.FullName(typeName) fieldnumber := protoreflect.FieldNumber(extNum) diff --git a/larking/proxy_test.go b/larking/proxy_test.go index 6da5110..ee6215f 100644 --- a/larking/proxy_test.go +++ b/larking/proxy_test.go @@ -7,6 +7,7 @@ package larking import ( "context" "net" + "net/http" "testing" "github.com/google/go-cmp/cmp" @@ -42,7 +43,9 @@ func TestGRPCProxy(t *testing.T) { var g errgroup.Group defer func() { if err := g.Wait(); err != nil { - t.Fatal(err) + if err != http.ErrServerClosed { + t.Fatal(err) + } } }() @@ -75,20 +78,18 @@ func TestGRPCProxy(t *testing.T) { } defer lisProxy.Close() - ts := grpc.NewServer( - grpc.UnknownServiceHandler(h.StreamHandler()), - ) + ts, err := NewServer(h) + if err != nil { + t.Fatal(err) + } g.Go(func() error { return ts.Serve(lisProxy) }) - defer ts.Stop() + defer ts.Close() cc, err := grpc.Dial( lisProxy.Addr().String(), - //grpc.WithTransportCredentials( - // credentials.NewTLS(transport.TLSClientConfig), - //), grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { diff --git a/larking/rules.go b/larking/rules.go index 19d062e..7f301c5 100644 --- a/larking/rules.go +++ b/larking/rules.go @@ -6,12 +6,10 @@ package larking import ( "bytes" - "context" "encoding/base64" "encoding/json" "fmt" "net/http" - "net/textproto" "net/url" "sort" "strconv" @@ -20,7 +18,6 @@ import ( "google.golang.org/genproto/googleapis/api/annotations" _ "google.golang.org/genproto/googleapis/api/httpbody" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" @@ -763,28 +760,3 @@ func (p *path) match(route, verb string) (*method, params, error) { } return p.search(l.tokens(), verb) } - -const httpHeaderPrefix = "http-" - -func newIncomingContext(ctx context.Context, header http.Header) (context.Context, metadata.MD) { - md := make(metadata.MD, len(header)) - for k, vs := range header { - md[httpHeaderPrefix+strings.ToLower(k)] = vs - } - return metadata.NewIncomingContext(ctx, md), md -} - -func setOutgoingHeader(header http.Header, mds ...metadata.MD) { - for _, md := range mds { - for k, vs := range md { - if !strings.HasPrefix(k, httpHeaderPrefix) { - continue - } - k = k[len(httpHeaderPrefix):] - if len(k) == 0 { - continue - } - header[textproto.CanonicalMIMEHeaderKey(k)] = vs - } - } -} diff --git a/larking/rules_test.go b/larking/rules_test.go index 0706eaa..c744c10 100644 --- a/larking/rules_test.go +++ b/larking/rules_test.go @@ -221,10 +221,10 @@ func TestMessageServer(t *testing.T) { // TODO: compare http.Response output tests := []struct { - want want + name string inouts []any + want want req *http.Request - name string }{{ name: "first", req: httptest.NewRequest(http.MethodGet, "/v1/messages/name/hello", nil), @@ -894,7 +894,7 @@ func TestMessageServer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - o.reset(t, "http-test", tt.inouts) + o.reset(t, "test", tt.inouts) req := tt.req if len(tt.inouts) > 0 { diff --git a/larking/server.go b/larking/server.go index ef616ac..cacf390 100644 --- a/larking/server.go +++ b/larking/server.go @@ -8,21 +8,14 @@ import ( "context" "crypto/tls" "fmt" - "math" - "net" "net/http" "os" "os/signal" - "runtime" "strings" "time" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" - "golang.org/x/net/trace" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/encoding" ) // NewOSSignalContext tries to gracefully handle OS closure. @@ -45,21 +38,10 @@ func NewOSSignalContext(ctx context.Context) (context.Context, func()) { } } -type Server struct { - opts serverOptions - mux *Mux - - gs *grpc.Server - hs *http.Server - h2s *http2.Server - - events trace.EventLog -} - -// NewServer creates a new Proxy server. +// NewServer creates a new http.Server with http2 support. // The server is configured with the given options. -// Codecs and Compressors are registered with the grpc.Server. -func NewServer(mux *Mux, opts ...ServerOption) (*Server, error) { +// It is a convenience function for creating a new http.Server. +func NewServer(mux *Mux, opts ...ServerOption) (*http.Server, error) { if mux == nil { return nil, fmt.Errorf("invalid mux must not be nil") } @@ -70,134 +52,40 @@ func NewServer(mux *Mux, opts ...ServerOption) (*Server, error) { return nil, err } } - if svrOpts.tlsConfig == nil && !svrOpts.insecure { - return nil, fmt.Errorf("credentials must be set") - } - svrOpts.serveMux = http.NewServeMux() + h := svrOpts.serveMux + if h == nil { + h = http.NewServeMux() + } if len(svrOpts.muxPatterns) == 0 { svrOpts.muxPatterns = []string{"/"} } for _, pattern := range svrOpts.muxPatterns { prefix := strings.TrimSuffix(pattern, "/") if len(prefix) > 0 { - svrOpts.serveMux.Handle(prefix+"/", http.StripPrefix(prefix, mux)) + h.Handle(prefix+"/", http.StripPrefix(prefix, mux)) } else { - svrOpts.serveMux.Handle("/", mux) + h.Handle("/", mux) } } - // TODO: use our own flag? - // grpc.EnableTracing sets tracing for the golang.org/x/net/trace - var events trace.EventLog - if grpc.EnableTracing { - _, file, line, _ := runtime.Caller(1) - events = trace.NewEventLog("larking.Server", fmt.Sprintf("%s:%d", file, line)) - } - - var grpcOpts []grpc.ServerOption - - grpcOpts = append(grpcOpts, grpc.UnknownServiceHandler(mux.StreamHandler())) - if i := mux.opts.unaryInterceptor; i != nil { - grpcOpts = append(grpcOpts, grpc.UnaryInterceptor(i)) - } - if i := mux.opts.streamInterceptor; i != nil { - grpcOpts = append(grpcOpts, grpc.StreamInterceptor(i)) - } - if h := mux.opts.statsHandler; h != nil { - grpcOpts = append(grpcOpts, grpc.StatsHandler(h)) - } - - // Register codecs - for _, c := range mux.opts.codecs { - encoding.RegisterCodec(c) - } - // Register compressors - for _, c := range mux.opts.compressors { - if c != nil { - encoding.RegisterCompressor(c) - } - } - - // TLS termination controlled by listeners in Serve. - creds := insecure.NewCredentials() - grpcOpts = append(grpcOpts, grpc.Creds(creds)) - - gs := grpc.NewServer(grpcOpts...) - // Register local gRPC services - for sd, ss := range mux.services { - gs.RegisterService(sd, ss) - } - serveWeb := createGRPCWebHandler(gs) - index := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - contentType := r.Header.Get("content-type") - if strings.HasPrefix(contentType, grpcWeb) { - serveWeb(w, r) - } else if r.ProtoMajor == 2 && strings.HasPrefix(contentType, grpcBase) { - gs.ServeHTTP(w, r) - } else { - svrOpts.serveMux.ServeHTTP(w, r) - } - }) h2s := &http2.Server{} hs := &http.Server{ - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - MaxHeaderBytes: 1 << 20, // 1 MB - Handler: h2c.NewHandler(index, h2s), - TLSConfig: svrOpts.tlsConfig, + ReadHeaderTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, // 1 MB + Handler: h2c.NewHandler(h, h2s), + TLSConfig: svrOpts.tlsConfig, } if err := http2.ConfigureServer(hs, h2s); err != nil { return nil, err } - - return &Server{ - opts: svrOpts, - mux: mux, - gs: gs, - hs: hs, - h2s: h2s, - events: events, - }, nil -} - -// Serve accepts incoming connections on the listener. -// Serve will return always return a non-nil error, http.ErrServerClosed. -func (s *Server) Serve(l net.Listener) error { - if config := s.opts.tlsConfig; config != nil { - l = tls.NewListener(l, config) - } - return s.hs.Serve(l) -} - -func (s *Server) Shutdown(ctx context.Context) error { - if s.events != nil { - s.events.Finish() - s.events = nil - } - if err := s.hs.Shutdown(ctx); err != nil { - return err - } - return nil -} - -func (s *Server) Close() error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return s.Shutdown(ctx) + return hs, nil } -const ( - defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 - defaultServerMaxSendMessageSize = math.MaxInt32 - defaultServerConnectionTimeout = 120 * time.Second -) - type serverOptions struct { tlsConfig *tls.Config serveMux *http.ServeMux muxPatterns []string - insecure bool } // ServerOption is similar to grpc.ServerOption. @@ -210,21 +98,6 @@ func TLSCredsOption(c *tls.Config) ServerOption { } } -func InsecureServerOption() ServerOption { - return func(opts *serverOptions) error { - opts.insecure = true - return nil - } -} - -//func LarkingServerOption(threads map[string]string) ServerOption { -// return func(opts *serverOptions) error { -// opts.larkingEnabled = true -// opts.larkingThreads = threads -// return nil -// } -//} - func MuxHandleOption(patterns ...string) ServerOption { return func(opts *serverOptions) error { if opts.muxPatterns != nil { @@ -244,16 +117,3 @@ func HTTPHandlerOption(pattern string, handler http.Handler) ServerOption { return nil } } - -//func AdminOption(addr string) ServerOption { -// return func(opts *serverOptions) { -// -// } -//} - -//func (s *Server) RegisterService(desc *grpc.ServiceDesc, impl interface{}) { -// s.gs.RegisterService(desc, impl) -// if s.opts.mux != nil { -// s.opts.mux.RegisterService(desc, impl) -// } -//} diff --git a/larking/server_test.go b/larking/server_test.go index f9bffea..9ad6352 100644 --- a/larking/server_test.go +++ b/larking/server_test.go @@ -84,7 +84,7 @@ func TestServer(t *testing.T) { t.Fatal(err) } - ts, err := NewServer(mux, InsecureServerOption()) + ts, err := NewServer(mux) if err != nil { t.Fatal(err) } @@ -188,7 +188,6 @@ func TestMuxHandleOption(t *testing.T) { s, err := NewServer( mux, - InsecureServerOption(), MuxHandleOption("/", "/api/", "/pfx"), ) if err != nil { @@ -426,7 +425,7 @@ func TestTLSServer(t *testing.T) { } g := errgroup.Group{} - g.Go(func() error { return s.Serve(l) }) + g.Go(func() error { return s.ServeTLS(l, "", "") }) defer func() { if err := s.Shutdown(ctx); err != nil { t.Error(err) diff --git a/larking/web.go b/larking/web.go index 9f226bc..64a2e13 100644 --- a/larking/web.go +++ b/larking/web.go @@ -16,8 +16,6 @@ import ( "io" "net/http" "strings" - - "google.golang.org/grpc" ) const ( @@ -146,34 +144,32 @@ type readCloser struct { io.Closer } -func createGRPCWebHandler(gs *grpc.Server) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - typ, enc, ok := isWebRequest(r) - if !ok { - msg := fmt.Sprintf("invalid gRPC-Web content type: %v", r.Header.Get("Content-Type")) - http.Error(w, msg, http.StatusBadRequest) - return - } - // TODO: Check for websocket request and upgrade. - if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { - http.Error(w, "unimplemented websocket support", http.StatusInternalServerError) - return - } - - r.ProtoMajor = 2 - r.ProtoMinor = 0 +func (m *Mux) serveGRPCWeb(w http.ResponseWriter, r *http.Request) { + typ, enc, ok := isWebRequest(r) + if !ok { + msg := fmt.Sprintf("invalid gRPC-Web content type: %v", r.Header.Get("Content-Type")) + http.Error(w, msg, http.StatusBadRequest) + return + } + // TODO: Check for websocket request and upgrade. + if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + http.Error(w, "unimplemented websocket support", http.StatusInternalServerError) + return + } - hdr := r.Header - hdr.Del("Content-Length") - hdr.Set("Content-Type", grpcBase+"+"+enc) + r.ProtoMajor = 2 + r.ProtoMinor = 0 - if typ == grpcWebText { - body := base64.NewDecoder(base64.StdEncoding, r.Body) - r.Body = readCloser{body, r.Body} - } + hdr := r.Header + hdr.Del("Content-Length") + hdr.Set("Content-Type", grpcBase+"+"+enc) - ww := newWebWriter(w, typ, enc) - gs.ServeHTTP(ww, r) - ww.flushWithTrailer() + if typ == grpcWebText { + body := base64.NewDecoder(base64.StdEncoding, r.Body) + r.Body = readCloser{body, r.Body} } + + ww := newWebWriter(w, typ, enc) + m.serveGRPC(ww, r) + ww.flushWithTrailer() } diff --git a/larking/web_test.go b/larking/web_test.go index 9da6c7a..f896f21 100644 --- a/larking/web_test.go +++ b/larking/web_test.go @@ -14,7 +14,6 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/genproto/googleapis/rpc/status" - "google.golang.org/grpc" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" @@ -27,10 +26,15 @@ func TestWeb(t *testing.T) { ms := &testpb.UnimplementedMessagingServer{} o := new(overrides) - gs := grpc.NewServer(o.unaryOption(), o.streamOption()) + m, err := NewMux( + UnaryServerInterceptorOption(o.unary()), + StreamServerInterceptorOption(o.stream()), + ) + if err != nil { + t.Fatalf("failed to create mux: %v", err) + } - testpb.RegisterMessagingServer(gs, ms) - h := createGRPCWebHandler(gs) + testpb.RegisterMessagingServer(m, ms) type want struct { msg proto.Message @@ -95,8 +99,7 @@ func TestWeb(t *testing.T) { req.Header["test"] = []string{tt.in.method} w := httptest.NewRecorder() - //s.gs.ServeHTTP(w, req) - h.ServeHTTP(w, req) + m.ServeHTTP(w, req) resp := w.Result() t.Log("resp", resp) diff --git a/larking/websocket_test.go b/larking/websocket_test.go index f5db914..cfb3e3a 100644 --- a/larking/websocket_test.go +++ b/larking/websocket_test.go @@ -50,7 +50,7 @@ func TestWebsocket(t *testing.T) { } mux.RegisterService(&testpb.ChatRoom_ServiceDesc, fs) - s, err := NewServer(mux, InsecureServerOption()) + s, err := NewServer(mux) if err != nil { t.Fatal(err) } @@ -120,7 +120,7 @@ func TestWebsocket(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - o.reset(t, "http-test", tt.server) + o.reset(t, "test", tt.server) ctx, cancel := context.WithTimeout(testContext(t), time.Minute) defer cancel()