From ea921a61968e74494a60579d5eb407838e11abd8 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Fri, 30 Aug 2024 14:05:55 -0400 Subject: [PATCH] refactor: add `service` package to prepare for split HTTP handling Package service exposes types to abstract services from the networking. The idea is that we build a set of services and a set of network endpoints (Listener). The services are then assigned to endpoints based on the address(es) they were configured for. Actual service to endpoint binding is not handled by the abstractions in this package as it is protocol specific. The general pattern is to make a "server" that wraps a service, and can then be started on an endpoint using a `Serve` method, similar to `http.Server`. To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port), services can implement `Merger`. --- helpertest/tls.go | 60 +++++++++++++++ server/http.go | 57 +++++++++++--- server/server.go | 138 +++++++++++++++++++++------------- server/server_endpoints.go | 18 +++-- service/endpoint.go | 61 +++++++++++++++ service/endpoint_test.go | 48 ++++++++++++ service/http.go | 94 +++++++++++++++++++++++ service/http_test.go | 88 ++++++++++++++++++++++ service/listener.go | 74 ++++++++++++++++++ service/listener_test.go | 124 ++++++++++++++++++++++++++++++ service/merge.go | 57 ++++++++++++++ service/merge_test.go | 41 ++++++++++ service/service.go | 107 ++++++++++++++++++++++++++ service/service_suite_test.go | 18 +++++ service/service_test.go | 88 ++++++++++++++++++++++ util/slices.go | 37 +++++++++ util/slices_test.go | 36 +++++++++ 17 files changed, 1075 insertions(+), 71 deletions(-) create mode 100644 helpertest/tls.go create mode 100644 service/endpoint.go create mode 100644 service/endpoint_test.go create mode 100644 service/http.go create mode 100644 service/http_test.go create mode 100644 service/listener.go create mode 100644 service/listener_test.go create mode 100644 service/merge.go create mode 100644 service/merge_test.go create mode 100644 service/service.go create mode 100644 service/service_suite_test.go create mode 100644 service/service_test.go create mode 100644 util/slices.go create mode 100644 util/slices_test.go diff --git a/helpertest/tls.go b/helpertest/tls.go new file mode 100644 index 000000000..91613cc6b --- /dev/null +++ b/helpertest/tls.go @@ -0,0 +1,60 @@ +package helpertest + +import ( + "crypto/tls" + "crypto/x509" + "sync" + + "github.com/0xERR0R/blocky/util" + . "github.com/onsi/gomega" +) + +const tlsTestServerName = "test.blocky.invalid" + +type tlsData struct { + ServerCfg *tls.Config + ClientCfg *tls.Config +} + +// Lazy init +// +//nolint:gochecknoglobals +var ( + initTLSData sync.Once + tlsDataStorage tlsData +) + +func getTLSData() tlsData { + initTLSData.Do(func() { + cert, err := util.TLSGenerateSelfSignedCert([]string{tlsTestServerName}) + Expect(err).Should(Succeed()) + + tlsDataStorage.ServerCfg = &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS13, + } + + certPool := x509.NewCertPool() + certPool.AddCert(cert.Leaf) + + tlsDataStorage.ClientCfg = &tls.Config{ + RootCAs: certPool, + ServerName: tlsTestServerName, + MinVersion: tls.VersionTLS13, + } + }) + + return tlsDataStorage +} + +// TLSTestServerConfig returns a TLS Config for use by test servers. +func TLSTestServerConfig() *tls.Config { + return getTLSData().ServerCfg.Clone() +} + +// TLSTestServerConfig returns a TLS Config for use by test clients. +// +// This is required to connect to a test TLS server, otherwise TLS verification fails. +func TLSTestClientConfig() *tls.Config { + return getTLSData().ClientCfg.Clone() +} diff --git a/server/http.go b/server/http.go index cac0e8102..7c4da3230 100644 --- a/server/http.go +++ b/server/http.go @@ -6,17 +6,55 @@ import ( "net/http" "time" + "github.com/0xERR0R/blocky/api" + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/service" + "github.com/0xERR0R/blocky/util" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" ) +// httpMiscService implements service.HTTPService. +// +// This supports the existing single HTTP/HTTPS endpoints +// that expose everything. The goal is to split it up +// and remove it. +type httpMiscService struct { + service.HTTPInfo +} + +func newHTTPMiscService( + cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler, +) *httpMiscService { + endpoints := util.ConcatSlices( + service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP), + service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS), + ) + + return &httpMiscService{ + HTTPInfo: service.HTTPInfo{ + Info: service.Info{ + Name: "HTTP", + Endpoints: endpoints, + }, + + Mux: createHTTPRouter(cfg, openAPIImpl, dnsHandler), + }, + } +} + +func (s *httpMiscService) Merge(other service.Service) (service.Merger, error) { + return service.MergeHTTP(s, other) +} + +// httpServer implements subServer for HTTP. type httpServer struct { - inner http.Server + service.HTTPService - name string + inner http.Server } -func newHTTPServer(name string, handler http.Handler) *httpServer { +func newHTTPServer(svc service.HTTPService) *httpServer { const ( readHeaderTimeout = 20 * time.Second readTimeout = 20 * time.Second @@ -24,22 +62,17 @@ func newHTTPServer(name string, handler http.Handler) *httpServer { ) return &httpServer{ + HTTPService: svc, + inner: http.Server{ - ReadTimeout: readTimeout, + Handler: withCommonMiddleware(svc.Router()), ReadHeaderTimeout: readHeaderTimeout, + ReadTimeout: readTimeout, WriteTimeout: writeTimeout, - - Handler: withCommonMiddleware(handler), }, - - name: name, } } -func (s *httpServer) String() string { - return s.name -} - func (s *httpServer) Serve(ctx context.Context, l net.Listener) error { go func() { <-ctx.Done() diff --git a/server/server.go b/server/server.go index 7404541ad..e5d4817b2 100644 --- a/server/server.go +++ b/server/server.go @@ -9,6 +9,7 @@ import ( "net/http" "runtime" "runtime/debug" + "slices" "strings" "time" @@ -18,6 +19,8 @@ import ( "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/redis" "github.com/0xERR0R/blocky/resolver" + "github.com/0xERR0R/blocky/service" + "golang.org/x/exp/maps" "github.com/0xERR0R/blocky/util" "github.com/google/uuid" @@ -40,7 +43,14 @@ type Server struct { queryResolver resolver.ChainedResolver cfg *config.Config - servers map[net.Listener]*httpServer + services map[service.Listener]service.Service +} + +type subServer interface { + fmt.Stringer + service.Service + + Serve(context.Context, net.Listener) error } func logger() *logrus.Entry { @@ -99,8 +109,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) { } // NewServer creates new server instance with passed config -// -//nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { var tlsCfg *tls.Config @@ -116,7 +124,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, fmt.Errorf("server creation failed: %w", err) } - httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg) + listeners, err := createListeners(ctx, cfg, tlsCfg) if err != nil { return nil, err } @@ -145,39 +153,41 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err dnsServers: dnsServers, queryResolver: queryResolver, cfg: cfg, - - servers: make(map[net.Listener]*httpServer), } server.printConfiguration() server.registerDNSHandlers(ctx) - openAPIImpl, err := server.createOpenAPIInterfaceImpl() + services, err := server.createServices() if err != nil { return nil, err } - httpRouter := createHTTPRouter(cfg, openAPIImpl) - server.registerDoHEndpoints(httpRouter) + server.services, err = service.GroupByListener(services, listeners) + if err != nil { + return nil, err + } - if len(cfg.Ports.HTTP) != 0 { - srv := newHTTPServer("http", httpRouter) + return server, err +} - for _, l := range httpListeners { - server.servers[l] = srv - } +func (s *Server) createServices() ([]service.Service, error) { + openAPIImpl, err := s.createOpenAPIInterfaceImpl() + if err != nil { + return nil, err } - if len(cfg.Ports.HTTPS) != 0 { - srv := newHTTPServer("https", httpRouter) - - for _, l := range httpsListeners { - server.servers[l] = srv - } + res := []service.Service{ + newHTTPMiscService(s.cfg, openAPIImpl, s.handleReq), } - return server, err + // Remove services the user has not enabled + res = slices.DeleteFunc(res, func(svc service.Service) bool { + return len(svc.ExposeOn()) == 0 + }) + + return res, nil } func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) { @@ -208,48 +218,51 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error return dnsServers, err.ErrorOrNil() } -func createHTTPListeners( - cfg *config.Config, tlsCfg *tls.Config, -) (httpListeners, httpsListeners []net.Listener, err error) { - httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP) - if err != nil { - return nil, nil, err +func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) { + res := make(map[string]service.Listener) + + listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) { + return service.ListenTLS(ctx, endpoint, tlsCfg) } - httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg) + err := errors.Join( + newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res), + newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res), + ) if err != nil { - return nil, nil, err + return nil, err } - return httpListeners, httpsListeners, nil + return maps.Values(res), nil } -func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) { - listeners := make([]net.Listener, 0, len(addresses)) - - for _, address := range addresses { - listener, err := net.Listen("tcp", address) - if err != nil { - return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err) +type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error) + +func newListeners[T service.Listener]( + ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener, +) error { + for _, addr := range addrs { + key := fmt.Sprintf("%s:%s", proto, addr) + if _, ok := out[key]; ok { + // Avoid "address already in use" + // We instead try to merge services, see services.GroupByListener + continue } - listeners = append(listeners, listener) - } - - return listeners, nil -} + endpoint := service.Endpoint{ + Protocol: proto, + AddrConf: addr, + } -func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) { - listeners, err := newTCPListeners(proto, addresses) - if err != nil { - return nil, err - } + l, err := listen(ctx, endpoint) + if err != nil { + return err // already has all info + } - for i, inner := range listeners { - listeners[i] = tls.NewListener(inner, tlsCfg) + out[key] = l } - return listeners, nil + return nil } func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) { @@ -385,6 +398,16 @@ func toMB(b uint64) uint64 { return b / bytesInKB / bytesInKB } +func newSubServer(svc service.Service) (subServer, error) { + switch svc := svc.(type) { + case service.HTTPService: + return newHTTPServer(svc), nil + + default: + return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc) + } +} + // Start starts the server func (s *Server) Start(ctx context.Context, errCh chan<- error) { logger().Info("Starting server") @@ -399,11 +422,18 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) { }() } - for listener, srv := range s.servers { - listener, srv := listener, srv + for listener, svc := range s.services { + listener, svc := listener, svc + + srv, err := newSubServer(svc) + if err != nil { + errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err) + + return + } go func() { - logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr()) + logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes()) err := srv.Serve(ctx, listener) if err != nil { @@ -506,6 +536,8 @@ type msgWriter interface { WriteMsg(msg *dns.Msg) error } +type dnsHandler func(context.Context, *model.Request, msgWriter) + func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) { response, err := s.resolve(ctx, request) if err != nil { diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 1fb3db602..cac287f43 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -52,9 +52,11 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil } -func (s *Server) registerDoHEndpoints(router *chi.Mux) { +func registerDoHEndpoints(router *chi.Mux, dnsHandler dnsHandler) { const pathDohQuery = "/dns-query" + s := &dohServer{dnsHandler} + router.Get(pathDohQuery, s.dohGetRequestHandler) router.Get(pathDohQuery+"/", s.dohGetRequestHandler) router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler) @@ -63,7 +65,9 @@ func (s *Server) registerDoHEndpoints(router *chi.Mux) { router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler) } -func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { +type dohServer struct{ handler dnsHandler } + +func (s *dohServer) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { dnsParam, ok := req.URL.Query()["dns"] if !ok || len(dnsParam[0]) < 1 { http.Error(rw, "dns param is missing", http.StatusBadRequest) @@ -87,7 +91,7 @@ func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) s.processDohMessage(rawMsg, rw, req) } -func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) { +func (s *dohServer) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) { contentType := req.Header.Get("Content-type") if contentType != dnsContentType { http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType) @@ -111,7 +115,7 @@ func (s *Server) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request s.processDohMessage(rawMsg, rw, req) } -func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { +func (s *dohServer) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { msg := new(dns.Msg) if err := msg.Unpack(rawMsg); err != nil { logger().Error("can't deserialize message: ", err) @@ -122,7 +126,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpRe ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg) - s.handleReq(ctx, dnsReq, httpMsgWriter{rw}) + s.handler(ctx, dnsReq, httpMsgWriter{rw}) } type httpMsgWriter struct { @@ -156,7 +160,7 @@ func (s *Server) Query( return s.resolve(ctx, req) } -func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { +func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler) *chi.Mux { router := chi.NewRouter() api.RegisterOpenAPIEndpoints(router, openAPIImpl) @@ -169,6 +173,8 @@ func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) configureRootHandler(cfg, router) + registerDoHEndpoints(router, dnsHandler) + metrics.Start(router, cfg.Prometheus) return router diff --git a/service/endpoint.go b/service/endpoint.go new file mode 100644 index 000000000..b76c57a72 --- /dev/null +++ b/service/endpoint.go @@ -0,0 +1,61 @@ +package service + +import ( + "fmt" + "slices" + "strings" + + "github.com/0xERR0R/blocky/util" + "golang.org/x/exp/maps" +) + +// Endpoint is a network endpoint on which to expose a service. +type Endpoint struct { + // Protocol is the protocol to be exposed on this endpoint. + Protocol string + + // AddrConf is the network address as configured by the user. + AddrConf string +} + +func EndpointsFromAddrs(proto string, addrs []string) []Endpoint { + return util.ConvertEach(addrs, func(addr string) Endpoint { + return Endpoint{ + Protocol: proto, + AddrConf: addr, + } + }) +} + +func (e Endpoint) String() string { + addr := e.AddrConf + if strings.HasPrefix(addr, ":") { + addr = "*" + addr + } + + return fmt.Sprintf("%s://%s", e.Protocol, addr) +} + +type endpointSet map[Endpoint]struct{} + +func newEndpointSet(endpoints ...Endpoint) endpointSet { + s := make(endpointSet, len(endpoints)) + + for _, endpoint := range endpoints { + s[endpoint] = struct{}{} + } + + return s +} + +func (s endpointSet) ToSlice() []Endpoint { + return maps.Keys(s) +} + +func (s endpointSet) IntersectSlice(others []Endpoint) { + for endpoint := range s { + if !slices.Contains(others, endpoint) { + delete(s, endpoint) + } + } +} diff --git a/service/endpoint_test.go b/service/endpoint_test.go new file mode 100644 index 000000000..dfd81c592 --- /dev/null +++ b/service/endpoint_test.go @@ -0,0 +1,48 @@ +package service + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Endpoints", func() { + Describe("EndpointsFromAddrs", func() { + It("assigns the expected values", func() { + Expect(EndpointsFromAddrs("proto", []string{":1", "localhost:2"})).Should(Equal([]Endpoint{ + {"proto", ":1"}, + {"proto", "localhost:2"}, + })) + }) + }) + + Describe("Endpoint", func() { + It("strings to a URL", func() { + sut := Endpoint{"proto", "addr:port/whatever!no format \000 expected?"} + + Expect(sut.String()).Should(Equal("proto://" + sut.AddrConf)) + }) + + It("strings with explicit wildcard host", func() { + sut := Endpoint{"https", ":443"} + + Expect(sut.String()).Should(Equal("https://*:443")) + }) + }) + + Describe("endpointSet", func() { + e1 := Endpoint{"proto", ":1"} + e2 := Endpoint{"proto", ":2"} + e3 := Endpoint{"proto", ":3"} + + sut := newEndpointSet(e1, e1, e2) + + It("should contain all elements", func() { + Expect(sut.ToSlice()).Should(SatisfyAll(HaveLen(2), ContainElements(e1, e2))) + }) + + It("should intersect common values", func() { + sut.IntersectSlice([]Endpoint{e2, e3}) + Expect(sut.ToSlice()).Should(Equal([]Endpoint{e2})) + }) + }) +}) diff --git a/service/http.go b/service/http.go new file mode 100644 index 000000000..6b634c821 --- /dev/null +++ b/service/http.go @@ -0,0 +1,94 @@ +package service + +import ( + "errors" + "net/http" + "strings" + + "github.com/0xERR0R/blocky/util" + "github.com/go-chi/chi/v5" +) + +const ( + HTTPProtocol = "http" + HTTPSProtocol = "https" +) + +// HTTPService is a Service using a HTTP router. +type HTTPService interface { + Service + Merger + + // Router returns the service's router. + Router() chi.Router +} + +// HTTPInfo can be embedded in structs to help implement HTTPService. +type HTTPInfo struct { + Info + + Mux *chi.Mux +} + +func (i *HTTPInfo) Router() chi.Router { return i.Mux } + +// MergeHTTP merges two compatible HTTPServices. +// +// The second parameter is of type `Service` to make it easy to call +// from a `Merger.Merge` implementation. +func MergeHTTP(a HTTPService, b Service) (Merger, error) { + return newHTTPMerger(a).Merge(b) +} + +var _ HTTPService = (*httpMerger)(nil) + +// httpMerger can merge HTTPServices by combining their routes. +type httpMerger struct { + inner []HTTPService + router chi.Router + endpoints endpointSet +} + +func newHTTPMerger(first HTTPService) *httpMerger { + return &httpMerger{ + inner: []HTTPService{first}, + router: first.Router(), + endpoints: newEndpointSet(first.ExposeOn()...), + } +} + +func (m *httpMerger) String() string { return svcString(m) } + +func (m *httpMerger) ServiceName() string { + names := util.ConvertEach(m.inner, func(svc HTTPService) string { + return svc.ServiceName() + }) + + return strings.Join(names, " & ") +} + +func (m *httpMerger) ExposeOn() []Endpoint { return m.endpoints.ToSlice() } +func (m *httpMerger) Router() chi.Router { return m.router } + +func (m *httpMerger) Merge(other Service) (Merger, error) { + httpSvc, ok := other.(HTTPService) + if !ok { + return nil, errors.New("not an HTTPService") + } + + type middleware = func(http.Handler) http.Handler + + // Can't do `.Mount("/", ...)` otherwise we can only merge at most once since / will already be defined + _ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error { + m.router.With(middlewares...).Method(method, route, handler) + + return nil + }) + + m.inner = append(m.inner, httpSvc) + + // Don't expose any service more than it expects + m.endpoints.IntersectSlice(other.ExposeOn()) + + return m, nil +} diff --git a/service/http_test.go b/service/http_test.go new file mode 100644 index 000000000..20c1a76e7 --- /dev/null +++ b/service/http_test.go @@ -0,0 +1,88 @@ +package service + +import ( + "github.com/go-chi/chi/v5" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Service HTTP", func() { + var err error + + Describe("HTTPInfo", func() { + It("returns the expected router", func() { + endpoints := EndpointsFromAddrs("proto", []string{":1", "localhost:2"}) + sut := HTTPInfo{Info{"name", endpoints}, chi.NewMux()} + + Expect(sut.ServiceName()).Should(Equal("name")) + Expect(sut.ExposeOn()).Should(Equal(endpoints)) + }) + }) + + Describe("httpMerger", func() { + httpSvcA1 := newFakeHTTPService("HTTP A1", "a:1") + httpSvcA1_ := newFakeHTTPService("HTTP A1_", "a:1") + httpSvcB1 := newFakeHTTPService("HTTP B1", "b:1") + + sut := newHTTPMerger(httpSvcA1) + + nonHTTPSvc := &Info{"non HTTP service", EndpointsFromAddrs("proto", []string{":1"})} + + It("uses the given service", func() { + Expect(sut.String()).Should(Equal(httpSvcA1.String())) + Expect(sut.Router()).Should(BeIdenticalTo(httpSvcA1.Router())) + Expect(sut.ExposeOn()).Should(Equal(httpSvcA1.ExposeOn())) + }) + + It("can merge other HTTP services", func() { + merged, err := sut.Merge(httpSvcA1_) + Expect(err).Should(Succeed()) + Expect(merged).Should(BeIdenticalTo(sut)) + Expect(merged.String()).Should(SatisfyAll( + ContainSubstring(httpSvcA1.ServiceName()), + ContainSubstring(httpSvcA1_.ServiceName())), + ) + + By("merging the common endpoints", func() { + Expect(merged.ExposeOn()).Should(Equal(httpSvcA1.ExposeOn())) + }) + + By("merging another service again", func() { + merged, err = sut.Merge(httpSvcB1) + Expect(err).Should(Succeed()) + Expect(merged).Should(BeIdenticalTo(sut)) + }) + + By("excluding non-common endpoints", func() { + Expect(merged.ExposeOn()).Should(BeEmpty()) + }) + + By("including all HTTP routes", func() { + Expect(sut.Router().Routes()).Should(HaveLen(3)) + }) + }) + + It("cannot merge a non HTTP service", func() { + _, err = sut.Merge(nonHTTPSvc) + Expect(err).Should(MatchError(ContainSubstring("not an HTTPService"))) + }) + }) +}) + +type fakeHTTPService struct { + HTTPInfo +} + +func newFakeHTTPService(name string, addrs ...string) *fakeHTTPService { + mux := chi.NewMux() + mux.Get("/"+name, nil) + + return &fakeHTTPService{HTTPInfo{ + Info: Info{Name: name, Endpoints: EndpointsFromAddrs("http", addrs)}, + Mux: mux, + }} +} + +func (s *fakeHTTPService) Merge(other Service) (Merger, error) { + return MergeHTTP(s, other) +} diff --git a/service/listener.go b/service/listener.go new file mode 100644 index 000000000..507f6b33a --- /dev/null +++ b/service/listener.go @@ -0,0 +1,74 @@ +package service + +import ( + "context" + "crypto/tls" + "fmt" + "net" +) + +// Listener is a net.Listener that provides information about +// what protocol and address it is configured for. +type Listener interface { + fmt.Stringer + net.Listener + + // Exposes returns the endpoint for this listener. + // + // It can be used to find service(s) with matching configuration. + Exposes() Endpoint +} + +// ListenerInfo can be embedded in structs to help implement Listener. +type ListenerInfo struct { + Endpoint +} + +func (i *ListenerInfo) Exposes() Endpoint { return i.Endpoint } + +// NetListener implements Listener using an existing net.Listener. +type NetListener struct { + net.Listener + ListenerInfo +} + +func NewNetListener(endpoint Endpoint, inner net.Listener) *NetListener { + return &NetListener{ + Listener: inner, + ListenerInfo: ListenerInfo{endpoint}, + } +} + +// TCPListener is a Listener for a TCP socket. +type TCPListener struct{ NetListener } + +// ListenTCP creates a new TCPListener. +func ListenTCP(ctx context.Context, endpoint Endpoint) (*TCPListener, error) { + var lc net.ListenConfig + + l, err := lc.Listen(ctx, "tcp", endpoint.AddrConf) + if err != nil { + return nil, err // err already has all the info we could add + } + + inner := NewNetListener(endpoint, l) + + return &TCPListener{*inner}, nil +} + +// TLSListener is a Listener using TLS over TCP. +type TLSListener struct{ NetListener } + +// ListenTLS creates a new TLSListener. +func ListenTLS(ctx context.Context, endpoint Endpoint, cfg *tls.Config) (*TLSListener, error) { + tcp, err := ListenTCP(ctx, endpoint) + if err != nil { + return nil, err + } + + inner := tcp.NetListener + + inner.Listener = tls.NewListener(inner.Listener, cfg) + + return &TLSListener{inner}, nil +} diff --git a/service/listener_test.go b/service/listener_test.go new file mode 100644 index 000000000..c21c0f690 --- /dev/null +++ b/service/listener_test.go @@ -0,0 +1,124 @@ +package service + +import ( + "context" + "crypto/tls" + "net" + "time" + + "github.com/0xERR0R/blocky/helpertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Service Listener", func() { + var err error + + Describe("NetListener", func() { + It("uses the given data", func() { + var nl net.Listener + + endpoint := Endpoint{"proto", ":1"} + + sut := NewNetListener(endpoint, nl) + + var l Listener = sut + Expect(l.Exposes()).Should(Equal(endpoint)) + Expect(l.String()).Should(Equal(endpoint.String())) + }) + }) + + type entryFuncs struct { + Listen func(context.Context, Endpoint) (Listener, error) + Dial func(ctx context.Context, addr string) (net.Conn, error) + } + + DescribeTable("Listener Functions", + func(ctx context.Context, funcs entryFuncs) { + By("failing for an invalid endpoint", func() { + endpoint := Endpoint{"proto", "invalid!"} + + _, err := funcs.Listen(ctx, endpoint) + Expect(err).ShouldNot(Succeed()) + }) + + var l Listener + By("listening on a valid endpoint", func() { + endpoint := Endpoint{"proto", ":0"} + + l, err = funcs.Listen(ctx, endpoint) + Expect(err).Should(Succeed()) + DeferCleanup(l.Close) + + Expect(l.Exposes()).Should(Equal(endpoint)) + Expect(l.String()).Should(Equal(endpoint.String())) + }) + + ch := make(chan struct{}) + data := []byte("test") + + // Server goroutine + go func() { + defer GinkgoRecover() + + var ( + conn net.Conn + err error // separate var to avoid data-race + ) + By("accepting client connection", func() { + conn, err = l.Accept() + Expect(err).Should(Succeed()) + DeferCleanup(conn.Close) + }) + + By("sending data to the client", func() { + Expect(conn.Write(data)).Should(Equal(len(data))) + }) + + close(ch) + }() + + var conn net.Conn + By("connecting to server", func() { + conn, err = funcs.Dial(ctx, l.Addr().String()) + Expect(err).Should(Succeed()) + DeferCleanup(conn.Close) + }) + + By("receiving the expected data", func() { + buff := make([]byte, len(data)) + Expect(conn.Read(buff)).Should(Equal(len(data))) + Expect(buff).Should(Equal(data)) + }) + + // Ensure the server goroutine exit before the test ends + Eventually(ctx, ch).Should(BeClosed()) + }, + Entry("ListenTCP", + entryFuncs{ + Listen: func(ctx context.Context, endpoint Endpoint) (Listener, error) { + return ListenTCP(ctx, endpoint) + }, + Dial: func(ctx context.Context, addr string) (net.Conn, error) { + return new(net.Dialer).DialContext(ctx, "tcp", addr) + }, + }, + SpecTimeout(100*time.Millisecond), + ), + Entry("ListenTLS", + entryFuncs{ + Listen: func(ctx context.Context, endpoint Endpoint) (Listener, error) { + return ListenTLS(ctx, endpoint, helpertest.TLSTestServerConfig()) + }, + Dial: func(ctx context.Context, addr string) (net.Conn, error) { + d := tls.Dialer{ + Config: helpertest.TLSTestClientConfig(), + } + + return d.DialContext(ctx, "tcp", addr) + }, + }, + SpecTimeout(100*time.Millisecond), + ), + ) +}) diff --git a/service/merge.go b/service/merge.go new file mode 100644 index 000000000..7db06ef8f --- /dev/null +++ b/service/merge.go @@ -0,0 +1,57 @@ +package service + +import "errors" + +// Merger is a Service that can be merged with another compatible one. +type Merger interface { + Service + + // Merge returns the result of merging the receiver with the other Service. + // + // Neither the receiver, nor the other Service should be used directly after + // calling this method. + Merge(other Service) (Merger, error) +} + +// MergeAll merges the given services, if they are compatible. +// +// This allows using multiple compatible services with a single listener. +// +// All passed-in services must not be re-used. +func MergeAll(services ...Service) (Service, error) { + switch len(services) { + case 0: + return nil, errors.New("no services given") + + case 1: + return services[0], nil + } + + merger, err := firstMerger(services) + if err != nil { + return nil, err + } + + for _, svc := range services { + if svc == merger { + continue + } + + merger, err = merger.Merge(svc) + if err != nil { + return nil, err + } + } + + return merger, nil +} + +func firstMerger(services []Service) (Merger, error) { + for _, t := range services { + if svc, ok := t.(Merger); ok { + return svc, nil + } + } + + return nil, errors.New("no merger found") +} diff --git a/service/merge_test.go b/service/merge_test.go new file mode 100644 index 000000000..dcb57f9b1 --- /dev/null +++ b/service/merge_test.go @@ -0,0 +1,41 @@ +package service + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Service Merge", func() { + var err error + + Describe("MergeAll", func() { + httpSvcA1 := newFakeHTTPService("HTTP A1", "a:1") + httpSvcA1_ := newFakeHTTPService("HTTP A1_", "a:1") + + nonMergeableSvc := &Info{"non mergeable", EndpointsFromAddrs("http", []string{"a:1"})} + + It("fails when no services are given", func() { + _, err = MergeAll() + Expect(err).Should(MatchError(ContainSubstring("no services"))) + }) + + It("does not fail for a single non mergeable service", func() { + Expect(MergeAll(nonMergeableSvc)).Should(BeIdenticalTo(nonMergeableSvc)) + }) + + It("fails when no service is mergeable", func() { + _, err = MergeAll(nonMergeableSvc, nonMergeableSvc) + Expect(err).Should(MatchError(ContainSubstring("no merger found"))) + }) + + It("merges services", func() { + merged, err := MergeAll(httpSvcA1, httpSvcA1_) + Expect(err).Should(Succeed()) + Expect(merged.String()).Should(SatisfyAll( + ContainSubstring(httpSvcA1.ServiceName()), + ContainSubstring(httpSvcA1_.ServiceName())), + ) + Expect(merged.ExposeOn()).Should(Equal(httpSvcA1.ExposeOn())) + }) + }) +}) diff --git a/service/service.go b/service/service.go new file mode 100644 index 000000000..e0e34db3e --- /dev/null +++ b/service/service.go @@ -0,0 +1,107 @@ +// Package service exposes types to abstract services from the networking. +// +// The idea is that we build a set of services and a set of network endpoints (Listener). +// The services are then assigned to endpoints based on the address(es) they were configured for. +// +// Actual service to endpoint binding is not handled by the abstractions in this package as it is +// protocol specific. +// The general pattern is to make a "server" that wraps a service, and can then be started on an +// endpoint using a `Serve` method, similar to `http.Server`. +// +// To support exposing multiple compatible services on a single endpoint (example: DoH + metrics on a single port), +// services can implement `Merger`. +package service + +import ( + "fmt" + "slices" + "strings" + + "github.com/0xERR0R/blocky/util" +) + +// Service is a network exposed service. +// +// It contains only the logic and user configured addresses it should be exposed on. +// Is is meant to be associated to one or more sockets via those addresses. +// Actual association with a socket is protocol specific. +type Service interface { + fmt.Stringer + + // ServiceName returns the user friendly name of the service. + ServiceName() string + + // ExposeOn returns the set of endpoints the service should be exposed on. + // + // They can be used to find listener(s) with matching configuration. + ExposeOn() []Endpoint +} + +func svcString(s Service) string { + endpoints := util.ConvertEach(s.ExposeOn(), func(e Endpoint) string { return e.String() }) + + return fmt.Sprintf("%s on %s", s.ServiceName(), strings.Join(endpoints, ", ")) +} + +// Info can be embedded in structs to help implement Service. +type Info struct { + Name string + Endpoints []Endpoint +} + +func (i *Info) ServiceName() string { return i.Name } +func (i *Info) ExposeOn() []Endpoint { return i.Endpoints } +func (i *Info) String() string { return svcString(i) } + +// GroupByListener returns a map of listener and services grouped by configured address. +// +// Each input listener is a key in the map. The corresponding value is a service +// merged from all services with a matching address. +func GroupByListener(services []Service, listeners []Listener) (map[Listener]Service, error) { + res := make(map[Listener]Service, len(listeners)) + unused := slices.Clone(services) + + for _, listener := range listeners { + services := findAllCompatible(services, listener.Exposes()) + if len(services) == 0 { + return nil, fmt.Errorf("found no compatible services for listener %s", listener) + } + + svc, err := MergeAll(services...) + if err != nil { + return nil, fmt.Errorf("cannot merge services configured for listener %s: %w", listener, err) + } + + res[listener] = svc + + // Algorithmic complexity is quite high here, but we don't care about performance here, at least for now + for _, svc := range services { + if i := slices.Index(unused, svc); i != -1 { + unused = slices.Delete(unused, i, i+1) + } + } + } + + if len(unused) != 0 { + return nil, fmt.Errorf("found no compatible listener for services: %v", unused) + } + + return res, nil +} + +// findAllCompatible returns the subset of services that use the given Listener. +func findAllCompatible(services []Service, endpoint Endpoint) []Service { + res := make([]Service, 0, len(services)) + + for _, svc := range services { + if isExposedOn(svc, endpoint) { + res = append(res, svc) + } + } + + return res +} + +func isExposedOn(svc Service, endpoint Endpoint) bool { + return slices.Index(svc.ExposeOn(), endpoint) != -1 +} diff --git a/service/service_suite_test.go b/service/service_suite_test.go new file mode 100644 index 000000000..9973a4547 --- /dev/null +++ b/service/service_suite_test.go @@ -0,0 +1,18 @@ +package service + +import ( + "testing" + + "github.com/0xERR0R/blocky/log" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func init() { + log.Silence() +} + +func TestLists(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Service Suite") +} diff --git a/service/service_test.go b/service/service_test.go new file mode 100644 index 000000000..a77a52a83 --- /dev/null +++ b/service/service_test.go @@ -0,0 +1,88 @@ +package service + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Service", func() { + var err error + + Describe("Info", func() { + endpoints := EndpointsFromAddrs("proto", []string{":1", "localhost:2"}) + sut := Info{"name", endpoints} + + It("implements Service", func() { + var svc Service = &sut + + Expect(svc.ServiceName()).Should(Equal("name")) + + Expect(svc.ExposeOn()).Should(Equal(endpoints)) + + Expect(svc.String()).Should(SatisfyAll( + ContainSubstring("name"), + ContainSubstring(":1"), + ContainSubstring("localhost:2"), + )) + }) + }) + + Describe("GroupByListener", func() { + httpSvcA1 := newFakeHTTPService("HTTP A1", "a:1") + httpSvcA1_ := newFakeHTTPService("HTTP A1_", "a:1") + httpSvcA2 := newFakeHTTPService("HTTP A2", "a:2") + httpSvcB1 := newFakeHTTPService("HTTP B1", "b:1") + + httpLnrA1 := &NetListener{nil, ListenerInfo{Endpoint{"http", "a:1"}}} + httpLnrA2 := &NetListener{nil, ListenerInfo{Endpoint{"http", "a:2"}}} + httpLnrB1 := &NetListener{nil, ListenerInfo{Endpoint{"http", "b:1"}}} + + It("assigns single service to matching listener", func() { + Expect( + GroupByListener([]Service{httpSvcA1}, []Listener{httpLnrA1}), + ).Should(Equal(map[Listener]Service{httpLnrA1: httpSvcA1})) + }) + + It("assigns each service to the matching listener", func() { + Expect( + GroupByListener([]Service{httpSvcA1, httpSvcA2, httpSvcB1}, []Listener{httpLnrA1, httpLnrA2, httpLnrB1}), + ).Should(Equal(map[Listener]Service{ + httpLnrA1: httpSvcA1, + httpLnrA2: httpSvcA2, + httpLnrB1: httpSvcB1, + })) + }) + + It("merges services with a common endpoint", func() { + merged, err := MergeAll(httpSvcA1, httpSvcA1_) + Expect(err).Should(Succeed()) + + Expect( + GroupByListener([]Service{httpSvcA1, httpSvcA1_}, []Listener{httpLnrA1}), + ).Should(Equal(map[Listener]Service{httpLnrA1: merged})) + }) + + It("fails when a service has no compatible listener", func() { + _, err = GroupByListener([]Service{httpSvcA1, httpSvcA1_}, nil) + Expect(err).Should(MatchError(ContainSubstring("no compatible listener"))) + + _, err = GroupByListener([]Service{httpSvcA1, httpSvcA1_, httpSvcA2}, []Listener{httpLnrA2}) + Expect(err).Should(MatchError(ContainSubstring("no compatible listener"))) + }) + + It("fails when a listener has no compatible services", func() { + _, err = GroupByListener(nil, []Listener{httpLnrA2}) + Expect(err).Should(MatchError(ContainSubstring("no compatible services"))) + + _, err = GroupByListener([]Service{httpSvcA1, httpSvcA1_}, []Listener{httpLnrA2}) + Expect(err).Should(MatchError(ContainSubstring("no compatible services"))) + }) + + It("fails when services with a common endpoint are not mergeable", func() { + nonMergeable := &Info{"non mergeable", EndpointsFromAddrs("http", []string{"a:1"})} + + _, err = GroupByListener([]Service{httpSvcA1, nonMergeable}, []Listener{httpLnrA1}) + Expect(err).Should(MatchError(ContainSubstring("cannot merge services"))) + }) + }) +}) diff --git a/util/slices.go b/util/slices.go new file mode 100644 index 000000000..2f30d4668 --- /dev/null +++ b/util/slices.go @@ -0,0 +1,37 @@ +package util + +// ConvertEach implements the functional map operation, under a different +// name to avoid confusion with Go's map type. +func ConvertEach[T, U any](slice []T, convert func(T) U) []U { + if slice == nil { + return nil + } + + res := make([]U, 0, len(slice)) + + for _, t := range slice { + u := convert(t) + + res = append(res, u) + } + + return res +} + +// ConcatSlices returns a new slice with contents of all the inputs concatenated. +func ConcatSlices[T any](slices ...[]T) []T { + // Allocation is usually the bottleneck, so do it all at once + totalLen := 0 + + for _, slice := range slices { + totalLen += len(slice) + } + + res := make([]T, 0, totalLen) + + for _, slice := range slices { + res = append(res, slice...) + } + + return res +} diff --git a/util/slices_test.go b/util/slices_test.go new file mode 100644 index 000000000..94d760ab4 --- /dev/null +++ b/util/slices_test.go @@ -0,0 +1,36 @@ +package util_test + +import ( + "strings" + + . "github.com/0xERR0R/blocky/util" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Slices Util", func() { + Describe("ConvertEach", func() { + It("calls the converter for each element", func() { + Expect(ConvertEach([]string{"a", "b"}, strings.ToUpper)).Should(Equal([]string{"A", "B"})) + }) + + It("maps nil to nil", func() { + Expect(ConvertEach(nil, func(any) any { + Fail("converter must not be called") + + return nil + })).Should(BeNil()) + }) + }) + + Describe("ConcatSlices", func() { + It("calls the converter for each element", func() { + Expect(ConcatSlices( + []string{"a", "b"}, + []string{"c"}, + []string{}, + []string{"d", "e"}, + )).Should(Equal([]string{"a", "b", "c", "d", "e"})) + }) + }) +})