Skip to content

Commit

Permalink
refactor: add service package to prepare for split HTTP handling
Browse files Browse the repository at this point in the history
Package service exposes types to abstract services from the networking.

The idea is that we build a set of services and a set of network
endpoints (Listener). The services are then assigned to endpoints based
on the address(es) they were configured for.

Actual service to endpoint binding is not handled by the abstractions in
this package as it is protocol specific.
The general pattern is to make a "server" that wraps a service, and can
then be started on an endpoint using a `Serve` method,
similar to `http.Server`.

To support exposing multiple compatible services on a single endpoint
(example: DoH + metrics on a single port),
services can implement `Merger`.
  • Loading branch information
ThinkChaos committed Aug 30, 2024
1 parent 9da89c3 commit ea921a6
Show file tree
Hide file tree
Showing 17 changed files with 1,075 additions and 71 deletions.
60 changes: 60 additions & 0 deletions helpertest/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package helpertest

import (
"crypto/tls"
"crypto/x509"
"sync"

"github.com/0xERR0R/blocky/util"
. "github.com/onsi/gomega"
)

const tlsTestServerName = "test.blocky.invalid"

type tlsData struct {
ServerCfg *tls.Config
ClientCfg *tls.Config
}

// Lazy init
//
//nolint:gochecknoglobals
var (
initTLSData sync.Once
tlsDataStorage tlsData
)

func getTLSData() tlsData {
initTLSData.Do(func() {
cert, err := util.TLSGenerateSelfSignedCert([]string{tlsTestServerName})
Expect(err).Should(Succeed())

tlsDataStorage.ServerCfg = &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
}

certPool := x509.NewCertPool()
certPool.AddCert(cert.Leaf)

tlsDataStorage.ClientCfg = &tls.Config{
RootCAs: certPool,
ServerName: tlsTestServerName,
MinVersion: tls.VersionTLS13,
}
})

return tlsDataStorage
}

// TLSTestServerConfig returns a TLS Config for use by test servers.
func TLSTestServerConfig() *tls.Config {
return getTLSData().ServerCfg.Clone()
}

// TLSTestServerConfig returns a TLS Config for use by test clients.
//
// This is required to connect to a test TLS server, otherwise TLS verification fails.
func TLSTestClientConfig() *tls.Config {
return getTLSData().ClientCfg.Clone()
}
57 changes: 45 additions & 12 deletions server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,73 @@ import (
"net/http"
"time"

"github.com/0xERR0R/blocky/api"
"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
)

// httpMiscService implements service.HTTPService.
//
// This supports the existing single HTTP/HTTPS endpoints
// that expose everything. The goal is to split it up
// and remove it.
type httpMiscService struct {
service.HTTPInfo
}

func newHTTPMiscService(
cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler,
) *httpMiscService {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP),
service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS),
)

return &httpMiscService{
HTTPInfo: service.HTTPInfo{
Info: service.Info{
Name: "HTTP",
Endpoints: endpoints,
},

Mux: createHTTPRouter(cfg, openAPIImpl, dnsHandler),
},
}
}

func (s *httpMiscService) Merge(other service.Service) (service.Merger, error) {
return service.MergeHTTP(s, other)
}

// httpServer implements subServer for HTTP.
type httpServer struct {
inner http.Server
service.HTTPService

name string
inner http.Server
}

func newHTTPServer(name string, handler http.Handler) *httpServer {
func newHTTPServer(svc service.HTTPService) *httpServer {
const (
readHeaderTimeout = 20 * time.Second
readTimeout = 20 * time.Second
writeTimeout = 20 * time.Second
)

return &httpServer{
HTTPService: svc,

inner: http.Server{
ReadTimeout: readTimeout,
Handler: withCommonMiddleware(svc.Router()),
ReadHeaderTimeout: readHeaderTimeout,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,

Handler: withCommonMiddleware(handler),
},

name: name,
}
}

func (s *httpServer) String() string {
return s.name
}

func (s *httpServer) Serve(ctx context.Context, l net.Listener) error {
go func() {
<-ctx.Done()
Expand Down
138 changes: 85 additions & 53 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"runtime"
"runtime/debug"
"slices"
"strings"
"time"

Expand All @@ -18,6 +19,8 @@ import (
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/redis"
"github.com/0xERR0R/blocky/resolver"
"github.com/0xERR0R/blocky/service"
"golang.org/x/exp/maps"

"github.com/0xERR0R/blocky/util"
"github.com/google/uuid"
Expand All @@ -40,7 +43,14 @@ type Server struct {
queryResolver resolver.ChainedResolver
cfg *config.Config

servers map[net.Listener]*httpServer
services map[service.Listener]service.Service
}

type subServer interface {
fmt.Stringer
service.Service

Serve(context.Context, net.Listener) error
}

func logger() *logrus.Entry {
Expand Down Expand Up @@ -99,8 +109,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
}

// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
var tlsCfg *tls.Config

Expand All @@ -116,7 +124,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, fmt.Errorf("server creation failed: %w", err)
}

httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg)
listeners, err := createListeners(ctx, cfg, tlsCfg)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -145,39 +153,41 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
dnsServers: dnsServers,
queryResolver: queryResolver,
cfg: cfg,

servers: make(map[net.Listener]*httpServer),
}

server.printConfiguration()

server.registerDNSHandlers(ctx)

openAPIImpl, err := server.createOpenAPIInterfaceImpl()
services, err := server.createServices()
if err != nil {
return nil, err
}

httpRouter := createHTTPRouter(cfg, openAPIImpl)
server.registerDoHEndpoints(httpRouter)
server.services, err = service.GroupByListener(services, listeners)
if err != nil {
return nil, err
}

if len(cfg.Ports.HTTP) != 0 {
srv := newHTTPServer("http", httpRouter)
return server, err
}

for _, l := range httpListeners {
server.servers[l] = srv
}
func (s *Server) createServices() ([]service.Service, error) {
openAPIImpl, err := s.createOpenAPIInterfaceImpl()
if err != nil {
return nil, err
}

if len(cfg.Ports.HTTPS) != 0 {
srv := newHTTPServer("https", httpRouter)

for _, l := range httpsListeners {
server.servers[l] = srv
}
res := []service.Service{
newHTTPMiscService(s.cfg, openAPIImpl, s.handleReq),
}

return server, err
// Remove services the user has not enabled
res = slices.DeleteFunc(res, func(svc service.Service) bool {
return len(svc.ExposeOn()) == 0
})

return res, nil
}

func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error) {
Expand Down Expand Up @@ -208,48 +218,51 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
return dnsServers, err.ErrorOrNil()
}

func createHTTPListeners(
cfg *config.Config, tlsCfg *tls.Config,
) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP)
if err != nil {
return nil, nil, err
func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config) ([]service.Listener, error) {
res := make(map[string]service.Listener)

listenTLS := func(ctx context.Context, endpoint service.Endpoint) (service.Listener, error) {
return service.ListenTLS(ctx, endpoint, tlsCfg)
}

httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg)
err := errors.Join(
newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res),
newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res),
)
if err != nil {
return nil, nil, err
return nil, err
}

return httpListeners, httpsListeners, nil
return maps.Values(res), nil
}

func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
listeners := make([]net.Listener, 0, len(addresses))

for _, address := range addresses {
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, fmt.Errorf("start %s listener on %s failed: %w", proto, address, err)
type listenFunc[T service.Listener] func(context.Context, service.Endpoint) (T, error)

func newListeners[T service.Listener](
ctx context.Context, proto string, addrs config.ListenConfig, listen listenFunc[T], out map[string]service.Listener,
) error {
for _, addr := range addrs {
key := fmt.Sprintf("%s:%s", proto, addr)
if _, ok := out[key]; ok {
// Avoid "address already in use"
// We instead try to merge services, see services.GroupByListener
continue
}

listeners = append(listeners, listener)
}

return listeners, nil
}
endpoint := service.Endpoint{
Protocol: proto,
AddrConf: addr,
}

func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) {
listeners, err := newTCPListeners(proto, addresses)
if err != nil {
return nil, err
}
l, err := listen(ctx, endpoint)
if err != nil {
return err // already has all info
}

for i, inner := range listeners {
listeners[i] = tls.NewListener(inner, tlsCfg)
out[key] = l
}

return listeners, nil
return nil
}

func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
Expand Down Expand Up @@ -385,6 +398,16 @@ func toMB(b uint64) uint64 {
return b / bytesInKB / bytesInKB
}

func newSubServer(svc service.Service) (subServer, error) {
switch svc := svc.(type) {
case service.HTTPService:
return newHTTPServer(svc), nil

default:
return nil, fmt.Errorf("unsupported service type: %T (%s)", svc, svc)
}
}

// Start starts the server
func (s *Server) Start(ctx context.Context, errCh chan<- error) {
logger().Info("Starting server")
Expand All @@ -399,11 +422,18 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
}()
}

for listener, srv := range s.servers {
listener, srv := listener, srv
for listener, svc := range s.services {
listener, svc := listener, svc

srv, err := newSubServer(svc)
if err != nil {
errCh <- fmt.Errorf("%s on %s: %w", svc.ServiceName(), listener.Exposes(), err)

return
}

go func() {
logger().Infof("%s server is up and running on addr/port %s", srv, listener.Addr())
logger().Infof("%s server is up and running on %s", svc.ServiceName(), listener.Exposes())

err := srv.Serve(ctx, listener)
if err != nil {
Expand Down Expand Up @@ -506,6 +536,8 @@ type msgWriter interface {
WriteMsg(msg *dns.Msg) error
}

type dnsHandler func(context.Context, *model.Request, msgWriter)

func (s *Server) handleReq(ctx context.Context, request *model.Request, w msgWriter) {
response, err := s.resolve(ctx, request)
if err != nil {
Expand Down
Loading

0 comments on commit ea921a6

Please sign in to comment.