From 0c437b5416b6877f22f05341ebada023d918e981 Mon Sep 17 00:00:00 2001 From: Dave Shanley Date: Mon, 31 Jul 2023 15:59:42 -0400 Subject: [PATCH] reverted `request` backward on model. This is a breaking change that makes ranch compatible with transport. I want a better name, but this will have to do! Also formatted code. Signed-off-by: Dave Shanley --- bridge/bridge_client.go | 418 +++---- bridge/bridge_client_subscription.go | 36 +- bridge/bridge_client_test.go | 70 +- bridge/broker_connector_test.go | 2 +- bridge/broker_connector_tls_test.go | 2 +- bridge/connection.go | 326 ++--- bridge/example_connector_broker_tcp_test.go | 134 +- bridge/example_connector_broker_ws_test.go | 124 +- bridge/subscription.go | 60 +- bus/channel.go | 256 ++-- bus/channel_manager.go | 250 ++-- bus/channel_manager_test.go | 390 +++--- bus/channel_test.go | 386 +++--- bus/eventbus_test.go | 1106 ++++++++--------- bus/example_galactic_channels_test.go | 160 +-- bus/fabric_endpoint_test.go | 630 +++++----- bus/message_handler.go | 80 +- bus/message_test.go | 26 +- bus/store.go | 752 +++++------ bus/store_manager.go | 230 ++-- bus/store_sync_service.go | 462 +++---- bus/store_sync_service_test.go | 964 +++++++------- bus/transaction.go | 348 +++--- bus/transaction_test.go | 406 +++--- model/request.go | 60 +- .../pkg/middleware/basic_security_headers.go | 16 +- plank/pkg/middleware/cache_control.go | 160 +-- plank/pkg/middleware/middleware_manager.go | 264 ++-- plank/pkg/server/base_error.go | 36 +- plank/pkg/server/base_error_test.go | 38 +- plank/pkg/server/core_models.go | 118 +- plank/pkg/server/flag_helper.go | 274 ++-- plank/pkg/server/flag_helper_test.go | 98 +- plank/pkg/server/helpers.go | 308 ++--- plank/pkg/server/helpers_test.go | 236 ++-- .../initialize_rest_bridge_override_test.go | 192 +-- plank/pkg/server/server_smoke_test.go | 320 ++--- plank/pkg/server/spa_config.go | 82 +- plank/pkg/server/test_suite_harness.go | 292 ++--- plank/pkg/server/test_suite_harness_test.go | 66 +- plank/services/ping-pong-service.go | 164 +-- plank/utils/cli.go | 186 +-- service/fabric_core_test.go | 556 ++++----- service/fabric_error.go | 22 +- service/fabric_service.go | 8 +- service/rest_service.go | 358 +++--- service/rest_service_test.go | 656 +++++----- service/service_lifecycle_manager.go | 138 +- service/service_lifecycle_manager_test.go | 168 +-- service/service_registry_test.go | 260 ++-- stompserver/stomp_connection.go | 748 +++++------ stompserver/stomp_connection_test.go | 1014 +++++++-------- stompserver/websocket_connection_listener.go | 268 ++-- 53 files changed, 7362 insertions(+), 7362 deletions(-) diff --git a/bridge/bridge_client.go b/bridge/bridge_client.go index fca4ca8..8f2d52f 100644 --- a/bridge/bridge_client.go +++ b/bridge/bridge_client.go @@ -4,260 +4,260 @@ package bridge import ( - "bufio" - "bytes" - "errors" - "fmt" - "github.com/go-stomp/stomp/v3" - "github.com/go-stomp/stomp/v3/frame" - "github.com/google/uuid" - "github.com/gorilla/websocket" - "github.com/pb33f/ranch/model" - "log" - "net/url" - "os" - "strconv" - "sync" + "bufio" + "bytes" + "errors" + "fmt" + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/frame" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/pb33f/ranch/model" + "log" + "net/url" + "os" + "strconv" + "sync" ) // BridgeClient encapsulates all subscriptions and io to and from brokers. type BridgeClient struct { - WSc *websocket.Conn // WebSocket connection - TCPc *stomp.Conn // STOMP TCP Connection - ConnectedChan chan bool - disconnectedChan chan bool - connected bool - inboundChan chan *frame.Frame - stompConnected bool - Subscriptions map[string]*BridgeClientSub - logger *log.Logger - lock sync.Mutex - sendLock sync.Mutex + WSc *websocket.Conn // WebSocket connection + TCPc *stomp.Conn // STOMP TCP Connection + ConnectedChan chan bool + disconnectedChan chan bool + connected bool + inboundChan chan *frame.Frame + stompConnected bool + Subscriptions map[string]*BridgeClientSub + logger *log.Logger + lock sync.Mutex + sendLock sync.Mutex } // NewBridgeWsClient Create a new WebSocket client. func NewBridgeWsClient(enableLogging bool) *BridgeClient { - return newBridgeWsClient(enableLogging) + return newBridgeWsClient(enableLogging) } func newBridgeWsClient(enableLogging bool) *BridgeClient { - var l *log.Logger = nil - if enableLogging { - l = log.New(os.Stderr, "WebSocket Client: ", 2) - } - return &BridgeClient{ - WSc: nil, - TCPc: nil, - stompConnected: false, - connected: false, - logger: l, - lock: sync.Mutex{}, - sendLock: sync.Mutex{}, - Subscriptions: make(map[string]*BridgeClientSub), - ConnectedChan: make(chan bool), - disconnectedChan: make(chan bool), - inboundChan: make(chan *frame.Frame)} + var l *log.Logger = nil + if enableLogging { + l = log.New(os.Stderr, "WebSocket Client: ", 2) + } + return &BridgeClient{ + WSc: nil, + TCPc: nil, + stompConnected: false, + connected: false, + logger: l, + lock: sync.Mutex{}, + sendLock: sync.Mutex{}, + Subscriptions: make(map[string]*BridgeClientSub), + ConnectedChan: make(chan bool), + disconnectedChan: make(chan bool), + inboundChan: make(chan *frame.Frame)} } // Connect to broker endpoint. func (ws *BridgeClient) Connect(url *url.URL, config *BrokerConnectorConfig) error { - ws.lock.Lock() - defer ws.lock.Unlock() - if ws.logger != nil { - ws.logger.Printf("connecting to fabric endpoint over %s", url.String()) - } - - dialer := websocket.DefaultDialer - if config.WebSocketConfig.UseTLS { - - // if the cert and key are not set, we're acting as a client, not a server so we have to - // allow these values to be empty when connecting vs serving over TLS. - if config.WebSocketConfig.CertFile != "" || config.WebSocketConfig.KeyFile != "" { - if err := config.WebSocketConfig.LoadX509KeyPairFromFiles( - config.WebSocketConfig.CertFile, - config.WebSocketConfig.KeyFile); err != nil { - return err - } - } - dialer.TLSClientConfig = config.WebSocketConfig.TLSConfig - } - - c, _, err := dialer.Dial(url.String(), config.HttpHeader) - if err != nil { - return err - } - ws.WSc = c - - // handle incoming STOMP frames. - go ws.handleIncomingSTOMPFrames() - - // go listen to the websocket - go ws.listenSocket() - - stompHeaders := []string{ - frame.AcceptVersion, - string(stomp.V12), - frame.Login, - config.Username, - frame.Passcode, - config.Password, - frame.HeartBeat, - fmt.Sprintf("%d,%d", config.HeartBeatOut.Milliseconds(), config.HeartBeatIn.Milliseconds())} - for key, value := range config.STOMPHeader { - stompHeaders = append(stompHeaders, key, value) - } - - // send connect frame. - ws.SendFrame(frame.New(frame.CONNECT, stompHeaders...)) - - // wait to be connected - <-ws.ConnectedChan - return nil + ws.lock.Lock() + defer ws.lock.Unlock() + if ws.logger != nil { + ws.logger.Printf("connecting to fabric endpoint over %s", url.String()) + } + + dialer := websocket.DefaultDialer + if config.WebSocketConfig.UseTLS { + + // if the cert and key are not set, we're acting as a client, not a server so we have to + // allow these values to be empty when connecting vs serving over TLS. + if config.WebSocketConfig.CertFile != "" || config.WebSocketConfig.KeyFile != "" { + if err := config.WebSocketConfig.LoadX509KeyPairFromFiles( + config.WebSocketConfig.CertFile, + config.WebSocketConfig.KeyFile); err != nil { + return err + } + } + dialer.TLSClientConfig = config.WebSocketConfig.TLSConfig + } + + c, _, err := dialer.Dial(url.String(), config.HttpHeader) + if err != nil { + return err + } + ws.WSc = c + + // handle incoming STOMP frames. + go ws.handleIncomingSTOMPFrames() + + // go listen to the websocket + go ws.listenSocket() + + stompHeaders := []string{ + frame.AcceptVersion, + string(stomp.V12), + frame.Login, + config.Username, + frame.Passcode, + config.Password, + frame.HeartBeat, + fmt.Sprintf("%d,%d", config.HeartBeatOut.Milliseconds(), config.HeartBeatIn.Milliseconds())} + for key, value := range config.STOMPHeader { + stompHeaders = append(stompHeaders, key, value) + } + + // send connect frame. + ws.SendFrame(frame.New(frame.CONNECT, stompHeaders...)) + + // wait to be connected + <-ws.ConnectedChan + return nil } // Disconnect from broker endpoint func (ws *BridgeClient) Disconnect() error { - if ws.WSc != nil { - defer ws.WSc.Close() - ws.disconnectedChan <- true - } else { - return fmt.Errorf("cannot disconnect, no connection defined") - } - return nil + if ws.WSc != nil { + defer ws.WSc.Close() + ws.disconnectedChan <- true + } else { + return fmt.Errorf("cannot disconnect, no connection defined") + } + return nil } // Subscribe to destination func (ws *BridgeClient) Subscribe(destination string) *BridgeClientSub { - ws.lock.Lock() - defer ws.lock.Unlock() - id := uuid.New() - s := &BridgeClientSub{ - C: make(chan *model.Message), - Id: &id, - Client: ws, - Destination: destination, - subscribed: true} - - ws.Subscriptions[destination] = s - - // create subscription frame. - subscribeFrame := frame.New(frame.SUBSCRIBE, - frame.Id, id.String(), - frame.Destination, destination, - frame.Ack, stomp.AckAuto.String()) - - // send subscription frame. - ws.SendFrame(subscribeFrame) - return s + ws.lock.Lock() + defer ws.lock.Unlock() + id := uuid.New() + s := &BridgeClientSub{ + C: make(chan *model.Message), + Id: &id, + Client: ws, + Destination: destination, + subscribed: true} + + ws.Subscriptions[destination] = s + + // create subscription frame. + subscribeFrame := frame.New(frame.SUBSCRIBE, + frame.Id, id.String(), + frame.Destination, destination, + frame.Ack, stomp.AckAuto.String()) + + // send subscription frame. + ws.SendFrame(subscribeFrame) + return s } // Send a payload to a destination func (ws *BridgeClient) Send(destination, contentType string, payload []byte, opts ...func(fr *frame.Frame) error) { - ws.lock.Lock() - defer ws.lock.Unlock() + ws.lock.Lock() + defer ws.lock.Unlock() - // create send frame. - sendFrame := frame.New(frame.SEND, - frame.Destination, destination, - frame.ContentLength, strconv.Itoa(len(payload)), - frame.ContentType, contentType) + // create send frame. + sendFrame := frame.New(frame.SEND, + frame.Destination, destination, + frame.ContentLength, strconv.Itoa(len(payload)), + frame.ContentType, contentType) - // apply extra frame options such as adding extra headers - for _, frameOpt := range opts { - _ = frameOpt(sendFrame) - } - // add payload - sendFrame.Body = payload + // apply extra frame options such as adding extra headers + for _, frameOpt := range opts { + _ = frameOpt(sendFrame) + } + // add payload + sendFrame.Body = payload - // send frame - go ws.SendFrame(sendFrame) + // send frame + go ws.SendFrame(sendFrame) } // SendFrame fire a STOMP frame down the WebSocket func (ws *BridgeClient) SendFrame(f *frame.Frame) { - ws.sendLock.Lock() - defer ws.sendLock.Unlock() - var b bytes.Buffer - br := bufio.NewWriter(&b) - sw := frame.NewWriter(br) + ws.sendLock.Lock() + defer ws.sendLock.Unlock() + var b bytes.Buffer + br := bufio.NewWriter(&b) + sw := frame.NewWriter(br) - // write frame to buffer - sw.Write(f) - w, _ := ws.WSc.NextWriter(websocket.TextMessage) - defer w.Close() + // write frame to buffer + sw.Write(f) + w, _ := ws.WSc.NextWriter(websocket.TextMessage) + defer w.Close() - w.Write(b.Bytes()) + w.Write(b.Bytes()) } func (ws *BridgeClient) listenSocket() { - for { - // read each incoming message from websocket - _, p, err := ws.WSc.ReadMessage() - b := bytes.NewReader(p) - sr := frame.NewReader(b) - f, _ := sr.Read() - - if err != nil { - break // socket can't be read anymore, exit. - } - if f != nil { - ws.inboundChan <- f - } - } + for { + // read each incoming message from websocket + _, p, err := ws.WSc.ReadMessage() + b := bytes.NewReader(p) + sr := frame.NewReader(b) + f, _ := sr.Read() + + if err != nil { + break // socket can't be read anymore, exit. + } + if f != nil { + ws.inboundChan <- f + } + } } func (ws *BridgeClient) handleIncomingSTOMPFrames() { - for { - select { - case <-ws.disconnectedChan: - return - case f := <-ws.inboundChan: - switch f.Command { - case frame.CONNECTED: - if ws.logger != nil { - ws.logger.Printf("STOMP Client connected") - } - ws.stompConnected = true - ws.connected = true - ws.ConnectedChan <- true - - case frame.MESSAGE: - for _, sub := range ws.Subscriptions { - if sub.Destination == f.Header.Get(frame.Destination) { - c := &model.MessageConfig{Payload: f.Body, Destination: sub.Destination} - sub.lock.RLock() - if sub.subscribed { - ws.sendResponseSafe(sub.C, model.GenerateResponse(c)) - } - sub.lock.RUnlock() - } - } - - case frame.ERROR: - if ws.logger != nil { - ws.logger.Printf("STOMP ErrorDir received") - } - - for _, sub := range ws.Subscriptions { - if sub.Destination == f.Header.Get(frame.Destination) { - c := &model.MessageConfig{Payload: f.Body, Err: errors.New("STOMP ErrorDir " + string(f.Body))} - sub.E <- model.GenerateError(c) - } - } - } - } - } + for { + select { + case <-ws.disconnectedChan: + return + case f := <-ws.inboundChan: + switch f.Command { + case frame.CONNECTED: + if ws.logger != nil { + ws.logger.Printf("STOMP Client connected") + } + ws.stompConnected = true + ws.connected = true + ws.ConnectedChan <- true + + case frame.MESSAGE: + for _, sub := range ws.Subscriptions { + if sub.Destination == f.Header.Get(frame.Destination) { + c := &model.MessageConfig{Payload: f.Body, Destination: sub.Destination} + sub.lock.RLock() + if sub.subscribed { + ws.sendResponseSafe(sub.C, model.GenerateResponse(c)) + } + sub.lock.RUnlock() + } + } + + case frame.ERROR: + if ws.logger != nil { + ws.logger.Printf("STOMP ErrorDir received") + } + + for _, sub := range ws.Subscriptions { + if sub.Destination == f.Header.Get(frame.Destination) { + c := &model.MessageConfig{Payload: f.Body, Err: errors.New("STOMP ErrorDir " + string(f.Body))} + sub.E <- model.GenerateError(c) + } + } + } + } + } } func (ws *BridgeClient) sendResponseSafe(C chan *model.Message, m *model.Message) { - defer func() { - if r := recover(); r != nil { - if ws.logger != nil { - ws.logger.Println("channel is closed, message undeliverable to closed channel.") - } - } - }() - C <- m + defer func() { + if r := recover(); r != nil { + if ws.logger != nil { + ws.logger.Println("channel is closed, message undeliverable to closed channel.") + } + } + }() + C <- m } diff --git a/bridge/bridge_client_subscription.go b/bridge/bridge_client_subscription.go index cf287f4..21e007c 100644 --- a/bridge/bridge_client_subscription.go +++ b/bridge/bridge_client_subscription.go @@ -4,31 +4,31 @@ package bridge import ( - "github.com/go-stomp/stomp/v3/frame" - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "sync" + "github.com/go-stomp/stomp/v3/frame" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "sync" ) // BridgeClientSub is a client subscription that encapsulates message and error channels for a subscription type BridgeClientSub struct { - C chan *model.Message // MESSAGE payloads - E chan *model.Message // ERROR payloads. - Id *uuid.UUID - Destination string - Client *BridgeClient - subscribed bool - lock sync.RWMutex + C chan *model.Message // MESSAGE payloads + E chan *model.Message // ERROR payloads. + Id *uuid.UUID + Destination string + Client *BridgeClient + subscribed bool + lock sync.RWMutex } // Send an UNSUBSCRIBE frame for subscription destination. func (cs *BridgeClientSub) Unsubscribe() { - cs.lock.Lock() - cs.subscribed = false - cs.lock.Unlock() - unsubscribeFrame := frame.New(frame.UNSUBSCRIBE, - frame.Id, cs.Id.String(), - frame.Destination, cs.Destination) + cs.lock.Lock() + cs.subscribed = false + cs.lock.Unlock() + unsubscribeFrame := frame.New(frame.UNSUBSCRIBE, + frame.Id, cs.Id.String(), + frame.Destination, cs.Destination) - cs.Client.SendFrame(unsubscribeFrame) + cs.Client.SendFrame(unsubscribeFrame) } diff --git a/bridge/bridge_client_test.go b/bridge/bridge_client_test.go index 3b0b23f..1331c05 100644 --- a/bridge/bridge_client_test.go +++ b/bridge/bridge_client_test.go @@ -4,46 +4,46 @@ package bridge import ( - "github.com/go-stomp/stomp/v3/frame" - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "log" - "os" - "sync" - "testing" + "github.com/go-stomp/stomp/v3/frame" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "log" + "os" + "sync" + "testing" ) func TestBridgeClient_Disconnect(t *testing.T) { - bc := new(BridgeClient) - e := bc.Disconnect() - assert.NotNil(t, e) + bc := new(BridgeClient) + e := bc.Disconnect() + assert.NotNil(t, e) } func TestBridgeClient_handleCommands(t *testing.T) { - d := "rainbow-land" - bc := new(BridgeClient) - i := make(chan *frame.Frame, 1) - e := make(chan *model.Message, 1) - - l := log.New(os.Stderr, "WebSocket Client: ", 2) - bc.logger = l - bc.inboundChan = i - bc.Subscriptions = make(map[string]*BridgeClientSub) - s := &BridgeClientSub{E: e, Destination: d} - bc.Subscriptions[d] = s - - go bc.handleIncomingSTOMPFrames() - wg := sync.WaitGroup{} - - var sendError = func() { - f := frame.New(frame.ERROR, frame.Destination, d) - bc.inboundChan <- f - wg.Done() - } - wg.Add(1) - go sendError() - wg.Wait() - m := <-s.E - assert.Error(t, m.Error) + d := "rainbow-land" + bc := new(BridgeClient) + i := make(chan *frame.Frame, 1) + e := make(chan *model.Message, 1) + + l := log.New(os.Stderr, "WebSocket Client: ", 2) + bc.logger = l + bc.inboundChan = i + bc.Subscriptions = make(map[string]*BridgeClientSub) + s := &BridgeClientSub{E: e, Destination: d} + bc.Subscriptions[d] = s + + go bc.handleIncomingSTOMPFrames() + wg := sync.WaitGroup{} + + var sendError = func() { + f := frame.New(frame.ERROR, frame.Destination, d) + bc.inboundChan <- f + wg.Done() + } + wg.Add(1) + go sendError() + wg.Wait() + m := <-s.E + assert.Error(t, m.Error) } diff --git a/bridge/broker_connector_test.go b/bridge/broker_connector_test.go index 82f4e98..e6206e7 100644 --- a/bridge/broker_connector_test.go +++ b/bridge/broker_connector_test.go @@ -71,7 +71,7 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { } } -//var srv Server +// var srv Server var testBrokerAddress = ":51581" var httpServer *httptest.Server var tcpServer net.Listener diff --git a/bridge/broker_connector_tls_test.go b/bridge/broker_connector_tls_test.go index 9682580..8ef3619 100644 --- a/bridge/broker_connector_tls_test.go +++ b/bridge/broker_connector_tls_test.go @@ -20,7 +20,7 @@ import ( var webSocketURLChanTLS = make(chan string) var websocketURLTLS string -//var srv Server +// var srv Server var testTLS = &tls.Config{ InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, diff --git a/bridge/connection.go b/bridge/connection.go index a120a76..bd4d13f 100644 --- a/bridge/connection.go +++ b/bridge/connection.go @@ -4,60 +4,60 @@ package bridge import ( - "fmt" - "github.com/go-stomp/stomp/v3" - "github.com/go-stomp/stomp/v3/frame" - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "log" - "sync" + "fmt" + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/frame" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "log" + "sync" ) type Connection interface { - GetId() *uuid.UUID - Subscribe(destination string) (Subscription, error) - SubscribeReplyDestination(destination string) (Subscription, error) - Disconnect() (err error) - SendJSONMessage(destination string, payload []byte, opts ...func(*frame.Frame) error) error - SendMessage(destination, contentType string, payload []byte, opts ...func(*frame.Frame) error) error - SendMessageWithReplyDestination(destination, replyDestination, contentType string, payload []byte, opts ...func(*frame.Frame) error) error + GetId() *uuid.UUID + Subscribe(destination string) (Subscription, error) + SubscribeReplyDestination(destination string) (Subscription, error) + Disconnect() (err error) + SendJSONMessage(destination string, payload []byte, opts ...func(*frame.Frame) error) error + SendMessage(destination, contentType string, payload []byte, opts ...func(*frame.Frame) error) error + SendMessageWithReplyDestination(destination, replyDestination, contentType string, payload []byte, opts ...func(*frame.Frame) error) error } // Connection represents a Connection to a message broker. type connection struct { - id *uuid.UUID - useWs bool - conn *stomp.Conn - wsConn *BridgeClient - disconnectChan chan bool - subscriptions map[string]Subscription - connLock sync.Mutex + id *uuid.UUID + useWs bool + conn *stomp.Conn + wsConn *BridgeClient + disconnectChan chan bool + subscriptions map[string]Subscription + connLock sync.Mutex } func (c *connection) GetId() *uuid.UUID { - return c.id + return c.id } // Subscribe to a destination, only one subscription can exist for a destination func (c *connection) Subscribe(destination string) (Subscription, error) { - // check if the subscription exists, if so, return it. - if c == nil { - return nil, fmt.Errorf("cannot subscribe to '%s', no connection to broker", destination) - } - - c.connLock.Lock() - if sub, ok := c.subscriptions[destination]; ok { - c.connLock.Unlock() - return sub, nil - } - c.connLock.Unlock() - - // use websocket, if not use stomp TCP. - if c.useWs { - return c.subscribeWs(destination) - } - - return c.subscribeTCP(destination) + // check if the subscription exists, if so, return it. + if c == nil { + return nil, fmt.Errorf("cannot subscribe to '%s', no connection to broker", destination) + } + + c.connLock.Lock() + if sub, ok := c.subscriptions[destination]; ok { + c.connLock.Unlock() + return sub, nil + } + c.connLock.Unlock() + + // use websocket, if not use stomp TCP. + if c.useWs { + return c.subscribeWs(destination) + } + + return c.subscribeTCP(destination) } // SubscribeReplyDestination subscribe to a reply destination (this will create an internal subscription to the @@ -65,160 +65,160 @@ func (c *connection) Subscribe(destination string) (Subscription, error) { // queues, the destinations are dynamic. The raw socket is send responses that are to a destination that // does not actually exist when using reply-to so this will allow that imaginary destination to operate. func (c *connection) SubscribeReplyDestination(destination string) (Subscription, error) { - // check if the subscription exists, if so, return it. - if c == nil { - return nil, fmt.Errorf("cannot subscribe to '%s', no connection to broker", destination) - } - - c.connLock.Lock() - if sub, ok := c.subscriptions[destination]; ok { - c.connLock.Unlock() - return sub, nil - } - c.connLock.Unlock() - - // use websocket, if not use stomp TCP. - if c.useWs { - return c.subscribeWs(destination) - } - - return c.subscribeTCPUsingReplyDestination(destination) + // check if the subscription exists, if so, return it. + if c == nil { + return nil, fmt.Errorf("cannot subscribe to '%s', no connection to broker", destination) + } + + c.connLock.Lock() + if sub, ok := c.subscriptions[destination]; ok { + c.connLock.Unlock() + return sub, nil + } + c.connLock.Unlock() + + // use websocket, if not use stomp TCP. + if c.useWs { + return c.subscribeWs(destination) + } + + return c.subscribeTCPUsingReplyDestination(destination) } // Disconnect from broker, will close all channels func (c *connection) Disconnect() (err error) { - if c == nil { - return fmt.Errorf("cannot disconnect, not connected") - } - if c.useWs { - if c.wsConn != nil && c.wsConn.connected { - defer c.cleanUpConnection() - err = c.wsConn.Disconnect() - } - } else { - if c.conn != nil { - defer c.cleanUpConnection() - err = c.conn.Disconnect() - } - } - return err + if c == nil { + return fmt.Errorf("cannot disconnect, not connected") + } + if c.useWs { + if c.wsConn != nil && c.wsConn.connected { + defer c.cleanUpConnection() + err = c.wsConn.Disconnect() + } + } else { + if c.conn != nil { + defer c.cleanUpConnection() + err = c.conn.Disconnect() + } + } + return err } func (c *connection) cleanUpConnection() { - if c.conn != nil { - c.conn = nil - } - if c.wsConn != nil { - c.wsConn = nil - } + if c.conn != nil { + c.conn = nil + } + if c.wsConn != nil { + c.wsConn = nil + } } func (c *connection) subscribeWs(destination string) (Subscription, error) { - c.connLock.Lock() - defer c.connLock.Unlock() - if c.wsConn != nil { - wsSub := c.wsConn.Subscribe(destination) - sub := &subscription{wsStompSub: wsSub, id: wsSub.Id, c: wsSub.C, destination: destination} - c.subscriptions[destination] = sub - return sub, nil - } - return nil, fmt.Errorf("cannot subscribe, websocket not connected / established") + c.connLock.Lock() + defer c.connLock.Unlock() + if c.wsConn != nil { + wsSub := c.wsConn.Subscribe(destination) + sub := &subscription{wsStompSub: wsSub, id: wsSub.Id, c: wsSub.C, destination: destination} + c.subscriptions[destination] = sub + return sub, nil + } + return nil, fmt.Errorf("cannot subscribe, websocket not connected / established") } func (c *connection) subscribeTCP(destination string) (Subscription, error) { - c.connLock.Lock() - defer c.connLock.Unlock() - if c.conn != nil { - sub, _ := c.conn.Subscribe(destination, stomp.AckAuto) - id := uuid.New() - destChan := make(chan *model.Message) - go c.listenTCPFrames(sub.C, destChan) - bcSub := &subscription{stompTCPSub: sub, id: &id, c: destChan} - c.subscriptions[destination] = bcSub - return bcSub, nil - } - return nil, fmt.Errorf("no STOMP TCP connection established") + c.connLock.Lock() + defer c.connLock.Unlock() + if c.conn != nil { + sub, _ := c.conn.Subscribe(destination, stomp.AckAuto) + id := uuid.New() + destChan := make(chan *model.Message) + go c.listenTCPFrames(sub.C, destChan) + bcSub := &subscription{stompTCPSub: sub, id: &id, c: destChan} + c.subscriptions[destination] = bcSub + return bcSub, nil + } + return nil, fmt.Errorf("no STOMP TCP connection established") } func (c *connection) subscribeTCPUsingReplyDestination(destination string) (Subscription, error) { - c.connLock.Lock() - defer c.connLock.Unlock() - if c.conn != nil { - - var reply = func(f *frame.Frame) error { - f.Header.Add("reply-to", destination) - return nil - } - - sub, _ := c.conn.Subscribe(destination, stomp.AckAuto, reply) - id := uuid.New() - destChan := make(chan *model.Message) - go c.listenTCPFrames(sub.C, destChan) - bcSub := &subscription{stompTCPSub: sub, id: &id, c: destChan} - c.subscriptions[destination] = bcSub - return bcSub, nil - } - return nil, fmt.Errorf("no STOMP TCP connection established") + c.connLock.Lock() + defer c.connLock.Unlock() + if c.conn != nil { + + var reply = func(f *frame.Frame) error { + f.Header.Add("reply-to", destination) + return nil + } + + sub, _ := c.conn.Subscribe(destination, stomp.AckAuto, reply) + id := uuid.New() + destChan := make(chan *model.Message) + go c.listenTCPFrames(sub.C, destChan) + bcSub := &subscription{stompTCPSub: sub, id: &id, c: destChan} + c.subscriptions[destination] = bcSub + return bcSub, nil + } + return nil, fmt.Errorf("no STOMP TCP connection established") } func (c *connection) listenTCPFrames(src chan *stomp.Message, dst chan *model.Message) { - defer func() { - if r := recover(); r != nil { - log.Println("subscription is closed, message undeliverable to closed channel.") - } - }() - for { - f := <-src - var body []byte - var dest string - if f != nil && f.Body != nil { - body = f.Body - } - if f != nil && len(f.Destination) > 0 { - dest = f.Destination - } - if f != nil { - cf := &model.MessageConfig{Payload: body, Destination: dest} - - // transfer over known non-standard, but important frame headers if they are set - if replyTo, ok := f.Header.Contains("reply-to"); ok { // used by rabbitmq for temp queues - cf.Headers = []model.MessageHeader{{Label: "reply-to", Value: replyTo}} - } - - m := model.GenerateResponse(cf) - dst <- m - } - } + defer func() { + if r := recover(); r != nil { + log.Println("subscription is closed, message undeliverable to closed channel.") + } + }() + for { + f := <-src + var body []byte + var dest string + if f != nil && f.Body != nil { + body = f.Body + } + if f != nil && len(f.Destination) > 0 { + dest = f.Destination + } + if f != nil { + cf := &model.MessageConfig{Payload: body, Destination: dest} + + // transfer over known non-standard, but important frame headers if they are set + if replyTo, ok := f.Header.Contains("reply-to"); ok { // used by rabbitmq for temp queues + cf.Headers = []model.MessageHeader{{Label: "reply-to", Value: replyTo}} + } + + m := model.GenerateResponse(cf) + dst <- m + } + } } // SendJSONMessage sends a []byte payload carrying JSON data to a destination. func (c *connection) SendJSONMessage(destination string, payload []byte, opts ...func(*frame.Frame) error) error { - return c.SendMessage(destination, "application/json", payload, opts...) + return c.SendMessage(destination, "application/json", payload, opts...) } // SendMessageWithReplyDestination is the same as SendMessage, but adds in a reply-to header automatically. // This is generally used in conjunction with SubscribeReplyDestination func (c *connection) SendMessageWithReplyDestination(destination string, replyDestination, contentType string, payload []byte, opts ...func(*frame.Frame) error) error { - var headerReplyTo = func(f *frame.Frame) error { - f.Header.Add("reply-to", replyDestination) - return nil - } - opts = append(opts, headerReplyTo) - return c.SendMessage(destination, contentType, payload, opts...) + var headerReplyTo = func(f *frame.Frame) error { + f.Header.Add("reply-to", replyDestination) + return nil + } + opts = append(opts, headerReplyTo) + return c.SendMessage(destination, contentType, payload, opts...) } // SendMessage will send a []byte payload to a destination. func (c *connection) SendMessage(destination string, contentType string, payload []byte, opts ...func(*frame.Frame) error) error { - c.connLock.Lock() - defer c.connLock.Unlock() - if c != nil && !c.useWs && c.conn != nil { - c.conn.Send(destination, contentType, payload, opts...) - return nil - } - if c != nil && c.useWs && c.wsConn != nil { - c.wsConn.Send(destination, contentType, payload, opts...) - return nil - } - return fmt.Errorf("cannot send message, no connection") + c.connLock.Lock() + defer c.connLock.Unlock() + if c != nil && !c.useWs && c.conn != nil { + c.conn.Send(destination, contentType, payload, opts...) + return nil + } + if c != nil && c.useWs && c.wsConn != nil { + c.wsConn.Send(destination, contentType, payload, opts...) + return nil + } + return fmt.Errorf("cannot send message, no connection") } diff --git a/bridge/example_connector_broker_tcp_test.go b/bridge/example_connector_broker_tcp_test.go index 4e66844..48e6f48 100644 --- a/bridge/example_connector_broker_tcp_test.go +++ b/bridge/example_connector_broker_tcp_test.go @@ -4,75 +4,75 @@ package bridge_test import ( - "fmt" - "github.com/pb33f/ranch/bridge" - "github.com/pb33f/ranch/bus" + "fmt" + "github.com/pb33f/ranch/bridge" + "github.com/pb33f/ranch/bus" ) func Example_connectUsingBrokerViaTCP() { - // get a reference to the event bus. - b := bus.GetBus() - - // create a broker connector configuration, using WebSockets. - // Make sure you have a STOMP TCP server running like RabbitMQ - config := &bridge.BrokerConnectorConfig{ - Username: "guest", - Password: "guest", - ServerAddr: ":61613", - STOMPHeader: map[string]string{ - "access-token": "test", - }, - } - - // connect to broker. - c, err := b.ConnectBroker(config) - if err != nil { - fmt.Printf("unable to connect, error: %e", err) - } - defer c.Disconnect() - - // subscribe to our demo simple-stream - s, _ := c.Subscribe("/queue/sample") - - // set a counter - n := 0 - - // create a control chan - done := make(chan bool) - - // listen for messages - var consumer = func() { - for { - // listen for incoming messages from subscription. - m := <-s.GetMsgChannel() - n++ - - // get byte array. - d := m.Payload.([]byte) - - fmt.Printf("Message Received: %s\n", string(d)) - // listen for 5 messages then stop. - if n >= 5 { - break - } - } - done <- true - } - - // send messages - var producer = func() { - for i := 0; i < 5; i++ { - c.SendMessage("/queue/sample", "text/plain", []byte(fmt.Sprintf("message: %d", i))) - } - } - - // listen for incoming messages on subscription for destination /queue/sample - go consumer() - - // send some messages to the broker on destination /queue/sample - go producer() - - // wait for messages to be processed. - <-done + // get a reference to the event bus. + b := bus.GetBus() + + // create a broker connector configuration, using WebSockets. + // Make sure you have a STOMP TCP server running like RabbitMQ + config := &bridge.BrokerConnectorConfig{ + Username: "guest", + Password: "guest", + ServerAddr: ":61613", + STOMPHeader: map[string]string{ + "access-token": "test", + }, + } + + // connect to broker. + c, err := b.ConnectBroker(config) + if err != nil { + fmt.Printf("unable to connect, error: %e", err) + } + defer c.Disconnect() + + // subscribe to our demo simple-stream + s, _ := c.Subscribe("/queue/sample") + + // set a counter + n := 0 + + // create a control chan + done := make(chan bool) + + // listen for messages + var consumer = func() { + for { + // listen for incoming messages from subscription. + m := <-s.GetMsgChannel() + n++ + + // get byte array. + d := m.Payload.([]byte) + + fmt.Printf("Message Received: %s\n", string(d)) + // listen for 5 messages then stop. + if n >= 5 { + break + } + } + done <- true + } + + // send messages + var producer = func() { + for i := 0; i < 5; i++ { + c.SendMessage("/queue/sample", "text/plain", []byte(fmt.Sprintf("message: %d", i))) + } + } + + // listen for incoming messages on subscription for destination /queue/sample + go consumer() + + // send some messages to the broker on destination /queue/sample + go producer() + + // wait for messages to be processed. + <-done } diff --git a/bridge/example_connector_broker_ws_test.go b/bridge/example_connector_broker_ws_test.go index b0ea3e3..bedf202 100644 --- a/bridge/example_connector_broker_ws_test.go +++ b/bridge/example_connector_broker_ws_test.go @@ -4,70 +4,70 @@ package bridge_test import ( - "encoding/json" - "fmt" - "github.com/pb33f/ranch/bridge" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/model" + "encoding/json" + "fmt" + "github.com/pb33f/ranch/bridge" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/model" ) func Example_connectUsingBrokerViaWebSocket() { - // get a reference to the event bus. - b := bus.GetBus() - - // create a broker connector configuration, using WebSockets. - config := &bridge.BrokerConnectorConfig{ - Username: "guest", - Password: "guest", - ServerAddr: "appfabric.vmware.com", - WebSocketConfig: &bridge.WebSocketConfig{WSPath: "/fabric"}, - UseWS: true, - STOMPHeader: map[string]string{ - "access-token": "test", - }, - } - - // connect to broker. - c, err := b.ConnectBroker(config) - if err != nil { - fmt.Printf("unable to connect, error: %e", err) - } - - // subscribe to our demo simple-stream - s, _ := c.Subscribe("/topic/simple-stream") - - // set a counter - n := 0 - - // create a control chan - done := make(chan bool) - - var listener = func() { - for { - // listen for incoming messages from subscription. - m := <-s.GetMsgChannel() - - // unmarshal message. - r := &model.Response{} - d := m.Payload.([]byte) - json.Unmarshal(d, &r) - fmt.Printf("Message Received: %s\n", r.Payload.(string)) - - n++ - - // listen for 5 messages then stop. - if n >= 5 { - break - } - } - done <- true - } - - // listen for incoming messages on subscription. - go listener() - - <-done - - c.Disconnect() + // get a reference to the event bus. + b := bus.GetBus() + + // create a broker connector configuration, using WebSockets. + config := &bridge.BrokerConnectorConfig{ + Username: "guest", + Password: "guest", + ServerAddr: "appfabric.vmware.com", + WebSocketConfig: &bridge.WebSocketConfig{WSPath: "/fabric"}, + UseWS: true, + STOMPHeader: map[string]string{ + "access-token": "test", + }, + } + + // connect to broker. + c, err := b.ConnectBroker(config) + if err != nil { + fmt.Printf("unable to connect, error: %e", err) + } + + // subscribe to our demo simple-stream + s, _ := c.Subscribe("/topic/simple-stream") + + // set a counter + n := 0 + + // create a control chan + done := make(chan bool) + + var listener = func() { + for { + // listen for incoming messages from subscription. + m := <-s.GetMsgChannel() + + // unmarshal message. + r := &model.Response{} + d := m.Payload.([]byte) + json.Unmarshal(d, &r) + fmt.Printf("Message Received: %s\n", r.Payload.(string)) + + n++ + + // listen for 5 messages then stop. + if n >= 5 { + break + } + } + done <- true + } + + // listen for incoming messages on subscription. + go listener() + + <-done + + c.Disconnect() } diff --git a/bridge/subscription.go b/bridge/subscription.go index 4410f65..be116d8 100644 --- a/bridge/subscription.go +++ b/bridge/subscription.go @@ -4,55 +4,55 @@ package bridge import ( - "fmt" - "github.com/go-stomp/stomp/v3" - "github.com/google/uuid" - "github.com/pb33f/ranch/model" + "fmt" + "github.com/go-stomp/stomp/v3" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" ) type Subscription interface { - GetId() *uuid.UUID - GetMsgChannel() chan *model.Message - GetDestination() string - Unsubscribe() error + GetId() *uuid.UUID + GetMsgChannel() chan *model.Message + GetDestination() string + Unsubscribe() error } // Subscription represents a subscription to a broker destination. type subscription struct { - c chan *model.Message // listen to this for incoming messages - id *uuid.UUID - destination string // Destination of where this message was sent. - stompTCPSub *stomp.Subscription - wsStompSub *BridgeClientSub + c chan *model.Message // listen to this for incoming messages + id *uuid.UUID + destination string // Destination of where this message was sent. + stompTCPSub *stomp.Subscription + wsStompSub *BridgeClientSub } func (s *subscription) GetId() *uuid.UUID { - return s.id + return s.id } func (s *subscription) GetMsgChannel() chan *model.Message { - return s.c + return s.c } func (s *subscription) GetDestination() string { - return s.destination + return s.destination } // Unsubscribe from destination. All channels will be closed. func (s *subscription) Unsubscribe() error { - // if we're using TCP - if s.stompTCPSub != nil { - go s.stompTCPSub.Unsubscribe() // local broker hangs, so lets make sure it is non blocking. - close(s.c) - return nil - } - - // if we're using Websockets. - if s.wsStompSub != nil { - s.wsStompSub.Unsubscribe() - close(s.c) - return nil - } - return fmt.Errorf("cannot unsubscribe from destination %s, no connection", s.destination) + // if we're using TCP + if s.stompTCPSub != nil { + go s.stompTCPSub.Unsubscribe() // local broker hangs, so lets make sure it is non blocking. + close(s.c) + return nil + } + + // if we're using Websockets. + if s.wsStompSub != nil { + s.wsStompSub.Unsubscribe() + close(s.c) + return nil + } + return fmt.Errorf("cannot unsubscribe from destination %s, no connection", s.destination) } diff --git a/bus/channel.go b/bus/channel.go index 2d15f7e..a8b41f9 100644 --- a/bus/channel.go +++ b/bus/channel.go @@ -4,220 +4,220 @@ package bus import ( - "github.com/google/uuid" - "github.com/pb33f/ranch/bridge" - "github.com/pb33f/ranch/model" - "sync" - "sync/atomic" + "github.com/google/uuid" + "github.com/pb33f/ranch/bridge" + "github.com/pb33f/ranch/model" + "sync" + "sync/atomic" ) // Channel represents the stream and the subscribed event handlers waiting for ticks on the stream type Channel struct { - Name string `json:"string"` - eventHandlers []*channelEventHandler - galactic bool - galacticMappedDestination string - private bool - channelLock sync.Mutex - wg sync.WaitGroup - brokerSubs []*connectionSub - brokerConns []bridge.Connection - brokerMappedEvent chan bool + Name string `json:"string"` + eventHandlers []*channelEventHandler + galactic bool + galacticMappedDestination string + private bool + channelLock sync.Mutex + wg sync.WaitGroup + brokerSubs []*connectionSub + brokerConns []bridge.Connection + brokerMappedEvent chan bool } // Create a new Channel with the supplied Channel name. Returns a pointer to that Channel. func NewChannel(channelName string) *Channel { - c := &Channel{ - Name: channelName, - eventHandlers: []*channelEventHandler{}, - channelLock: sync.Mutex{}, - galactic: false, - private: false, - wg: sync.WaitGroup{}, - brokerMappedEvent: make(chan bool, 10), - brokerConns: []bridge.Connection{}, - brokerSubs: []*connectionSub{}} - return c + c := &Channel{ + Name: channelName, + eventHandlers: []*channelEventHandler{}, + channelLock: sync.Mutex{}, + galactic: false, + private: false, + wg: sync.WaitGroup{}, + brokerMappedEvent: make(chan bool, 10), + brokerConns: []bridge.Connection{}, + brokerSubs: []*connectionSub{}} + return c } // Mark the Channel as private func (channel *Channel) SetPrivate(private bool) { - channel.private = private + channel.private = private } // Mark the Channel as galactic func (channel *Channel) SetGalactic(mappedDestination string) { - channel.galactic = true - channel.galacticMappedDestination = mappedDestination + channel.galactic = true + channel.galacticMappedDestination = mappedDestination } // Mark the Channel as local func (channel *Channel) SetLocal() { - channel.galactic = false - channel.galacticMappedDestination = "" + channel.galactic = false + channel.galacticMappedDestination = "" } // Returns true is the Channel is marked as galactic func (channel *Channel) IsGalactic() bool { - return channel.galactic + return channel.galactic } // Returns true if the Channel is marked as private func (channel *Channel) IsPrivate() bool { - return channel.private + return channel.private } // Send a new message on this Channel, to all event handlers. func (channel *Channel) Send(message *model.Message) { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() - if eventHandlers := channel.eventHandlers; len(eventHandlers) > 0 { - - // if a handler is run once only, then the slice will be mutated mid cycle. - // copy slice to ensure that removed handler is still fired. - handlerDuplicate := make([]*channelEventHandler, 0, len(eventHandlers)) - handlerDuplicate = append(handlerDuplicate, eventHandlers...) - for n, eventHandler := range handlerDuplicate { - if eventHandler.runOnce && atomic.LoadInt64(&eventHandler.runCount) > 0 { - channel.removeEventHandler(n) // remove from slice. - continue - } - channel.wg.Add(1) - go channel.sendMessageToHandler(eventHandler, message) - } - } + channel.channelLock.Lock() + defer channel.channelLock.Unlock() + if eventHandlers := channel.eventHandlers; len(eventHandlers) > 0 { + + // if a handler is run once only, then the slice will be mutated mid cycle. + // copy slice to ensure that removed handler is still fired. + handlerDuplicate := make([]*channelEventHandler, 0, len(eventHandlers)) + handlerDuplicate = append(handlerDuplicate, eventHandlers...) + for n, eventHandler := range handlerDuplicate { + if eventHandler.runOnce && atomic.LoadInt64(&eventHandler.runCount) > 0 { + channel.removeEventHandler(n) // remove from slice. + continue + } + channel.wg.Add(1) + go channel.sendMessageToHandler(eventHandler, message) + } + } } // Check if the Channel has any registered subscribers func (channel *Channel) ContainsHandlers() bool { - return len(channel.eventHandlers) > 0 + return len(channel.eventHandlers) > 0 } // Send message to handler function func (channel *Channel) sendMessageToHandler(handler *channelEventHandler, message *model.Message) { - handler.callBackFunction(message) - atomic.AddInt64(&handler.runCount, 1) - channel.wg.Done() + handler.callBackFunction(message) + atomic.AddInt64(&handler.runCount, 1) + channel.wg.Done() } // Subscribe a new handler function. func (channel *Channel) subscribeHandler(handler *channelEventHandler) { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() - channel.eventHandlers = append(channel.eventHandlers, handler) + channel.channelLock.Lock() + defer channel.channelLock.Unlock() + channel.eventHandlers = append(channel.eventHandlers, handler) } func (channel *Channel) unsubscribeHandler(uuid *uuid.UUID) bool { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for i, handler := range channel.eventHandlers { - if handler.uuid.String() == uuid.String() { - channel.removeEventHandler(i) - return true - } - } - return false + for i, handler := range channel.eventHandlers { + if handler.uuid.String() == uuid.String() { + channel.removeEventHandler(i) + return true + } + } + return false } // Remove handler function from being subscribed to the Channel. func (channel *Channel) removeEventHandler(index int) { - numHandlers := len(channel.eventHandlers) - if numHandlers <= 0 { - return - } - if index >= numHandlers { - return - } + numHandlers := len(channel.eventHandlers) + if numHandlers <= 0 { + return + } + if index >= numHandlers { + return + } - // delete from event handler slice. - copy(channel.eventHandlers[index:], channel.eventHandlers[index+1:]) - channel.eventHandlers[numHandlers-1] = nil - channel.eventHandlers = channel.eventHandlers[:numHandlers-1] + // delete from event handler slice. + copy(channel.eventHandlers[index:], channel.eventHandlers[index+1:]) + channel.eventHandlers[numHandlers-1] = nil + channel.eventHandlers = channel.eventHandlers[:numHandlers-1] } func (channel *Channel) listenToBrokerSubscription(sub bridge.Subscription) { - for { - msg, m := <-sub.GetMsgChannel() - if m { - channel.Send(msg) - } else { - break - } - } + for { + msg, m := <-sub.GetMsgChannel() + if m { + channel.Send(msg) + } else { + break + } + } } func (channel *Channel) isBrokerSubscribed(sub bridge.Subscription) bool { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for _, cs := range channel.brokerSubs { - if sub.GetId().String() == cs.s.GetId().String() { - return true - } - } - return false + for _, cs := range channel.brokerSubs { + if sub.GetId().String() == cs.s.GetId().String() { + return true + } + } + return false } func (channel *Channel) isBrokerSubscribedToDestination(c bridge.Connection, dest string) bool { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for _, cs := range channel.brokerSubs { - if cs.s != nil && cs.s.GetDestination() == dest && cs.c != nil && cs.c.GetId() == c.GetId() { - return true - } - } - return false + for _, cs := range channel.brokerSubs { + if cs.s != nil && cs.s.GetDestination() == dest && cs.c != nil && cs.c.GetId() == c.GetId() { + return true + } + } + return false } func (channel *Channel) addBrokerConnection(c bridge.Connection) { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for _, brCon := range channel.brokerConns { - if brCon.GetId() == c.GetId() { - return - } - } + for _, brCon := range channel.brokerConns { + if brCon.GetId() == c.GetId() { + return + } + } - channel.brokerConns = append(channel.brokerConns, c) + channel.brokerConns = append(channel.brokerConns, c) } func (channel *Channel) removeBrokerConnections() { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - channel.brokerConns = []bridge.Connection{} + channel.brokerConns = []bridge.Connection{} } func (channel *Channel) addBrokerSubscription(conn bridge.Connection, sub bridge.Subscription) { - cs := &connectionSub{c: conn, s: sub} + cs := &connectionSub{c: conn, s: sub} - channel.channelLock.Lock() - channel.brokerSubs = append(channel.brokerSubs, cs) - channel.channelLock.Unlock() + channel.channelLock.Lock() + channel.brokerSubs = append(channel.brokerSubs, cs) + channel.channelLock.Unlock() - go channel.listenToBrokerSubscription(sub) + go channel.listenToBrokerSubscription(sub) } func (channel *Channel) removeBrokerSubscription(sub bridge.Subscription) { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for i, cs := range channel.brokerSubs { - if sub.GetId().String() == cs.s.GetId().String() { - channel.brokerSubs = removeSub(channel.brokerSubs, i) - } - } + for i, cs := range channel.brokerSubs { + if sub.GetId().String() == cs.s.GetId().String() { + channel.brokerSubs = removeSub(channel.brokerSubs, i) + } + } } func removeSub(s []*connectionSub, i int) []*connectionSub { - s[len(s)-1], s[i] = s[i], s[len(s)-1] - return s[:len(s)-1] + s[len(s)-1], s[i] = s[i], s[len(s)-1] + return s[:len(s)-1] } type connectionSub struct { - c bridge.Connection - s bridge.Subscription + c bridge.Connection + s bridge.Subscription } diff --git a/bus/channel_manager.go b/bus/channel_manager.go index 049a1df..000e7cd 100644 --- a/bus/channel_manager.go +++ b/bus/channel_manager.go @@ -4,206 +4,206 @@ package bus import ( - "errors" - "fmt" - "github.com/google/uuid" - "github.com/pb33f/ranch/bridge" - "github.com/pb33f/ranch/model" - "sync" + "errors" + "fmt" + "github.com/google/uuid" + "github.com/pb33f/ranch/bridge" + "github.com/pb33f/ranch/model" + "sync" ) // ChannelManager interfaces controls all access to channels vis the bus. type ChannelManager interface { - CreateChannel(channelName string) *Channel - DestroyChannel(channelName string) - CheckChannelExists(channelName string) bool - GetChannel(channelName string) (*Channel, error) - GetAllChannels() map[string]*Channel - SubscribeChannelHandler(channelName string, fn MessageHandlerFunction, runOnce bool) (*uuid.UUID, error) - UnsubscribeChannelHandler(channelName string, id *uuid.UUID) error - WaitForChannel(channelName string) error - MarkChannelAsGalactic(channelName string, brokerDestination string, connection bridge.Connection) (err error) - MarkChannelAsLocal(channelName string) (err error) + CreateChannel(channelName string) *Channel + DestroyChannel(channelName string) + CheckChannelExists(channelName string) bool + GetChannel(channelName string) (*Channel, error) + GetAllChannels() map[string]*Channel + SubscribeChannelHandler(channelName string, fn MessageHandlerFunction, runOnce bool) (*uuid.UUID, error) + UnsubscribeChannelHandler(channelName string, id *uuid.UUID) error + WaitForChannel(channelName string) error + MarkChannelAsGalactic(channelName string, brokerDestination string, connection bridge.Connection) (err error) + MarkChannelAsLocal(channelName string) (err error) } func NewBusChannelManager(bus EventBus) ChannelManager { - manager := new(busChannelManager) - manager.Channels = make(map[string]*Channel) - manager.bus = bus.(*transportEventBus) - return manager + manager := new(busChannelManager) + manager.Channels = make(map[string]*Channel) + manager.bus = bus.(*transportEventBus) + return manager } type busChannelManager struct { - Channels map[string]*Channel - bus *transportEventBus - lock sync.RWMutex + Channels map[string]*Channel + bus *transportEventBus + lock sync.RWMutex } // Create a new Channel with the supplied Channel name. Returns pointer to new Channel object func (manager *busChannelManager) CreateChannel(channelName string) *Channel { - manager.lock.Lock() - defer manager.lock.Unlock() + manager.lock.Lock() + defer manager.lock.Unlock() - channel, ok := manager.Channels[channelName] - if ok { - return channel - } + channel, ok := manager.Channels[channelName] + if ok { + return channel + } - manager.Channels[channelName] = NewChannel(channelName) - go manager.bus.SendMonitorEvent(ChannelCreatedEvt, channelName, nil) - return manager.Channels[channelName] + manager.Channels[channelName] = NewChannel(channelName) + go manager.bus.SendMonitorEvent(ChannelCreatedEvt, channelName, nil) + return manager.Channels[channelName] } // Destroy a Channel and all the handlers listening on it. func (manager *busChannelManager) DestroyChannel(channelName string) { - manager.lock.Lock() - defer manager.lock.Unlock() + manager.lock.Lock() + defer manager.lock.Unlock() - delete(manager.Channels, channelName) - go manager.bus.SendMonitorEvent(ChannelDestroyedEvt, channelName, nil) + delete(manager.Channels, channelName) + go manager.bus.SendMonitorEvent(ChannelDestroyedEvt, channelName, nil) } // Get a pointer to a Channel by name. Returns points, or error if no Channel is found. func (manager *busChannelManager) GetChannel(channelName string) (*Channel, error) { - manager.lock.RLock() - defer manager.lock.RUnlock() + manager.lock.RLock() + defer manager.lock.RUnlock() - if channel, ok := manager.Channels[channelName]; ok { - return channel, nil - } else { - return nil, errors.New("Channel does not exist: " + channelName) - } + if channel, ok := manager.Channels[channelName]; ok { + return channel, nil + } else { + return nil, errors.New("Channel does not exist: " + channelName) + } } // Get all channels currently open. Returns a map of Channel names and pointers to those Channel objects. func (manager *busChannelManager) GetAllChannels() map[string]*Channel { - return manager.Channels + return manager.Channels } // Check Channel exists, returns true if so. func (manager *busChannelManager) CheckChannelExists(channelName string) bool { - manager.lock.RLock() - defer manager.lock.RUnlock() + manager.lock.RLock() + defer manager.lock.RUnlock() - return manager.Channels[channelName] != nil + return manager.Channels[channelName] != nil } // Subscribe new handler lambda for Channel, bool flag runOnce determines if this is a single Fire handler. // Returns UUID pointer, or error if there is no Channel by that name. func (manager *busChannelManager) SubscribeChannelHandler(channelName string, fn MessageHandlerFunction, runOnce bool) (*uuid.UUID, error) { - channel, err := manager.GetChannel(channelName) - if err != nil { - return nil, err - } - id := uuid.New() - channel.subscribeHandler(&channelEventHandler{callBackFunction: fn, runOnce: runOnce, uuid: &id}) - manager.bus.SendMonitorEvent(ChannelSubscriberJoinedEvt, channelName, nil) - return &id, nil + channel, err := manager.GetChannel(channelName) + if err != nil { + return nil, err + } + id := uuid.New() + channel.subscribeHandler(&channelEventHandler{callBackFunction: fn, runOnce: runOnce, uuid: &id}) + manager.bus.SendMonitorEvent(ChannelSubscriberJoinedEvt, channelName, nil) + return &id, nil } // Unsubscribe a handler for a Channel event handler. func (manager *busChannelManager) UnsubscribeChannelHandler(channelName string, uuid *uuid.UUID) error { - channel, err := manager.GetChannel(channelName) - if err != nil { - return err - } - found := channel.unsubscribeHandler(uuid) - if !found { - return fmt.Errorf("no handler in Channel '%s' for uuid [%s]", channelName, uuid) - } - manager.bus.SendMonitorEvent(ChannelSubscriberLeftEvt, channelName, nil) - return nil + channel, err := manager.GetChannel(channelName) + if err != nil { + return err + } + found := channel.unsubscribeHandler(uuid) + if !found { + return fmt.Errorf("no handler in Channel '%s' for uuid [%s]", channelName, uuid) + } + manager.bus.SendMonitorEvent(ChannelSubscriberLeftEvt, channelName, nil) + return nil } func (manager *busChannelManager) WaitForChannel(channelName string) error { - channel, _ := manager.GetChannel(channelName) - if channel == nil { - return fmt.Errorf("no such Channel as '%s'", channelName) - } - channel.wg.Wait() - return nil + channel, _ := manager.GetChannel(channelName) + if channel == nil { + return fmt.Errorf("no such Channel as '%s'", channelName) + } + channel.wg.Wait() + return nil } // Mark a channel as Galactic. This will map this channel to the supplied broker destination, if the broker connector // is active and connected, this will result in a subscription to the broker destination being created. Returns // an error if the channel does not exist. func (manager *busChannelManager) MarkChannelAsGalactic(channelName string, dest string, conn bridge.Connection) (err error) { - channel, err := manager.GetChannel(channelName) - if err != nil { - return - } + channel, err := manager.GetChannel(channelName) + if err != nil { + return + } - // mark as galactic/ - channel.SetGalactic(dest) + // mark as galactic/ + channel.SetGalactic(dest) - // create a galactic event - pl := &galacticEvent{conn: conn, dest: dest} + // create a galactic event + pl := &galacticEvent{conn: conn, dest: dest} - manager.handleGalacticChannelEvent(channelName, pl) - return nil + manager.handleGalacticChannelEvent(channelName, pl) + return nil } // Mark a channel as Local. This will unmap the channel from the broker destination, and perform an unsubscribe // operation if the broker connector is active and connected. Returns an error if the channel does not exist. func (manager *busChannelManager) MarkChannelAsLocal(channelName string) (err error) { - channel, err := manager.GetChannel(channelName) - if err != nil { - return - } - channel.SetLocal() + channel, err := manager.GetChannel(channelName) + if err != nil { + return + } + channel.SetLocal() - // get rid of all broker connections. - channel.removeBrokerConnections() + // get rid of all broker connections. + channel.removeBrokerConnections() - manager.handleLocalChannelEvent(channelName) + manager.handleLocalChannelEvent(channelName) - return nil + return nil } func (manager *busChannelManager) handleGalacticChannelEvent(channelName string, ge *galacticEvent) { - ch, _ := manager.GetChannel(channelName) + ch, _ := manager.GetChannel(channelName) - if ge.conn == nil { - return - } + if ge.conn == nil { + return + } - // check if channel is already subscribed on this connection - if !ch.isBrokerSubscribedToDestination(ge.conn, ge.dest) { - if sub, e := ge.conn.Subscribe(ge.dest); e == nil { + // check if channel is already subscribed on this connection + if !ch.isBrokerSubscribedToDestination(ge.conn, ge.dest) { + if sub, e := ge.conn.Subscribe(ge.dest); e == nil { - // add broker connection to channel. - ch.addBrokerConnection(ge.conn) + // add broker connection to channel. + ch.addBrokerConnection(ge.conn) - m := model.GenerateResponse(&model.MessageConfig{Payload: ge.dest}) // set the mapped destination as the payload - ch.addBrokerSubscription(ge.conn, sub) - manager.bus.SendMonitorEvent(BrokerSubscribedEvt, channelName, m) - select { - case ch.brokerMappedEvent <- true: // let channel watcher know, the channel is mapped - default: // if no-one is listening, drop. - } - } - } + m := model.GenerateResponse(&model.MessageConfig{Payload: ge.dest}) // set the mapped destination as the payload + ch.addBrokerSubscription(ge.conn, sub) + manager.bus.SendMonitorEvent(BrokerSubscribedEvt, channelName, m) + select { + case ch.brokerMappedEvent <- true: // let channel watcher know, the channel is mapped + default: // if no-one is listening, drop. + } + } + } } func (manager *busChannelManager) handleLocalChannelEvent(channelName string) { - ch, _ := manager.GetChannel(channelName) - // loop through all the connections we have mapped, and subscribe! - for _, s := range ch.brokerSubs { - if e := s.s.Unsubscribe(); e == nil { - ch.removeBrokerSubscription(s.s) - m := model.GenerateResponse(&model.MessageConfig{Payload: s.s.GetDestination()}) // set the unmapped destination as the payload - manager.bus.SendMonitorEvent(BrokerUnsubscribedEvt, channelName, m) - select { - case ch.brokerMappedEvent <- false: // let channel watcher know, the channel is un-mapped - default: // if no-one is listening, drop. - } - } - } - // get rid of all broker subscriptions on this channel. - ch.removeBrokerConnections() + ch, _ := manager.GetChannel(channelName) + // loop through all the connections we have mapped, and subscribe! + for _, s := range ch.brokerSubs { + if e := s.s.Unsubscribe(); e == nil { + ch.removeBrokerSubscription(s.s) + m := model.GenerateResponse(&model.MessageConfig{Payload: s.s.GetDestination()}) // set the unmapped destination as the payload + manager.bus.SendMonitorEvent(BrokerUnsubscribedEvt, channelName, m) + select { + case ch.brokerMappedEvent <- false: // let channel watcher know, the channel is un-mapped + default: // if no-one is listening, drop. + } + } + } + // get rid of all broker subscriptions on this channel. + ch.removeBrokerConnections() } type galacticEvent struct { - conn bridge.Connection - dest string + conn bridge.Connection + dest string } diff --git a/bus/channel_manager_test.go b/bus/channel_manager_test.go index ffca84f..effe418 100644 --- a/bus/channel_manager_test.go +++ b/bus/channel_manager_test.go @@ -4,259 +4,259 @@ package bus import ( - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "sync" - "testing" - "time" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "sync" + "testing" + "time" ) var testChannelManager ChannelManager var testChannelManagerChannelName = "melody" func createManager() (ChannelManager, EventBus) { - b := newTestEventBus() - manager := NewBusChannelManager(b) - return manager, b + b := newTestEventBus() + manager := NewBusChannelManager(b) + return manager, b } func TestChannelManager_Boot(t *testing.T) { - testChannelManager, _ = createManager() - assert.Len(t, testChannelManager.GetAllChannels(), 0) + testChannelManager, _ = createManager() + assert.Len(t, testChannelManager.GetAllChannels(), 0) } func TestChannelManager_CreateChannel(t *testing.T) { - var bus EventBus - testChannelManager, bus = createManager() + var bus EventBus + testChannelManager, bus = createManager() - wg := sync.WaitGroup{} - wg.Add(1) - bus.AddMonitorEventListener( - func(monitorEvt *MonitorEvent) { - if monitorEvt.EntityName == testChannelManagerChannelName { - assert.Equal(t, monitorEvt.EventType, ChannelCreatedEvt) - wg.Done() - } - }) + wg := sync.WaitGroup{} + wg.Add(1) + bus.AddMonitorEventListener( + func(monitorEvt *MonitorEvent) { + if monitorEvt.EntityName == testChannelManagerChannelName { + assert.Equal(t, monitorEvt.EventType, ChannelCreatedEvt) + wg.Done() + } + }) - testChannelManager.CreateChannel(testChannelManagerChannelName) + testChannelManager.CreateChannel(testChannelManagerChannelName) - wg.Wait() + wg.Wait() - assert.Len(t, testChannelManager.GetAllChannels(), 1) + assert.Len(t, testChannelManager.GetAllChannels(), 1) - fetchedChannel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.NotNil(t, fetchedChannel) - assert.True(t, testChannelManager.CheckChannelExists(testChannelManagerChannelName)) + fetchedChannel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.NotNil(t, fetchedChannel) + assert.True(t, testChannelManager.CheckChannelExists(testChannelManagerChannelName)) } func TestChannelManager_GetNotExistentChannel(t *testing.T) { - testChannelManager, _ = createManager() + testChannelManager, _ = createManager() - fetchedChannel, err := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.NotNil(t, err) - assert.Nil(t, fetchedChannel) + fetchedChannel, err := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.NotNil(t, err) + assert.Nil(t, fetchedChannel) } func TestChannelManager_DestroyChannel(t *testing.T) { - testChannelManager, _ = createManager() - - testChannelManager.CreateChannel(testChannelManagerChannelName) - testChannelManager.DestroyChannel(testChannelManagerChannelName) - fetchedChannel, err := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.Len(t, testChannelManager.GetAllChannels(), 0) - assert.NotNil(t, err) - assert.Nil(t, fetchedChannel) + testChannelManager, _ = createManager() + + testChannelManager.CreateChannel(testChannelManagerChannelName) + testChannelManager.DestroyChannel(testChannelManagerChannelName) + fetchedChannel, err := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.Len(t, testChannelManager.GetAllChannels(), 0) + assert.NotNil(t, err) + assert.Nil(t, fetchedChannel) } func TestChannelManager_SubscribeChannelHandler(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel(testChannelManagerChannelName) - - handler := func(*model.Message) {} - uuid, err := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) - assert.Nil(t, err) - assert.NotNil(t, uuid) - channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.Len(t, channel.eventHandlers, 1) + testChannelManager, _ = createManager() + testChannelManager.CreateChannel(testChannelManagerChannelName) + + handler := func(*model.Message) {} + uuid, err := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) + assert.Nil(t, err) + assert.NotNil(t, uuid) + channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.Len(t, channel.eventHandlers, 1) } func TestChannelManager_SubscribeChannelHandlerMissingChannel(t *testing.T) { - testChannelManager, _ = createManager() - handler := func(*model.Message) {} - _, err := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) - assert.NotNil(t, err) + testChannelManager, _ = createManager() + handler := func(*model.Message) {} + _, err := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) + assert.NotNil(t, err) } func TestChannelManager_UnsubscribeChannelHandler(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel(testChannelManagerChannelName) + testChannelManager, _ = createManager() + testChannelManager.CreateChannel(testChannelManagerChannelName) - handler := func(*model.Message) {} - uuid, _ := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) - channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.Len(t, channel.eventHandlers, 1) + handler := func(*model.Message) {} + uuid, _ := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) + channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.Len(t, channel.eventHandlers, 1) - err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, uuid) - assert.Nil(t, err) - assert.Len(t, channel.eventHandlers, 0) + err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, uuid) + assert.Nil(t, err) + assert.Len(t, channel.eventHandlers, 0) } func TestChannelManager_UnsubscribeChannelHandlerMissingChannel(t *testing.T) { - testChannelManager, _ = createManager() - uuid := uuid.New() - err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, &uuid) - assert.NotNil(t, err) + testChannelManager, _ = createManager() + uuid := uuid.New() + err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, &uuid) + assert.NotNil(t, err) } func TestChannelManager_UnsubscribeChannelHandlerNoId(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel(testChannelManagerChannelName) - - handler := func(*model.Message) {} - testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) - channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.Len(t, channel.eventHandlers, 1) - id := uuid.New() - err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, &id) - assert.NotNil(t, err) - assert.Len(t, channel.eventHandlers, 1) + testChannelManager, _ = createManager() + testChannelManager.CreateChannel(testChannelManagerChannelName) + + handler := func(*model.Message) {} + testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) + channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.Len(t, channel.eventHandlers, 1) + id := uuid.New() + err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, &id) + assert.NotNil(t, err) + assert.Len(t, channel.eventHandlers, 1) } func TestChannelManager_TestWaitForGroupOnBadChannel(t *testing.T) { - testChannelManager, _ = createManager() - err := testChannelManager.WaitForChannel("unknown") - assert.Error(t, err, "no such Channel as 'unknown'") + testChannelManager, _ = createManager() + err := testChannelManager.WaitForChannel("unknown") + assert.Error(t, err, "no such Channel as 'unknown'") } func TestChannelManager_TestGalacticChannelOpen(t *testing.T) { - testChannelManager, _ = createManager() - galacticChannel := testChannelManager.CreateChannel(testChannelManagerChannelName) - id := uuid.New() + testChannelManager, _ = createManager() + galacticChannel := testChannelManager.CreateChannel(testChannelManagerChannelName) + id := uuid.New() - // mark channel as galactic. + // mark channel as galactic. - subId := uuid.New() - sub := &MockBridgeSubscription{ - Id: &subId, - } + subId := uuid.New() + sub := &MockBridgeSubscription{ + Id: &subId, + } - c := &MockBridgeConnection{Id: &id} - c.On("Subscribe", "/topic/testy-test").Return(sub, nil).Once() - e := testChannelManager.MarkChannelAsGalactic(testChannelManagerChannelName, "/topic/testy-test", c) + c := &MockBridgeConnection{Id: &id} + c.On("Subscribe", "/topic/testy-test").Return(sub, nil).Once() + e := testChannelManager.MarkChannelAsGalactic(testChannelManagerChannelName, "/topic/testy-test", c) - assert.Nil(t, e) - c.AssertExpectations(t) + assert.Nil(t, e) + c.AssertExpectations(t) - assert.True(t, galacticChannel.galactic) + assert.True(t, galacticChannel.galactic) - assert.Equal(t, len(galacticChannel.brokerConns), 1) - assert.Equal(t, galacticChannel.brokerConns[0], c) + assert.Equal(t, len(galacticChannel.brokerConns), 1) + assert.Equal(t, galacticChannel.brokerConns[0], c) - assert.Equal(t, len(galacticChannel.brokerSubs), 1) - assert.Equal(t, galacticChannel.brokerSubs[0].s, sub) - assert.Equal(t, galacticChannel.brokerSubs[0].c, c) + assert.Equal(t, len(galacticChannel.brokerSubs), 1) + assert.Equal(t, galacticChannel.brokerSubs[0].s, sub) + assert.Equal(t, galacticChannel.brokerSubs[0].c, c) - testChannelManager.MarkChannelAsLocal(testChannelManagerChannelName) - assert.False(t, galacticChannel.galactic) + testChannelManager.MarkChannelAsLocal(testChannelManagerChannelName) + assert.False(t, galacticChannel.galactic) - assert.Equal(t, len(galacticChannel.brokerConns), 0) - assert.Equal(t, len(galacticChannel.brokerSubs), 0) + assert.Equal(t, len(galacticChannel.brokerConns), 0) + assert.Equal(t, len(galacticChannel.brokerSubs), 0) } func TestChannelManager_TestGalacticChannelOpenError(t *testing.T) { - // channel is not open / does not exist, so this should fail. - e := testChannelManager.MarkChannelAsGalactic(evtbusTestChannelName, "/topic/testy-test", nil) - assert.Error(t, e) + // channel is not open / does not exist, so this should fail. + e := testChannelManager.MarkChannelAsGalactic(evtbusTestChannelName, "/topic/testy-test", nil) + assert.Error(t, e) } func TestChannelManager_TestGalacticChannelCloseError(t *testing.T) { - // channel is not open / does not exist, so this should fail. - e := testChannelManager.MarkChannelAsLocal(evtbusTestChannelName) - assert.Error(t, e) + // channel is not open / does not exist, so this should fail. + e := testChannelManager.MarkChannelAsLocal(evtbusTestChannelName) + assert.Error(t, e) } func TestChannelManager_TestListenToMonitorGalactic(t *testing.T) { - myChan := "mychan" + myChan := "mychan" - b := newTestEventBus() + b := newTestEventBus() - testChannelManager = b.GetChannelManager() - c := testChannelManager.CreateChannel(myChan) + testChannelManager = b.GetChannelManager() + c := testChannelManager.CreateChannel(myChan) - // mark channel as galactic. - id := uuid.New() - subId := uuid.New() - mockSub := &MockBridgeSubscription{ - Id: &subId, - Channel: make(chan *model.Message, 10), - Destination: "/queue/hiya", - } + // mark channel as galactic. + id := uuid.New() + subId := uuid.New() + mockSub := &MockBridgeSubscription{ + Id: &subId, + Channel: make(chan *model.Message, 10), + Destination: "/queue/hiya", + } - mockCon := &MockBridgeConnection{Id: &id} - mockCon.On("Subscribe", "/queue/hiya").Return(mockSub, nil).Once() + mockCon := &MockBridgeConnection{Id: &id} + mockCon.On("Subscribe", "/queue/hiya").Return(mockSub, nil).Once() - x := 0 + x := 0 - h, e := b.ListenOnce(myChan) - assert.Nil(t, e) + h, e := b.ListenOnce(myChan) + assert.Nil(t, e) - var m1 = make(chan bool) - var m2 = make(chan bool) + var m1 = make(chan bool) + var m2 = make(chan bool) - h.Handle( - func(msg *model.Message) { - x++ - m1 <- true - }, - func(err error) { + h.Handle( + func(msg *model.Message) { + x++ + m1 <- true + }, + func(err error) { - }) + }) - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon) - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon) // double up for fun - <-c.brokerMappedEvent - assert.Len(t, c.brokerConns, 1) - mockSub.GetMsgChannel() <- &model.Message{Payload: "test-message", Direction: model.ResponseDir} - <-m1 + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon) + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon) // double up for fun + <-c.brokerMappedEvent + assert.Len(t, c.brokerConns, 1) + mockSub.GetMsgChannel() <- &model.Message{Payload: "test-message", Direction: model.ResponseDir} + <-m1 - // lets add another connection to the same channel. + // lets add another connection to the same channel. - id2 := uuid.New() - subId2 := uuid.New() - mockSub2 := &MockBridgeSubscription{ - Id: &subId2, - Channel: make(chan *model.Message, 10), - Destination: "/queue/hiya", - } + id2 := uuid.New() + subId2 := uuid.New() + mockSub2 := &MockBridgeSubscription{ + Id: &subId2, + Channel: make(chan *model.Message, 10), + Destination: "/queue/hiya", + } - mockCon2 := &MockBridgeConnection{Id: &id2} - mockCon2.On("Subscribe", "/queue/hiya").Return(mockSub2, nil).Once() + mockCon2 := &MockBridgeConnection{Id: &id2} + mockCon2.On("Subscribe", "/queue/hiya").Return(mockSub2, nil).Once() - h, e = b.ListenOnce(myChan) + h, e = b.ListenOnce(myChan) - h.Handle( - func(msg *model.Message) { - x++ - m2 <- true - }, - func(err error) {}) + h.Handle( + func(msg *model.Message) { + x++ + m2 <- true + }, + func(err error) {}) - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon2) - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon2) // trigger double (should ignore) + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon2) + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon2) // trigger double (should ignore) - select { - case <-c.brokerMappedEvent: - case <-time.After(5 * time.Second): - assert.FailNow(t, "TestChannelManager_TestListenToMonitorGalactic timeout on brokerMappedEvent") - } + select { + case <-c.brokerMappedEvent: + case <-time.After(5 * time.Second): + assert.FailNow(t, "TestChannelManager_TestListenToMonitorGalactic timeout on brokerMappedEvent") + } - mockSub.GetMsgChannel() <- &model.Message{Payload: "Hi baby melody!", Direction: model.ResponseDir} + mockSub.GetMsgChannel() <- &model.Message{Payload: "Hi baby melody!", Direction: model.ResponseDir} - <-m2 - assert.Equal(t, 2, x) + <-m2 + assert.Equal(t, 2, x) } // This test performs a end to end run of the monitor. @@ -264,47 +264,47 @@ func TestChannelManager_TestListenToMonitorGalactic(t *testing.T) { // then it will unsubscribe and check that the unsubscription went through ok. func TestChannelManager_TestListenToMonitorLocal(t *testing.T) { - myChan := "mychan-local" + myChan := "mychan-local" - b := newTestEventBus() + b := newTestEventBus() - // run ws broker - testChannelManager = b.GetChannelManager() + // run ws broker + testChannelManager = b.GetChannelManager() - c := testChannelManager.CreateChannel(myChan) + c := testChannelManager.CreateChannel(myChan) - subId := uuid.New() - sub := &MockBridgeSubscription{ - Id: &subId, - } + subId := uuid.New() + sub := &MockBridgeSubscription{ + Id: &subId, + } - id := uuid.New() - mockCon := &MockBridgeConnection{Id: &id} - mockCon.On("Subscribe", "/queue/seeya").Return(sub, nil).Once() + id := uuid.New() + mockCon := &MockBridgeConnection{Id: &id} + mockCon.On("Subscribe", "/queue/seeya").Return(sub, nil).Once() - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/seeya", mockCon) - <-c.brokerMappedEvent - assert.Len(t, c.brokerConns, 1) + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/seeya", mockCon) + <-c.brokerMappedEvent + assert.Len(t, c.brokerConns, 1) - testChannelManager.MarkChannelAsLocal(myChan) - <-c.brokerMappedEvent - assert.Len(t, c.brokerConns, 0) - assert.Len(t, c.brokerSubs, 0) + testChannelManager.MarkChannelAsLocal(myChan) + <-c.brokerMappedEvent + assert.Len(t, c.brokerConns, 0) + assert.Len(t, c.brokerSubs, 0) } func TestChannelManager_TestGalacticMonitorInvalidChannel(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel("fun-chan") + testChannelManager, _ = createManager() + testChannelManager.CreateChannel("fun-chan") - err := testChannelManager.MarkChannelAsGalactic("fun-chan", "/queue/woo", nil) - assert.Nil(t, err) + err := testChannelManager.MarkChannelAsGalactic("fun-chan", "/queue/woo", nil) + assert.Nil(t, err) } func TestChannelManager_TestLocalMonitorInvalidChannel(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel("fun-chan") + testChannelManager, _ = createManager() + testChannelManager.CreateChannel("fun-chan") - err := testChannelManager.MarkChannelAsLocal("fun-chan") - assert.Nil(t, err) + err := testChannelManager.MarkChannelAsLocal("fun-chan") + assert.Nil(t, err) } diff --git a/bus/channel_test.go b/bus/channel_test.go index 684335d..bf16147 100644 --- a/bus/channel_test.go +++ b/bus/channel_test.go @@ -4,300 +4,300 @@ package bus import ( - "github.com/go-stomp/stomp/v3/frame" - "github.com/google/uuid" - "github.com/pb33f/ranch/bridge" - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "testing" + "github.com/go-stomp/stomp/v3/frame" + "github.com/google/uuid" + "github.com/pb33f/ranch/bridge" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "testing" ) var testChannelName string = "testing" func TestChannel_CheckChannelCreation(t *testing.T) { - channel := NewChannel(testChannelName) - assert.Empty(t, channel.eventHandlers) + channel := NewChannel(testChannelName) + assert.Empty(t, channel.eventHandlers) } func TestChannel_SubscribeHandler(t *testing.T) { - id := uuid.New() - channel := NewChannel(testChannelName) - handler := func(*model.Message) {} - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) + id := uuid.New() + channel := NewChannel(testChannelName) + handler := func(*model.Message) {} + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) - assert.Equal(t, 1, len(channel.eventHandlers)) + assert.Equal(t, 1, len(channel.eventHandlers)) - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) - assert.Equal(t, 2, len(channel.eventHandlers)) + assert.Equal(t, 2, len(channel.eventHandlers)) } func TestChannel_HandlerCheck(t *testing.T) { - channel := NewChannel(testChannelName) - id := uuid.New() - assert.False(t, channel.ContainsHandlers()) + channel := NewChannel(testChannelName) + id := uuid.New() + assert.False(t, channel.ContainsHandlers()) - handler := func(*model.Message) {} - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) + handler := func(*model.Message) {} + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) - assert.True(t, channel.ContainsHandlers()) + assert.True(t, channel.ContainsHandlers()) } func TestChannel_SendMessage(t *testing.T) { - id := uuid.New() - channel := NewChannel(testChannelName) - handler := func(message *model.Message) { - assert.Equal(t, message.Payload.(string), "pickled eggs") - assert.Equal(t, message.Channel, testChannelName) - } - - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) - - var message = &model.Message{ - Id: &id, - Payload: "pickled eggs", - Channel: testChannelName, - Direction: model.RequestDir} - - channel.Send(message) - channel.wg.Wait() + id := uuid.New() + channel := NewChannel(testChannelName) + handler := func(message *model.Message) { + assert.Equal(t, message.Payload.(string), "pickled eggs") + assert.Equal(t, message.Channel, testChannelName) + } + + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) + + var message = &model.Message{ + Id: &id, + Payload: "pickled eggs", + Channel: testChannelName, + Direction: model.RequestDir} + + channel.Send(message) + channel.wg.Wait() } func TestChannel_SendMessageRunOnceHasRun(t *testing.T) { - id := uuid.New() - channel := NewChannel(testChannelName) - count := 0 - handler := func(message *model.Message) { - assert.Equal(t, message.Payload.(string), "pickled eggs") - assert.Equal(t, message.Channel, testChannelName) - count++ - } - - h := &channelEventHandler{callBackFunction: handler, runOnce: true, uuid: &id} - channel.subscribeHandler(h) - - var message = &model.Message{ - Id: &id, - Payload: "pickled eggs", - Channel: testChannelName, - Direction: model.RequestDir} - - channel.Send(message) - channel.wg.Wait() - h.runCount = 1 - channel.Send(message) - assert.Len(t, channel.eventHandlers, 0) - assert.Equal(t, 1, count) + id := uuid.New() + channel := NewChannel(testChannelName) + count := 0 + handler := func(message *model.Message) { + assert.Equal(t, message.Payload.(string), "pickled eggs") + assert.Equal(t, message.Channel, testChannelName) + count++ + } + + h := &channelEventHandler{callBackFunction: handler, runOnce: true, uuid: &id} + channel.subscribeHandler(h) + + var message = &model.Message{ + Id: &id, + Payload: "pickled eggs", + Channel: testChannelName, + Direction: model.RequestDir} + + channel.Send(message) + channel.wg.Wait() + h.runCount = 1 + channel.Send(message) + assert.Len(t, channel.eventHandlers, 0) + assert.Equal(t, 1, count) } func TestChannel_SendMultipleMessages(t *testing.T) { - id := uuid.New() - channel := NewChannel(testChannelName) - var counter int32 = 0 - handler := func(message *model.Message) { - assert.Equal(t, message.Payload.(string), "chewy louie") - assert.Equal(t, message.Channel, testChannelName) - inc(&counter) - } - - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) - var message = &model.Message{ - Id: &id, - Payload: "chewy louie", - Channel: testChannelName, - Direction: model.RequestDir} - - channel.Send(message) - channel.Send(message) - channel.Send(message) - channel.wg.Wait() - assert.Equal(t, int32(3), counter) + id := uuid.New() + channel := NewChannel(testChannelName) + var counter int32 = 0 + handler := func(message *model.Message) { + assert.Equal(t, message.Payload.(string), "chewy louie") + assert.Equal(t, message.Channel, testChannelName) + inc(&counter) + } + + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) + var message = &model.Message{ + Id: &id, + Payload: "chewy louie", + Channel: testChannelName, + Direction: model.RequestDir} + + channel.Send(message) + channel.Send(message) + channel.Send(message) + channel.wg.Wait() + assert.Equal(t, int32(3), counter) } func TestChannel_MultiHandlerSingleMessage(t *testing.T) { - id := uuid.New() - channel := NewChannel(testChannelName) - var counterA, counterB, counterC int32 = 0, 0, 0 - - handlerA := func(message *model.Message) { - inc(&counterA) - } - handlerB := func(message *model.Message) { - inc(&counterB) - } - handlerC := func(message *model.Message) { - inc(&counterC) - } - - channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerA, runOnce: false, uuid: &id}) - - channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerB, runOnce: false, uuid: &id}) - - channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerC, runOnce: false, uuid: &id}) - - var message = &model.Message{ - Id: &id, - Payload: "late night munchies", - Channel: testChannelName, - Direction: model.RequestDir} - - channel.Send(message) - channel.Send(message) - channel.Send(message) - channel.wg.Wait() - value := counterA + counterB + counterC - - assert.Equal(t, int32(9), value) + id := uuid.New() + channel := NewChannel(testChannelName) + var counterA, counterB, counterC int32 = 0, 0, 0 + + handlerA := func(message *model.Message) { + inc(&counterA) + } + handlerB := func(message *model.Message) { + inc(&counterB) + } + handlerC := func(message *model.Message) { + inc(&counterC) + } + + channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerA, runOnce: false, uuid: &id}) + + channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerB, runOnce: false, uuid: &id}) + + channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerC, runOnce: false, uuid: &id}) + + var message = &model.Message{ + Id: &id, + Payload: "late night munchies", + Channel: testChannelName, + Direction: model.RequestDir} + + channel.Send(message) + channel.Send(message) + channel.Send(message) + channel.wg.Wait() + value := counterA + counterB + counterC + + assert.Equal(t, int32(9), value) } func TestChannel_Privacy(t *testing.T) { - channel := NewChannel(testChannelName) - assert.False(t, channel.private) - channel.SetPrivate(true) - assert.True(t, channel.IsPrivate()) + channel := NewChannel(testChannelName) + assert.False(t, channel.private) + channel.SetPrivate(true) + assert.True(t, channel.IsPrivate()) } func TestChannel_ChannelGalactic(t *testing.T) { - channel := NewChannel(testChannelName) - assert.False(t, channel.galactic) - channel.SetGalactic("somewhere") - assert.True(t, channel.IsGalactic()) + channel := NewChannel(testChannelName) + assert.False(t, channel.galactic) + channel.SetGalactic("somewhere") + assert.True(t, channel.IsGalactic()) } func TestChannel_RemoveEventHandler(t *testing.T) { - channel := NewChannel(testChannelName) - handlerA := func(message *model.Message) {} - handlerB := func(message *model.Message) {} + channel := NewChannel(testChannelName) + handlerA := func(message *model.Message) {} + handlerB := func(message *model.Message) {} - idA := uuid.New() - idB := uuid.New() + idA := uuid.New() + idB := uuid.New() - channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerA, runOnce: false, uuid: &idA}) + channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerA, runOnce: false, uuid: &idA}) - channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerB, runOnce: false, uuid: &idB}) + channel.subscribeHandler(&channelEventHandler{callBackFunction: handlerB, runOnce: false, uuid: &idB}) - assert.Len(t, channel.eventHandlers, 2) + assert.Len(t, channel.eventHandlers, 2) - // remove the first handler (A) - channel.removeEventHandler(0) - assert.Len(t, channel.eventHandlers, 1) - assert.Equal(t, idB.String(), channel.eventHandlers[0].uuid.String()) + // remove the first handler (A) + channel.removeEventHandler(0) + assert.Len(t, channel.eventHandlers, 1) + assert.Equal(t, idB.String(), channel.eventHandlers[0].uuid.String()) - // remove the second handler B) - channel.removeEventHandler(0) - assert.True(t, len(channel.eventHandlers) == 0) + // remove the second handler B) + channel.removeEventHandler(0) + assert.True(t, len(channel.eventHandlers) == 0) } func TestChannel_RemoveEventHandlerNoHandlers(t *testing.T) { - channel := NewChannel(testChannelName) + channel := NewChannel(testChannelName) - channel.removeEventHandler(0) - assert.Len(t, channel.eventHandlers, 0) + channel.removeEventHandler(0) + assert.Len(t, channel.eventHandlers, 0) } func TestChannel_RemoveEventHandlerOOBIndex(t *testing.T) { - channel := NewChannel(testChannelName) - id := uuid.New() - handler := func(*model.Message) {} - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) + channel := NewChannel(testChannelName) + id := uuid.New() + handler := func(*model.Message) {} + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) - channel.removeEventHandler(999) - assert.Len(t, channel.eventHandlers, 1) + channel.removeEventHandler(999) + assert.Len(t, channel.eventHandlers, 1) } func TestChannel_AddRemoveBrokerSubscription(t *testing.T) { - channel := NewChannel(testChannelName) - id := uuid.New() - sub := &MockBridgeSubscription{Id: &id} - c := &MockBridgeConnection{Id: &id} - channel.addBrokerSubscription(c, sub) - assert.Len(t, channel.brokerSubs, 1) - channel.removeBrokerSubscription(sub) - assert.Len(t, channel.brokerSubs, 0) + channel := NewChannel(testChannelName) + id := uuid.New() + sub := &MockBridgeSubscription{Id: &id} + c := &MockBridgeConnection{Id: &id} + channel.addBrokerSubscription(c, sub) + assert.Len(t, channel.brokerSubs, 1) + channel.removeBrokerSubscription(sub) + assert.Len(t, channel.brokerSubs, 0) } func TestChannel_CheckIfBrokerSubscribed(t *testing.T) { - cId := uuid.New() - sId := uuid.New() - sId2 := uuid.New() + cId := uuid.New() + sId := uuid.New() + sId2 := uuid.New() - c := &MockBridgeConnection{ - Id: &cId, - } - s := &MockBridgeSubscription{Id: &sId} - s2 := &MockBridgeSubscription{Id: &sId2} + c := &MockBridgeConnection{ + Id: &cId, + } + s := &MockBridgeSubscription{Id: &sId} + s2 := &MockBridgeSubscription{Id: &sId2} - cm := NewBusChannelManager(GetBus()) - ch := cm.CreateChannel("testing-broker-subs") - ch.addBrokerSubscription(c, s) - assert.True(t, ch.isBrokerSubscribed(s)) - assert.False(t, ch.isBrokerSubscribed(s2)) + cm := NewBusChannelManager(GetBus()) + ch := cm.CreateChannel("testing-broker-subs") + ch.addBrokerSubscription(c, s) + assert.True(t, ch.isBrokerSubscribed(s)) + assert.False(t, ch.isBrokerSubscribed(s2)) - ch.removeBrokerSubscription(s) - assert.False(t, ch.isBrokerSubscribed(s)) + ch.removeBrokerSubscription(s) + assert.False(t, ch.isBrokerSubscribed(s)) } type MockBridgeConnection struct { - mock.Mock - Id *uuid.UUID + mock.Mock + Id *uuid.UUID } func (c *MockBridgeConnection) GetId() *uuid.UUID { - return c.Id + return c.Id } func (c *MockBridgeConnection) SubscribeReplyDestination(destination string) (bridge.Subscription, error) { - args := c.MethodCalled("Subscribe", destination) - return args.Get(0).(bridge.Subscription), args.Error(1) + args := c.MethodCalled("Subscribe", destination) + return args.Get(0).(bridge.Subscription), args.Error(1) } func (c *MockBridgeConnection) Subscribe(destination string) (bridge.Subscription, error) { - args := c.MethodCalled("Subscribe", destination) - return args.Get(0).(bridge.Subscription), args.Error(1) + args := c.MethodCalled("Subscribe", destination) + return args.Get(0).(bridge.Subscription), args.Error(1) } func (c *MockBridgeConnection) Disconnect() (err error) { - return nil + return nil } func (c *MockBridgeConnection) SendJSONMessage(destination string, payload []byte, opts ...func(frame *frame.Frame) error) error { - args := c.MethodCalled("SendJSONMessage", destination, payload) - return args.Error(0) + args := c.MethodCalled("SendJSONMessage", destination, payload) + return args.Error(0) } func (c *MockBridgeConnection) SendMessage(destination, contentType string, payload []byte, opts ...func(frame *frame.Frame) error) error { - args := c.MethodCalled("SendMessage", destination, contentType, payload) - return args.Error(0) + args := c.MethodCalled("SendMessage", destination, contentType, payload) + return args.Error(0) } func (c *MockBridgeConnection) SendMessageWithReplyDestination(destination, reply, contentType string, payload []byte, opts ...func(frame *frame.Frame) error) error { - args := c.MethodCalled("SendMessage", destination, contentType, payload) - return args.Error(0) + args := c.MethodCalled("SendMessage", destination, contentType, payload) + return args.Error(0) } type MockBridgeSubscription struct { - Id *uuid.UUID - Destination string - Channel chan *model.Message + Id *uuid.UUID + Destination string + Channel chan *model.Message } func (m *MockBridgeSubscription) GetId() *uuid.UUID { - return m.Id + return m.Id } func (m *MockBridgeSubscription) GetDestination() string { - return m.Destination + return m.Destination } func (m *MockBridgeSubscription) GetMsgChannel() chan *model.Message { - return m.Channel + return m.Channel } func (m *MockBridgeSubscription) Unsubscribe() error { - return nil + return nil } diff --git a/bus/eventbus_test.go b/bus/eventbus_test.go index 6722d81..f1d5a94 100644 --- a/bus/eventbus_test.go +++ b/bus/eventbus_test.go @@ -22,748 +22,748 @@ var evtbusTestChannelName string = "test-channel" var evtbusTestManager ChannelManager type MockBrokerConnector struct { - mock.Mock + mock.Mock } func (mock *MockBrokerConnector) Connect(config *bridge.BrokerConnectorConfig, enableLogging bool) (bridge.Connection, error) { - args := mock.MethodCalled("Connect", config) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(bridge.Connection), args.Error(1) + args := mock.MethodCalled("Connect", config) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(bridge.Connection), args.Error(1) } func (mock *MockBrokerConnector) StartTCPServer(address string) error { - args := mock.MethodCalled("StartTCPServer", address) - return args.Error(0) + args := mock.MethodCalled("StartTCPServer", address) + return args.Error(0) } func newTestEventBus() EventBus { - return NewEventBusInstance() + return NewEventBusInstance() } func init() { - evtBusTest = GetBus().(*transportEventBus) + evtBusTest = GetBus().(*transportEventBus) } func createTestChannel() *Channel { - //create new bus - //bf := new(transportEventBus) - //bf.init() - //evtBusTest = bf - //busInstance = bf // set GetBus() instance to return new instance also. + //create new bus + //bf := new(transportEventBus) + //bf.init() + //evtBusTest = bf + //busInstance = bf // set GetBus() instance to return new instance also. - evtbusTestManager = evtBusTest.GetChannelManager() - return evtbusTestManager.CreateChannel(evtbusTestChannelName) + evtbusTestManager = evtBusTest.GetChannelManager() + return evtbusTestManager.CreateChannel(evtbusTestChannelName) } func inc(counter *int32) { - atomic.AddInt32(counter, 1) + atomic.AddInt32(counter, 1) } func destroyTestChannel() { - evtbusTestManager.DestroyChannel(evtbusTestChannelName) + evtbusTestManager.DestroyChannel(evtbusTestChannelName) } func TestEventBus_Boot(t *testing.T) { - bus1 := GetBus() - bus2 := GetBus() - bus3 := GetBus() + bus1 := GetBus() + bus2 := GetBus() + bus3 := GetBus() - assert.EqualValues(t, bus1.GetId(), bus2.GetId()) - assert.EqualValues(t, bus2.GetId(), bus3.GetId()) - assert.NotNil(t, evtBusTest.GetChannelManager()) + assert.EqualValues(t, bus1.GetId(), bus2.GetId()) + assert.EqualValues(t, bus2.GetId(), bus3.GetId()) + assert.NotNil(t, evtBusTest.GetChannelManager()) } func TestEventBus_SendResponseMessageNoChannel(t *testing.T) { - err := evtBusTest.SendResponseMessage("Channel-not-here", "hello melody", nil) - assert.NotNil(t, err) + err := evtBusTest.SendResponseMessage("Channel-not-here", "hello melody", nil) + assert.NotNil(t, err) } func TestEventBus_SendRequestMessageNoChannel(t *testing.T) { - err := evtBusTest.SendRequestMessage("Channel-not-here", "hello melody", nil) - assert.NotNil(t, err) + err := evtBusTest.SendRequestMessage("Channel-not-here", "hello melody", nil) + assert.NotNil(t, err) } func TestTransportEventBus_SendBroadcastMessageNoChannel(t *testing.T) { - err := evtBusTest.SendBroadcastMessage("Channel-not-here", "hello melody") - assert.NotNil(t, err) + err := evtBusTest.SendBroadcastMessage("Channel-not-here", "hello melody") + assert.NotNil(t, err) } func TestEventBus_ListenStream(t *testing.T) { - createTestChannel() - handler, err := evtBusTest.ListenStream(evtbusTestChannelName) - assert.Nil(t, err) - assert.NotNil(t, handler) - var count int32 = 0 - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "hello melody", msg.Payload.(string)) - inc(&count) - }, - func(err error) {}) - - for i := 0; i < 3; i++ { - evtBusTest.SendResponseMessage(evtbusTestChannelName, "hello melody", nil) - - // send requests to make sure we're only getting requests - //evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, nil) - evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, nil) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, int32(3), count) - destroyTestChannel() + createTestChannel() + handler, err := evtBusTest.ListenStream(evtbusTestChannelName) + assert.Nil(t, err) + assert.NotNil(t, handler) + var count int32 = 0 + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "hello melody", msg.Payload.(string)) + inc(&count) + }, + func(err error) {}) + + for i := 0; i < 3; i++ { + evtBusTest.SendResponseMessage(evtbusTestChannelName, "hello melody", nil) + + // send requests to make sure we're only getting requests + //evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, nil) + evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, nil) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, int32(3), count) + destroyTestChannel() } func TestTransportEventBus_ListenStreamForBroadcast(t *testing.T) { - createTestChannel() - handler, err := evtBusTest.ListenStream(evtbusTestChannelName) - assert.Nil(t, err) - assert.NotNil(t, handler) - var count int32 = 0 - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "hello melody", msg.Payload.(string)) - inc(&count) - }, - func(err error) {}) - - for i := 0; i < 3; i++ { - evtBusTest.SendBroadcastMessage(evtbusTestChannelName, "hello melody") - - // send requests to make sure we're only getting requests - evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, nil) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, int32(3), count) - destroyTestChannel() + createTestChannel() + handler, err := evtBusTest.ListenStream(evtbusTestChannelName) + assert.Nil(t, err) + assert.NotNil(t, handler) + var count int32 = 0 + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "hello melody", msg.Payload.(string)) + inc(&count) + }, + func(err error) {}) + + for i := 0; i < 3; i++ { + evtBusTest.SendBroadcastMessage(evtbusTestChannelName, "hello melody") + + // send requests to make sure we're only getting requests + evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, nil) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, int32(3), count) + destroyTestChannel() } func TestTransportEventBus_ListenStreamForDestination(t *testing.T) { - createTestChannel() - id := uuid.New() - handler, _ := evtBusTest.ListenStreamForDestination(evtbusTestChannelName, &id) - var count int32 = 0 - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "hello melody", msg.Payload.(string)) - inc(&count) - }, - func(err error) {}) - - for i := 0; i < 20; i++ { - evtBusTest.SendResponseMessage(evtbusTestChannelName, "hello melody", &id) - - // send requests to make sure we're only getting requests - evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, &id) - evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, &id) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, int32(20), count) - destroyTestChannel() + createTestChannel() + id := uuid.New() + handler, _ := evtBusTest.ListenStreamForDestination(evtbusTestChannelName, &id) + var count int32 = 0 + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "hello melody", msg.Payload.(string)) + inc(&count) + }, + func(err error) {}) + + for i := 0; i < 20; i++ { + evtBusTest.SendResponseMessage(evtbusTestChannelName, "hello melody", &id) + + // send requests to make sure we're only getting requests + evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, &id) + evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, &id) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, int32(20), count) + destroyTestChannel() } func TestEventBus_ListenStreamNoChannel(t *testing.T) { - _, err := evtBusTest.ListenStream("missing-Channel") - assert.NotNil(t, err) + _, err := evtBusTest.ListenStream("missing-Channel") + assert.NotNil(t, err) } func TestEventBus_ListenOnce(t *testing.T) { - createTestChannel() - handler, _ := evtBusTest.ListenOnce(evtbusTestChannelName) - count := 0 - handler.Handle( - func(msg *model.Message) { - count++ - }, - func(err error) {}) - - for i := 0; i < 10; i++ { - evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, handler.GetDestinationId()) - } - - for i := 0; i < 2; i++ { - evtBusTest.SendResponseMessage(evtbusTestChannelName, 0, handler.GetDestinationId()) - - // send requests to make sure we're only getting requests - evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, handler.GetDestinationId()) - evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, handler.GetDestinationId()) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, 1, count) - destroyTestChannel() + createTestChannel() + handler, _ := evtBusTest.ListenOnce(evtbusTestChannelName) + count := 0 + handler.Handle( + func(msg *model.Message) { + count++ + }, + func(err error) {}) + + for i := 0; i < 10; i++ { + evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, handler.GetDestinationId()) + } + + for i := 0; i < 2; i++ { + evtBusTest.SendResponseMessage(evtbusTestChannelName, 0, handler.GetDestinationId()) + + // send requests to make sure we're only getting requests + evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, handler.GetDestinationId()) + evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, handler.GetDestinationId()) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, 1, count) + destroyTestChannel() } func TestEventBus_ListenOnceForDestination(t *testing.T) { - createTestChannel() - dest := uuid.New() - handler, _ := evtBusTest.ListenOnceForDestination(evtbusTestChannelName, &dest) - count := 0 - handler.Handle( - func(msg *model.Message) { - count++ - }, - func(err error) {}) - - for i := 0; i < 300; i++ { - evtBusTest.SendResponseMessage(evtbusTestChannelName, 0, &dest) - - // send duplicate - evtBusTest.SendResponseMessage(evtbusTestChannelName, 0, &dest) - - // send random noise - evtBusTest.SendResponseMessage(evtbusTestChannelName, 0, nil) - - // send requests to make sure we're only getting requests - evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, &dest) - evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, &dest) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, 1, count) - destroyTestChannel() + createTestChannel() + dest := uuid.New() + handler, _ := evtBusTest.ListenOnceForDestination(evtbusTestChannelName, &dest) + count := 0 + handler.Handle( + func(msg *model.Message) { + count++ + }, + func(err error) {}) + + for i := 0; i < 300; i++ { + evtBusTest.SendResponseMessage(evtbusTestChannelName, 0, &dest) + + // send duplicate + evtBusTest.SendResponseMessage(evtbusTestChannelName, 0, &dest) + + // send random noise + evtBusTest.SendResponseMessage(evtbusTestChannelName, 0, nil) + + // send requests to make sure we're only getting requests + evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, &dest) + evtBusTest.SendRequestMessage(evtbusTestChannelName, 1, &dest) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, 1, count) + destroyTestChannel() } func TestEventBus_ListenOnceNoChannel(t *testing.T) { - _, err := evtBusTest.ListenOnce("missing-Channel") - assert.NotNil(t, err) + _, err := evtBusTest.ListenOnce("missing-Channel") + assert.NotNil(t, err) } func TestEventBus_ListenOnceForDestinationNoChannel(t *testing.T) { - _, err := evtBusTest.ListenOnceForDestination("missing-Channel", nil) - assert.NotNil(t, err) + _, err := evtBusTest.ListenOnceForDestination("missing-Channel", nil) + assert.NotNil(t, err) } func TestEventBus_ListenOnceForDestinationNoDestination(t *testing.T) { - createTestChannel() - _, err := evtBusTest.ListenOnceForDestination(evtbusTestChannelName, nil) - assert.NotNil(t, err) - destroyTestChannel() + createTestChannel() + _, err := evtBusTest.ListenOnceForDestination(evtbusTestChannelName, nil) + assert.NotNil(t, err) + destroyTestChannel() } func TestEventBus_ListenRequestStream(t *testing.T) { - createTestChannel() - handler, _ := evtBusTest.ListenRequestStream(evtbusTestChannelName) - var count int32 = 0 - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "hello melody", msg.Payload.(string)) - inc(&count) - }, - func(err error) {}) - - for i := 0; i < 10000; i++ { - evtBusTest.SendRequestMessage(evtbusTestChannelName, "hello melody", nil) - - // send responses to make sure we're only getting requests - evtBusTest.SendResponseMessage(evtbusTestChannelName, "will fail assertion if picked up", nil) - evtBusTest.SendResponseMessage(evtbusTestChannelName, "will fail assertion again", nil) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, count, int32(10000)) - destroyTestChannel() + createTestChannel() + handler, _ := evtBusTest.ListenRequestStream(evtbusTestChannelName) + var count int32 = 0 + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "hello melody", msg.Payload.(string)) + inc(&count) + }, + func(err error) {}) + + for i := 0; i < 10000; i++ { + evtBusTest.SendRequestMessage(evtbusTestChannelName, "hello melody", nil) + + // send responses to make sure we're only getting requests + evtBusTest.SendResponseMessage(evtbusTestChannelName, "will fail assertion if picked up", nil) + evtBusTest.SendResponseMessage(evtbusTestChannelName, "will fail assertion again", nil) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, count, int32(10000)) + destroyTestChannel() } func TestEventBus_ListenRequestStreamForDestination(t *testing.T) { - createTestChannel() - id := uuid.New() - handler, _ := evtBusTest.ListenRequestStreamForDestination(evtbusTestChannelName, &id) - var count int32 = 0 - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "hello melody", msg.Payload.(string)) - inc(&count) - }, - func(err error) {}) - - for i := 0; i < 1000; i++ { - evtBusTest.SendRequestMessage(evtbusTestChannelName, "hello melody", &id) - - // send responses to make sure we're only getting requests - evtBusTest.SendResponseMessage(evtbusTestChannelName, "will fail assertion if picked up", &id) - evtBusTest.SendResponseMessage(evtbusTestChannelName, "will fail assertion again", &id) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, count, int32(1000)) - destroyTestChannel() + createTestChannel() + id := uuid.New() + handler, _ := evtBusTest.ListenRequestStreamForDestination(evtbusTestChannelName, &id) + var count int32 = 0 + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "hello melody", msg.Payload.(string)) + inc(&count) + }, + func(err error) {}) + + for i := 0; i < 1000; i++ { + evtBusTest.SendRequestMessage(evtbusTestChannelName, "hello melody", &id) + + // send responses to make sure we're only getting requests + evtBusTest.SendResponseMessage(evtbusTestChannelName, "will fail assertion if picked up", &id) + evtBusTest.SendResponseMessage(evtbusTestChannelName, "will fail assertion again", &id) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, count, int32(1000)) + destroyTestChannel() } func TestEventBus_ListenStreamForDestinationNoChannel(t *testing.T) { - _, err := evtBusTest.ListenStreamForDestination("missing-Channel", nil) - assert.NotNil(t, err) + _, err := evtBusTest.ListenStreamForDestination("missing-Channel", nil) + assert.NotNil(t, err) } func TestEventBus_ListenStreamForDestinationNoDestination(t *testing.T) { - createTestChannel() - _, err := evtBusTest.ListenStreamForDestination(evtbusTestChannelName, nil) - assert.NotNil(t, err) + createTestChannel() + _, err := evtBusTest.ListenStreamForDestination(evtbusTestChannelName, nil) + assert.NotNil(t, err) } func TestEventBus_ListenRequestStreamForDestinationNoDestination(t *testing.T) { - createTestChannel() - _, err := evtBusTest.ListenRequestStreamForDestination(evtbusTestChannelName, nil) - assert.NotNil(t, err) + createTestChannel() + _, err := evtBusTest.ListenRequestStreamForDestination(evtbusTestChannelName, nil) + assert.NotNil(t, err) } func TestEventBus_ListenRequestStreamForDestinationNoChannel(t *testing.T) { - _, err := evtBusTest.ListenRequestStreamForDestination("nowhere", nil) - assert.NotNil(t, err) + _, err := evtBusTest.ListenRequestStreamForDestination("nowhere", nil) + assert.NotNil(t, err) } func TestEventBus_ListenRequestOnce(t *testing.T) { - createTestChannel() - handler, _ := evtBusTest.ListenRequestOnce(evtbusTestChannelName) - count := 0 - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "hello melody", msg.Payload.(string)) - count++ - }, - func(err error) {}) - - for i := 0; i < 5; i++ { - evtBusTest.SendRequestMessage(evtbusTestChannelName, "hello melody", handler.GetDestinationId()) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, 1, count) - destroyTestChannel() + createTestChannel() + handler, _ := evtBusTest.ListenRequestOnce(evtbusTestChannelName) + count := 0 + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "hello melody", msg.Payload.(string)) + count++ + }, + func(err error) {}) + + for i := 0; i < 5; i++ { + evtBusTest.SendRequestMessage(evtbusTestChannelName, "hello melody", handler.GetDestinationId()) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, 1, count) + destroyTestChannel() } func TestEventBus_ListenRequestOnceForDestination(t *testing.T) { - createTestChannel() - dest := uuid.New() - handler, _ := evtBusTest.ListenRequestOnceForDestination(evtbusTestChannelName, &dest) - count := 0 - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "hello melody", msg.Payload.(string)) - count++ - }, - func(err error) {}) - - for i := 0; i < 5; i++ { - evtBusTest.SendRequestMessage(evtbusTestChannelName, "hello melody", &dest) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, 1, count) - destroyTestChannel() + createTestChannel() + dest := uuid.New() + handler, _ := evtBusTest.ListenRequestOnceForDestination(evtbusTestChannelName, &dest) + count := 0 + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "hello melody", msg.Payload.(string)) + count++ + }, + func(err error) {}) + + for i := 0; i < 5; i++ { + evtBusTest.SendRequestMessage(evtbusTestChannelName, "hello melody", &dest) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, 1, count) + destroyTestChannel() } func TestEventBus_ListenRequestOnceNoChannel(t *testing.T) { - _, err := evtBusTest.ListenRequestOnce("missing-Channel") - assert.NotNil(t, err) + _, err := evtBusTest.ListenRequestOnce("missing-Channel") + assert.NotNil(t, err) } func TestEventBus_ListenRequestStreamNoChannel(t *testing.T) { - _, err := evtBusTest.ListenRequestStream("missing-Channel") - assert.NotNil(t, err) + _, err := evtBusTest.ListenRequestStream("missing-Channel") + assert.NotNil(t, err) } func TestEventBus_ListenRequestOnceForDestinationNoChannel(t *testing.T) { - _, err := evtBusTest.ListenRequestOnceForDestination("missing-Channel", nil) - assert.NotNil(t, err) + _, err := evtBusTest.ListenRequestOnceForDestination("missing-Channel", nil) + assert.NotNil(t, err) } func TestEventBus_ListenRequestOnceForDestinationNoDestination(t *testing.T) { - createTestChannel() - _, err := evtBusTest.ListenRequestOnceForDestination(evtbusTestChannelName, nil) - assert.NotNil(t, err) - destroyTestChannel() + createTestChannel() + _, err := evtBusTest.ListenRequestOnceForDestination(evtbusTestChannelName, nil) + assert.NotNil(t, err) + destroyTestChannel() } func TestEventBus_TestErrorMessageHandling(t *testing.T) { - createTestChannel() + createTestChannel() - err := evtBusTest.SendErrorMessage("invalid-Channel", errors.New("something went wrong"), nil) - assert.NotNil(t, err) + err := evtBusTest.SendErrorMessage("invalid-Channel", errors.New("something went wrong"), nil) + assert.NotNil(t, err) - handler, _ := evtBusTest.ListenStream(evtbusTestChannelName) - var countError int32 = 0 - handler.Handle( - func(msg *model.Message) {}, - func(err error) { - assert.Errorf(t, err, "something went wrong") - inc(&countError) - }) + handler, _ := evtBusTest.ListenStream(evtbusTestChannelName) + var countError int32 = 0 + handler.Handle( + func(msg *model.Message) {}, + func(err error) { + assert.Errorf(t, err, "something went wrong") + inc(&countError) + }) - for i := 0; i < 5; i++ { - err := errors.New("something went wrong") - evtBusTest.SendErrorMessage(evtbusTestChannelName, err, handler.GetId()) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, int32(5), countError) - destroyTestChannel() + for i := 0; i < 5; i++ { + err := errors.New("something went wrong") + evtBusTest.SendErrorMessage(evtbusTestChannelName, err, handler.GetId()) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, int32(5), countError) + destroyTestChannel() } func TestEventBus_ListenFirehose(t *testing.T) { - createTestChannel() - var counter int32 = 0 - - responseHandler, _ := evtBusTest.ListenFirehose(evtbusTestChannelName) - responseHandler.Handle( - func(msg *model.Message) { - inc(&counter) - }, - func(err error) { - inc(&counter) - }) - for i := 0; i < 5; i++ { - err := errors.New("something went wrong") - evtBusTest.SendErrorMessage(evtbusTestChannelName, err, nil) - evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, nil) - evtBusTest.SendResponseMessage(evtbusTestChannelName, 1, nil) - } - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, counter, int32(15)) - destroyTestChannel() + createTestChannel() + var counter int32 = 0 + + responseHandler, _ := evtBusTest.ListenFirehose(evtbusTestChannelName) + responseHandler.Handle( + func(msg *model.Message) { + inc(&counter) + }, + func(err error) { + inc(&counter) + }) + for i := 0; i < 5; i++ { + err := errors.New("something went wrong") + evtBusTest.SendErrorMessage(evtbusTestChannelName, err, nil) + evtBusTest.SendRequestMessage(evtbusTestChannelName, 0, nil) + evtBusTest.SendResponseMessage(evtbusTestChannelName, 1, nil) + } + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, counter, int32(15)) + destroyTestChannel() } func TestEventBus_ListenFirehoseNoChannel(t *testing.T) { - _, err := evtBusTest.ListenFirehose("missing-Channel") - assert.NotNil(t, err) + _, err := evtBusTest.ListenFirehose("missing-Channel") + assert.NotNil(t, err) } func TestEventBus_RequestOnce(t *testing.T) { - createTestChannel() - handler, _ := evtBusTest.ListenRequestStream(evtbusTestChannelName) - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "who is a pretty baby?", msg.Payload.(string)) - evtBusTest.SendResponseMessage(evtbusTestChannelName, "why melody is of course", msg.DestinationId) - }, - func(err error) {}) - - count := 0 - responseHandler, _ := evtBusTest.RequestOnce(evtbusTestChannelName, "who is a pretty baby?") - responseHandler.Handle( - func(msg *model.Message) { - assert.Equal(t, "why melody is of course", msg.Payload.(string)) - count++ - }, - func(err error) {}) - - responseHandler.Fire() - evtbusTestManager.WaitForChannel(evtbusTestChannelName) - assert.Equal(t, 1, count) - destroyTestChannel() + createTestChannel() + handler, _ := evtBusTest.ListenRequestStream(evtbusTestChannelName) + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "who is a pretty baby?", msg.Payload.(string)) + evtBusTest.SendResponseMessage(evtbusTestChannelName, "why melody is of course", msg.DestinationId) + }, + func(err error) {}) + + count := 0 + responseHandler, _ := evtBusTest.RequestOnce(evtbusTestChannelName, "who is a pretty baby?") + responseHandler.Handle( + func(msg *model.Message) { + assert.Equal(t, "why melody is of course", msg.Payload.(string)) + count++ + }, + func(err error) {}) + + responseHandler.Fire() + evtbusTestManager.WaitForChannel(evtbusTestChannelName) + assert.Equal(t, 1, count) + destroyTestChannel() } func TestEventBus_RequestOnceForDestination(t *testing.T) { - createTestChannel() - dest := uuid.New() - handler, _ := evtBusTest.ListenRequestStream(evtbusTestChannelName) - handler.Handle( - func(msg *model.Message) { - assert.Equal(t, "who is a pretty baby?", msg.Payload.(string)) - evtBusTest.SendResponseMessage(evtbusTestChannelName, "why melody is of course", msg.DestinationId) - }, - func(err error) {}) - - count := 0 - responseHandler, _ := evtBusTest.RequestOnceForDestination(evtbusTestChannelName, "who is a pretty baby?", &dest) - responseHandler.Handle( - func(msg *model.Message) { - assert.Equal(t, "why melody is of course", msg.Payload.(string)) - count++ - }, - func(err error) {}) - - responseHandler.Fire() - assert.Equal(t, 1, count) - destroyTestChannel() + createTestChannel() + dest := uuid.New() + handler, _ := evtBusTest.ListenRequestStream(evtbusTestChannelName) + handler.Handle( + func(msg *model.Message) { + assert.Equal(t, "who is a pretty baby?", msg.Payload.(string)) + evtBusTest.SendResponseMessage(evtbusTestChannelName, "why melody is of course", msg.DestinationId) + }, + func(err error) {}) + + count := 0 + responseHandler, _ := evtBusTest.RequestOnceForDestination(evtbusTestChannelName, "who is a pretty baby?", &dest) + responseHandler.Handle( + func(msg *model.Message) { + assert.Equal(t, "why melody is of course", msg.Payload.(string)) + count++ + }, + func(err error) {}) + + responseHandler.Fire() + assert.Equal(t, 1, count) + destroyTestChannel() } func TestEventBus_RequestOnceForDesintationNoChannel(t *testing.T) { - _, err := evtBusTest.RequestOnceForDestination("some-chan", nil, nil) - assert.NotNil(t, err) + _, err := evtBusTest.RequestOnceForDestination("some-chan", nil, nil) + assert.NotNil(t, err) } func TestEventBus_RequestOnceForDesintationNoDestination(t *testing.T) { - createTestChannel() - _, err := evtBusTest.RequestOnceForDestination(evtbusTestChannelName, nil, nil) - assert.NotNil(t, err) - destroyTestChannel() + createTestChannel() + _, err := evtBusTest.RequestOnceForDestination(evtbusTestChannelName, nil, nil) + assert.NotNil(t, err) + destroyTestChannel() } func TestEventBus_RequestStream(t *testing.T) { - channel := createTestChannel() - handler := func(message *model.Message) { - if message.Direction == model.RequestDir { - assert.Equal(t, "who has the cutest laugh?", message.Payload.(string)) - config := buildConfig(channel.Name, "why melody does of course", message.DestinationId) - - // fire a few times, ensure that the handler only ever picks up a single response. - for i := 0; i < 5; i++ { - channel.Send(model.GenerateResponse(config)) - } - } - } - id := uuid.New() - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) - - var count int32 = 0 - responseHandler, _ := evtBusTest.RequestStream(evtbusTestChannelName, "who has the cutest laugh?") - responseHandler.Handle( - func(msg *model.Message) { - assert.Equal(t, "why melody does of course", msg.Payload.(string)) - inc(&count) - }, - func(err error) {}) - - responseHandler.Fire() - assert.Equal(t, int32(5), count) - destroyTestChannel() + channel := createTestChannel() + handler := func(message *model.Message) { + if message.Direction == model.RequestDir { + assert.Equal(t, "who has the cutest laugh?", message.Payload.(string)) + config := buildConfig(channel.Name, "why melody does of course", message.DestinationId) + + // fire a few times, ensure that the handler only ever picks up a single response. + for i := 0; i < 5; i++ { + channel.Send(model.GenerateResponse(config)) + } + } + } + id := uuid.New() + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) + + var count int32 = 0 + responseHandler, _ := evtBusTest.RequestStream(evtbusTestChannelName, "who has the cutest laugh?") + responseHandler.Handle( + func(msg *model.Message) { + assert.Equal(t, "why melody does of course", msg.Payload.(string)) + inc(&count) + }, + func(err error) {}) + + responseHandler.Fire() + assert.Equal(t, int32(5), count) + destroyTestChannel() } func TestEventBus_RequestStreamForDesintation(t *testing.T) { - channel := createTestChannel() - dest := uuid.New() - handler := func(message *model.Message) { - if message.Direction == model.RequestDir { - assert.Equal(t, "who has the cutest laugh?", message.Payload.(string)) - config := buildConfig(channel.Name, "why melody does of course", message.DestinationId) - - // fire a few times, ensure that the handler only ever picks up a single response. - for i := 0; i < 5; i++ { - channel.Send(model.GenerateResponse(config)) - } - } - } - id := uuid.New() - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) - - var count int32 = 0 - responseHandler, _ := evtBusTest.RequestStreamForDestination(evtbusTestChannelName, "who has the cutest laugh?", &dest) - responseHandler.Handle( - func(msg *model.Message) { - assert.Equal(t, "why melody does of course", msg.Payload.(string)) - inc(&count) - }, - func(err error) {}) - - responseHandler.Fire() - assert.Equal(t, int32(5), count) - destroyTestChannel() + channel := createTestChannel() + dest := uuid.New() + handler := func(message *model.Message) { + if message.Direction == model.RequestDir { + assert.Equal(t, "who has the cutest laugh?", message.Payload.(string)) + config := buildConfig(channel.Name, "why melody does of course", message.DestinationId) + + // fire a few times, ensure that the handler only ever picks up a single response. + for i := 0; i < 5; i++ { + channel.Send(model.GenerateResponse(config)) + } + } + } + id := uuid.New() + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: false, uuid: &id}) + + var count int32 = 0 + responseHandler, _ := evtBusTest.RequestStreamForDestination(evtbusTestChannelName, "who has the cutest laugh?", &dest) + responseHandler.Handle( + func(msg *model.Message) { + assert.Equal(t, "why melody does of course", msg.Payload.(string)) + inc(&count) + }, + func(err error) {}) + + responseHandler.Fire() + assert.Equal(t, int32(5), count) + destroyTestChannel() } func TestEventBus_RequestStreamForDestinationNoChannel(t *testing.T) { - _, err := evtBusTest.RequestStreamForDestination("missing-Channel", nil, nil) - assert.NotNil(t, err) + _, err := evtBusTest.RequestStreamForDestination("missing-Channel", nil, nil) + assert.NotNil(t, err) } func TestEventBus_RequestStreamForDestinationNoDestination(t *testing.T) { - createTestChannel() - _, err := evtBusTest.RequestStreamForDestination(evtbusTestChannelName, nil, nil) - assert.NotNil(t, err) - destroyTestChannel() + createTestChannel() + _, err := evtBusTest.RequestStreamForDestination(evtbusTestChannelName, nil, nil) + assert.NotNil(t, err) + destroyTestChannel() } func TestEventBus_RequestStreamNoChannel(t *testing.T) { - _, err := evtBusTest.RequestStream("missing-Channel", nil) - assert.NotNil(t, err) + _, err := evtBusTest.RequestStream("missing-Channel", nil) + assert.NotNil(t, err) } func TestEventBus_HandleSingleRunError(t *testing.T) { - channel := createTestChannel() - handler := func(message *model.Message) { - if message.Direction == model.RequestDir { - config := buildError(channel.Name, fmt.Errorf("whoops!"), message.DestinationId) - - // fire a few times, ensure that the handler only ever picks up a single response. - for i := 0; i < 5; i++ { - channel.Send(model.GenerateError(config)) - } - } - } - id := uuid.New() - channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: true, uuid: &id}) - - count := 0 - responseHandler, _ := evtBusTest.RequestOnce(evtbusTestChannelName, 0) - responseHandler.Handle( - func(msg *model.Message) {}, - func(err error) { - assert.Error(t, err, "whoops!") - count++ - }) - - responseHandler.Fire() - assert.Equal(t, 1, count) - destroyTestChannel() + channel := createTestChannel() + handler := func(message *model.Message) { + if message.Direction == model.RequestDir { + config := buildError(channel.Name, fmt.Errorf("whoops!"), message.DestinationId) + + // fire a few times, ensure that the handler only ever picks up a single response. + for i := 0; i < 5; i++ { + channel.Send(model.GenerateError(config)) + } + } + } + id := uuid.New() + channel.subscribeHandler(&channelEventHandler{callBackFunction: handler, runOnce: true, uuid: &id}) + + count := 0 + responseHandler, _ := evtBusTest.RequestOnce(evtbusTestChannelName, 0) + responseHandler.Handle( + func(msg *model.Message) {}, + func(err error) { + assert.Error(t, err, "whoops!") + count++ + }) + + responseHandler.Fire() + assert.Equal(t, 1, count) + destroyTestChannel() } func TestEventBus_RequestOnceNoChannel(t *testing.T) { - _, err := evtBusTest.RequestOnce("missing-Channel", 0) - assert.NotNil(t, err) + _, err := evtBusTest.RequestOnce("missing-Channel", 0) + assert.NotNil(t, err) } func TestEventBus_HandlerWithoutRequestToFire(t *testing.T) { - createTestChannel() - responseHandler, _ := evtBusTest.ListenFirehose(evtbusTestChannelName) - responseHandler.Handle( - func(msg *model.Message) {}, - func(err error) {}) - err := responseHandler.Fire() - assert.Errorf(t, err, "nothing to fire, request is empty") - destroyTestChannel() + createTestChannel() + responseHandler, _ := evtBusTest.ListenFirehose(evtbusTestChannelName) + responseHandler.Handle( + func(msg *model.Message) {}, + func(err error) {}) + err := responseHandler.Fire() + assert.Errorf(t, err, "nothing to fire, request is empty") + destroyTestChannel() } func TestEventBus_GetStoreManager(t *testing.T) { - assert.NotNil(t, evtBusTest.GetStoreManager()) - store := evtBusTest.GetStoreManager().CreateStore("test") - assert.NotNil(t, store) - assert.True(t, evtBusTest.GetStoreManager().DestroyStore("test")) + assert.NotNil(t, evtBusTest.GetStoreManager()) + store := evtBusTest.GetStoreManager().CreateStore("test") + assert.NotNil(t, store) + assert.True(t, evtBusTest.GetStoreManager().DestroyStore("test")) } func TestChannelManager_TestConnectBroker(t *testing.T) { - // create new transportEventBus instance and replace the brokerConnector - // with MockBrokerConnector instance. - evtBusTest := newTestEventBus().(*transportEventBus) - evtBusTest.bc = new(MockBrokerConnector) + // create new transportEventBus instance and replace the brokerConnector + // with MockBrokerConnector instance. + evtBusTest := newTestEventBus().(*transportEventBus) + evtBusTest.bc = new(MockBrokerConnector) - // connect to broker - cf := &bridge.BrokerConnectorConfig{ - Username: "test", - Password: "test", - UseWS: true, - WebSocketConfig: &bridge.WebSocketConfig{ - WSPath: "/", - }, - ServerAddr: "broker-url"} + // connect to broker + cf := &bridge.BrokerConnectorConfig{ + Username: "test", + Password: "test", + UseWS: true, + WebSocketConfig: &bridge.WebSocketConfig{ + WSPath: "/", + }, + ServerAddr: "broker-url"} - id := uuid.New() - mockCon := &MockBridgeConnection{ - Id: &id, - } - evtBusTest.bc.(*MockBrokerConnector).On("Connect", cf).Return(mockCon, nil) + id := uuid.New() + mockCon := &MockBridgeConnection{ + Id: &id, + } + evtBusTest.bc.(*MockBrokerConnector).On("Connect", cf).Return(mockCon, nil) - c, _ := evtBusTest.ConnectBroker(cf) + c, _ := evtBusTest.ConnectBroker(cf) - assert.Equal(t, c, mockCon) - assert.Equal(t, len(evtBusTest.brokerConnections), 1) - assert.Equal(t, evtBusTest.brokerConnections[mockCon.Id], mockCon) + assert.Equal(t, c, mockCon) + assert.Equal(t, len(evtBusTest.brokerConnections), 1) + assert.Equal(t, evtBusTest.brokerConnections[mockCon.Id], mockCon) } func TestEventBus_TestCreateSyncTransaction(t *testing.T) { - tr := evtBusTest.CreateSyncTransaction() - assert.NotNil(t, tr) - assert.Equal(t, tr.(*busTransaction).transactionType, syncTransaction) + tr := evtBusTest.CreateSyncTransaction() + assert.NotNil(t, tr) + assert.Equal(t, tr.(*busTransaction).transactionType, syncTransaction) } func TestEventBus_TestCreateAsyncTransaction(t *testing.T) { - tr := evtBusTest.CreateAsyncTransaction() - assert.NotNil(t, tr) - assert.Equal(t, tr.(*busTransaction).transactionType, asyncTransaction) + tr := evtBusTest.CreateAsyncTransaction() + assert.NotNil(t, tr) + assert.Equal(t, tr.(*busTransaction).transactionType, asyncTransaction) } type MockRawConnListener struct { - stopped bool - connections chan stompserver.RawConnection - wg sync.WaitGroup + stopped bool + connections chan stompserver.RawConnection + wg sync.WaitGroup } func (cl *MockRawConnListener) Accept() (stompserver.RawConnection, error) { - cl.wg.Done() - con := <-cl.connections - return con, nil + cl.wg.Done() + con := <-cl.connections + return con, nil } func (cl *MockRawConnListener) Close() error { - cl.stopped = true - cl.wg.Done() - return nil + cl.stopped = true + cl.wg.Done() + return nil } func TestBifrostEventBus_StartFabricEndpoint(t *testing.T) { - bus := newTestEventBus().(*transportEventBus) + bus := newTestEventBus().(*transportEventBus) - connListener := &MockRawConnListener{ - connections: make(chan stompserver.RawConnection), - } + connListener := &MockRawConnListener{ + connections: make(chan stompserver.RawConnection), + } - err := bus.StartFabricEndpoint(connListener, EndpointConfig{}) - assert.EqualError(t, err, "invalid TopicPrefix") + err := bus.StartFabricEndpoint(connListener, EndpointConfig{}) + assert.EqualError(t, err, "invalid TopicPrefix") - err = bus.StartFabricEndpoint(connListener, EndpointConfig{TopicPrefix: "asd"}) - assert.EqualError(t, err, "invalid TopicPrefix") + err = bus.StartFabricEndpoint(connListener, EndpointConfig{TopicPrefix: "asd"}) + assert.EqualError(t, err, "invalid TopicPrefix") - err = bus.StartFabricEndpoint(connListener, EndpointConfig{TopicPrefix: "/topic", - AppRequestQueuePrefix: "/pub"}) - assert.EqualError(t, err, "missing UserQueuePrefix") + err = bus.StartFabricEndpoint(connListener, EndpointConfig{TopicPrefix: "/topic", + AppRequestQueuePrefix: "/pub"}) + assert.EqualError(t, err, "missing UserQueuePrefix") - connListener.wg.Add(1) - go bus.StartFabricEndpoint(connListener, EndpointConfig{TopicPrefix: "/topic"}) + connListener.wg.Add(1) + go bus.StartFabricEndpoint(connListener, EndpointConfig{TopicPrefix: "/topic"}) - connListener.wg.Wait() + connListener.wg.Wait() - err = bus.StartFabricEndpoint(connListener, EndpointConfig{TopicPrefix: "/topic"}) - assert.EqualError(t, err, "unable to start: fabric endpoint is already running") + err = bus.StartFabricEndpoint(connListener, EndpointConfig{TopicPrefix: "/topic"}) + assert.EqualError(t, err, "unable to start: fabric endpoint is already running") - connListener.wg.Add(1) - bus.StopFabricEndpoint() - connListener.wg.Wait() + connListener.wg.Add(1) + bus.StopFabricEndpoint() + connListener.wg.Wait() - assert.Nil(t, bus.fabEndpoint) - assert.True(t, connListener.stopped) + assert.Nil(t, bus.fabEndpoint) + assert.True(t, connListener.stopped) - assert.EqualError(t, bus.StopFabricEndpoint(), "unable to stop: fabric endpoint is not running") + assert.EqualError(t, bus.StopFabricEndpoint(), "unable to stop: fabric endpoint is not running") } func TestBifrostEventBus_AddMonitorEventListener(t *testing.T) { - bus := newTestEventBus() - - listener1Count := 0 - listener1 := bus.AddMonitorEventListener(func(event *MonitorEvent) { - listener1Count++ - }, ChannelCreatedEvt) - - listener2Count := 0 - listener2 := bus.AddMonitorEventListener(func(event *MonitorEvent) { - listener2Count++ - }, ChannelCreatedEvt, ChannelDestroyedEvt) - - listener3Count := 0 - listener3 := bus.AddMonitorEventListener(func(event *MonitorEvent) { - listener3Count++ - }) - - assert.NotEqual(t, listener1, listener2) - assert.NotEqual(t, listener1, listener3) - assert.NotEqual(t, listener2, listener3) - - bus.SendMonitorEvent(ChannelCreatedEvt, "test-channel", nil) - assert.Equal(t, listener1Count, 1) - assert.Equal(t, listener2Count, 1) - assert.Equal(t, listener3Count, 1) - - bus.SendMonitorEvent(ChannelDestroyedEvt, "test-channel", nil) - assert.Equal(t, listener1Count, 1) - assert.Equal(t, listener2Count, 2) - assert.Equal(t, listener3Count, 2) - - bus.SendMonitorEvent(StoreInitializedEvt, "store1", nil) - assert.Equal(t, listener1Count, 1) - assert.Equal(t, listener2Count, 2) - assert.Equal(t, listener3Count, 3) - - bus.RemoveMonitorEventListener(listener2) - - bus.SendMonitorEvent(ChannelCreatedEvt, "test-channel", nil) - assert.Equal(t, listener1Count, 2) - assert.Equal(t, listener2Count, 2) - assert.Equal(t, listener3Count, 4) - - bus.SendMonitorEvent(ChannelDestroyedEvt, "test-channel", nil) - assert.Equal(t, listener1Count, 2) - assert.Equal(t, listener2Count, 2) - assert.Equal(t, listener3Count, 5) - - bus.RemoveMonitorEventListener(listener3) - bus.SendMonitorEvent(ChannelCreatedEvt, "test-channel", nil) - assert.Equal(t, listener1Count, 3) - assert.Equal(t, listener2Count, 2) - assert.Equal(t, listener3Count, 5) + bus := newTestEventBus() + + listener1Count := 0 + listener1 := bus.AddMonitorEventListener(func(event *MonitorEvent) { + listener1Count++ + }, ChannelCreatedEvt) + + listener2Count := 0 + listener2 := bus.AddMonitorEventListener(func(event *MonitorEvent) { + listener2Count++ + }, ChannelCreatedEvt, ChannelDestroyedEvt) + + listener3Count := 0 + listener3 := bus.AddMonitorEventListener(func(event *MonitorEvent) { + listener3Count++ + }) + + assert.NotEqual(t, listener1, listener2) + assert.NotEqual(t, listener1, listener3) + assert.NotEqual(t, listener2, listener3) + + bus.SendMonitorEvent(ChannelCreatedEvt, "test-channel", nil) + assert.Equal(t, listener1Count, 1) + assert.Equal(t, listener2Count, 1) + assert.Equal(t, listener3Count, 1) + + bus.SendMonitorEvent(ChannelDestroyedEvt, "test-channel", nil) + assert.Equal(t, listener1Count, 1) + assert.Equal(t, listener2Count, 2) + assert.Equal(t, listener3Count, 2) + + bus.SendMonitorEvent(StoreInitializedEvt, "store1", nil) + assert.Equal(t, listener1Count, 1) + assert.Equal(t, listener2Count, 2) + assert.Equal(t, listener3Count, 3) + + bus.RemoveMonitorEventListener(listener2) + + bus.SendMonitorEvent(ChannelCreatedEvt, "test-channel", nil) + assert.Equal(t, listener1Count, 2) + assert.Equal(t, listener2Count, 2) + assert.Equal(t, listener3Count, 4) + + bus.SendMonitorEvent(ChannelDestroyedEvt, "test-channel", nil) + assert.Equal(t, listener1Count, 2) + assert.Equal(t, listener2Count, 2) + assert.Equal(t, listener3Count, 5) + + bus.RemoveMonitorEventListener(listener3) + bus.SendMonitorEvent(ChannelCreatedEvt, "test-channel", nil) + assert.Equal(t, listener1Count, 3) + assert.Equal(t, listener2Count, 2) + assert.Equal(t, listener3Count, 5) } diff --git a/bus/example_galactic_channels_test.go b/bus/example_galactic_channels_test.go index da76a82..71b5b5b 100644 --- a/bus/example_galactic_channels_test.go +++ b/bus/example_galactic_channels_test.go @@ -4,87 +4,87 @@ package bus_test import ( - "encoding/json" - "fmt" - "github.com/pb33f/ranch/bridge" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/model" - "log" + "encoding/json" + "fmt" + "github.com/pb33f/ranch/bridge" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/model" + "log" ) func Example_usingGalacticChannels() { - // get a pointer to the bus. - b := bus.GetBus() - - // get a pointer to the channel manager - cm := b.GetChannelManager() - - channel := "my-stream" - cm.CreateChannel(channel) - - // create done signal - var done = make(chan bool) - - // listen to stream of messages coming in on channel. - h, err := b.ListenStream(channel) - - if err != nil { - log.Panicf("unable to listen to channel stream, error: %e", err) - } - - count := 0 - - // listen for five messages and then exit, send a completed signal on channel. - h.Handle( - func(msg *model.Message) { - - // unmarshal the payload into a Response object (used by fabric services) - r := &model.Response{} - d := msg.Payload.([]byte) - json.Unmarshal(d, &r) - fmt.Printf("Stream Ticked: %s\n", r.Payload.(string)) - count++ - if count >= 5 { - done <- true - } - }, - func(err error) { - log.Panicf("error received on channel %e", err) - }) - - // create a broker connector config, in this case, we will connect to the application fabric demo endpoint. - config := &bridge.BrokerConnectorConfig{ - Username: "guest", - Password: "guest", - ServerAddr: "appfabric.vmware.com", - WebSocketConfig: &bridge.WebSocketConfig{ - WSPath: "/fabric", - }, - UseWS: true} - - // connect to broker. - c, err := b.ConnectBroker(config) - if err != nil { - log.Panicf("unable to connect to fabric, error: %e", err) - } - - // mark our local channel as galactic and map it to our connection and the /topic/simple-stream service - // running on appfabric.vmware.com - err = cm.MarkChannelAsGalactic(channel, "/topic/simple-stream", c) - if err != nil { - log.Panicf("unable to map local channel to broker destination: %e", err) - } - - // wait for done signal - <-done - - // mark channel as local (unsubscribe from all mappings) - err = cm.MarkChannelAsLocal(channel) - if err != nil { - log.Panicf("unable to unsubscribe, error: %e", err) - } - err = c.Disconnect() - if err != nil { - log.Panicf("unable to disconnect, error: %e", err) - } + // get a pointer to the bus. + b := bus.GetBus() + + // get a pointer to the channel manager + cm := b.GetChannelManager() + + channel := "my-stream" + cm.CreateChannel(channel) + + // create done signal + var done = make(chan bool) + + // listen to stream of messages coming in on channel. + h, err := b.ListenStream(channel) + + if err != nil { + log.Panicf("unable to listen to channel stream, error: %e", err) + } + + count := 0 + + // listen for five messages and then exit, send a completed signal on channel. + h.Handle( + func(msg *model.Message) { + + // unmarshal the payload into a Response object (used by fabric services) + r := &model.Response{} + d := msg.Payload.([]byte) + json.Unmarshal(d, &r) + fmt.Printf("Stream Ticked: %s\n", r.Payload.(string)) + count++ + if count >= 5 { + done <- true + } + }, + func(err error) { + log.Panicf("error received on channel %e", err) + }) + + // create a broker connector config, in this case, we will connect to the application fabric demo endpoint. + config := &bridge.BrokerConnectorConfig{ + Username: "guest", + Password: "guest", + ServerAddr: "appfabric.vmware.com", + WebSocketConfig: &bridge.WebSocketConfig{ + WSPath: "/fabric", + }, + UseWS: true} + + // connect to broker. + c, err := b.ConnectBroker(config) + if err != nil { + log.Panicf("unable to connect to fabric, error: %e", err) + } + + // mark our local channel as galactic and map it to our connection and the /topic/simple-stream service + // running on appfabric.vmware.com + err = cm.MarkChannelAsGalactic(channel, "/topic/simple-stream", c) + if err != nil { + log.Panicf("unable to map local channel to broker destination: %e", err) + } + + // wait for done signal + <-done + + // mark channel as local (unsubscribe from all mappings) + err = cm.MarkChannelAsLocal(channel) + if err != nil { + log.Panicf("unable to unsubscribe, error: %e", err) + } + err = c.Disconnect() + if err != nil { + log.Panicf("unable to disconnect, error: %e", err) + } } diff --git a/bus/fabric_endpoint_test.go b/bus/fabric_endpoint_test.go index e6eb7e3..75bace7 100644 --- a/bus/fabric_endpoint_test.go +++ b/bus/fabric_endpoint_test.go @@ -4,390 +4,390 @@ package bus import ( - "encoding/json" - "errors" - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "github.com/pb33f/ranch/stompserver" - "github.com/stretchr/testify/assert" - "sync" - "testing" + "encoding/json" + "errors" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "github.com/pb33f/ranch/stompserver" + "github.com/stretchr/testify/assert" + "sync" + "testing" ) type MockStompServerMessage struct { - Destination string `json:"destination"` - Payload []byte `json:"payload"` - conId string + Destination string `json:"destination"` + Payload []byte `json:"payload"` + conId string } type MockStompServer struct { - started bool - sentMessages []MockStompServerMessage - subscribeHandlerFunction stompserver.SubscribeHandlerFunction - connectionEventCallbacks map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent) - unsubscribeHandlerFunction stompserver.UnsubscribeHandlerFunction - applicationRequestHandlerFunction stompserver.ApplicationRequestHandlerFunction - wg *sync.WaitGroup + started bool + sentMessages []MockStompServerMessage + subscribeHandlerFunction stompserver.SubscribeHandlerFunction + connectionEventCallbacks map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent) + unsubscribeHandlerFunction stompserver.UnsubscribeHandlerFunction + applicationRequestHandlerFunction stompserver.ApplicationRequestHandlerFunction + wg *sync.WaitGroup } func (s *MockStompServer) Start() { - s.started = true + s.started = true } func (s *MockStompServer) Stop() { - s.started = false + s.started = false } func (s *MockStompServer) SendMessage(destination string, messageBody []byte) { - s.sentMessages = append(s.sentMessages, - MockStompServerMessage{Destination: destination, Payload: messageBody}) + s.sentMessages = append(s.sentMessages, + MockStompServerMessage{Destination: destination, Payload: messageBody}) - if s.wg != nil { - s.wg.Done() - } + if s.wg != nil { + s.wg.Done() + } } func (s *MockStompServer) SendMessageToClient(conId string, destination string, messageBody []byte) { - s.sentMessages = append(s.sentMessages, - MockStompServerMessage{Destination: destination, Payload: messageBody, conId: conId}) + s.sentMessages = append(s.sentMessages, + MockStompServerMessage{Destination: destination, Payload: messageBody, conId: conId}) - if s.wg != nil { - s.wg.Done() - } + if s.wg != nil { + s.wg.Done() + } } func (s *MockStompServer) OnUnsubscribeEvent(callback stompserver.UnsubscribeHandlerFunction) { - s.unsubscribeHandlerFunction = callback + s.unsubscribeHandlerFunction = callback } func (s *MockStompServer) OnApplicationRequest(callback stompserver.ApplicationRequestHandlerFunction) { - s.applicationRequestHandlerFunction = callback + s.applicationRequestHandlerFunction = callback } func (s *MockStompServer) OnSubscribeEvent(callback stompserver.SubscribeHandlerFunction) { - s.subscribeHandlerFunction = callback + s.subscribeHandlerFunction = callback } func (s *MockStompServer) SetConnectionEventCallback(connEventType stompserver.StompSessionEventType, cb func(connEvent *stompserver.ConnEvent)) { - s.connectionEventCallbacks[connEventType] = cb - cb(&stompserver.ConnEvent{ConnId: "id"}) + s.connectionEventCallbacks[connEventType] = cb + cb(&stompserver.ConnEvent{ConnId: "id"}) } func newTestFabricEndpoint(bus EventBus, config EndpointConfig) (*fabricEndpoint, *MockStompServer) { - fe := newFabricEndpoint(bus, nil, config).(*fabricEndpoint) - ms := &MockStompServer{connectionEventCallbacks: make(map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent))} + fe := newFabricEndpoint(bus, nil, config).(*fabricEndpoint) + ms := &MockStompServer{connectionEventCallbacks: make(map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent))} - fe.server = ms - fe.initHandlers() + fe.server = ms + fe.initHandlers() - return fe, ms + return fe, ms } func TestFabricEndpoint_newFabricEndpoint(t *testing.T) { - fe, _ := newTestFabricEndpoint(nil, EndpointConfig{ - TopicPrefix: "/topic", - AppRequestPrefix: "/pub", - Heartbeat: 0, - }) - - assert.NotNil(t, fe) - assert.Equal(t, fe.config.TopicPrefix, "/topic/") - assert.Equal(t, fe.config.AppRequestPrefix, "/pub/") - - fe, _ = newTestFabricEndpoint(nil, EndpointConfig{ - TopicPrefix: "/topic/", - AppRequestPrefix: "", - Heartbeat: 0, - }) - - assert.Equal(t, fe.config.TopicPrefix, "/topic/") - assert.Equal(t, fe.config.AppRequestPrefix, "") + fe, _ := newTestFabricEndpoint(nil, EndpointConfig{ + TopicPrefix: "/topic", + AppRequestPrefix: "/pub", + Heartbeat: 0, + }) + + assert.NotNil(t, fe) + assert.Equal(t, fe.config.TopicPrefix, "/topic/") + assert.Equal(t, fe.config.AppRequestPrefix, "/pub/") + + fe, _ = newTestFabricEndpoint(nil, EndpointConfig{ + TopicPrefix: "/topic/", + AppRequestPrefix: "", + Heartbeat: 0, + }) + + assert.Equal(t, fe.config.TopicPrefix, "/topic/") + assert.Equal(t, fe.config.AppRequestPrefix, "") } func TestFabricEndpoint_StartAndStop(t *testing.T) { - fe, mockServer := newTestFabricEndpoint(nil, EndpointConfig{}) - assert.Equal(t, mockServer.started, false) - fe.Start() - assert.Equal(t, mockServer.started, true) - fe.Stop() - assert.Equal(t, mockServer.started, false) + fe, mockServer := newTestFabricEndpoint(nil, EndpointConfig{}) + assert.Equal(t, mockServer.started, false) + fe.Start() + assert.Equal(t, mockServer.started, true) + fe.Stop() + assert.Equal(t, mockServer.started, false) } func TestFabricEndpoint_SubscribeEvent(t *testing.T) { - bus := newTestEventBus() - bus.GetChannelManager().CreateChannel(STOMP_SESSION_NOTIFY_CHANNEL) // used for internal channel protection test - fe, mockServer := newTestFabricEndpoint(bus, - EndpointConfig{TopicPrefix: "/topic", UserQueuePrefix: "/user/queue"}) - - bus.GetChannelManager().CreateChannel("test-service") - - monitorWg := sync.WaitGroup{} - var monitorEvents []*MonitorEvent - bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) { - monitorEvents = append(monitorEvents, monitorEvt) - monitorWg.Done() - }, FabricEndpointSubscribeEvt) - - // subscribe to invalid topic - mockServer.subscribeHandlerFunction("con1", "sub1", "/topic2/test-service", nil) - assert.Equal(t, len(fe.chanMappings), 0) - - bus.SendResponseMessage("test-service", "test-message", nil) - assert.Equal(t, len(mockServer.sentMessages), 0) - - // subscribe to valid channel - monitorWg.Add(1) - mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil) - monitorWg.Wait() - assert.Equal(t, len(monitorEvents), 1) - assert.Equal(t, monitorEvents[0].EventType, FabricEndpointSubscribeEvt) - assert.Equal(t, monitorEvents[0].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1) - assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub1"], true) - - // subscribe again to the same channel - monitorWg.Add(1) - mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil) - monitorWg.Wait() - - assert.Equal(t, len(monitorEvents), 2) - assert.Equal(t, monitorEvents[1].EventType, FabricEndpointSubscribeEvt) - assert.Equal(t, monitorEvents[1].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) - assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub2"], true) - - // subscribe to queue channel - monitorWg.Add(1) - mockServer.subscribeHandlerFunction("con1", "sub3", "/user/queue/test-service", nil) - monitorWg.Wait() - assert.Equal(t, len(monitorEvents), 3) - assert.Equal(t, monitorEvents[2].EventType, FabricEndpointSubscribeEvt) - assert.Equal(t, monitorEvents[2].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 3) - assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub3"], true) - - // attempt to subscribe to a protected destination - mockServer.subscribeHandlerFunction("con1", "sub4", "/topic/"+STOMP_SESSION_NOTIFY_CHANNEL, nil) - _, chanMapCreated := fe.chanMappings[STOMP_SESSION_NOTIFY_CHANNEL] - assert.False(t, chanMapCreated) - - mockServer.wg = &sync.WaitGroup{} - mockServer.wg.Add(1) - - bus.SendResponseMessage("test-service", "test-message", nil) - - mockServer.wg.Wait() - - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", []byte{1, 2, 3}, nil) - mockServer.wg.Wait() - - mockServer.wg.Add(1) - msg := MockStompServerMessage{Destination: "test", Payload: []byte("test-message")} - bus.SendResponseMessage("test-service", msg, nil) - mockServer.wg.Wait() - - mockServer.wg.Add(1) - bus.SendErrorMessage("test-service", errors.New("test-error"), nil) - mockServer.wg.Wait() - - assert.Equal(t, len(mockServer.sentMessages), 4) - assert.Equal(t, mockServer.sentMessages[0].Destination, "/topic/test-service") - assert.Equal(t, string(mockServer.sentMessages[0].Payload), "test-message") - assert.Equal(t, mockServer.sentMessages[1].Payload, []byte{1, 2, 3}) - - var sentMsg MockStompServerMessage - json.Unmarshal(mockServer.sentMessages[2].Payload, &sentMsg) - assert.Equal(t, msg, sentMsg) - - assert.Equal(t, string(mockServer.sentMessages[3].Payload), "test-error") - - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", model.Response{ - BrokerDestination: &model.BrokerDestinationConfig{ - Destination: "/user/queue/test-service", - ConnectionId: "con1", - }, - Payload: "test-private-message", - }, nil) - - mockServer.wg.Wait() - - assert.Equal(t, len(mockServer.sentMessages), 5) - assert.Equal(t, mockServer.sentMessages[4].Destination, "/user/queue/test-service") - var sentResponse model.Response - json.Unmarshal(mockServer.sentMessages[4].Payload, &sentResponse) - assert.Equal(t, sentResponse.Payload, "test-private-message") - - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", &model.Response{ - BrokerDestination: &model.BrokerDestinationConfig{ - Destination: "/user/queue/test-service", - ConnectionId: "con1", - }, - Payload: "test-private-message-ptr", - }, nil) - - mockServer.wg.Wait() - - assert.Equal(t, len(mockServer.sentMessages), 6) - assert.Equal(t, mockServer.sentMessages[5].Destination, "/user/queue/test-service") - json.Unmarshal(mockServer.sentMessages[5].Payload, &sentResponse) - assert.Equal(t, sentResponse.Payload, "test-private-message-ptr") + bus := newTestEventBus() + bus.GetChannelManager().CreateChannel(STOMP_SESSION_NOTIFY_CHANNEL) // used for internal channel protection test + fe, mockServer := newTestFabricEndpoint(bus, + EndpointConfig{TopicPrefix: "/topic", UserQueuePrefix: "/user/queue"}) + + bus.GetChannelManager().CreateChannel("test-service") + + monitorWg := sync.WaitGroup{} + var monitorEvents []*MonitorEvent + bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) { + monitorEvents = append(monitorEvents, monitorEvt) + monitorWg.Done() + }, FabricEndpointSubscribeEvt) + + // subscribe to invalid topic + mockServer.subscribeHandlerFunction("con1", "sub1", "/topic2/test-service", nil) + assert.Equal(t, len(fe.chanMappings), 0) + + bus.SendResponseMessage("test-service", "test-message", nil) + assert.Equal(t, len(mockServer.sentMessages), 0) + + // subscribe to valid channel + monitorWg.Add(1) + mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil) + monitorWg.Wait() + assert.Equal(t, len(monitorEvents), 1) + assert.Equal(t, monitorEvents[0].EventType, FabricEndpointSubscribeEvt) + assert.Equal(t, monitorEvents[0].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1) + assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub1"], true) + + // subscribe again to the same channel + monitorWg.Add(1) + mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil) + monitorWg.Wait() + + assert.Equal(t, len(monitorEvents), 2) + assert.Equal(t, monitorEvents[1].EventType, FabricEndpointSubscribeEvt) + assert.Equal(t, monitorEvents[1].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) + assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub2"], true) + + // subscribe to queue channel + monitorWg.Add(1) + mockServer.subscribeHandlerFunction("con1", "sub3", "/user/queue/test-service", nil) + monitorWg.Wait() + assert.Equal(t, len(monitorEvents), 3) + assert.Equal(t, monitorEvents[2].EventType, FabricEndpointSubscribeEvt) + assert.Equal(t, monitorEvents[2].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 3) + assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub3"], true) + + // attempt to subscribe to a protected destination + mockServer.subscribeHandlerFunction("con1", "sub4", "/topic/"+STOMP_SESSION_NOTIFY_CHANNEL, nil) + _, chanMapCreated := fe.chanMappings[STOMP_SESSION_NOTIFY_CHANNEL] + assert.False(t, chanMapCreated) + + mockServer.wg = &sync.WaitGroup{} + mockServer.wg.Add(1) + + bus.SendResponseMessage("test-service", "test-message", nil) + + mockServer.wg.Wait() + + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", []byte{1, 2, 3}, nil) + mockServer.wg.Wait() + + mockServer.wg.Add(1) + msg := MockStompServerMessage{Destination: "test", Payload: []byte("test-message")} + bus.SendResponseMessage("test-service", msg, nil) + mockServer.wg.Wait() + + mockServer.wg.Add(1) + bus.SendErrorMessage("test-service", errors.New("test-error"), nil) + mockServer.wg.Wait() + + assert.Equal(t, len(mockServer.sentMessages), 4) + assert.Equal(t, mockServer.sentMessages[0].Destination, "/topic/test-service") + assert.Equal(t, string(mockServer.sentMessages[0].Payload), "test-message") + assert.Equal(t, mockServer.sentMessages[1].Payload, []byte{1, 2, 3}) + + var sentMsg MockStompServerMessage + json.Unmarshal(mockServer.sentMessages[2].Payload, &sentMsg) + assert.Equal(t, msg, sentMsg) + + assert.Equal(t, string(mockServer.sentMessages[3].Payload), "test-error") + + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", model.Response{ + BrokerDestination: &model.BrokerDestinationConfig{ + Destination: "/user/queue/test-service", + ConnectionId: "con1", + }, + Payload: "test-private-message", + }, nil) + + mockServer.wg.Wait() + + assert.Equal(t, len(mockServer.sentMessages), 5) + assert.Equal(t, mockServer.sentMessages[4].Destination, "/user/queue/test-service") + var sentResponse model.Response + json.Unmarshal(mockServer.sentMessages[4].Payload, &sentResponse) + assert.Equal(t, sentResponse.Payload, "test-private-message") + + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", &model.Response{ + BrokerDestination: &model.BrokerDestinationConfig{ + Destination: "/user/queue/test-service", + ConnectionId: "con1", + }, + Payload: "test-private-message-ptr", + }, nil) + + mockServer.wg.Wait() + + assert.Equal(t, len(mockServer.sentMessages), 6) + assert.Equal(t, mockServer.sentMessages[5].Destination, "/user/queue/test-service") + json.Unmarshal(mockServer.sentMessages[5].Payload, &sentResponse) + assert.Equal(t, sentResponse.Payload, "test-private-message-ptr") } func TestFabricEndpoint_UnsubscribeEvent(t *testing.T) { - bus := newTestEventBus() - fe, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic"}) - - bus.GetChannelManager().CreateChannel("test-service") - - monitorWg := sync.WaitGroup{} - var monitorEvents []*MonitorEvent - bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) { - monitorEvents = append(monitorEvents, monitorEvt) - monitorWg.Done() - }, FabricEndpointUnsubscribeEvt) - - // subscribe to valid channel - mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil) - mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil) - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) - - mockServer.wg = &sync.WaitGroup{} - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", "test-message", nil) - mockServer.wg.Wait() - assert.Equal(t, len(mockServer.sentMessages), 1) - - mockServer.unsubscribeHandlerFunction("con1", "sub2", "/invalid-topic/test-service") - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) - - mockServer.unsubscribeHandlerFunction("invalid-con1", "sub2", "/topic/test-service") - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) - - monitorWg.Add(1) - mockServer.unsubscribeHandlerFunction("con1", "sub2", "/topic/test-service") - monitorWg.Wait() - - assert.Equal(t, len(monitorEvents), 1) - assert.Equal(t, monitorEvents[0].EventType, FabricEndpointUnsubscribeEvt) - assert.Equal(t, monitorEvents[0].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1) - - mockServer.wg = &sync.WaitGroup{} - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", "test-message", nil) - mockServer.wg.Wait() - assert.Equal(t, len(mockServer.sentMessages), 2) - - monitorWg.Add(1) - mockServer.unsubscribeHandlerFunction("con1", "sub1", "/topic/test-service") - monitorWg.Wait() - - assert.Equal(t, len(monitorEvents), 2) - assert.Equal(t, monitorEvents[1].EventType, FabricEndpointUnsubscribeEvt) - assert.Equal(t, monitorEvents[1].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 0) - bus.SendResponseMessage("test-service", "test-message", nil) - - // subscribe to non-existing channel - mockServer.subscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel", nil) - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["non-existing-channel"].subs), 1) - assert.Equal(t, fe.chanMappings["non-existing-channel"].autoCreated, true) - assert.True(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel")) - - monitorWg.Add(1) - mockServer.unsubscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel") - monitorWg.Wait() - - assert.Equal(t, len(monitorEvents), 3) - assert.Equal(t, monitorEvents[2].EventType, FabricEndpointUnsubscribeEvt) - assert.Equal(t, monitorEvents[2].EntityName, "non-existing-channel") - - assert.Equal(t, len(fe.chanMappings), 0) - assert.False(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel")) + bus := newTestEventBus() + fe, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic"}) + + bus.GetChannelManager().CreateChannel("test-service") + + monitorWg := sync.WaitGroup{} + var monitorEvents []*MonitorEvent + bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) { + monitorEvents = append(monitorEvents, monitorEvt) + monitorWg.Done() + }, FabricEndpointUnsubscribeEvt) + + // subscribe to valid channel + mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil) + mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil) + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) + + mockServer.wg = &sync.WaitGroup{} + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", "test-message", nil) + mockServer.wg.Wait() + assert.Equal(t, len(mockServer.sentMessages), 1) + + mockServer.unsubscribeHandlerFunction("con1", "sub2", "/invalid-topic/test-service") + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) + + mockServer.unsubscribeHandlerFunction("invalid-con1", "sub2", "/topic/test-service") + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) + + monitorWg.Add(1) + mockServer.unsubscribeHandlerFunction("con1", "sub2", "/topic/test-service") + monitorWg.Wait() + + assert.Equal(t, len(monitorEvents), 1) + assert.Equal(t, monitorEvents[0].EventType, FabricEndpointUnsubscribeEvt) + assert.Equal(t, monitorEvents[0].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1) + + mockServer.wg = &sync.WaitGroup{} + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", "test-message", nil) + mockServer.wg.Wait() + assert.Equal(t, len(mockServer.sentMessages), 2) + + monitorWg.Add(1) + mockServer.unsubscribeHandlerFunction("con1", "sub1", "/topic/test-service") + monitorWg.Wait() + + assert.Equal(t, len(monitorEvents), 2) + assert.Equal(t, monitorEvents[1].EventType, FabricEndpointUnsubscribeEvt) + assert.Equal(t, monitorEvents[1].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 0) + bus.SendResponseMessage("test-service", "test-message", nil) + + // subscribe to non-existing channel + mockServer.subscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel", nil) + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["non-existing-channel"].subs), 1) + assert.Equal(t, fe.chanMappings["non-existing-channel"].autoCreated, true) + assert.True(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel")) + + monitorWg.Add(1) + mockServer.unsubscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel") + monitorWg.Wait() + + assert.Equal(t, len(monitorEvents), 3) + assert.Equal(t, monitorEvents[2].EventType, FabricEndpointUnsubscribeEvt) + assert.Equal(t, monitorEvents[2].EntityName, "non-existing-channel") + + assert.Equal(t, len(fe.chanMappings), 0) + assert.False(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel")) } func TestFabricEndpoint_BridgeMessage(t *testing.T) { - bus := newTestEventBus() - _, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic", AppRequestPrefix: "/pub", - AppRequestQueuePrefix: "/pub/queue", UserQueuePrefix: "/user/queue"}) + bus := newTestEventBus() + _, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic", AppRequestPrefix: "/pub", + AppRequestQueuePrefix: "/pub/queue", UserQueuePrefix: "/user/queue"}) - bus.GetChannelManager().CreateChannel("request-channel") - mh, _ := bus.ListenRequestStream("request-channel") - assert.NotNil(t, mh) + bus.GetChannelManager().CreateChannel("request-channel") + mh, _ := bus.ListenRequestStream("request-channel") + assert.NotNil(t, mh) - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - var messages []*model.Message + var messages []*model.Message - mh.Handle(func(message *model.Message) { - messages = append(messages, message) - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) + mh.Handle(func(message *model.Message) { + messages = append(messages, message) + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) - id1 := uuid.New() - req1, _ := json.Marshal(model.Request{ - Request: "test-request", - Payload: "test-rq", - Id: &id1, - }) + id1 := uuid.New() + req1, _ := json.Marshal(model.Request{ + Request: "test-request", + Payload: "test-rq", + Id: &id1, + }) - wg.Add(1) + wg.Add(1) - mockServer.applicationRequestHandlerFunction("/pub/request-channel", req1, "con1") + mockServer.applicationRequestHandlerFunction("/pub/request-channel", req1, "con1") - mockServer.applicationRequestHandlerFunction("/pub2/request-channel", req1, "con1") - mockServer.applicationRequestHandlerFunction("/pub/request-channel-2", req1, "con1") + mockServer.applicationRequestHandlerFunction("/pub2/request-channel", req1, "con1") + mockServer.applicationRequestHandlerFunction("/pub/request-channel-2", req1, "con1") - mockServer.applicationRequestHandlerFunction("/pub/request-channel", []byte("invalid-request-json"), "con1") + mockServer.applicationRequestHandlerFunction("/pub/request-channel", []byte("invalid-request-json"), "con1") - id2 := uuid.New() - req2, _ := json.Marshal(model.Request{ - Request: "test-request2", - Payload: "test-rq2", - Id: &id2, - }) + id2 := uuid.New() + req2, _ := json.Marshal(model.Request{ + Request: "test-request2", + Payload: "test-rq2", + Id: &id2, + }) - wg.Wait() + wg.Wait() - wg.Add(1) - mockServer.applicationRequestHandlerFunction("/pub/queue/request-channel", req2, "con2") - wg.Wait() + wg.Add(1) + mockServer.applicationRequestHandlerFunction("/pub/queue/request-channel", req2, "con2") + wg.Wait() - assert.Equal(t, len(messages), 2) + assert.Equal(t, len(messages), 2) - receivedReq := messages[0].Payload.(*model.Request) + receivedReq := messages[0].Payload.(*model.Request) - assert.Equal(t, receivedReq.Request, "test-request") - assert.Equal(t, receivedReq.Payload, "test-rq") - assert.Equal(t, *receivedReq.Id, id1) - assert.Nil(t, receivedReq.BrokerDestination) + assert.Equal(t, receivedReq.Request, "test-request") + assert.Equal(t, receivedReq.Payload, "test-rq") + assert.Equal(t, *receivedReq.Id, id1) + assert.Nil(t, receivedReq.BrokerDestination) - receivedReq2 := messages[1].Payload.(*model.Request) + receivedReq2 := messages[1].Payload.(*model.Request) - assert.Equal(t, receivedReq2.Request, "test-request2") - assert.Equal(t, receivedReq2.Payload, "test-rq2") - assert.Equal(t, *receivedReq2.Id, id2) - assert.Equal(t, receivedReq2.BrokerDestination.ConnectionId, "con2") - assert.Equal(t, receivedReq2.BrokerDestination.Destination, "/user/queue/request-channel") + assert.Equal(t, receivedReq2.Request, "test-request2") + assert.Equal(t, receivedReq2.Payload, "test-rq2") + assert.Equal(t, *receivedReq2.Id, id2) + assert.Equal(t, receivedReq2.BrokerDestination.ConnectionId, "con2") + assert.Equal(t, receivedReq2.BrokerDestination.Destination, "/user/queue/request-channel") } diff --git a/bus/message_handler.go b/bus/message_handler.go index 3797d9a..e457979 100644 --- a/bus/message_handler.go +++ b/bus/message_handler.go @@ -4,10 +4,10 @@ package bus import ( - "fmt" - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "sync" + "fmt" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "sync" ) // Signature used for all functions used on bus stream APIs to Handle messages. @@ -20,59 +20,59 @@ type MessageErrorFunction func(error) // It also provides a Handle method that accepts a success and error function as handlers. // The Fire method will fire the message queued when using RequestOnce or RequestStream type MessageHandler interface { - GetId() *uuid.UUID - GetDestinationId() *uuid.UUID - Handle(successHandler MessageHandlerFunction, errorHandler MessageErrorFunction) - Fire() error - Close() + GetId() *uuid.UUID + GetDestinationId() *uuid.UUID + Handle(successHandler MessageHandlerFunction, errorHandler MessageErrorFunction) + Fire() error + Close() } type messageHandler struct { - id *uuid.UUID - destination *uuid.UUID - eventCount int64 - closed bool - channel *Channel - requestMessage *model.Message - runCount int64 - ignoreId bool - wrapperFunction MessageHandlerFunction - successHandler MessageHandlerFunction - errorHandler MessageErrorFunction - subscriptionId *uuid.UUID - invokeOnce *sync.Once - channelManager ChannelManager + id *uuid.UUID + destination *uuid.UUID + eventCount int64 + closed bool + channel *Channel + requestMessage *model.Message + runCount int64 + ignoreId bool + wrapperFunction MessageHandlerFunction + successHandler MessageHandlerFunction + errorHandler MessageErrorFunction + subscriptionId *uuid.UUID + invokeOnce *sync.Once + channelManager ChannelManager } func (msgHandler *messageHandler) Handle(successHandler MessageHandlerFunction, errorHandler MessageErrorFunction) { - msgHandler.successHandler = successHandler - msgHandler.errorHandler = errorHandler + msgHandler.successHandler = successHandler + msgHandler.errorHandler = errorHandler - msgHandler.subscriptionId, _ = msgHandler.channelManager.SubscribeChannelHandler( - msgHandler.channel.Name, msgHandler.wrapperFunction, false) + msgHandler.subscriptionId, _ = msgHandler.channelManager.SubscribeChannelHandler( + msgHandler.channel.Name, msgHandler.wrapperFunction, false) } func (msgHandler *messageHandler) Close() { - if msgHandler.subscriptionId != nil { - msgHandler.channelManager.UnsubscribeChannelHandler( - msgHandler.channel.Name, msgHandler.subscriptionId) - } + if msgHandler.subscriptionId != nil { + msgHandler.channelManager.UnsubscribeChannelHandler( + msgHandler.channel.Name, msgHandler.subscriptionId) + } } func (msgHandler *messageHandler) GetId() *uuid.UUID { - return msgHandler.id + return msgHandler.id } func (msgHandler *messageHandler) GetDestinationId() *uuid.UUID { - return msgHandler.destination + return msgHandler.destination } func (msgHandler *messageHandler) Fire() error { - if msgHandler.requestMessage != nil { - sendMessageToChannel(msgHandler.channel, msgHandler.requestMessage) - msgHandler.channel.wg.Wait() - return nil - } else { - return fmt.Errorf("nothing to fire, request is empty") - } + if msgHandler.requestMessage != nil { + sendMessageToChannel(msgHandler.channel, msgHandler.requestMessage) + msgHandler.channel.wg.Wait() + return nil + } else { + return fmt.Errorf("nothing to fire, request is empty") + } } diff --git a/bus/message_test.go b/bus/message_test.go index 9dbd432..930076c 100644 --- a/bus/message_test.go +++ b/bus/message_test.go @@ -4,20 +4,20 @@ package bus import ( - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "testing" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "testing" ) func TestMessageModel(t *testing.T) { - id := uuid.New() - var message = &model.Message{ - Id: &id, - Payload: "A new message", - Channel: "123", - Direction: model.RequestDir} - assert.Equal(t, "A new message", message.Payload) - assert.Equal(t, model.RequestDir, message.Direction) - assert.Equal(t, message.Channel, "123") + id := uuid.New() + var message = &model.Message{ + Id: &id, + Payload: "A new message", + Channel: "123", + Direction: model.RequestDir} + assert.Equal(t, "A new message", message.Payload) + assert.Equal(t, model.RequestDir, message.Direction) + assert.Equal(t, message.Channel, "123") } diff --git a/bus/store.go b/bus/store.go index 5612c34..533ec3f 100644 --- a/bus/store.go +++ b/bus/store.go @@ -4,551 +4,551 @@ package bus import ( - "encoding/json" - "fmt" - "github.com/google/uuid" - "github.com/pb33f/ranch/log" - "github.com/pb33f/ranch/model" - "reflect" - "sync" + "encoding/json" + "fmt" + "github.com/google/uuid" + "github.com/pb33f/ranch/log" + "github.com/pb33f/ranch/model" + "reflect" + "sync" ) // Describes a single store item change type StoreChange struct { - Id string // the id of the updated item - Value interface{} // the updated value of the item - State interface{} // state associated with this change - IsDeleteChange bool // true if the item was removed from the store - StoreVersion int64 // the store's version when this change was made + Id string // the id of the updated item + Value interface{} // the updated value of the item + State interface{} // state associated with this change + IsDeleteChange bool // true if the item was removed from the store + StoreVersion int64 // the store's version when this change was made } // BusStore is a stateful in memory cache for objects. All state changes (any time the cache is modified) // will broadcast that updated object to any subscribers of the BusStore for those specific objects // or all objects of a certain type and state changes. type BusStore interface { - // Get the name (the id) of the store. - GetName() string - // Add new or updates existing item in the store. - Put(id string, value interface{}, state interface{}) - // Returns an item from the store and a boolean flag - // indicating whether the item exists - Get(id string) (interface{}, bool) - // Shorten version of the Get() method, returns only the item value. - GetValue(id string) interface{} - // Remove an item from the store. Returns true if the remove operation was successful. - Remove(id string, state interface{}) bool - // Return a slice containing all store items. - AllValues() []interface{} - // Return a map with all items from the store. - AllValuesAsMap() map[string]interface{} - // Return a map with all items from the store with the current store version. - AllValuesAndVersion() (map[string]interface{}, int64) - // Subscribe to state changes for a specific object. - OnChange(id string, state ...interface{}) StoreStream - // Subscribe to state changes for all objects - OnAllChanges(state ...interface{}) StoreStream - // Notify when the store has been initialize (via populate() or initialize() - WhenReady(readyFunction func()) - // Populate the store with a map of items and their ID's. - Populate(items map[string]interface{}) error - // Mark the store as initialized and notify all watchers. - Initialize() - // Subscribe to mutation requests made via mutate() method. - OnMutationRequest(mutationType ...interface{}) MutationStoreStream - // Send a mutation request to any subscribers handling mutations. - Mutate(request interface{}, requestType interface{}, - successHandler func(interface{}), errorHandler func(interface{})) - // Removes all items from the store and change its state to uninitialized". - Reset() - // Returns true if this is galactic store. - IsGalactic() bool - // Get the item type if such is specified during the creation of the - // store - GetItemType() reflect.Type + // Get the name (the id) of the store. + GetName() string + // Add new or updates existing item in the store. + Put(id string, value interface{}, state interface{}) + // Returns an item from the store and a boolean flag + // indicating whether the item exists + Get(id string) (interface{}, bool) + // Shorten version of the Get() method, returns only the item value. + GetValue(id string) interface{} + // Remove an item from the store. Returns true if the remove operation was successful. + Remove(id string, state interface{}) bool + // Return a slice containing all store items. + AllValues() []interface{} + // Return a map with all items from the store. + AllValuesAsMap() map[string]interface{} + // Return a map with all items from the store with the current store version. + AllValuesAndVersion() (map[string]interface{}, int64) + // Subscribe to state changes for a specific object. + OnChange(id string, state ...interface{}) StoreStream + // Subscribe to state changes for all objects + OnAllChanges(state ...interface{}) StoreStream + // Notify when the store has been initialize (via populate() or initialize() + WhenReady(readyFunction func()) + // Populate the store with a map of items and their ID's. + Populate(items map[string]interface{}) error + // Mark the store as initialized and notify all watchers. + Initialize() + // Subscribe to mutation requests made via mutate() method. + OnMutationRequest(mutationType ...interface{}) MutationStoreStream + // Send a mutation request to any subscribers handling mutations. + Mutate(request interface{}, requestType interface{}, + successHandler func(interface{}), errorHandler func(interface{})) + // Removes all items from the store and change its state to uninitialized". + Reset() + // Returns true if this is galactic store. + IsGalactic() bool + // Get the item type if such is specified during the creation of the + // store + GetItemType() reflect.Type } // Internal BusStore implementation type busStore struct { - name string - itemsLock sync.RWMutex - items map[string]interface{} - storeVersion int64 - storeStreamsLock sync.RWMutex - storeStreams []*storeStream - mutationStreamsLock sync.RWMutex - mutationStreams []*mutationStoreStream - initializer sync.Once - readyC chan struct{} - isGalactic bool - galacticConf *galacticStoreConfig - bus EventBus - itemType reflect.Type - storeSynHandler MessageHandler + name string + itemsLock sync.RWMutex + items map[string]interface{} + storeVersion int64 + storeStreamsLock sync.RWMutex + storeStreams []*storeStream + mutationStreamsLock sync.RWMutex + mutationStreams []*mutationStoreStream + initializer sync.Once + readyC chan struct{} + isGalactic bool + galacticConf *galacticStoreConfig + bus EventBus + itemType reflect.Type + storeSynHandler MessageHandler } type galacticStoreConfig struct { - syncChannelConfig *storeSyncChannelConfig + syncChannelConfig *storeSyncChannelConfig } func newBusStore(name string, bus EventBus, itemType reflect.Type, galacticConf *galacticStoreConfig) BusStore { - store := new(busStore) - store.name = name - store.bus = bus - store.itemType = itemType - store.galacticConf = galacticConf + store := new(busStore) + store.name = name + store.bus = bus + store.itemType = itemType + store.galacticConf = galacticConf - initStore(store) + initStore(store) - store.isGalactic = galacticConf != nil + store.isGalactic = galacticConf != nil - if store.isGalactic { - initGalacticStore(store) - } + if store.isGalactic { + initGalacticStore(store) + } - return store + return store } func initStore(store *busStore) { - store.readyC = make(chan struct{}) - store.storeStreams = []*storeStream{} - store.mutationStreams = []*mutationStoreStream{} - store.items = make(map[string]interface{}) - store.storeVersion = 1 - store.initializer = sync.Once{} + store.readyC = make(chan struct{}) + store.storeStreams = []*storeStream{} + store.mutationStreams = []*mutationStoreStream{} + store.items = make(map[string]interface{}) + store.storeVersion = 1 + store.initializer = sync.Once{} } func initGalacticStore(store *busStore) { - syncChannelConf := store.galacticConf.syncChannelConfig - - var err error - store.storeSynHandler, err = store.bus.ListenStream(syncChannelConf.syncChannelName) - if err != nil { - return - } - - store.storeSynHandler.Handle( - func(msg *model.Message) { - d := msg.Payload.([]byte) - var storeResponse map[string]interface{} - - err := json.Unmarshal(d, &storeResponse) - if err != nil { - log.Warn("failed to unmarshal storeResponse") - return - } - - if storeResponse["storeId"] != store.GetName() { - // the response is for another store - return - } - - responseType := storeResponse["responseType"].(string) - - switch responseType { - case "storeContentResponse": - - store.itemsLock.Lock() - defer store.itemsLock.Unlock() - - store.updateVersionFromResponse(storeResponse) - items := storeResponse["items"].(map[string]interface{}) - store.items = make(map[string]interface{}) - for key, val := range items { - deserializedValue, err := store.deserializeRawValue(val) - if err != nil { - log.Warn("failed to deserialize store item value %e", err) - continue - } else { - store.items[key] = deserializedValue - } - } - store.Initialize() - case "updateStoreResponse": - - store.itemsLock.Lock() - defer store.itemsLock.Unlock() - - store.updateVersionFromResponse(storeResponse) - newItemRaw, ok := storeResponse["newItemValue"] - itemId := storeResponse["itemId"].(string) - if !ok || newItemRaw == nil { - store.removeInternal(itemId, "galacticSyncRemove") - } else { - newItemValue, err := store.deserializeRawValue(newItemRaw) - if err != nil { - log.Warn("failed to deserialize store item value %e", err) - return - } - store.putInternal(itemId, newItemValue, "galacticSyncUpdate") - } - } - }, - func(e error) { - }) - - store.sendOpenStoreRequest() + syncChannelConf := store.galacticConf.syncChannelConfig + + var err error + store.storeSynHandler, err = store.bus.ListenStream(syncChannelConf.syncChannelName) + if err != nil { + return + } + + store.storeSynHandler.Handle( + func(msg *model.Message) { + d := msg.Payload.([]byte) + var storeResponse map[string]interface{} + + err := json.Unmarshal(d, &storeResponse) + if err != nil { + log.Warn("failed to unmarshal storeResponse") + return + } + + if storeResponse["storeId"] != store.GetName() { + // the response is for another store + return + } + + responseType := storeResponse["responseType"].(string) + + switch responseType { + case "storeContentResponse": + + store.itemsLock.Lock() + defer store.itemsLock.Unlock() + + store.updateVersionFromResponse(storeResponse) + items := storeResponse["items"].(map[string]interface{}) + store.items = make(map[string]interface{}) + for key, val := range items { + deserializedValue, err := store.deserializeRawValue(val) + if err != nil { + log.Warn("failed to deserialize store item value %e", err) + continue + } else { + store.items[key] = deserializedValue + } + } + store.Initialize() + case "updateStoreResponse": + + store.itemsLock.Lock() + defer store.itemsLock.Unlock() + + store.updateVersionFromResponse(storeResponse) + newItemRaw, ok := storeResponse["newItemValue"] + itemId := storeResponse["itemId"].(string) + if !ok || newItemRaw == nil { + store.removeInternal(itemId, "galacticSyncRemove") + } else { + newItemValue, err := store.deserializeRawValue(newItemRaw) + if err != nil { + log.Warn("failed to deserialize store item value %e", err) + return + } + store.putInternal(itemId, newItemValue, "galacticSyncUpdate") + } + } + }, + func(e error) { + }) + + store.sendOpenStoreRequest() } func (store *busStore) updateVersionFromResponse(storeResponse map[string]interface{}) { - version := storeResponse["storeVersion"] - switch version.(type) { - case float64: - store.storeVersion = int64(version.(float64)) - case int64: - store.storeVersion = version.(int64) - default: - log.Warn("failed to deserialize store version") - store.storeVersion = 1 - } + version := storeResponse["storeVersion"] + switch version.(type) { + case float64: + store.storeVersion = int64(version.(float64)) + case int64: + store.storeVersion = version.(int64) + default: + log.Warn("failed to deserialize store version") + store.storeVersion = 1 + } } func (store *busStore) deserializeRawValue(rawValue interface{}) (interface{}, error) { - return model.ConvertValueToType(rawValue, store.itemType) + return model.ConvertValueToType(rawValue, store.itemType) } func (store *busStore) sendOpenStoreRequest() { - openStoreReq := map[string]string{ - "storeId": store.GetName(), - } - store.sendGalacticRequest("openStore", openStoreReq) + openStoreReq := map[string]string{ + "storeId": store.GetName(), + } + store.sendGalacticRequest("openStore", openStoreReq) } func (store *busStore) sendGalacticRequest(requestCmd string, requestPayload interface{}) { - // create request - id := uuid.New() - r := &model.Request{} - r.RequestCommand = requestCmd - r.Payload = requestPayload - r.Id = &id - jsonReq, _ := json.Marshal(r) + // create request + id := uuid.New() + r := &model.Request{} + r.RequestCommand = requestCmd + r.Payload = requestPayload + r.Id = &id + jsonReq, _ := json.Marshal(r) - syncChannelConfig := store.galacticConf.syncChannelConfig + syncChannelConfig := store.galacticConf.syncChannelConfig - // send request. - syncChannelConfig.conn.SendJSONMessage( - syncChannelConfig.pubPrefix+syncChannelConfig.syncChannelName, - jsonReq) + // send request. + syncChannelConfig.conn.SendJSONMessage( + syncChannelConfig.pubPrefix+syncChannelConfig.syncChannelName, + jsonReq) } func (store *busStore) sendCloseStoreRequest() { - closeStoreReq := map[string]string{ - "storeId": store.GetName(), - } - store.sendGalacticRequest("closeStore", closeStoreReq) + closeStoreReq := map[string]string{ + "storeId": store.GetName(), + } + store.sendGalacticRequest("closeStore", closeStoreReq) } func (store *busStore) OnDestroy() { - if store.IsGalactic() { - store.sendCloseStoreRequest() - if store.storeSynHandler != nil { - store.storeSynHandler.Close() - } - } + if store.IsGalactic() { + store.sendCloseStoreRequest() + if store.storeSynHandler != nil { + store.storeSynHandler.Close() + } + } } func (store *busStore) IsGalactic() bool { - return store.isGalactic + return store.isGalactic } func (store *busStore) GetItemType() reflect.Type { - return store.itemType + return store.itemType } func (store *busStore) GetName() string { - return store.name + return store.name } func (store *busStore) Populate(items map[string]interface{}) error { - if store.IsGalactic() { - return fmt.Errorf("populate() API is not supported for galactic stores") - } + if store.IsGalactic() { + return fmt.Errorf("populate() API is not supported for galactic stores") + } - store.itemsLock.Lock() - defer store.itemsLock.Unlock() + store.itemsLock.Lock() + defer store.itemsLock.Unlock() - if len(store.items) > 0 { - return fmt.Errorf("store items already initialized") - } + if len(store.items) > 0 { + return fmt.Errorf("store items already initialized") + } - for k, v := range items { - store.items[k] = v - } - store.Initialize() - return nil + for k, v := range items { + store.items[k] = v + } + store.Initialize() + return nil } func (store *busStore) Put(id string, value interface{}, state interface{}) { - if store.IsGalactic() { - store.putGalactic(id, value) - } else { - store.itemsLock.Lock() - defer store.itemsLock.Unlock() + if store.IsGalactic() { + store.putGalactic(id, value) + } else { + store.itemsLock.Lock() + defer store.itemsLock.Unlock() - store.putInternal(id, value, state) - } + store.putInternal(id, value, state) + } } func (store *busStore) putGalactic(id string, value interface{}) { - store.itemsLock.RLock() - clientStoreVersion := store.storeVersion - store.itemsLock.RUnlock() + store.itemsLock.RLock() + clientStoreVersion := store.storeVersion + store.itemsLock.RUnlock() - store.sendUpdateStoreRequest(id, value, clientStoreVersion) + store.sendUpdateStoreRequest(id, value, clientStoreVersion) } func (store *busStore) sendUpdateStoreRequest(id string, value interface{}, storeVersion int64) { - updateReq := map[string]interface{}{ - "storeId": store.GetName(), - "clientStoreVersion": storeVersion, - "itemId": id, - "newItemValue": value, - } + updateReq := map[string]interface{}{ + "storeId": store.GetName(), + "clientStoreVersion": storeVersion, + "itemId": id, + "newItemValue": value, + } - store.sendGalacticRequest("updateStore", updateReq) + store.sendGalacticRequest("updateStore", updateReq) } func (store *busStore) putInternal(id string, value interface{}, state interface{}) { - if !store.IsGalactic() { - store.storeVersion++ - } - store.items[id] = value + if !store.IsGalactic() { + store.storeVersion++ + } + store.items[id] = value - change := &StoreChange{ - Id: id, - State: state, - Value: value, - StoreVersion: store.storeVersion, - } + change := &StoreChange{ + Id: id, + State: state, + Value: value, + StoreVersion: store.storeVersion, + } - go store.onStoreChange(change) + go store.onStoreChange(change) } func (store *busStore) Get(id string) (interface{}, bool) { - store.itemsLock.RLock() - defer store.itemsLock.RUnlock() + store.itemsLock.RLock() + defer store.itemsLock.RUnlock() - val, ok := store.items[id] + val, ok := store.items[id] - return val, ok + return val, ok } func (store *busStore) GetValue(id string) interface{} { - val, _ := store.Get(id) - return val + val, _ := store.Get(id) + return val } func (store *busStore) Remove(id string, state interface{}) bool { - if store.IsGalactic() { - return store.removeGalactic(id) - } else { - store.itemsLock.Lock() - defer store.itemsLock.Unlock() + if store.IsGalactic() { + return store.removeGalactic(id) + } else { + store.itemsLock.Lock() + defer store.itemsLock.Unlock() - return store.removeInternal(id, state) - } + return store.removeInternal(id, state) + } } func (store *busStore) removeGalactic(id string) bool { - store.itemsLock.RLock() - _, ok := store.items[id] - storeVersion := store.storeVersion - store.itemsLock.RUnlock() + store.itemsLock.RLock() + _, ok := store.items[id] + storeVersion := store.storeVersion + store.itemsLock.RUnlock() - if ok { - store.sendUpdateStoreRequest(id, nil, storeVersion) - return true - } - return false + if ok { + store.sendUpdateStoreRequest(id, nil, storeVersion) + return true + } + return false } func (store *busStore) removeInternal(id string, state interface{}) bool { - value, ok := store.items[id] - if !ok { - return false - } + value, ok := store.items[id] + if !ok { + return false + } - if !store.IsGalactic() { - store.storeVersion++ - } - delete(store.items, id) + if !store.IsGalactic() { + store.storeVersion++ + } + delete(store.items, id) - change := &StoreChange{ - Id: id, - State: state, - Value: value, - StoreVersion: store.storeVersion, - IsDeleteChange: true, - } + change := &StoreChange{ + Id: id, + State: state, + Value: value, + StoreVersion: store.storeVersion, + IsDeleteChange: true, + } - go store.onStoreChange(change) - return true + go store.onStoreChange(change) + return true } func (store *busStore) AllValues() []interface{} { - store.itemsLock.RLock() - defer store.itemsLock.RUnlock() + store.itemsLock.RLock() + defer store.itemsLock.RUnlock() - values := make([]interface{}, 0, len(store.items)) - for _, value := range store.items { - values = append(values, value) - } + values := make([]interface{}, 0, len(store.items)) + for _, value := range store.items { + values = append(values, value) + } - return values + return values } func (store *busStore) AllValuesAsMap() map[string]interface{} { - store.itemsLock.RLock() - defer store.itemsLock.RUnlock() + store.itemsLock.RLock() + defer store.itemsLock.RUnlock() - values := make(map[string]interface{}) + values := make(map[string]interface{}) - for key, value := range store.items { - values[key] = value - } + for key, value := range store.items { + values[key] = value + } - return values + return values } func (store *busStore) AllValuesAndVersion() (map[string]interface{}, int64) { - store.itemsLock.RLock() - defer store.itemsLock.RUnlock() + store.itemsLock.RLock() + defer store.itemsLock.RUnlock() - values := make(map[string]interface{}) + values := make(map[string]interface{}) - for key, value := range store.items { - values[key] = value - } + for key, value := range store.items { + values[key] = value + } - return values, store.storeVersion + return values, store.storeVersion } func (store *busStore) OnMutationRequest(requestType ...interface{}) MutationStoreStream { - return newMutationStoreStream(store, &mutationStreamFilter{ - requestTypes: requestType, - }) + return newMutationStoreStream(store, &mutationStreamFilter{ + requestTypes: requestType, + }) } func (store *busStore) Mutate(request interface{}, requestType interface{}, - successHandler func(interface{}), errorHandler func(interface{})) { + successHandler func(interface{}), errorHandler func(interface{})) { - store.mutationStreamsLock.RLock() - defer store.mutationStreamsLock.RUnlock() + store.mutationStreamsLock.RLock() + defer store.mutationStreamsLock.RUnlock() - for _, ms := range store.mutationStreams { - ms.onMutationRequest(&MutationRequest{ - Request: request, - RequestType: requestType, - SuccessHandler: successHandler, - ErrorHandler: errorHandler, - }) - } + for _, ms := range store.mutationStreams { + ms.onMutationRequest(&MutationRequest{ + Request: request, + RequestType: requestType, + SuccessHandler: successHandler, + ErrorHandler: errorHandler, + }) + } } func (store *busStore) onStoreChange(change *StoreChange) { - store.storeStreamsLock.RLock() - defer store.storeStreamsLock.RUnlock() + store.storeStreamsLock.RLock() + defer store.storeStreamsLock.RUnlock() - for _, storeStream := range store.storeStreams { - storeStream.onStoreChange(change) - } + for _, storeStream := range store.storeStreams { + storeStream.onStoreChange(change) + } } func (store *busStore) Initialize() { - store.initializer.Do(func() { - close(store.readyC) - store.bus.SendMonitorEvent(StoreInitializedEvt, store.name, nil) - }) + store.initializer.Do(func() { + close(store.readyC) + store.bus.SendMonitorEvent(StoreInitializedEvt, store.name, nil) + }) } func (store *busStore) Reset() { - store.itemsLock.Lock() - defer store.itemsLock.Unlock() + store.itemsLock.Lock() + defer store.itemsLock.Unlock() - store.mutationStreamsLock.Lock() - defer store.mutationStreamsLock.Unlock() + store.mutationStreamsLock.Lock() + defer store.mutationStreamsLock.Unlock() - store.storeStreamsLock.Lock() - defer store.storeStreamsLock.Unlock() + store.storeStreamsLock.Lock() + defer store.storeStreamsLock.Unlock() - initStore(store) + initStore(store) - if store.IsGalactic() { - store.sendOpenStoreRequest() - } + if store.IsGalactic() { + store.sendOpenStoreRequest() + } } func (store *busStore) WhenReady(readyFunc func()) { - go func() { - <-store.readyC - readyFunc() - }() + go func() { + <-store.readyC + readyFunc() + }() } func (store *busStore) OnChange(id string, state ...interface{}) StoreStream { - return newStoreStream(store, &streamFilter{ - itemId: id, - states: state, - }) + return newStoreStream(store, &streamFilter{ + itemId: id, + states: state, + }) } func (store *busStore) OnAllChanges(state ...interface{}) StoreStream { - return newStoreStream(store, &streamFilter{ - states: state, - matchAllItems: true, - }) + return newStoreStream(store, &streamFilter{ + states: state, + matchAllItems: true, + }) } func (store *busStore) onStreamSubscribe(stream *storeStream) { - store.storeStreamsLock.Lock() - defer store.storeStreamsLock.Unlock() + store.storeStreamsLock.Lock() + defer store.storeStreamsLock.Unlock() - store.storeStreams = append(store.storeStreams, stream) + store.storeStreams = append(store.storeStreams, stream) } func (store *busStore) onMutationStreamSubscribe(stream *mutationStoreStream) { - store.mutationStreamsLock.Lock() - defer store.mutationStreamsLock.Unlock() + store.mutationStreamsLock.Lock() + defer store.mutationStreamsLock.Unlock() - store.mutationStreams = append(store.mutationStreams, stream) + store.mutationStreams = append(store.mutationStreams, stream) } func (store *busStore) onStreamUnsubscribe(stream *storeStream) { - store.storeStreamsLock.Lock() - defer store.storeStreamsLock.Unlock() + store.storeStreamsLock.Lock() + defer store.storeStreamsLock.Unlock() - var i int - var s *storeStream - for i, s = range store.storeStreams { - if s == stream { - break - } - } + var i int + var s *storeStream + for i, s = range store.storeStreams { + if s == stream { + break + } + } - if s == stream { - n := len(store.storeStreams) - store.storeStreams[i] = store.storeStreams[n-1] - store.storeStreams = store.storeStreams[:n-1] - } + if s == stream { + n := len(store.storeStreams) + store.storeStreams[i] = store.storeStreams[n-1] + store.storeStreams = store.storeStreams[:n-1] + } } func (store *busStore) onMutationStreamUnsubscribe(stream *mutationStoreStream) { - store.mutationStreamsLock.Lock() - defer store.mutationStreamsLock.Unlock() - - var i int - var s *mutationStoreStream - for i, s = range store.mutationStreams { - if s == stream { - break - } - } - - if s == stream { - n := len(store.mutationStreams) - store.mutationStreams[i] = store.mutationStreams[n-1] - store.mutationStreams = store.mutationStreams[:n-1] - } + store.mutationStreamsLock.Lock() + defer store.mutationStreamsLock.Unlock() + + var i int + var s *mutationStoreStream + for i, s = range store.mutationStreams { + if s == stream { + break + } + } + + if s == stream { + n := len(store.mutationStreams) + store.mutationStreams[i] = store.mutationStreams[n-1] + store.mutationStreams = store.mutationStreams[:n-1] + } } diff --git a/bus/store_manager.go b/bus/store_manager.go index e65feeb..b7a4fbf 100644 --- a/bus/store_manager.go +++ b/bus/store_manager.go @@ -4,172 +4,172 @@ package bus import ( - "fmt" - "github.com/go-stomp/stomp/v3/frame" - "github.com/google/uuid" - "github.com/pb33f/ranch/bridge" - "reflect" - "strings" - "sync" + "fmt" + "github.com/go-stomp/stomp/v3/frame" + "github.com/google/uuid" + "github.com/pb33f/ranch/bridge" + "reflect" + "strings" + "sync" ) // StoreManager interface controls all access to BusStores type StoreManager interface { - // Create a new Store, if the store already exists, then it will be returned. - CreateStore(name string) BusStore - // Create a new Store and use the itemType to deserialize item values when handling - // incoming UpdateStoreRequest. If the store already exists, the method will return - // the existing store instance. - CreateStoreWithType(name string, itemType reflect.Type) BusStore - // Get a reference to the existing store. Returns nil if the store doesn't exist. - GetStore(name string) BusStore - // Deletes a store. - DestroyStore(name string) bool - // Configure galactic store sync channel for a given connection. - // Should be called before OpenGalacticStore() and OpenGalacticStoreWithItemType() APIs. - ConfigureStoreSyncChannel(conn bridge.Connection, topicPrefix string, pubPrefix string) error - // Open new galactic store - OpenGalacticStore(name string, conn bridge.Connection) (BusStore, error) - // Open new galactic store and deserialize items from server to itemType - OpenGalacticStoreWithItemType(name string, conn bridge.Connection, itemType reflect.Type) (BusStore, error) + // Create a new Store, if the store already exists, then it will be returned. + CreateStore(name string) BusStore + // Create a new Store and use the itemType to deserialize item values when handling + // incoming UpdateStoreRequest. If the store already exists, the method will return + // the existing store instance. + CreateStoreWithType(name string, itemType reflect.Type) BusStore + // Get a reference to the existing store. Returns nil if the store doesn't exist. + GetStore(name string) BusStore + // Deletes a store. + DestroyStore(name string) bool + // Configure galactic store sync channel for a given connection. + // Should be called before OpenGalacticStore() and OpenGalacticStoreWithItemType() APIs. + ConfigureStoreSyncChannel(conn bridge.Connection, topicPrefix string, pubPrefix string) error + // Open new galactic store + OpenGalacticStore(name string, conn bridge.Connection) (BusStore, error) + // Open new galactic store and deserialize items from server to itemType + OpenGalacticStoreWithItemType(name string, conn bridge.Connection, itemType reflect.Type) (BusStore, error) } // Interface which is a subset of the bridge.Connection methods. // Used to mock connection objects during unit testing. type galacticStoreConnection interface { - SendJSONMessage(destination string, payload []byte, opts ...func(frame *frame.Frame) error) error - SendMessage(destination, contentType string, payload []byte, opts ...func(frame *frame.Frame) error) error + SendJSONMessage(destination string, payload []byte, opts ...func(frame *frame.Frame) error) error + SendMessage(destination, contentType string, payload []byte, opts ...func(frame *frame.Frame) error) error } type storeSyncChannelConfig struct { - topicPrefix string - pubPrefix string - syncChannelName string - conn galacticStoreConnection + topicPrefix string + pubPrefix string + syncChannelName string + conn galacticStoreConnection } type storeManager struct { - stores map[string]BusStore - storesLock sync.RWMutex - eventBus EventBus - syncChannelsLock sync.RWMutex - syncChannels map[uuid.UUID]*storeSyncChannelConfig + stores map[string]BusStore + storesLock sync.RWMutex + eventBus EventBus + syncChannelsLock sync.RWMutex + syncChannels map[uuid.UUID]*storeSyncChannelConfig } func newStoreManager(eventBus EventBus) StoreManager { - m := new(storeManager) - m.stores = make(map[string]BusStore) - m.syncChannels = make(map[uuid.UUID]*storeSyncChannelConfig) - m.eventBus = eventBus + m := new(storeManager) + m.stores = make(map[string]BusStore) + m.syncChannels = make(map[uuid.UUID]*storeSyncChannelConfig) + m.eventBus = eventBus - return m + return m } func (m *storeManager) CreateStore(name string) BusStore { - return m.CreateStoreWithType(name, nil) + return m.CreateStoreWithType(name, nil) } func (m *storeManager) CreateStoreWithType(name string, itemType reflect.Type) BusStore { - m.storesLock.Lock() - defer m.storesLock.Unlock() + m.storesLock.Lock() + defer m.storesLock.Unlock() - store, ok := m.stores[name] + store, ok := m.stores[name] - if ok { - return store - } + if ok { + return store + } - m.stores[name] = newBusStore(name, m.eventBus, itemType, nil) - go m.eventBus.SendMonitorEvent(StoreCreatedEvt, name, nil) - return m.stores[name] + m.stores[name] = newBusStore(name, m.eventBus, itemType, nil) + go m.eventBus.SendMonitorEvent(StoreCreatedEvt, name, nil) + return m.stores[name] } func (m *storeManager) GetStore(name string) BusStore { - m.storesLock.RLock() - defer m.storesLock.RUnlock() + m.storesLock.RLock() + defer m.storesLock.RUnlock() - return m.stores[name] + return m.stores[name] } func (m *storeManager) DestroyStore(name string) bool { - m.storesLock.Lock() - defer m.storesLock.Unlock() + m.storesLock.Lock() + defer m.storesLock.Unlock() - store, ok := m.stores[name] - if ok { - store.(*busStore).OnDestroy() - delete(m.stores, name) + store, ok := m.stores[name] + if ok { + store.(*busStore).OnDestroy() + delete(m.stores, name) - go m.eventBus.SendMonitorEvent(StoreDestroyedEvt, name, nil) - } - return ok + go m.eventBus.SendMonitorEvent(StoreDestroyedEvt, name, nil) + } + return ok } func (m *storeManager) ConfigureStoreSyncChannel( - conn bridge.Connection, topicPrefix string, pubPrefix string) error { + conn bridge.Connection, topicPrefix string, pubPrefix string) error { - m.syncChannelsLock.Lock() - defer m.syncChannelsLock.Unlock() + m.syncChannelsLock.Lock() + defer m.syncChannelsLock.Unlock() - _, ok := m.syncChannels[*conn.GetId()] - if ok { - return fmt.Errorf("store sync channel already configured for this connection") - } + _, ok := m.syncChannels[*conn.GetId()] + if ok { + return fmt.Errorf("store sync channel already configured for this connection") + } - if !strings.HasSuffix(topicPrefix, "/") { - topicPrefix += "/" - } - if !strings.HasSuffix(pubPrefix, "/") { - pubPrefix += "/" - } + if !strings.HasSuffix(topicPrefix, "/") { + topicPrefix += "/" + } + if !strings.HasSuffix(pubPrefix, "/") { + pubPrefix += "/" + } - syncChannel := "transport-store-sync." + conn.GetId().String() + syncChannel := "transport-store-sync." + conn.GetId().String() - storeSyncChannelConfig := &storeSyncChannelConfig{ - topicPrefix: topicPrefix, - pubPrefix: pubPrefix, - syncChannelName: syncChannel, - conn: conn, - } - m.syncChannels[*conn.GetId()] = storeSyncChannelConfig + storeSyncChannelConfig := &storeSyncChannelConfig{ + topicPrefix: topicPrefix, + pubPrefix: pubPrefix, + syncChannelName: syncChannel, + conn: conn, + } + m.syncChannels[*conn.GetId()] = storeSyncChannelConfig - m.eventBus.GetChannelManager().CreateChannel(syncChannel) - m.eventBus.GetChannelManager().MarkChannelAsGalactic(syncChannel, topicPrefix+syncChannel, conn) + m.eventBus.GetChannelManager().CreateChannel(syncChannel) + m.eventBus.GetChannelManager().MarkChannelAsGalactic(syncChannel, topicPrefix+syncChannel, conn) - return nil + return nil } func (m *storeManager) OpenGalacticStore(name string, conn bridge.Connection) (BusStore, error) { - return m.OpenGalacticStoreWithItemType(name, conn, nil) + return m.OpenGalacticStoreWithItemType(name, conn, nil) } func (m *storeManager) OpenGalacticStoreWithItemType( - name string, conn bridge.Connection, itemType reflect.Type) (BusStore, error) { - - m.syncChannelsLock.RLock() - chanConf, ok := m.syncChannels[*conn.GetId()] - m.syncChannelsLock.RUnlock() - - if !ok { - return nil, fmt.Errorf("sync channel is not configured for this connection") - } - - m.storesLock.Lock() - defer m.storesLock.Unlock() - - store, ok := m.stores[name] - - if ok { - if store.IsGalactic() { - return store, nil - } else { - return store, fmt.Errorf("cannot open galactic store: there is a local store with the same name") - } - } - - m.stores[name] = newBusStore(name, m.eventBus, itemType, &galacticStoreConfig{ - syncChannelConfig: chanConf, - }) - go m.eventBus.SendMonitorEvent(StoreCreatedEvt, name, nil) - return m.stores[name], nil + name string, conn bridge.Connection, itemType reflect.Type) (BusStore, error) { + + m.syncChannelsLock.RLock() + chanConf, ok := m.syncChannels[*conn.GetId()] + m.syncChannelsLock.RUnlock() + + if !ok { + return nil, fmt.Errorf("sync channel is not configured for this connection") + } + + m.storesLock.Lock() + defer m.storesLock.Unlock() + + store, ok := m.stores[name] + + if ok { + if store.IsGalactic() { + return store, nil + } else { + return store, fmt.Errorf("cannot open galactic store: there is a local store with the same name") + } + } + + m.stores[name] = newBusStore(name, m.eventBus, itemType, &galacticStoreConfig{ + syncChannelConfig: chanConf, + }) + go m.eventBus.SendMonitorEvent(StoreCreatedEvt, name, nil) + return m.stores[name], nil } diff --git a/bus/store_sync_service.go b/bus/store_sync_service.go index 2560baa..97d35bc 100644 --- a/bus/store_sync_service.go +++ b/bus/store_sync_service.go @@ -4,295 +4,295 @@ package bus import ( - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "strings" - "sync" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "strings" + "sync" ) const ( - openStoreRequest = "openStore" - updateStoreRequest = "updateStore" - closeStoreRequest = "closeStore" - galacticStoreSyncUpdate = "galacticStoreSyncUpdate" - galacticStoreSyncRemove = "galacticStoreSyncRemove" + openStoreRequest = "openStore" + updateStoreRequest = "updateStore" + closeStoreRequest = "closeStore" + galacticStoreSyncUpdate = "galacticStoreSyncUpdate" + galacticStoreSyncRemove = "galacticStoreSyncRemove" ) type storeSyncService struct { - bus EventBus - lock sync.Mutex - syncClients map[string]*syncClientChannel - syncStoreListeners map[string]*syncStoreListener + bus EventBus + lock sync.Mutex + syncClients map[string]*syncClientChannel + syncStoreListeners map[string]*syncStoreListener } type syncStoreListener struct { - storeStream StoreStream - clientSyncChannels map[string]bool - lock sync.RWMutex + storeStream StoreStream + clientSyncChannels map[string]bool + lock sync.RWMutex } type syncClientChannel struct { - channelName string - clientRequestListener MessageHandler - openStores map[string]bool + channelName string + clientRequestListener MessageHandler + openStores map[string]bool } func newStoreSyncService(bus EventBus) *storeSyncService { - syncService := &storeSyncService{ - bus: bus, - syncClients: make(map[string]*syncClientChannel), - syncStoreListeners: make(map[string]*syncStoreListener), - } - syncService.init() - return syncService + syncService := &storeSyncService{ + bus: bus, + syncClients: make(map[string]*syncClientChannel), + syncStoreListeners: make(map[string]*syncStoreListener), + } + syncService.init() + return syncService } func (syncService *storeSyncService) init() { - syncService.bus.AddMonitorEventListener( - func(monitorEvt *MonitorEvent) { - if !strings.HasPrefix(monitorEvt.EntityName, "transport-store-sync.") { - // not a store sync channel, ignore the message - return - } - - switch monitorEvt.EventType { - case FabricEndpointSubscribeEvt: - syncService.openNewClientSyncChannel(monitorEvt.EntityName) - case ChannelDestroyedEvt: - syncService.closeClientSyncChannel(monitorEvt.EntityName) - } - }, - FabricEndpointSubscribeEvt, ChannelDestroyedEvt) + syncService.bus.AddMonitorEventListener( + func(monitorEvt *MonitorEvent) { + if !strings.HasPrefix(monitorEvt.EntityName, "transport-store-sync.") { + // not a store sync channel, ignore the message + return + } + + switch monitorEvt.EventType { + case FabricEndpointSubscribeEvt: + syncService.openNewClientSyncChannel(monitorEvt.EntityName) + case ChannelDestroyedEvt: + syncService.closeClientSyncChannel(monitorEvt.EntityName) + } + }, + FabricEndpointSubscribeEvt, ChannelDestroyedEvt) } func (syncService *storeSyncService) openNewClientSyncChannel(channelName string) { - syncService.lock.Lock() - defer syncService.lock.Unlock() - - if _, ok := syncService.syncClients[channelName]; ok { - // channel already opened. - return - } - - syncClient := &syncClientChannel{ - channelName: channelName, - openStores: make(map[string]bool), - } - syncClient.clientRequestListener, _ = syncService.bus.ListenRequestStream(channelName) - if syncClient.clientRequestListener != nil { - syncClient.clientRequestListener.Handle( - func(message *model.Message) { - request, reqOk := message.Payload.(*model.Request) - if !reqOk || request.Payload == nil { - return - } - var storeRequest map[string]interface{} - storeRequest, ok := request.Payload.(map[string]interface{}) - if !ok { - return - } - - switch request.RequestCommand { - case openStoreRequest: - syncService.openStore(syncClient, storeRequest, request.Id) - case closeStoreRequest: - syncService.closeStore(syncClient, storeRequest, request.Id) - case updateStoreRequest: - syncService.updateStore(syncClient, storeRequest, request.Id) - } - }, func(e error) {}) - } - syncService.syncClients[channelName] = syncClient + syncService.lock.Lock() + defer syncService.lock.Unlock() + + if _, ok := syncService.syncClients[channelName]; ok { + // channel already opened. + return + } + + syncClient := &syncClientChannel{ + channelName: channelName, + openStores: make(map[string]bool), + } + syncClient.clientRequestListener, _ = syncService.bus.ListenRequestStream(channelName) + if syncClient.clientRequestListener != nil { + syncClient.clientRequestListener.Handle( + func(message *model.Message) { + request, reqOk := message.Payload.(*model.Request) + if !reqOk || request.Payload == nil { + return + } + var storeRequest map[string]interface{} + storeRequest, ok := request.Payload.(map[string]interface{}) + if !ok { + return + } + + switch request.RequestCommand { + case openStoreRequest: + syncService.openStore(syncClient, storeRequest, request.Id) + case closeStoreRequest: + syncService.closeStore(syncClient, storeRequest, request.Id) + case updateStoreRequest: + syncService.updateStore(syncClient, storeRequest, request.Id) + } + }, func(e error) {}) + } + syncService.syncClients[channelName] = syncClient } func (syncService *storeSyncService) closeClientSyncChannel(channelName string) { - syncService.lock.Lock() - defer syncService.lock.Unlock() - - syncClient, ok := syncService.syncClients[channelName] - if !ok || syncClient == nil { - // client is already closed - return - } - - for storeId := range syncClient.openStores { - listener := syncService.syncStoreListeners[storeId] - if listener != nil { - listener.removeChannel(channelName) - if listener.isEmpty() { - listener.unsubscribe() - delete(syncService.syncStoreListeners, storeId) - } - } - } - - delete(syncService.syncClients, channelName) + syncService.lock.Lock() + defer syncService.lock.Unlock() + + syncClient, ok := syncService.syncClients[channelName] + if !ok || syncClient == nil { + // client is already closed + return + } + + for storeId := range syncClient.openStores { + listener := syncService.syncStoreListeners[storeId] + if listener != nil { + listener.removeChannel(channelName) + if listener.isEmpty() { + listener.unsubscribe() + delete(syncService.syncStoreListeners, storeId) + } + } + } + + delete(syncService.syncClients, channelName) } func (syncService *storeSyncService) openStore( - syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { - - storeId, ok := getStingProperty("storeId", request) - if !ok || storeId == "" { - syncService.sendErrorResponse(syncClient.channelName, "Invalid OpenStoreRequest", reqId) - return - } - - store := syncService.bus.GetStoreManager().GetStore(storeId) - if store == nil { - syncService.sendErrorResponse( - syncClient.channelName, "Cannot open non-existing store: "+storeId, reqId) - return - } - - syncService.lock.Lock() - defer syncService.lock.Unlock() - - syncClient.openStores[storeId] = true - - storeListener, ok := syncService.syncStoreListeners[storeId] - if !ok { - storeListener = newSyncStoreListener(syncService.bus, store) - syncService.syncStoreListeners[storeId] = storeListener - } - storeListener.addChannel(syncClient.channelName) - - store.WhenReady(func() { - items, version := store.AllValuesAndVersion() - - syncService.bus.SendResponseMessage(syncClient.channelName, - model.NewStoreContentResponse(storeId, items, version), nil) - }) + syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { + + storeId, ok := getStingProperty("storeId", request) + if !ok || storeId == "" { + syncService.sendErrorResponse(syncClient.channelName, "Invalid OpenStoreRequest", reqId) + return + } + + store := syncService.bus.GetStoreManager().GetStore(storeId) + if store == nil { + syncService.sendErrorResponse( + syncClient.channelName, "Cannot open non-existing store: "+storeId, reqId) + return + } + + syncService.lock.Lock() + defer syncService.lock.Unlock() + + syncClient.openStores[storeId] = true + + storeListener, ok := syncService.syncStoreListeners[storeId] + if !ok { + storeListener = newSyncStoreListener(syncService.bus, store) + syncService.syncStoreListeners[storeId] = storeListener + } + storeListener.addChannel(syncClient.channelName) + + store.WhenReady(func() { + items, version := store.AllValuesAndVersion() + + syncService.bus.SendResponseMessage(syncClient.channelName, + model.NewStoreContentResponse(storeId, items, version), nil) + }) } func (syncService *storeSyncService) closeStore( - syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { - - storeId, ok := getStingProperty("storeId", request) - if !ok || storeId == "" { - syncService.sendErrorResponse(syncClient.channelName, "Invalid CloseStoreRequest", reqId) - return - } - - syncService.lock.Lock() - defer syncService.lock.Unlock() - - delete(syncClient.openStores, storeId) - - storeListener, ok := syncService.syncStoreListeners[storeId] - if ok && storeListener != nil { - storeListener.removeChannel(syncClient.channelName) - if storeListener.isEmpty() { - storeListener.unsubscribe() - delete(syncService.syncStoreListeners, storeId) - } - } + syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { + + storeId, ok := getStingProperty("storeId", request) + if !ok || storeId == "" { + syncService.sendErrorResponse(syncClient.channelName, "Invalid CloseStoreRequest", reqId) + return + } + + syncService.lock.Lock() + defer syncService.lock.Unlock() + + delete(syncClient.openStores, storeId) + + storeListener, ok := syncService.syncStoreListeners[storeId] + if ok && storeListener != nil { + storeListener.removeChannel(syncClient.channelName) + if storeListener.isEmpty() { + storeListener.unsubscribe() + delete(syncService.syncStoreListeners, storeId) + } + } } func (syncService *storeSyncService) updateStore( - syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { - - storeId, ok := getStingProperty("storeId", request) - if !ok || storeId == "" { - syncService.sendErrorResponse( - syncClient.channelName, "Invalid UpdateStoreRequest: missing storeId", reqId) - return - } - itemId, ok := getStingProperty("itemId", request) - if !ok || itemId == "" { - syncService.sendErrorResponse( - syncClient.channelName, "Invalid UpdateStoreRequest: missing itemId", reqId) - return - } - - store := syncService.bus.GetStoreManager().GetStore(storeId) - if store == nil { - syncService.sendErrorResponse( - syncClient.channelName, "Cannot update non-existing store: "+storeId, reqId) - return - } - - rawValue, ok := request["newItemValue"] - if rawValue == nil { - store.Remove(itemId, galacticStoreSyncRemove) - } else { - deserializedValue, err := model.ConvertValueToType(rawValue, store.GetItemType()) - if err != nil || deserializedValue == nil { - errMsg := "Cannot deserialize UpdateStoreRequest item value" - if err != nil { - errMsg = "Cannot deserialize UpdateStoreRequest item value: " + err.Error() - } - syncService.sendErrorResponse(syncClient.channelName, errMsg, reqId) - return - } - store.Put(itemId, deserializedValue, galacticStoreSyncUpdate) - } + syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { + + storeId, ok := getStingProperty("storeId", request) + if !ok || storeId == "" { + syncService.sendErrorResponse( + syncClient.channelName, "Invalid UpdateStoreRequest: missing storeId", reqId) + return + } + itemId, ok := getStingProperty("itemId", request) + if !ok || itemId == "" { + syncService.sendErrorResponse( + syncClient.channelName, "Invalid UpdateStoreRequest: missing itemId", reqId) + return + } + + store := syncService.bus.GetStoreManager().GetStore(storeId) + if store == nil { + syncService.sendErrorResponse( + syncClient.channelName, "Cannot update non-existing store: "+storeId, reqId) + return + } + + rawValue, ok := request["newItemValue"] + if rawValue == nil { + store.Remove(itemId, galacticStoreSyncRemove) + } else { + deserializedValue, err := model.ConvertValueToType(rawValue, store.GetItemType()) + if err != nil || deserializedValue == nil { + errMsg := "Cannot deserialize UpdateStoreRequest item value" + if err != nil { + errMsg = "Cannot deserialize UpdateStoreRequest item value: " + err.Error() + } + syncService.sendErrorResponse(syncClient.channelName, errMsg, reqId) + return + } + store.Put(itemId, deserializedValue, galacticStoreSyncUpdate) + } } func getStingProperty(id string, request map[string]interface{}) (string, bool) { - propValue, ok := request[id] - if !ok || propValue == nil { - return "", false - } - stringValue, ok := propValue.(string) - return stringValue, ok + propValue, ok := request[id] + if !ok || propValue == nil { + return "", false + } + stringValue, ok := propValue.(string) + return stringValue, ok } func (syncService *storeSyncService) sendErrorResponse( - clientChannel string, errorMsg string, reqId *uuid.UUID) { - - syncService.bus.SendResponseMessage(clientChannel, &model.Response{ - Id: reqId, - Error: true, - ErrorCode: 1, - ErrorMessage: errorMsg, - }, nil) + clientChannel string, errorMsg string, reqId *uuid.UUID) { + + syncService.bus.SendResponseMessage(clientChannel, &model.Response{ + Id: reqId, + Error: true, + ErrorCode: 1, + ErrorMessage: errorMsg, + }, nil) } func newSyncStoreListener(bus EventBus, store BusStore) *syncStoreListener { - listener := &syncStoreListener{ - storeStream: store.OnAllChanges(), - clientSyncChannels: make(map[string]bool), - } + listener := &syncStoreListener{ + storeStream: store.OnAllChanges(), + clientSyncChannels: make(map[string]bool), + } - listener.storeStream.Subscribe(func(change *StoreChange) { - updateStoreResp := model.NewUpdateStoreResponse( - store.GetName(), change.Id, change.Value, change.StoreVersion) - if change.IsDeleteChange { - updateStoreResp.NewItemValue = nil - } + listener.storeStream.Subscribe(func(change *StoreChange) { + updateStoreResp := model.NewUpdateStoreResponse( + store.GetName(), change.Id, change.Value, change.StoreVersion) + if change.IsDeleteChange { + updateStoreResp.NewItemValue = nil + } - listener.lock.RLock() - defer listener.lock.RUnlock() + listener.lock.RLock() + defer listener.lock.RUnlock() - for chName := range listener.clientSyncChannels { - bus.SendResponseMessage(chName, updateStoreResp, nil) - } - }) + for chName := range listener.clientSyncChannels { + bus.SendResponseMessage(chName, updateStoreResp, nil) + } + }) - return listener + return listener } func (l *syncStoreListener) unsubscribe() { - l.storeStream.Unsubscribe() + l.storeStream.Unsubscribe() } func (l *syncStoreListener) addChannel(clientChannel string) { - l.lock.Lock() - defer l.lock.Unlock() - l.clientSyncChannels[clientChannel] = true + l.lock.Lock() + defer l.lock.Unlock() + l.clientSyncChannels[clientChannel] = true } func (l *syncStoreListener) removeChannel(clientChannel string) { - l.lock.Lock() - defer l.lock.Unlock() - delete(l.clientSyncChannels, clientChannel) + l.lock.Lock() + defer l.lock.Unlock() + delete(l.clientSyncChannels, clientChannel) } func (l *syncStoreListener) isEmpty() bool { - l.lock.Lock() - defer l.lock.Unlock() - return len(l.clientSyncChannels) == 0 + l.lock.Lock() + defer l.lock.Unlock() + return len(l.clientSyncChannels) == 0 } diff --git a/bus/store_sync_service_test.go b/bus/store_sync_service_test.go index a3fcde3..d5381f2 100644 --- a/bus/store_sync_service_test.go +++ b/bus/store_sync_service_test.go @@ -4,511 +4,511 @@ package bus import ( - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "reflect" - "strings" - "sync" - "testing" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "reflect" + "strings" + "sync" + "testing" ) func testStoreSyncService() (*storeSyncService, EventBus) { - bus := newTestEventBus() - return newStoreSyncService(bus), bus + bus := newTestEventBus() + return newStoreSyncService(bus), bus } func TestStoreSyncService_NewConnection(t *testing.T) { - service, bus := testStoreSyncService() + service, bus := testStoreSyncService() - // verify that the service ignores non transport-store-sync events - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, "galactic-channel", nil) - assert.Equal(t, len(service.syncClients), 0) + // verify that the service ignores non transport-store-sync events + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, "galactic-channel", nil) + assert.Equal(t, len(service.syncClients), 0) - syncChan := "transport-store-sync.1" + syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) + bus.GetChannelManager().CreateChannel(syncChan) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - assert.Equal(t, len(service.syncClients), 1) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + assert.Equal(t, len(service.syncClients), 1) } func TestStoreSyncService_OpenStoreErrors(t *testing.T) { - _, bus := testStoreSyncService() - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - - mh, _ := bus.ListenStream(syncChan) - wg := sync.WaitGroup{} - var errors []*model.Response - mh.Handle(func(message *model.Message) { - errors = append(errors, message.Payload.(*model.Response)) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - id := uuid.New() - bus.SendRequestMessage(syncChan, "invalid-request", nil) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: "invalid-payload", - Id: &id, - }, nil) - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: make(map[string]interface{}), - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, errors[0].Id, &id) - assert.True(t, errors[0].Error) - assert.Equal(t, errors[0].ErrorMessage, "Invalid OpenStoreRequest") - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{}{"storeId": "non-existing-store"}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, errors[1].Id, &id) - assert.True(t, errors[1].Error) - assert.Equal(t, errors[1].ErrorMessage, "Cannot open non-existing store: non-existing-store") + _, bus := testStoreSyncService() + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + + mh, _ := bus.ListenStream(syncChan) + wg := sync.WaitGroup{} + var errors []*model.Response + mh.Handle(func(message *model.Message) { + errors = append(errors, message.Payload.(*model.Response)) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + id := uuid.New() + bus.SendRequestMessage(syncChan, "invalid-request", nil) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: "invalid-payload", + Id: &id, + }, nil) + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: make(map[string]interface{}), + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, errors[0].Id, &id) + assert.True(t, errors[0].Error) + assert.Equal(t, errors[0].ErrorMessage, "Invalid OpenStoreRequest") + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "non-existing-store"}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, errors[1].Id, &id) + assert.True(t, errors[1].Error) + assert.Equal(t, errors[1].ErrorMessage, "Cannot open non-existing store: non-existing-store") } func TestStoreSyncService_OpenStore(t *testing.T) { - service, bus := testStoreSyncService() - - store := bus.GetStoreManager().CreateStoreWithType( - "test-store", reflect.TypeOf(&MockStoreItem{})) - store.Populate(map[string]interface{}{ - "item1": &MockStoreItem{From: "test", Message: "test-message"}, - "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, - }) - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - - wg := sync.WaitGroup{} - var syncResp []interface{} - - mh, _ := bus.ListenStream(syncChan) - mh.Handle(func(message *model.Message) { - syncResp = append(syncResp, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - }, nil) - wg.Wait() - - assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) - assert.Equal(t, len(service.syncStoreListeners), 1) - assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan], true) - - resp := syncResp[0].(*model.StoreContentResponse) - - assert.Equal(t, resp.StoreId, "test-store") - items, version := store.AllValuesAndVersion() - - assert.Equal(t, resp.StoreVersion, version) - assert.Equal(t, resp.Items, items) - assert.Equal(t, resp.ResponseType, "storeContentResponse") - - // try subscribing to the same sync channel again - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - }, nil) - wg.Wait() - - assert.Equal(t, len(syncResp), 2) - assert.Equal(t, syncResp[1].(*model.StoreContentResponse).ResponseType, "storeContentResponse") - - syncChan2 := "transport-store-sync.2" - bus.GetChannelManager().CreateChannel(syncChan2) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) - - mh2, _ := bus.ListenStream(syncChan2) - mh2.Handle(func(message *model.Message) { - syncResp = append(syncResp, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - wg.Add(1) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - }, nil) - wg.Wait() - - assert.Equal(t, len(syncResp), 3) - assert.Equal(t, syncResp[2].(*model.StoreContentResponse).ResponseType, "storeContentResponse") - - assert.Equal(t, len(service.syncClients), 2) - assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) - assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) - assert.Equal(t, service.syncClients[syncChan2].openStores["test-store"], true) - - assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 2) - assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) - - bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan, nil) - - assert.Equal(t, len(service.syncClients), 1) - assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) - assert.Equal(t, service.syncClients[syncChan2].openStores["test-store"], true) - assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 1) - assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) - - bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan2, nil) - - assert.Equal(t, len(service.syncClients), 0) - assert.Equal(t, len(service.syncStoreListeners), 0) - - // try closing the syncChan2 again - bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan2, nil) + service, bus := testStoreSyncService() + + store := bus.GetStoreManager().CreateStoreWithType( + "test-store", reflect.TypeOf(&MockStoreItem{})) + store.Populate(map[string]interface{}{ + "item1": &MockStoreItem{From: "test", Message: "test-message"}, + "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, + }) + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + + wg := sync.WaitGroup{} + var syncResp []interface{} + + mh, _ := bus.ListenStream(syncChan) + mh.Handle(func(message *model.Message) { + syncResp = append(syncResp, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) + assert.Equal(t, len(service.syncStoreListeners), 1) + assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan], true) + + resp := syncResp[0].(*model.StoreContentResponse) + + assert.Equal(t, resp.StoreId, "test-store") + items, version := store.AllValuesAndVersion() + + assert.Equal(t, resp.StoreVersion, version) + assert.Equal(t, resp.Items, items) + assert.Equal(t, resp.ResponseType, "storeContentResponse") + + // try subscribing to the same sync channel again + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(syncResp), 2) + assert.Equal(t, syncResp[1].(*model.StoreContentResponse).ResponseType, "storeContentResponse") + + syncChan2 := "transport-store-sync.2" + bus.GetChannelManager().CreateChannel(syncChan2) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) + + mh2, _ := bus.ListenStream(syncChan2) + mh2.Handle(func(message *model.Message) { + syncResp = append(syncResp, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + wg.Add(1) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(syncResp), 3) + assert.Equal(t, syncResp[2].(*model.StoreContentResponse).ResponseType, "storeContentResponse") + + assert.Equal(t, len(service.syncClients), 2) + assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) + assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) + assert.Equal(t, service.syncClients[syncChan2].openStores["test-store"], true) + + assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 2) + assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) + + bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan, nil) + + assert.Equal(t, len(service.syncClients), 1) + assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) + assert.Equal(t, service.syncClients[syncChan2].openStores["test-store"], true) + assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 1) + assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) + + bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan2, nil) + + assert.Equal(t, len(service.syncClients), 0) + assert.Equal(t, len(service.syncStoreListeners), 0) + + // try closing the syncChan2 again + bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan2, nil) } func TestStoreSyncService_CloseStore(t *testing.T) { - service, bus := testStoreSyncService() - - store := bus.GetStoreManager().CreateStoreWithType( - "test-store", reflect.TypeOf(&MockStoreItem{})) - store.Populate(map[string]interface{}{ - "item1": &MockStoreItem{From: "test", Message: "test-message"}, - "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, - }) - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - - syncChan2 := "transport-store-sync.2" - bus.GetChannelManager().CreateChannel(syncChan2) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) - - wg := sync.WaitGroup{} - var syncResp1 []interface{} - - mh, _ := bus.ListenStream(syncChan) - mh.Handle(func(message *model.Message) { - syncResp1 = append(syncResp1, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - mh2, _ := bus.ListenStream(syncChan2) - mh2.Handle(func(message *model.Message) { - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - id := uuid.New() - - wg.Add(2) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - }, nil) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - }, nil) - wg.Wait() - - assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 2) - - bus.SendRequestMessage(syncChan, &model.Request{ - Request: closeStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - Id: &id, - }, nil) - - wg.Add(2) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: closeStoreRequest, - Payload: make(map[string]interface{}), - Id: &id, - }, nil) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: closeStoreRequest, - Payload: map[string]interface{}{"storeId": ""}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp1[1].(*model.Response).ErrorMessage, "Invalid CloseStoreRequest") - assert.Equal(t, syncResp1[1].(*model.Response).Id, &id) - assert.Equal(t, syncResp1[1].(*model.Response).Error, true) - - service.lock.Lock() - assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 1) - assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) - assert.Equal(t, len(service.syncClients[syncChan].openStores), 0) - assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) - service.lock.Unlock() - - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: closeStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - Id: &id, - }, nil) - - wg.Add(2) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: closeStoreRequest, - Payload: make(map[string]interface{}), - Id: &id, - }, nil) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: closeStoreRequest, - Payload: map[string]interface{}{"storeId": ""}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp1[2].(*model.Response).ErrorMessage, "Invalid CloseStoreRequest") - assert.Equal(t, syncResp1[2].(*model.Response).Id, &id) - assert.Equal(t, syncResp1[2].(*model.Response).Error, true) - - service.lock.Lock() - assert.Equal(t, len(service.syncStoreListeners), 0) - assert.Equal(t, len(service.syncClients[syncChan].openStores), 0) - assert.Equal(t, len(service.syncClients[syncChan2].openStores), 0) - service.lock.Unlock() + service, bus := testStoreSyncService() + + store := bus.GetStoreManager().CreateStoreWithType( + "test-store", reflect.TypeOf(&MockStoreItem{})) + store.Populate(map[string]interface{}{ + "item1": &MockStoreItem{From: "test", Message: "test-message"}, + "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, + }) + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + + syncChan2 := "transport-store-sync.2" + bus.GetChannelManager().CreateChannel(syncChan2) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) + + wg := sync.WaitGroup{} + var syncResp1 []interface{} + + mh, _ := bus.ListenStream(syncChan) + mh.Handle(func(message *model.Message) { + syncResp1 = append(syncResp1, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + mh2, _ := bus.ListenStream(syncChan2) + mh2.Handle(func(message *model.Message) { + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + id := uuid.New() + + wg.Add(2) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 2) + + bus.SendRequestMessage(syncChan, &model.Request{ + Request: closeStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + Id: &id, + }, nil) + + wg.Add(2) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: closeStoreRequest, + Payload: make(map[string]interface{}), + Id: &id, + }, nil) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: closeStoreRequest, + Payload: map[string]interface{}{"storeId": ""}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp1[1].(*model.Response).ErrorMessage, "Invalid CloseStoreRequest") + assert.Equal(t, syncResp1[1].(*model.Response).Id, &id) + assert.Equal(t, syncResp1[1].(*model.Response).Error, true) + + service.lock.Lock() + assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 1) + assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) + assert.Equal(t, len(service.syncClients[syncChan].openStores), 0) + assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) + service.lock.Unlock() + + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: closeStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + Id: &id, + }, nil) + + wg.Add(2) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: closeStoreRequest, + Payload: make(map[string]interface{}), + Id: &id, + }, nil) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: closeStoreRequest, + Payload: map[string]interface{}{"storeId": ""}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp1[2].(*model.Response).ErrorMessage, "Invalid CloseStoreRequest") + assert.Equal(t, syncResp1[2].(*model.Response).Id, &id) + assert.Equal(t, syncResp1[2].(*model.Response).Error, true) + + service.lock.Lock() + assert.Equal(t, len(service.syncStoreListeners), 0) + assert.Equal(t, len(service.syncClients[syncChan].openStores), 0) + assert.Equal(t, len(service.syncClients[syncChan2].openStores), 0) + service.lock.Unlock() } func TestStoreSyncService_UpdateStoreErrors(t *testing.T) { - _, bus := testStoreSyncService() - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - - wg := sync.WaitGroup{} - var syncResp []interface{} - - mh, _ := bus.ListenStream(syncChan) - mh.Handle(func(message *model.Message) { - syncResp = append(syncResp, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - id := uuid.New() - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{}{}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp[0].(*model.Response).ErrorMessage, "Invalid UpdateStoreRequest: missing storeId") - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp[1].(*model.Response).ErrorMessage, "Invalid UpdateStoreRequest: missing itemId") - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store", "itemId": "item1"}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp[2].(*model.Response).ErrorMessage, "Cannot update non-existing store: test-store") + _, bus := testStoreSyncService() + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + + wg := sync.WaitGroup{} + var syncResp []interface{} + + mh, _ := bus.ListenStream(syncChan) + mh.Handle(func(message *model.Message) { + syncResp = append(syncResp, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + id := uuid.New() + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp[0].(*model.Response).ErrorMessage, "Invalid UpdateStoreRequest: missing storeId") + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp[1].(*model.Response).ErrorMessage, "Invalid UpdateStoreRequest: missing itemId") + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store", "itemId": "item1"}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp[2].(*model.Response).ErrorMessage, "Cannot update non-existing store: test-store") } func TestStoreSyncService_UpdateStore(t *testing.T) { - _, bus := testStoreSyncService() - - store := bus.GetStoreManager().CreateStoreWithType( - "test-store", reflect.TypeOf(&MockStoreItem{})) - store.Populate(map[string]interface{}{ - "item1": &MockStoreItem{From: "test", Message: "test-message"}, - "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, - }) - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - - syncChan2 := "transport-store-sync.2" - bus.GetChannelManager().CreateChannel(syncChan2) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) - - wg := sync.WaitGroup{} - var syncResp1 []interface{} - var syncResp2 []interface{} - - mh, _ := bus.ListenStream(syncChan) - mh.Handle(func(message *model.Message) { - syncResp1 = append(syncResp1, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - mh2, _ := bus.ListenStream(syncChan2) - mh2.Handle(func(message *model.Message) { - syncResp2 = append(syncResp2, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - wg.Add(2) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - }, nil) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{}{"storeId": "test-store"}, - }, nil) - wg.Wait() - - assert.Equal(t, len(syncResp1), 1) - assert.Equal(t, len(syncResp2), 1) - - wg.Add(2) - - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{}{ - "storeId": "test-store", - "itemId": "item3", - "newItemValue": map[string]interface{}{ - "From": "test3", - "Message": "test-message3", - }}, - }, nil) - - wg.Wait() - - assert.Equal(t, len(syncResp1), 2) - assert.Equal(t, len(syncResp2), 2) - - assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") - assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).StoreId, "test-store") - assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).StoreVersion, int64(2)) - assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).NewItemValue, &MockStoreItem{ - From: "test3", - Message: "test-message3", - }) - - assert.Equal(t, syncResp1[1], syncResp2[1]) - - assert.Equal(t, store.GetValue("item3"), &MockStoreItem{ - From: "test3", - Message: "test-message3", - }) - - wg.Add(2) - store.Remove("item2", "test-remove") - wg.Wait() - - assert.Equal(t, len(syncResp1), 3) - assert.Equal(t, len(syncResp2), 3) - - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).StoreId, "test-store") - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).ItemId, "item2") - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).StoreVersion, int64(3)) - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).NewItemValue, nil) - - assert.Equal(t, syncResp1[2], syncResp2[2]) - - wg.Add(2) - store.Put("item1", &MockStoreItem{From: "u1", Message: "m1"}, nil) - wg.Wait() - - assert.Equal(t, len(syncResp1), 4) - assert.Equal(t, len(syncResp2), 4) - - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).StoreId, "test-store") - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).ItemId, "item1") - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).StoreVersion, int64(4)) - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).NewItemValue, - &MockStoreItem{From: "u1", Message: "m1"}) - - assert.Equal(t, syncResp1[3], syncResp2[3]) - - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{}{ - "storeId": "test-store", - "itemId": "item4", - "newItemValue": nil}, - }, nil) - - wg.Add(2) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{}{ - "storeId": "test-store", - "itemId": "item3", - "newItemValue": nil}, - }, nil) - wg.Wait() - - assert.Equal(t, len(syncResp1), 5) - assert.Equal(t, len(syncResp2), 5) - - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).StoreId, "test-store") - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).ItemId, "item3") - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).StoreVersion, int64(5)) - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).NewItemValue, nil) - - assert.Equal(t, syncResp1[4], syncResp2[4]) - - assert.Equal(t, store.GetValue("item3"), nil) - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{}{ - "storeId": "test-store", - "itemId": "item3", - "newItemValue": "test"}, - }, nil) - wg.Wait() - assert.Equal(t, len(syncResp1), 6) - assert.True(t, strings.HasPrefix(syncResp1[5].(*model.Response).ErrorMessage, - "Cannot deserialize UpdateStoreRequest item value:")) + _, bus := testStoreSyncService() + + store := bus.GetStoreManager().CreateStoreWithType( + "test-store", reflect.TypeOf(&MockStoreItem{})) + store.Populate(map[string]interface{}{ + "item1": &MockStoreItem{From: "test", Message: "test-message"}, + "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, + }) + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + + syncChan2 := "transport-store-sync.2" + bus.GetChannelManager().CreateChannel(syncChan2) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) + + wg := sync.WaitGroup{} + var syncResp1 []interface{} + var syncResp2 []interface{} + + mh, _ := bus.ListenStream(syncChan) + mh.Handle(func(message *model.Message) { + syncResp1 = append(syncResp1, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + mh2, _ := bus.ListenStream(syncChan2) + mh2.Handle(func(message *model.Message) { + syncResp2 = append(syncResp2, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + wg.Add(2) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(syncResp1), 1) + assert.Equal(t, len(syncResp2), 1) + + wg.Add(2) + + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{ + "storeId": "test-store", + "itemId": "item3", + "newItemValue": map[string]interface{}{ + "From": "test3", + "Message": "test-message3", + }}, + }, nil) + + wg.Wait() + + assert.Equal(t, len(syncResp1), 2) + assert.Equal(t, len(syncResp2), 2) + + assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") + assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).StoreId, "test-store") + assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).StoreVersion, int64(2)) + assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).NewItemValue, &MockStoreItem{ + From: "test3", + Message: "test-message3", + }) + + assert.Equal(t, syncResp1[1], syncResp2[1]) + + assert.Equal(t, store.GetValue("item3"), &MockStoreItem{ + From: "test3", + Message: "test-message3", + }) + + wg.Add(2) + store.Remove("item2", "test-remove") + wg.Wait() + + assert.Equal(t, len(syncResp1), 3) + assert.Equal(t, len(syncResp2), 3) + + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).StoreId, "test-store") + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).ItemId, "item2") + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).StoreVersion, int64(3)) + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).NewItemValue, nil) + + assert.Equal(t, syncResp1[2], syncResp2[2]) + + wg.Add(2) + store.Put("item1", &MockStoreItem{From: "u1", Message: "m1"}, nil) + wg.Wait() + + assert.Equal(t, len(syncResp1), 4) + assert.Equal(t, len(syncResp2), 4) + + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).StoreId, "test-store") + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).ItemId, "item1") + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).StoreVersion, int64(4)) + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).NewItemValue, + &MockStoreItem{From: "u1", Message: "m1"}) + + assert.Equal(t, syncResp1[3], syncResp2[3]) + + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{ + "storeId": "test-store", + "itemId": "item4", + "newItemValue": nil}, + }, nil) + + wg.Add(2) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{ + "storeId": "test-store", + "itemId": "item3", + "newItemValue": nil}, + }, nil) + wg.Wait() + + assert.Equal(t, len(syncResp1), 5) + assert.Equal(t, len(syncResp2), 5) + + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).StoreId, "test-store") + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).ItemId, "item3") + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).StoreVersion, int64(5)) + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).NewItemValue, nil) + + assert.Equal(t, syncResp1[4], syncResp2[4]) + + assert.Equal(t, store.GetValue("item3"), nil) + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{ + "storeId": "test-store", + "itemId": "item3", + "newItemValue": "test"}, + }, nil) + wg.Wait() + assert.Equal(t, len(syncResp1), 6) + assert.True(t, strings.HasPrefix(syncResp1[5].(*model.Response).ErrorMessage, + "Cannot deserialize UpdateStoreRequest item value:")) } diff --git a/bus/transaction.go b/bus/transaction.go index b87f67d..b7b80f9 100644 --- a/bus/transaction.go +++ b/bus/transaction.go @@ -4,265 +4,265 @@ package bus import ( - "fmt" - "github.com/google/uuid" - "github.com/pb33f/ranch/model" - "sync" + "fmt" + "github.com/google/uuid" + "github.com/pb33f/ranch/model" + "sync" ) type transactionType int const ( - asyncTransaction transactionType = iota - syncTransaction + asyncTransaction transactionType = iota + syncTransaction ) type BusTransactionReadyFunction func(responses []*model.Message) type BusTransaction interface { - // Sends a request to a channel as a part of this transaction. - SendRequest(channel string, payload interface{}) error - // Wait for a store to be initialized as a part of this transaction. - WaitForStoreReady(storeName string) error - // Registers a new complete handler. Once all responses to requests have been received, - // the transaction is complete. - OnComplete(completeHandler BusTransactionReadyFunction) error - // Register a new error handler. If an error is thrown by any of the responders, the transaction - // is aborted and the error sent to the registered errorHandlers. - OnError(errorHandler MessageErrorFunction) error - // Commit the transaction, all requests will be sent and will wait for responses. - // Once all the responses are in, onComplete handlers will be called with the responses. - Commit() error + // Sends a request to a channel as a part of this transaction. + SendRequest(channel string, payload interface{}) error + // Wait for a store to be initialized as a part of this transaction. + WaitForStoreReady(storeName string) error + // Registers a new complete handler. Once all responses to requests have been received, + // the transaction is complete. + OnComplete(completeHandler BusTransactionReadyFunction) error + // Register a new error handler. If an error is thrown by any of the responders, the transaction + // is aborted and the error sent to the registered errorHandlers. + OnError(errorHandler MessageErrorFunction) error + // Commit the transaction, all requests will be sent and will wait for responses. + // Once all the responses are in, onComplete handlers will be called with the responses. + Commit() error } type transactionState int const ( - uncommittedState transactionState = iota - committedState - completedState - abortedState + uncommittedState transactionState = iota + committedState + completedState + abortedState ) type busTransactionRequest struct { - requestIndex int - storeName string - channelName string - payload interface{} + requestIndex int + storeName string + channelName string + payload interface{} } type busTransaction struct { - transactionType transactionType - state transactionState - lock sync.Mutex - requests []*busTransactionRequest - responses []*model.Message - onCompleteHandlers []BusTransactionReadyFunction - onErrorHandlers []MessageErrorFunction - bus EventBus - completedRequests int + transactionType transactionType + state transactionState + lock sync.Mutex + requests []*busTransactionRequest + responses []*model.Message + onCompleteHandlers []BusTransactionReadyFunction + onErrorHandlers []MessageErrorFunction + bus EventBus + completedRequests int } func newBusTransaction(bus EventBus, transactionType transactionType) BusTransaction { - transaction := new(busTransaction) + transaction := new(busTransaction) - transaction.bus = bus - transaction.state = uncommittedState - transaction.transactionType = transactionType - transaction.requests = make([]*busTransactionRequest, 0) - transaction.onCompleteHandlers = make([]BusTransactionReadyFunction, 0) - transaction.onErrorHandlers = make([]MessageErrorFunction, 0) - transaction.completedRequests = 0 + transaction.bus = bus + transaction.state = uncommittedState + transaction.transactionType = transactionType + transaction.requests = make([]*busTransactionRequest, 0) + transaction.onCompleteHandlers = make([]BusTransactionReadyFunction, 0) + transaction.onErrorHandlers = make([]MessageErrorFunction, 0) + transaction.completedRequests = 0 - return transaction + return transaction } func (tr *busTransaction) checkUncommittedState() error { - if tr.state != uncommittedState { - return fmt.Errorf("transaction has already been committed") - } - return nil + if tr.state != uncommittedState { + return fmt.Errorf("transaction has already been committed") + } + return nil } func (tr *busTransaction) SendRequest(channel string, payload interface{}) error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - tr.requests = append(tr.requests, &busTransactionRequest{ - channelName: channel, - payload: payload, - requestIndex: len(tr.requests), - }) + tr.requests = append(tr.requests, &busTransactionRequest{ + channelName: channel, + payload: payload, + requestIndex: len(tr.requests), + }) - return nil + return nil } func (tr *busTransaction) WaitForStoreReady(storeName string) error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - if tr.bus.GetStoreManager().GetStore(storeName) == nil { - return fmt.Errorf("cannot find store '%s'", storeName) - } + if tr.bus.GetStoreManager().GetStore(storeName) == nil { + return fmt.Errorf("cannot find store '%s'", storeName) + } - tr.requests = append(tr.requests, &busTransactionRequest{ - storeName: storeName, - requestIndex: len(tr.requests), - }) + tr.requests = append(tr.requests, &busTransactionRequest{ + storeName: storeName, + requestIndex: len(tr.requests), + }) - return nil + return nil } func (tr *busTransaction) OnComplete(completeHandler BusTransactionReadyFunction) error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - tr.onCompleteHandlers = append(tr.onCompleteHandlers, completeHandler) - return nil + tr.onCompleteHandlers = append(tr.onCompleteHandlers, completeHandler) + return nil } func (tr *busTransaction) OnError(errorHandler MessageErrorFunction) error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - tr.onErrorHandlers = append(tr.onErrorHandlers, errorHandler) - return nil + tr.onErrorHandlers = append(tr.onErrorHandlers, errorHandler) + return nil } func (tr *busTransaction) Commit() error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - if len(tr.requests) == 0 { - return fmt.Errorf("cannot commit empty transaction") - } + if len(tr.requests) == 0 { + return fmt.Errorf("cannot commit empty transaction") + } - tr.state = committedState + tr.state = committedState - // init responses slice - tr.responses = make([]*model.Message, len(tr.requests)) + // init responses slice + tr.responses = make([]*model.Message, len(tr.requests)) - if tr.transactionType == asyncTransaction { - tr.startAsyncTransaction() - } else { - tr.startSyncTransaction() - } + if tr.transactionType == asyncTransaction { + tr.startAsyncTransaction() + } else { + tr.startSyncTransaction() + } - return nil + return nil } func (tr *busTransaction) startSyncTransaction() { - tr.executeRequest(tr.requests[0]) + tr.executeRequest(tr.requests[0]) } func (tr *busTransaction) executeRequest(request *busTransactionRequest) { - if request.storeName != "" { - tr.waitForStore(request) - } else { - tr.sendRequest(request) - } + if request.storeName != "" { + tr.waitForStore(request) + } else { + tr.sendRequest(request) + } } func (tr *busTransaction) startAsyncTransaction() { - for _, req := range tr.requests { - tr.executeRequest(req) - } + for _, req := range tr.requests { + tr.executeRequest(req) + } } func (tr *busTransaction) sendRequest(req *busTransactionRequest) { - reqId := uuid.New() + reqId := uuid.New() - mh, err := tr.bus.ListenOnceForDestination(req.channelName, &reqId) - if err != nil { - tr.onTransactionError(err) - return - } + mh, err := tr.bus.ListenOnceForDestination(req.channelName, &reqId) + if err != nil { + tr.onTransactionError(err) + return + } - mh.Handle(func(message *model.Message) { - tr.onTransactionRequestSuccess(req, message) - }, func(e error) { - tr.onTransactionError(e) - }) + mh.Handle(func(message *model.Message) { + tr.onTransactionRequestSuccess(req, message) + }, func(e error) { + tr.onTransactionError(e) + }) - tr.bus.SendRequestMessage(req.channelName, req.payload, &reqId) + tr.bus.SendRequestMessage(req.channelName, req.payload, &reqId) } func (tr *busTransaction) onTransactionError(err error) { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if tr.state == abortedState { - return - } + if tr.state == abortedState { + return + } - tr.state = abortedState - for _, errorHandler := range tr.onErrorHandlers { - go errorHandler(err) - } + tr.state = abortedState + for _, errorHandler := range tr.onErrorHandlers { + go errorHandler(err) + } } func (tr *busTransaction) waitForStore(req *busTransactionRequest) { - store := tr.bus.GetStoreManager().GetStore(req.storeName) - if store == nil { - tr.onTransactionError(fmt.Errorf("cannot find store '%s'", req.storeName)) - return - } - store.WhenReady(func() { - tr.onTransactionRequestSuccess(req, &model.Message{ - Direction: model.ResponseDir, - Payload: store.AllValuesAsMap(), - }) - }) + store := tr.bus.GetStoreManager().GetStore(req.storeName) + if store == nil { + tr.onTransactionError(fmt.Errorf("cannot find store '%s'", req.storeName)) + return + } + store.WhenReady(func() { + tr.onTransactionRequestSuccess(req, &model.Message{ + Direction: model.ResponseDir, + Payload: store.AllValuesAsMap(), + }) + }) } func (tr *busTransaction) onTransactionRequestSuccess(req *busTransactionRequest, message *model.Message) { - var triggerOnCompleteHandler = false - tr.lock.Lock() - - if tr.state == abortedState { - tr.lock.Unlock() - return - } - - tr.responses[req.requestIndex] = message - tr.completedRequests++ - - if tr.completedRequests == len(tr.requests) { - tr.state = completedState - triggerOnCompleteHandler = true - } - - tr.lock.Unlock() - - if triggerOnCompleteHandler { - for _, completeHandler := range tr.onCompleteHandlers { - go completeHandler(tr.responses) - } - return - } - - // If this is a sync transaction execute the next request - if tr.transactionType == syncTransaction && req.requestIndex < len(tr.requests)-1 { - tr.executeRequest(tr.requests[req.requestIndex+1]) - } + var triggerOnCompleteHandler = false + tr.lock.Lock() + + if tr.state == abortedState { + tr.lock.Unlock() + return + } + + tr.responses[req.requestIndex] = message + tr.completedRequests++ + + if tr.completedRequests == len(tr.requests) { + tr.state = completedState + triggerOnCompleteHandler = true + } + + tr.lock.Unlock() + + if triggerOnCompleteHandler { + for _, completeHandler := range tr.onCompleteHandlers { + go completeHandler(tr.responses) + } + return + } + + // If this is a sync transaction execute the next request + if tr.transactionType == syncTransaction && req.requestIndex < len(tr.requests)-1 { + tr.executeRequest(tr.requests[req.requestIndex+1]) + } } diff --git a/bus/transaction_test.go b/bus/transaction_test.go index d1df210..2bb95e0 100644 --- a/bus/transaction_test.go +++ b/bus/transaction_test.go @@ -14,309 +14,309 @@ import ( func TestBusTransaction_OnCompleteSync(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - bus.GetChannelManager().CreateChannel("test-channel") + bus.GetChannelManager().CreateChannel("test-channel") - var channelReqMessage *model.Message - var requestCounter = 0 + var channelReqMessage *model.Message + var requestCounter = 0 - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - mh, _ := bus.ListenRequestStream("test-channel") - mh.Handle(func(message *model.Message) { - requestCounter++ - channelReqMessage = message - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) + mh, _ := bus.ListenRequestStream("test-channel") + mh.Handle(func(message *model.Message) { + requestCounter++ + channelReqMessage = message + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) - tr := newBusTransaction(bus, syncTransaction) + tr := newBusTransaction(bus, syncTransaction) - bus.GetStoreManager().CreateStore("testStore") - assert.Nil(t, tr.WaitForStoreReady("testStore")) - assert.Nil(t, tr.SendRequest("test-channel", "sample-request")) + bus.GetStoreManager().CreateStore("testStore") + assert.Nil(t, tr.WaitForStoreReady("testStore")) + assert.Nil(t, tr.SendRequest("test-channel", "sample-request")) - var completeCounter int64 + var completeCounter int64 - tr.OnComplete(func(responses []*model.Message) { - atomic.AddInt64(&completeCounter, 1) - wg.Done() - }) + tr.OnComplete(func(responses []*model.Message) { + atomic.AddInt64(&completeCounter, 1) + wg.Done() + }) - tr.OnError(func(e error) { - assert.Fail(t, "unexpected error") - }) + tr.OnError(func(e error) { + assert.Fail(t, "unexpected error") + }) - tr.OnComplete(func(responses []*model.Message) { - atomic.AddInt64(&completeCounter, 1) - assert.Equal(t, len(responses), 2) - assert.Equal(t, responses[1].Channel, "test-channel") - assert.Equal(t, responses[1].Payload, "sample-response") - wg.Done() - }) + tr.OnComplete(func(responses []*model.Message) { + atomic.AddInt64(&completeCounter, 1) + assert.Equal(t, len(responses), 2) + assert.Equal(t, responses[1].Channel, "test-channel") + assert.Equal(t, responses[1].Payload, "sample-response") + wg.Done() + }) - assert.Equal(t, requestCounter, 0) + assert.Equal(t, requestCounter, 0) - wg.Add(1) + wg.Add(1) - assert.Nil(t, tr.Commit()) + assert.Nil(t, tr.Commit()) - go bus.GetStoreManager().CreateStore("testStore").Initialize() + go bus.GetStoreManager().CreateStore("testStore").Initialize() - wg.Wait() + wg.Wait() - assert.Equal(t, requestCounter, 1) - assert.NotNil(t, channelReqMessage) + assert.Equal(t, requestCounter, 1) + assert.NotNil(t, channelReqMessage) - assert.Equal(t, channelReqMessage.Payload, "sample-request") + assert.Equal(t, channelReqMessage.Payload, "sample-request") - for i := 0; i < 50; i++ { - bus.SendResponseMessage("test-channel", "general-message", nil) - } + for i := 0; i < 50; i++ { + bus.SendResponseMessage("test-channel", "general-message", nil) + } - assert.Equal(t, completeCounter, int64(0)) + assert.Equal(t, completeCounter, int64(0)) - wg.Add(2) - bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId) + wg.Add(2) + bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId) - wg.Wait() + wg.Wait() - assert.Equal(t, tr.(*busTransaction).state, completedState) + assert.Equal(t, tr.(*busTransaction).state, completedState) - assert.Equal(t, completeCounter, int64(2)) + assert.Equal(t, completeCounter, int64(2)) - bus.SendResponseMessage("test-channel", "sample-response2", channelReqMessage.DestinationId) - assert.Equal(t, completeCounter, int64(2)) + bus.SendResponseMessage("test-channel", "sample-response2", channelReqMessage.DestinationId) + assert.Equal(t, completeCounter, int64(2)) } func TestBusTransaction_OnCompleteErrorHandling(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - tr := newBusTransaction(bus, syncTransaction) + tr := newBusTransaction(bus, syncTransaction) - assert.EqualError(t, tr.Commit(), "cannot commit empty transaction") + assert.EqualError(t, tr.Commit(), "cannot commit empty transaction") - assert.Equal(t, tr.(*busTransaction).state, uncommittedState) + assert.Equal(t, tr.(*busTransaction).state, uncommittedState) - bus.GetStoreManager().CreateStore("testStore") - assert.Nil(t, tr.WaitForStoreReady("testStore")) + bus.GetStoreManager().CreateStore("testStore") + assert.Nil(t, tr.WaitForStoreReady("testStore")) - assert.EqualError(t, tr.WaitForStoreReady("invalid-store"), "cannot find store 'invalid-store'") + assert.EqualError(t, tr.WaitForStoreReady("invalid-store"), "cannot find store 'invalid-store'") - tr.Commit() + tr.Commit() - assert.EqualError(t, tr.OnComplete(func(responses []*model.Message) {}), "transaction has already been committed") + assert.EqualError(t, tr.OnComplete(func(responses []*model.Message) {}), "transaction has already been committed") - assert.Equal(t, tr.(*busTransaction).state, committedState) - assert.EqualError(t, tr.Commit(), "transaction has already been committed") + assert.Equal(t, tr.(*busTransaction).state, committedState) + assert.EqualError(t, tr.Commit(), "transaction has already been committed") - assert.EqualError(t, tr.WaitForStoreReady("test"), "transaction has already been committed") - assert.EqualError(t, tr.SendRequest("test", "test"), "transaction has already been committed") + assert.EqualError(t, tr.WaitForStoreReady("test"), "transaction has already been committed") + assert.EqualError(t, tr.SendRequest("test", "test"), "transaction has already been committed") } func TestBusTransaction_OnErrorSync(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - tr := newBusTransaction(bus, syncTransaction) + tr := newBusTransaction(bus, syncTransaction) - bus.GetStoreManager().CreateStore("testStore") - assert.Nil(t, tr.WaitForStoreReady("testStore")) + bus.GetStoreManager().CreateStore("testStore") + assert.Nil(t, tr.WaitForStoreReady("testStore")) - bus.GetChannelManager().CreateChannel("test-channel") + bus.GetChannelManager().CreateChannel("test-channel") - var channelReqMessage *model.Message - var requestCounter = 0 + var channelReqMessage *model.Message + var requestCounter = 0 - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - mh, _ := bus.ListenRequestStream("test-channel") - mh.Handle(func(message *model.Message) { - requestCounter++ - channelReqMessage = message - wg.Done() - }, func(e error) { - }) + mh, _ := bus.ListenRequestStream("test-channel") + mh.Handle(func(message *model.Message) { + requestCounter++ + channelReqMessage = message + wg.Done() + }, func(e error) { + }) - tr.SendRequest("test-channel", "sample-request") - tr.SendRequest("test-channel", "sample-request") - tr.SendRequest("test-channel", "sample-request") + tr.SendRequest("test-channel", "sample-request") + tr.SendRequest("test-channel", "sample-request") + tr.SendRequest("test-channel", "sample-request") - tr.OnComplete(func(responses []*model.Message) { - assert.Fail(t, "invalid state") - }) + tr.OnComplete(func(responses []*model.Message) { + assert.Fail(t, "invalid state") + }) - var errorHandlerCount int64 = 0 - tr.OnError(func(e error) { - atomic.AddInt64(&errorHandlerCount, 1) - wg.Done() - }) + var errorHandlerCount int64 = 0 + tr.OnError(func(e error) { + atomic.AddInt64(&errorHandlerCount, 1) + wg.Done() + }) - tr.OnError(func(e error) { - atomic.AddInt64(&errorHandlerCount, 1) - assert.EqualError(t, e, "test-error") - wg.Done() - }) + tr.OnError(func(e error) { + atomic.AddInt64(&errorHandlerCount, 1) + assert.EqualError(t, e, "test-error") + wg.Done() + }) - tr.Commit() + tr.Commit() - assert.Equal(t, tr.(*busTransaction).state, committedState) + assert.Equal(t, tr.(*busTransaction).state, committedState) - wg.Add(1) + wg.Add(1) - bus.GetStoreManager().GetStore("testStore").Initialize() + bus.GetStoreManager().GetStore("testStore").Initialize() - wg.Wait() + wg.Wait() - assert.Equal(t, requestCounter, 1) - assert.NotNil(t, channelReqMessage) + assert.Equal(t, requestCounter, 1) + assert.NotNil(t, channelReqMessage) - wg.Add(2) - bus.SendErrorMessage("test-channel", errors.New("test-error"), channelReqMessage.DestinationId) + wg.Add(2) + bus.SendErrorMessage("test-channel", errors.New("test-error"), channelReqMessage.DestinationId) - wg.Wait() + wg.Wait() - assert.Equal(t, tr.(*busTransaction).state, abortedState) + assert.Equal(t, tr.(*busTransaction).state, abortedState) - assert.Equal(t, requestCounter, 1) - assert.Equal(t, errorHandlerCount, int64(2)) + assert.Equal(t, requestCounter, 1) + assert.Equal(t, errorHandlerCount, int64(2)) - assert.EqualError(t, tr.Commit(), "transaction has already been committed") + assert.EqualError(t, tr.Commit(), "transaction has already been committed") } func TestBusTransaction_OnCompleteAsync(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - bus.GetChannelManager().CreateChannel("test-channel") + bus.GetChannelManager().CreateChannel("test-channel") - var channelReqMessage *model.Message - var requestCounter = 0 + var channelReqMessage *model.Message + var requestCounter = 0 - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - mh, _ := bus.ListenRequestStream("test-channel") - mh.Handle(func(message *model.Message) { - requestCounter++ - channelReqMessage = message - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) + mh, _ := bus.ListenRequestStream("test-channel") + mh.Handle(func(message *model.Message) { + requestCounter++ + channelReqMessage = message + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) - tr := newBusTransaction(bus, asyncTransaction) + tr := newBusTransaction(bus, asyncTransaction) - bus.GetStoreManager().CreateStore("testStore") - assert.Nil(t, tr.WaitForStoreReady("testStore")) - assert.Nil(t, tr.WaitForStoreReady("testStore")) - bus.GetStoreManager().CreateStore("testStore2") - assert.Nil(t, tr.WaitForStoreReady("testStore2")) - bus.GetStoreManager().CreateStore("testStore3") - assert.Nil(t, tr.WaitForStoreReady("testStore3")) - assert.Nil(t, tr.SendRequest("test-channel", "sample-request")) + bus.GetStoreManager().CreateStore("testStore") + assert.Nil(t, tr.WaitForStoreReady("testStore")) + assert.Nil(t, tr.WaitForStoreReady("testStore")) + bus.GetStoreManager().CreateStore("testStore2") + assert.Nil(t, tr.WaitForStoreReady("testStore2")) + bus.GetStoreManager().CreateStore("testStore3") + assert.Nil(t, tr.WaitForStoreReady("testStore3")) + assert.Nil(t, tr.SendRequest("test-channel", "sample-request")) - var completeCounter int64 + var completeCounter int64 - tr.OnComplete(func(responses []*model.Message) { - atomic.AddInt64(&completeCounter, 1) - wg.Done() - }) + tr.OnComplete(func(responses []*model.Message) { + atomic.AddInt64(&completeCounter, 1) + wg.Done() + }) - tr.OnComplete(func(responses []*model.Message) { - atomic.AddInt64(&completeCounter, 1) - assert.Equal(t, len(responses), 5) - assert.Equal(t, responses[4].Channel, "test-channel") - assert.Equal(t, responses[4].Payload, "sample-response") - wg.Done() - }) + tr.OnComplete(func(responses []*model.Message) { + atomic.AddInt64(&completeCounter, 1) + assert.Equal(t, len(responses), 5) + assert.Equal(t, responses[4].Channel, "test-channel") + assert.Equal(t, responses[4].Payload, "sample-response") + wg.Done() + }) - wg.Add(1) - assert.Nil(t, tr.Commit()) - wg.Wait() + wg.Add(1) + assert.Nil(t, tr.Commit()) + wg.Wait() - assert.NotNil(t, bus.GetStoreManager().GetStore("testStore")) - assert.NotNil(t, bus.GetStoreManager().GetStore("testStore2")) - assert.NotNil(t, bus.GetStoreManager().GetStore("testStore3")) - assert.Equal(t, requestCounter, 1) - assert.NotNil(t, channelReqMessage) - assert.Equal(t, channelReqMessage.Payload, "sample-request") + assert.NotNil(t, bus.GetStoreManager().GetStore("testStore")) + assert.NotNil(t, bus.GetStoreManager().GetStore("testStore2")) + assert.NotNil(t, bus.GetStoreManager().GetStore("testStore3")) + assert.Equal(t, requestCounter, 1) + assert.NotNil(t, channelReqMessage) + assert.Equal(t, channelReqMessage.Payload, "sample-request") - for i := 0; i < 20; i++ { - bus.SendResponseMessage("test-channel", "general-message", nil) - } + for i := 0; i < 20; i++ { + bus.SendResponseMessage("test-channel", "general-message", nil) + } - assert.Equal(t, completeCounter, int64(0)) + assert.Equal(t, completeCounter, int64(0)) - wg.Add(2) + wg.Add(2) - bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId) - bus.GetStoreManager().GetStore("testStore").Initialize() - bus.GetStoreManager().GetStore("testStore2").Initialize() - bus.GetStoreManager().GetStore("testStore3").Initialize() + bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId) + bus.GetStoreManager().GetStore("testStore").Initialize() + bus.GetStoreManager().GetStore("testStore2").Initialize() + bus.GetStoreManager().GetStore("testStore3").Initialize() - wg.Wait() + wg.Wait() - assert.Equal(t, completeCounter, int64(2)) + assert.Equal(t, completeCounter, int64(2)) } func TestBusTransaction_OnErrorAsync(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - tr := newBusTransaction(bus, asyncTransaction) + tr := newBusTransaction(bus, asyncTransaction) - bus.GetChannelManager().CreateChannel("test-channel") - bus.GetChannelManager().CreateChannel("test-channel2") + bus.GetChannelManager().CreateChannel("test-channel") + bus.GetChannelManager().CreateChannel("test-channel2") - var channelReqMessage, channelReqMessage2 *model.Message + var channelReqMessage, channelReqMessage2 *model.Message - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - mh, _ := bus.ListenRequestStream("test-channel") - mh.Handle(func(message *model.Message) { - channelReqMessage = message - wg.Done() - }, func(e error) { - }) + mh, _ := bus.ListenRequestStream("test-channel") + mh.Handle(func(message *model.Message) { + channelReqMessage = message + wg.Done() + }, func(e error) { + }) - mh2, _ := bus.ListenRequestStream("test-channel2") - mh2.Handle(func(message *model.Message) { - channelReqMessage2 = message - wg.Done() - }, func(e error) { - }) + mh2, _ := bus.ListenRequestStream("test-channel2") + mh2.Handle(func(message *model.Message) { + channelReqMessage2 = message + wg.Done() + }, func(e error) { + }) - tr.OnComplete(func(responses []*model.Message) { - assert.Fail(t, "invalid state") - }) + tr.OnComplete(func(responses []*model.Message) { + assert.Fail(t, "invalid state") + }) - var errorHandlerCount int64 = 0 - tr.OnError(func(e error) { - atomic.AddInt64(&errorHandlerCount, 1) - assert.EqualError(t, e, "test-error") - wg.Done() - }) + var errorHandlerCount int64 = 0 + tr.OnError(func(e error) { + atomic.AddInt64(&errorHandlerCount, 1) + assert.EqualError(t, e, "test-error") + wg.Done() + }) - tr.SendRequest("test-channel", "sample-request") - tr.SendRequest("test-channel2", "sample-request2") + tr.SendRequest("test-channel", "sample-request") + tr.SendRequest("test-channel2", "sample-request2") - wg.Add(2) - tr.Commit() - wg.Wait() + wg.Add(2) + tr.Commit() + wg.Wait() - wg.Add(1) - bus.SendErrorMessage("test-channel2", errors.New("test-error"), channelReqMessage2.DestinationId) + wg.Add(1) + bus.SendErrorMessage("test-channel2", errors.New("test-error"), channelReqMessage2.DestinationId) - wg.Wait() + wg.Wait() - assert.Equal(t, errorHandlerCount, int64(1)) + assert.Equal(t, errorHandlerCount, int64(1)) - for i := 0; i < 50; i++ { - bus.SendErrorMessage("test-channel", errors.New("test-error-2"), channelReqMessage.DestinationId) - } + for i := 0; i < 50; i++ { + bus.SendErrorMessage("test-channel", errors.New("test-error-2"), channelReqMessage.DestinationId) + } - assert.Equal(t, errorHandlerCount, int64(1)) + assert.Equal(t, errorHandlerCount, int64(1)) } diff --git a/model/request.go b/model/request.go index 76cca5a..0c34cd6 100644 --- a/model/request.go +++ b/model/request.go @@ -4,52 +4,52 @@ package model import ( - "github.com/google/uuid" - "net/http" - "net/url" + "github.com/google/uuid" + "net/http" + "net/url" ) type Request struct { - Id *uuid.UUID `json:"id,omitempty"` - Destination string `json:"channel,omitempty"` - Payload interface{} `json:"payload,omitempty"` - RequestCommand string `json:"requestCommand,omitempty"` - HttpRequest *http.Request `json:"-"` - HttpResponseWriter http.ResponseWriter `json:"-"` - // Populated if the request was sent on a "private" channel and - // indicates where to send back the Response. - // A service should check this field and if not null copy it to the - // Response.BrokerDestination field to ensure that the response will be sent - // back on the correct the "private" channel. - BrokerDestination *BrokerDestinationConfig `json:"-"` + Id *uuid.UUID `json:"id,omitempty"` + Destination string `json:"channel,omitempty"` + Payload interface{} `json:"payload,omitempty"` + RequestCommand string `json:"request,omitempty"` + HttpRequest *http.Request `json:"-"` + HttpResponseWriter http.ResponseWriter `json:"-"` + // Populated if the request was sent on a "private" channel and + // indicates where to send back the Response. + // A service should check this field and if not null copy it to the + // Response.BrokerDestination field to ensure that the response will be sent + // back on the correct the "private" channel. + BrokerDestination *BrokerDestinationConfig `json:"-"` } // CreateServiceRequest is a small utility function that takes request type and payload and // returns a new model.Request instance populated with them func CreateServiceRequest(requestType string, body []byte) Request { - id := uuid.New() - return Request{ - Id: &id, - RequestCommand: requestType, - Payload: body} + id := uuid.New() + return Request{ + Id: &id, + RequestCommand: requestType, + Payload: body} } // CreateServiceRequestWithValues does the same as CreateServiceRequest, except the payload is url.Values and not // A byte[] array func CreateServiceRequestWithValues(requestType string, vals url.Values) Request { - id := uuid.New() - return Request{ - Id: &id, - RequestCommand: requestType, - Payload: vals} + id := uuid.New() + return Request{ + Id: &id, + RequestCommand: requestType, + Payload: vals} } // CreateServiceRequestWithHttpRequest does the same as CreateServiceRequest, except the payload is a pointer to the // Incoming http.Request, so you can essentially extract what ever you want from the incoming request within your service. func CreateServiceRequestWithHttpRequest(requestType string, r *http.Request) Request { - id := uuid.New() - return Request{ - Id: &id, - RequestCommand: requestType, - Payload: r} + id := uuid.New() + return Request{ + Id: &id, + RequestCommand: requestType, + Payload: r} } diff --git a/plank/pkg/middleware/basic_security_headers.go b/plank/pkg/middleware/basic_security_headers.go index dc2e68c..03b2749 100644 --- a/plank/pkg/middleware/basic_security_headers.go +++ b/plank/pkg/middleware/basic_security_headers.go @@ -4,15 +4,15 @@ package middleware import ( - "github.com/gorilla/mux" - "net/http" + "github.com/gorilla/mux" + "net/http" ) func BasicSecurityHeaderMiddleware() mux.MiddlewareFunc { - return func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Frame-Options", "allow-from https://pb33f.io/") - handler.ServeHTTP(w, r) - }) - } + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Frame-Options", "allow-from https://pb33f.io/") + handler.ServeHTTP(w, r) + }) + } } diff --git a/plank/pkg/middleware/cache_control.go b/plank/pkg/middleware/cache_control.go index c77c527..5458022 100644 --- a/plank/pkg/middleware/cache_control.go +++ b/plank/pkg/middleware/cache_control.go @@ -4,159 +4,159 @@ package middleware import ( - "fmt" - "github.com/gobwas/glob" - "github.com/gorilla/mux" - "github.com/pb33f/ranch/plank/utils" - "net/http" - "strings" - "time" + "fmt" + "github.com/gobwas/glob" + "github.com/gorilla/mux" + "github.com/pb33f/ranch/plank/utils" + "net/http" + "strings" + "time" ) // CacheControlRulePair is a container that lumps together the glob pattern, cache control rule that should be applied to // for the matching pattern and the compiled glob pattern for use in runtime. see https://github.com/gobwas/glob for // detailed examples of glob patterns. type CacheControlRulePair struct { - GlobPattern string - CacheControlRule string - CompiledGlobPattern glob.Glob + GlobPattern string + CacheControlRule string + CompiledGlobPattern glob.Glob } // NewCacheControlRulePair returns a new CacheControlRulePair with the provided text glob pattern and its compiled // counterpart and the cache control rule. func NewCacheControlRulePair(globPattern string, cacheControlRule string) (CacheControlRulePair, error) { - var err error - pair := CacheControlRulePair{ - GlobPattern: globPattern, - CacheControlRule: cacheControlRule, - } + var err error + pair := CacheControlRulePair{ + GlobPattern: globPattern, + CacheControlRule: cacheControlRule, + } - pair.CompiledGlobPattern, err = glob.Compile(globPattern) - return pair, err + pair.CompiledGlobPattern, err = glob.Compile(globPattern) + return pair, err } // CacheControlDirective is a data structure that represents the entry for Cache-Control rules type CacheControlDirective struct { - directives []string + directives []string } // NewCacheControlDirective returns creates a new instance of CacheControlDirective and returns its pointer func NewCacheControlDirective() *CacheControlDirective { - return &CacheControlDirective{} + return &CacheControlDirective{} } func (c *CacheControlDirective) Public() *CacheControlDirective { - c.directives = append(c.directives, "public") - return c + c.directives = append(c.directives, "public") + return c } func (c *CacheControlDirective) Private() *CacheControlDirective { - c.directives = append(c.directives, "private") - return c + c.directives = append(c.directives, "private") + return c } func (c *CacheControlDirective) NoCache() *CacheControlDirective { - c.directives = append(c.directives, "no-cache") - return c + c.directives = append(c.directives, "no-cache") + return c } func (c *CacheControlDirective) NoStore() *CacheControlDirective { - c.directives = append(c.directives, "no-store") - return c + c.directives = append(c.directives, "no-store") + return c } func (c *CacheControlDirective) MaxAge(t time.Duration) *CacheControlDirective { - c.directives = append(c.directives, fmt.Sprintf("max-age=%d", int64(t.Seconds()))) - return c + c.directives = append(c.directives, fmt.Sprintf("max-age=%d", int64(t.Seconds()))) + return c } func (c *CacheControlDirective) SharedMaxAge(t time.Duration) *CacheControlDirective { - c.directives = append(c.directives, fmt.Sprintf("s-maxage=%d", int64(t.Seconds()))) - return c + c.directives = append(c.directives, fmt.Sprintf("s-maxage=%d", int64(t.Seconds()))) + return c } func (c *CacheControlDirective) MaxStale(t time.Duration) *CacheControlDirective { - c.directives = append(c.directives, fmt.Sprintf("max-stale=%d", int64(t.Seconds()))) - return c + c.directives = append(c.directives, fmt.Sprintf("max-stale=%d", int64(t.Seconds()))) + return c } func (c *CacheControlDirective) MinFresh(t time.Duration) *CacheControlDirective { - c.directives = append(c.directives, fmt.Sprintf("min-fresh=%d", int64(t.Seconds()))) - return c + c.directives = append(c.directives, fmt.Sprintf("min-fresh=%d", int64(t.Seconds()))) + return c } func (c *CacheControlDirective) MustRevalidate() *CacheControlDirective { - c.directives = append(c.directives, "must-revalidate") - return c + c.directives = append(c.directives, "must-revalidate") + return c } func (c *CacheControlDirective) ProxyRevalidate() *CacheControlDirective { - c.directives = append(c.directives, "proxy-revalidate") - return c + c.directives = append(c.directives, "proxy-revalidate") + return c } func (c *CacheControlDirective) Immutable() *CacheControlDirective { - c.directives = append(c.directives, "immutable") - return c + c.directives = append(c.directives, "immutable") + return c } func (c *CacheControlDirective) NoTransform() *CacheControlDirective { - c.directives = append(c.directives, "no-transform") - return c + c.directives = append(c.directives, "no-transform") + return c } func (c *CacheControlDirective) OnlyIfCached() *CacheControlDirective { - c.directives = append(c.directives, "only-if-cached") - return c + c.directives = append(c.directives, "only-if-cached") + return c } func (c *CacheControlDirective) String() string { - return strings.Join(c.directives, ", ") + return strings.Join(c.directives, ", ") } // CacheControlMiddleware is the middleware function to be provided as a parameter to mux.Handler() func CacheControlMiddleware(globPatterns []string, directive *CacheControlDirective) mux.MiddlewareFunc { - parsed := parseGlobPatterns(globPatterns) - return func(handler http.Handler) http.Handler { - return cacheControlWrapper(handler, parsed, directive) - } + parsed := parseGlobPatterns(globPatterns) + return func(handler http.Handler) http.Handler { + return cacheControlWrapper(handler, parsed, directive) + } } // parseGlobPatterns takes an array of glob patterns and returns an array of glob.Glob instances func parseGlobPatterns(globPatterns []string) []glob.Glob { - results := make([]glob.Glob, 0) - for _, exp := range globPatterns { - globP, err := glob.Compile(exp) - if err != nil { - utils.Log.Errorln("Ignoring invalid glob pattern provided as cache control matcher rule", err) - continue - } - results = append(results, globP) - } - return results + results := make([]glob.Glob, 0) + for _, exp := range globPatterns { + globP, err := glob.Compile(exp) + if err != nil { + utils.Log.Errorln("Ignoring invalid glob pattern provided as cache control matcher rule", err) + continue + } + results = append(results, globP) + } + return results } // cacheControlWrapper is the internal function that actually performs the adding of cache control rules based on // glob pattern matching and the rules provided. func cacheControlWrapper(h http.Handler, globs []glob.Glob, directive *CacheControlDirective) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if len(globs) == 0 { - h.ServeHTTP(w, r) - return - } - - uriMatches := false - for _, glob := range globs { - if uriMatches = glob.Match(r.RequestURI); uriMatches { - break - } - } - - if !uriMatches { - h.ServeHTTP(w, r) - } - - w.Header().Set("Cache-Control", directive.String()) - h.ServeHTTP(w, r) - }) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(globs) == 0 { + h.ServeHTTP(w, r) + return + } + + uriMatches := false + for _, glob := range globs { + if uriMatches = glob.Match(r.RequestURI); uriMatches { + break + } + } + + if !uriMatches { + h.ServeHTTP(w, r) + } + + w.Header().Set("Cache-Control", directive.String()) + h.ServeHTTP(w, r) + }) } diff --git a/plank/pkg/middleware/middleware_manager.go b/plank/pkg/middleware/middleware_manager.go index fe21310..5b1a7e5 100644 --- a/plank/pkg/middleware/middleware_manager.go +++ b/plank/pkg/middleware/middleware_manager.go @@ -4,174 +4,174 @@ package middleware import ( - "fmt" - "github.com/gorilla/mux" - "github.com/pb33f/ranch/plank/utils" - "net/http" - "strings" - "sync" + "fmt" + "github.com/gorilla/mux" + "github.com/pb33f/ranch/plank/utils" + "net/http" + "strings" + "sync" ) type MiddlewareManager interface { - SetGlobalMiddleware(middleware []mux.MiddlewareFunc) error - SetNewMiddleware(route *mux.Route, middleware []mux.MiddlewareFunc) error - RemoveMiddleware(route *mux.Route) error - GetRouteByUriAndMethod(uri, method string) (*mux.Route, error) - GetRouteByUri(uri string) (*mux.Route, error) - GetStaticRoute(prefix string) (*mux.Route, error) + SetGlobalMiddleware(middleware []mux.MiddlewareFunc) error + SetNewMiddleware(route *mux.Route, middleware []mux.MiddlewareFunc) error + RemoveMiddleware(route *mux.Route) error + GetRouteByUriAndMethod(uri, method string) (*mux.Route, error) + GetRouteByUri(uri string) (*mux.Route, error) + GetStaticRoute(prefix string) (*mux.Route, error) } type Middleware interface { - //Intercept(h http.Handler) http.Handler - Interceptor() mux.MiddlewareFunc - Name() string + //Intercept(h http.Handler) http.Handler + Interceptor() mux.MiddlewareFunc + Name() string } type middlewareManager struct { - endpointHandlerMap *map[string]http.HandlerFunc - originalHandlersMap map[string]http.HandlerFunc - router *mux.Router - mu sync.Mutex + endpointHandlerMap *map[string]http.HandlerFunc + originalHandlersMap map[string]http.HandlerFunc + router *mux.Router + mu sync.Mutex } func (m *middlewareManager) SetGlobalMiddleware(middleware []mux.MiddlewareFunc) error { - m.mu.Lock() - defer m.mu.Unlock() - m.router.Use(middleware...) - return nil + m.mu.Lock() + defer m.mu.Unlock() + m.router.Use(middleware...) + return nil } func (m *middlewareManager) SetNewMiddleware(route *mux.Route, middleware []mux.MiddlewareFunc) error { - var key string - // expection is that a route's name ending with '*' means it's a prefix route - isPrefixRoute := route.GetName()[len(route.GetName())-1] == '*' - - if !isPrefixRoute { - uri, method := m.extractUriVerbFromMuxRoute(route) - if route == nil { - return fmt.Errorf("failed to set a new middleware. route does not exist at %s (%s)", uri, method) - } - // for REST-bridge service a key is in the format of {uri}-{verb} - key = uri + "-" + method - } else { - // if the route instance is a prefix route use the route name as-is - key = route.GetName() - } - - m.mu.Lock() - defer m.mu.Unlock() - - // find if base handler exists first. if not, error out - original, exists := (*m.endpointHandlerMap)[key] - if !exists { - return fmt.Errorf("cannot set middleware. handler does not exist at %s", key) - } - - // make a backup of the original handler that has no other middleware attached to it - if _, exists := m.originalHandlersMap[key]; !exists { - m.originalHandlersMap[key] = original - } - - // build a new middleware chain and apply it - handler := m.buildMiddlewareChain(middleware, original).(http.HandlerFunc) - (*m.endpointHandlerMap)[key] = handler - route.Handler(handler) - - for _, mw := range middleware { - utils.Log.Debugf("middleware '%v' registered for %s", mw, key) - } - - utils.Log.Infof("New middleware configured for REST bridge at %s", key) - - return nil + var key string + // expection is that a route's name ending with '*' means it's a prefix route + isPrefixRoute := route.GetName()[len(route.GetName())-1] == '*' + + if !isPrefixRoute { + uri, method := m.extractUriVerbFromMuxRoute(route) + if route == nil { + return fmt.Errorf("failed to set a new middleware. route does not exist at %s (%s)", uri, method) + } + // for REST-bridge service a key is in the format of {uri}-{verb} + key = uri + "-" + method + } else { + // if the route instance is a prefix route use the route name as-is + key = route.GetName() + } + + m.mu.Lock() + defer m.mu.Unlock() + + // find if base handler exists first. if not, error out + original, exists := (*m.endpointHandlerMap)[key] + if !exists { + return fmt.Errorf("cannot set middleware. handler does not exist at %s", key) + } + + // make a backup of the original handler that has no other middleware attached to it + if _, exists := m.originalHandlersMap[key]; !exists { + m.originalHandlersMap[key] = original + } + + // build a new middleware chain and apply it + handler := m.buildMiddlewareChain(middleware, original).(http.HandlerFunc) + (*m.endpointHandlerMap)[key] = handler + route.Handler(handler) + + for _, mw := range middleware { + utils.Log.Debugf("middleware '%v' registered for %s", mw, key) + } + + utils.Log.Infof("New middleware configured for REST bridge at %s", key) + + return nil } func (m *middlewareManager) RemoveMiddleware(route *mux.Route) error { - uri, method := m.extractUriVerbFromMuxRoute(route) - if route == nil { - return fmt.Errorf("failed to remove middleware. route does not exist at %s (%s)", uri, method) - } - m.mu.Lock() - defer m.mu.Unlock() - key := uri + "-" + method - if _, found := (*m.endpointHandlerMap)[key]; !found { - return fmt.Errorf("failed to remove handler. REST bridge handler does not exist at %s (%s)", uri, method) - } - defer func() { - if r := recover(); r != nil { - utils.Log.Errorln(r) - } - }() - - (*m.endpointHandlerMap)[key] = m.originalHandlersMap[key] - route.Handler(m.originalHandlersMap[key]) - utils.Log.Debugf("All middleware have been stripped from %s (%s)", uri, method) - - return nil + uri, method := m.extractUriVerbFromMuxRoute(route) + if route == nil { + return fmt.Errorf("failed to remove middleware. route does not exist at %s (%s)", uri, method) + } + m.mu.Lock() + defer m.mu.Unlock() + key := uri + "-" + method + if _, found := (*m.endpointHandlerMap)[key]; !found { + return fmt.Errorf("failed to remove handler. REST bridge handler does not exist at %s (%s)", uri, method) + } + defer func() { + if r := recover(); r != nil { + utils.Log.Errorln(r) + } + }() + + (*m.endpointHandlerMap)[key] = m.originalHandlersMap[key] + route.Handler(m.originalHandlersMap[key]) + utils.Log.Debugf("All middleware have been stripped from %s (%s)", uri, method) + + return nil } func (m *middlewareManager) GetRouteByUriAndMethod(uri, method string) (*mux.Route, error) { - m.mu.Lock() - defer m.mu.Unlock() - route := m.router.Get(fmt.Sprintf("%s-%s", uri, method)) - if route == nil { - return nil, fmt.Errorf("no route found at %s (%s)", uri, method) - } - return route, nil + m.mu.Lock() + defer m.mu.Unlock() + route := m.router.Get(fmt.Sprintf("%s-%s", uri, method)) + if route == nil { + return nil, fmt.Errorf("no route found at %s (%s)", uri, method) + } + return route, nil } func (m *middlewareManager) GetStaticRoute(prefix string) (*mux.Route, error) { - m.mu.Lock() - defer m.mu.Unlock() - routeName := prefix + "*" - route := m.router.Get(routeName) - if route == nil { - return nil, fmt.Errorf("no route found at static prefix %s", routeName) - } - return route, nil + m.mu.Lock() + defer m.mu.Unlock() + routeName := prefix + "*" + route := m.router.Get(routeName) + if route == nil { + return nil, fmt.Errorf("no route found at static prefix %s", routeName) + } + return route, nil } func (m *middlewareManager) GetRouteByUri(uri string) (*mux.Route, error) { - m.mu.Lock() - defer m.mu.Unlock() - route := m.router.Get(uri) - if route == nil { - return nil, fmt.Errorf("no route found at %s", uri) - } - return route, nil + m.mu.Lock() + defer m.mu.Unlock() + route := m.router.Get(uri) + if route == nil { + return nil, fmt.Errorf("no route found at %s", uri) + } + return route, nil } func (m *middlewareManager) buildMiddlewareChain(handlers []mux.MiddlewareFunc, originalHandler http.Handler) http.Handler { - var idx = len(handlers) - 1 - var finalHandler http.Handler - - for idx >= 0 { - var currHandler http.Handler - if idx == len(handlers)-1 { - currHandler = originalHandler - } else { - currHandler = finalHandler - } - middlewareFn := handlers[idx] - finalHandler = middlewareFn(currHandler) - idx-- - } - - return finalHandler + var idx = len(handlers) - 1 + var finalHandler http.Handler + + for idx >= 0 { + var currHandler http.Handler + if idx == len(handlers)-1 { + currHandler = originalHandler + } else { + currHandler = finalHandler + } + middlewareFn := handlers[idx] + finalHandler = middlewareFn(currHandler) + idx-- + } + + return finalHandler } // extractUriVerbFromMuxRoute takes *mux.Route and returns URI and verb as string values func (m *middlewareManager) extractUriVerbFromMuxRoute(route *mux.Route) (string, string) { - opRawString := route.GetName() - delimiterIdx := strings.LastIndex(opRawString, "-") - return opRawString[:delimiterIdx], opRawString[delimiterIdx+1:] + opRawString := route.GetName() + delimiterIdx := strings.LastIndex(opRawString, "-") + return opRawString[:delimiterIdx], opRawString[delimiterIdx+1:] } // NewMiddlewareManager sets up a new middleware manager singleton instance func NewMiddlewareManager(endpointHandlerMapPtr *map[string]http.HandlerFunc, router *mux.Router) MiddlewareManager { - return &middlewareManager{ - endpointHandlerMap: endpointHandlerMapPtr, - originalHandlersMap: make(map[string]http.HandlerFunc), - router: router, - } + return &middlewareManager{ + endpointHandlerMap: endpointHandlerMapPtr, + originalHandlersMap: make(map[string]http.HandlerFunc), + router: router, + } } diff --git a/plank/pkg/server/base_error.go b/plank/pkg/server/base_error.go index 20bf70f..5d6fc83 100644 --- a/plank/pkg/server/base_error.go +++ b/plank/pkg/server/base_error.go @@ -3,34 +3,34 @@ package server import "fmt" var ( - errServerInit = &baseError{message: "Server initialization failed"} - errHttp = &baseError{message: "HTTP error"} - errInternal = &baseError{message: "Internal error"} - errUndefined = &baseError{message: "Undefined error"} + errServerInit = &baseError{message: "Server initialization failed"} + errHttp = &baseError{message: "HTTP error"} + errInternal = &baseError{message: "Internal error"} + errUndefined = &baseError{message: "Undefined error"} ) type baseError struct { - wrappedErr error - baseErr *baseError - message string + wrappedErr error + baseErr *baseError + message string } func (e baseError) Is(err error) bool { - return e.baseErr == err + return e.baseErr == err } func (e baseError) Error() string { - return fmt.Sprintf("[ranch] Error: %s: %s\n", e.baseErr.message, e.wrappedErr.Error()) + return fmt.Sprintf("[ranch] Error: %s: %s\n", e.baseErr.message, e.wrappedErr.Error()) } func wrapError(baseType error, err error) error { - switch baseType { - case errServerInit: - return &baseError{baseErr: errServerInit, wrappedErr: err} - case errInternal: - return &baseError{baseErr: errInternal, wrappedErr: err} - case errHttp: - return &baseError{baseErr: errHttp, wrappedErr: err} - } - return &baseError{baseErr: errUndefined, wrappedErr: err} + switch baseType { + case errServerInit: + return &baseError{baseErr: errServerInit, wrappedErr: err} + case errInternal: + return &baseError{baseErr: errInternal, wrappedErr: err} + case errHttp: + return &baseError{baseErr: errHttp, wrappedErr: err} + } + return &baseError{baseErr: errUndefined, wrappedErr: err} } diff --git a/plank/pkg/server/base_error_test.go b/plank/pkg/server/base_error_test.go index 2cd479c..447aaf0 100644 --- a/plank/pkg/server/base_error_test.go +++ b/plank/pkg/server/base_error_test.go @@ -1,47 +1,47 @@ package server import ( - "errors" - "github.com/stretchr/testify/assert" - "testing" + "errors" + "github.com/stretchr/testify/assert" + "testing" ) func TestBaseError_Is_errServerInit(t *testing.T) { - e := wrapError(errServerInit, errors.New("some init fail")) - assert.True(t, errors.Is(e, errServerInit)) + e := wrapError(errServerInit, errors.New("some init fail")) + assert.True(t, errors.Is(e, errServerInit)) } func TestBaseError_Error_errServerInit(t *testing.T) { - e := wrapError(errServerInit, errors.New("some init fail")) - assert.EqualValues(t, "[ranch] Error: Server initialization failed: some init fail\n", e.Error()) + e := wrapError(errServerInit, errors.New("some init fail")) + assert.EqualValues(t, "[ranch] Error: Server initialization failed: some init fail\n", e.Error()) } func TestBaseError_Is_errInternal(t *testing.T) { - e := wrapError(errInternal, errors.New("internal server error")) - assert.True(t, errors.Is(e, errInternal)) + e := wrapError(errInternal, errors.New("internal server error")) + assert.True(t, errors.Is(e, errInternal)) } func TestBaseError_Error_errInternal(t *testing.T) { - e := wrapError(errInternal, errors.New("internal server error")) - assert.EqualValues(t, "[ranch] Error: Internal error: internal server error\n", e.Error()) + e := wrapError(errInternal, errors.New("internal server error")) + assert.EqualValues(t, "[ranch] Error: Internal error: internal server error\n", e.Error()) } func TestBaseError_Is_errHttp(t *testing.T) { - e := wrapError(errHttp, errors.New("404")) - assert.True(t, errors.Is(e, errHttp)) + e := wrapError(errHttp, errors.New("404")) + assert.True(t, errors.Is(e, errHttp)) } func TestBaseError_Error_errHttp(t *testing.T) { - e := wrapError(errHttp, errors.New("404")) - assert.EqualValues(t, "[ranch] Error: HTTP error: 404\n", e.Error()) + e := wrapError(errHttp, errors.New("404")) + assert.EqualValues(t, "[ranch] Error: HTTP error: 404\n", e.Error()) } func TestBaseError_Is_undefined(t *testing.T) { - e := wrapError(errors.New("some random stuff"), errors.New("?")) - assert.True(t, errors.Is(e, errUndefined)) + e := wrapError(errors.New("some random stuff"), errors.New("?")) + assert.True(t, errors.Is(e, errUndefined)) } func TestBaseError_Error_undefined(t *testing.T) { - e := wrapError(errors.New("some random stuff"), errors.New("?")) - assert.EqualValues(t, "[ranch] Error: Undefined error: ?\n", e.Error()) + e := wrapError(errors.New("some random stuff"), errors.New("?")) + assert.EqualValues(t, "[ranch] Error: Undefined error: ?\n", e.Error()) } diff --git a/plank/pkg/server/core_models.go b/plank/pkg/server/core_models.go index cbfd2b4..2adfaf7 100644 --- a/plank/pkg/server/core_models.go +++ b/plank/pkg/server/core_models.go @@ -4,92 +4,92 @@ package server import ( - "crypto/tls" - "github.com/gorilla/mux" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/model" - "github.com/pb33f/ranch/plank/pkg/middleware" - "github.com/pb33f/ranch/plank/utils" - "github.com/pb33f/ranch/service" - "github.com/pb33f/ranch/stompserver" - "io" - "net/http" - "os" - "sync" - "time" + "crypto/tls" + "github.com/gorilla/mux" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/model" + "github.com/pb33f/ranch/plank/pkg/middleware" + "github.com/pb33f/ranch/plank/utils" + "github.com/pb33f/ranch/service" + "github.com/pb33f/ranch/stompserver" + "io" + "net/http" + "os" + "sync" + "time" ) // PlatformServerConfig holds all the core configuration needed for the functionality of Plank type PlatformServerConfig struct { - RootDir string `json:"root_dir"` // root directory the server should base itself on - StaticDir []string `json:"static_dir"` // static content folders that HTTP server should serve - SpaConfig *SpaConfig `json:"spa_config"` // single page application configuration - Host string `json:"host"` // hostname for the server - Port int `json:"port"` // port for the server - LogConfig *utils.LogConfig `json:"log_config"` // log configuration (plank, Http access and error logs) - FabricConfig *FabricBrokerConfig `json:"fabric_config"` // Fabric (websocket) configuration - TLSCertConfig *TLSCertConfig `json:"tls_config"` // TLS certificate configuration - Debug bool `json:"debug"` // enable debug logging - NoBanner bool `json:"no_banner"` // start server without displaying the banner - ShutdownTimeout time.Duration `json:"shutdown_timeout_in_minutes"` // graceful server shutdown timeout in minutes - RestBridgeTimeout time.Duration `json:"rest_bridge_timeout_in_minutes"` // rest bridge timeout in minutes + RootDir string `json:"root_dir"` // root directory the server should base itself on + StaticDir []string `json:"static_dir"` // static content folders that HTTP server should serve + SpaConfig *SpaConfig `json:"spa_config"` // single page application configuration + Host string `json:"host"` // hostname for the server + Port int `json:"port"` // port for the server + LogConfig *utils.LogConfig `json:"log_config"` // log configuration (plank, Http access and error logs) + FabricConfig *FabricBrokerConfig `json:"fabric_config"` // Fabric (websocket) configuration + TLSCertConfig *TLSCertConfig `json:"tls_config"` // TLS certificate configuration + Debug bool `json:"debug"` // enable debug logging + NoBanner bool `json:"no_banner"` // start server without displaying the banner + ShutdownTimeout time.Duration `json:"shutdown_timeout_in_minutes"` // graceful server shutdown timeout in minutes + RestBridgeTimeout time.Duration `json:"rest_bridge_timeout_in_minutes"` // rest bridge timeout in minutes } // TLSCertConfig wraps around key information for TLS configuration type TLSCertConfig struct { - CertFile string `json:"cert_file"` // path to certificate file - KeyFile string `json:"key_file"` // path to private key file - SkipCertificateValidation bool `json:"skip_certificate_validation"` // whether to skip certificate validation (useful for self-signed cert) + CertFile string `json:"cert_file"` // path to certificate file + KeyFile string `json:"key_file"` // path to private key file + SkipCertificateValidation bool `json:"skip_certificate_validation"` // whether to skip certificate validation (useful for self-signed cert) } // FabricBrokerConfig defines the endpoint for WebSocket as well as detailed endpoint configuration type FabricBrokerConfig struct { - FabricEndpoint string `json:"fabric_endpoint"` // URI to WebSocket endpoint - UseTCP bool `json:"use_tcp"` // Use TCP instead of WebSocket - TCPPort int `json:"tcp_port"` // TCP port to use if UseTCP is true - EndpointConfig *bus.EndpointConfig `json:"endpoint_config"` // STOMP configuration + FabricEndpoint string `json:"fabric_endpoint"` // URI to WebSocket endpoint + UseTCP bool `json:"use_tcp"` // Use TCP instead of WebSocket + TCPPort int `json:"tcp_port"` // TCP port to use if UseTCP is true + EndpointConfig *bus.EndpointConfig `json:"endpoint_config"` // STOMP configuration } // PlatformServer exposes public API methods that control the behavior of the Plank instance. type PlatformServer interface { - StartServer(syschan chan os.Signal) // start server - StopServer() // stop server - RegisterService(svc service.FabricService, svcChannel string) error // register a new service at given channel - SetHttpChannelBridge(bridgeConfig *service.RESTBridgeConfig) // set up a REST bridge for a service - SetStaticRoute(prefix, fullpath string, middlewareFn ...mux.MiddlewareFunc) // set up a static content route - SetHttpPathPrefixChannelBridge(bridgeConfig *service.RESTBridgeConfig) // set up a REST bridge for a path prefix for a service. - CustomizeTLSConfig(tls *tls.Config) error // used to replace default tls.Config for HTTP server with a custom config - GetRestBridgeSubRoute(uri, method string) (*mux.Route, error) // get *mux.Route that maps to the provided uri and method - GetMiddlewareManager() middleware.MiddlewareManager // get middleware manager + StartServer(syschan chan os.Signal) // start server + StopServer() // stop server + RegisterService(svc service.FabricService, svcChannel string) error // register a new service at given channel + SetHttpChannelBridge(bridgeConfig *service.RESTBridgeConfig) // set up a REST bridge for a service + SetStaticRoute(prefix, fullpath string, middlewareFn ...mux.MiddlewareFunc) // set up a static content route + SetHttpPathPrefixChannelBridge(bridgeConfig *service.RESTBridgeConfig) // set up a REST bridge for a path prefix for a service. + CustomizeTLSConfig(tls *tls.Config) error // used to replace default tls.Config for HTTP server with a custom config + GetRestBridgeSubRoute(uri, method string) (*mux.Route, error) // get *mux.Route that maps to the provided uri and method + GetMiddlewareManager() middleware.MiddlewareManager // get middleware manager } // platformServer is the main struct that holds all components together including servers, various managers etc. type platformServer struct { - HttpServer *http.Server // Http server instance - SyscallChan chan os.Signal // syscall channel to receive SIGINT, SIGKILL events - eventbus bus.EventBus // event bus pointer - serverConfig *PlatformServerConfig // server config instance - middlewareManager middleware.MiddlewareManager // middleware maanger instance - router *mux.Router // *mux.Router instance - routerConcurrencyProtection *int32 // atomic int32 to protect the main router being concurrently written to - out io.Writer // platform log output pointer - endpointHandlerMap map[string]http.HandlerFunc // internal map to store rest endpoint -handler mappings - serviceChanToBridgeEndpoints map[string][]string // internal map to store service channel - endpoint handler key mappings - fabricConn stompserver.RawConnectionListener // WebSocket listener instance - ServerAvailability *ServerAvailability // server availability (not much used other than for internal monitoring for now) - lock sync.Mutex // lock - messageBridgeMap map[string]*MessageBridge + HttpServer *http.Server // Http server instance + SyscallChan chan os.Signal // syscall channel to receive SIGINT, SIGKILL events + eventbus bus.EventBus // event bus pointer + serverConfig *PlatformServerConfig // server config instance + middlewareManager middleware.MiddlewareManager // middleware maanger instance + router *mux.Router // *mux.Router instance + routerConcurrencyProtection *int32 // atomic int32 to protect the main router being concurrently written to + out io.Writer // platform log output pointer + endpointHandlerMap map[string]http.HandlerFunc // internal map to store rest endpoint -handler mappings + serviceChanToBridgeEndpoints map[string][]string // internal map to store service channel - endpoint handler key mappings + fabricConn stompserver.RawConnectionListener // WebSocket listener instance + ServerAvailability *ServerAvailability // server availability (not much used other than for internal monitoring for now) + lock sync.Mutex // lock + messageBridgeMap map[string]*MessageBridge } // MessageBridge is a conduit used for returning service responses as HTTP responses type MessageBridge struct { - ServiceListenStream bus.MessageHandler // message handler returned by bus.ListenStream responsible for relaying back messages as HTTP responses - payloadChannel chan *model.Message // internal golang channel used for passing bus responses/errors across goroutines + ServiceListenStream bus.MessageHandler // message handler returned by bus.ListenStream responsible for relaying back messages as HTTP responses + payloadChannel chan *model.Message // internal golang channel used for passing bus responses/errors across goroutines } // ServerAvailability contains boolean fields to indicate what components of the system are available or not type ServerAvailability struct { - Http bool // Http server availability - Fabric bool // stomp broker availability + Http bool // Http server availability + Fabric bool // stomp broker availability } diff --git a/plank/pkg/server/flag_helper.go b/plank/pkg/server/flag_helper.go index f71cc25..1573820 100644 --- a/plank/pkg/server/flag_helper.go +++ b/plank/pkg/server/flag_helper.go @@ -4,221 +4,221 @@ package server import ( - "github.com/pb33f/ranch/plank/utils" - flag "github.com/spf13/pflag" - "github.com/spf13/viper" - "os" + "github.com/pb33f/ranch/plank/utils" + flag "github.com/spf13/pflag" + "github.com/spf13/viper" + "os" ) type serverConfigFactory struct { - statics []string - flagSet *flag.FlagSet - flagsParsed bool + statics []string + flagSet *flag.FlagSet + flagsParsed bool } func (f *serverConfigFactory) Hostname() string { - return viper.GetString(utils.PlatformServerFlagConstants["Hostname"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["Hostname"]["FlagName"]) } func (f *serverConfigFactory) Port() int { - return viper.GetInt(utils.PlatformServerFlagConstants["Port"]["FlagName"]) + return viper.GetInt(utils.PlatformServerFlagConstants["Port"]["FlagName"]) } func (f *serverConfigFactory) RootDir() string { - return viper.GetString(utils.PlatformServerFlagConstants["RootDir"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["RootDir"]["FlagName"]) } func (f *serverConfigFactory) Cert() string { - return viper.GetString(utils.PlatformServerFlagConstants["Cert"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["Cert"]["FlagName"]) } func (f *serverConfigFactory) CertKey() string { - return viper.GetString(utils.PlatformServerFlagConstants["CertKey"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["CertKey"]["FlagName"]) } func (f *serverConfigFactory) Static() []string { - return viper.GetStringSlice(utils.PlatformServerFlagConstants["Static"]["FlagName"]) + return viper.GetStringSlice(utils.PlatformServerFlagConstants["Static"]["FlagName"]) } func (f *serverConfigFactory) SpaPath() string { - return viper.GetString(utils.PlatformServerFlagConstants["SpaPath"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["SpaPath"]["FlagName"]) } func (f *serverConfigFactory) NoFabricBroker() bool { - return viper.GetBool(utils.PlatformServerFlagConstants["NoFabricBroker"]["FlagName"]) + return viper.GetBool(utils.PlatformServerFlagConstants["NoFabricBroker"]["FlagName"]) } func (f *serverConfigFactory) FabricEndpoint() string { - return viper.GetString(utils.PlatformServerFlagConstants["FabricEndpoint"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["FabricEndpoint"]["FlagName"]) } func (f *serverConfigFactory) TopicPrefix() string { - return viper.GetString(utils.PlatformServerFlagConstants["TopicPrefix"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["TopicPrefix"]["FlagName"]) } func (f *serverConfigFactory) QueuePrefix() string { - return viper.GetString(utils.PlatformServerFlagConstants["QueuePrefix"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["QueuePrefix"]["FlagName"]) } func (f *serverConfigFactory) RequestPrefix() string { - return viper.GetString(utils.PlatformServerFlagConstants["RequestPrefix"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["RequestPrefix"]["FlagName"]) } func (f *serverConfigFactory) RequestQueuePrefix() string { - return viper.GetString(utils.PlatformServerFlagConstants["RequestQueuePrefix"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["RequestQueuePrefix"]["FlagName"]) } func (f *serverConfigFactory) ConfigFile() string { - return viper.GetString(utils.PlatformServerFlagConstants["ConfigFile"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["ConfigFile"]["FlagName"]) } func (f *serverConfigFactory) ShutdownTimeout() int64 { - return viper.GetInt64(utils.PlatformServerFlagConstants["ShutdownTimeout"]["FlagName"]) + return viper.GetInt64(utils.PlatformServerFlagConstants["ShutdownTimeout"]["FlagName"]) } func (f *serverConfigFactory) OutputLog() string { - return viper.GetString(utils.PlatformServerFlagConstants["OutputLog"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["OutputLog"]["FlagName"]) } func (f *serverConfigFactory) AccessLog() string { - return viper.GetString(utils.PlatformServerFlagConstants["AccessLog"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["AccessLog"]["FlagName"]) } func (f *serverConfigFactory) ErrorLog() string { - return viper.GetString(utils.PlatformServerFlagConstants["ErrorLog"]["FlagName"]) + return viper.GetString(utils.PlatformServerFlagConstants["ErrorLog"]["FlagName"]) } func (f *serverConfigFactory) Debug() bool { - return viper.GetBool(utils.PlatformServerFlagConstants["Debug"]["FlagName"]) + return viper.GetBool(utils.PlatformServerFlagConstants["Debug"]["FlagName"]) } func (f *serverConfigFactory) NoBanner() bool { - return viper.GetBool(utils.PlatformServerFlagConstants["NoBanner"]["FlagName"]) + return viper.GetBool(utils.PlatformServerFlagConstants["NoBanner"]["FlagName"]) } func (f *serverConfigFactory) RestBridgeTimeout() int64 { - return viper.GetInt64(utils.PlatformServerFlagConstants["RestBridgeTimeout"]["FlagName"]) + return viper.GetInt64(utils.PlatformServerFlagConstants["RestBridgeTimeout"]["FlagName"]) } // parseFlags reads OS arguments into the FlagSet in this factory instance func (f *serverConfigFactory) parseFlags(args []string) { - f.flagSet.Parse(args[1:]) - f.flagsParsed = flag.Parsed() + f.flagSet.Parse(args[1:]) + f.flagsParsed = flag.Parsed() } // configureFlags defines flags definitions as well as associate environment variables for a few // flags. see configureFlagsInFlagSet() for detailed flag defining logic. func (f *serverConfigFactory) configureFlags(flagset *flag.FlagSet) { - viper.SetEnvPrefix("PLANK_SERVER") - viper.BindEnv("hostname") - viper.BindEnv("port") - viper.BindEnv("rootdir") - f.flagSet = flagset - f.configureFlagsInFlagSet(f.flagSet) - viper.BindPFlags(f.flagSet) + viper.SetEnvPrefix("PLANK_SERVER") + viper.BindEnv("hostname") + viper.BindEnv("port") + viper.BindEnv("rootdir") + f.flagSet = flagset + f.configureFlagsInFlagSet(f.flagSet) + viper.BindPFlags(f.flagSet) } // configureFlagsInFlagSet takes the pointer to an arbitrary FlagSet instance and // populates it with the flag definitions necessary for PlatformServerConfig. func (f *serverConfigFactory) configureFlagsInFlagSet(fs *flag.FlagSet) { - wd, _ := os.Getwd() - fs.StringP( - utils.PlatformServerFlagConstants["Hostname"]["FlagName"], - utils.PlatformServerFlagConstants["Hostname"]["ShortFlag"], - "localhost", - utils.PlatformServerFlagConstants["Hostname"]["Description"]) - fs.IntP( - utils.PlatformServerFlagConstants["Port"]["FlagName"], - utils.PlatformServerFlagConstants["Port"]["ShortFlag"], - 30080, - utils.PlatformServerFlagConstants["Port"]["Description"]) - fs.StringP( - utils.PlatformServerFlagConstants["RootDir"]["FlagName"], - utils.PlatformServerFlagConstants["RootDir"]["ShortFlag"], - wd, - utils.PlatformServerFlagConstants["RootDir"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["Cert"]["FlagName"], - "", - utils.PlatformServerFlagConstants["Cert"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["CertKey"]["FlagName"], - "", - utils.PlatformServerFlagConstants["CertKey"]["Description"]) - - fs.StringSliceVarP( - &f.statics, - utils.PlatformServerFlagConstants["Static"]["FlagName"], - utils.PlatformServerFlagConstants["Static"]["ShortFlag"], - []string{}, - utils.PlatformServerFlagConstants["Static"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["SpaPath"]["FlagName"], - "", - utils.PlatformServerFlagConstants["SpaPath"]["Description"]) - fs.Bool( - utils.PlatformServerFlagConstants["NoFabricBroker"]["FlagName"], - false, - utils.PlatformServerFlagConstants["NoFabricBroker"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["FabricEndpoint"]["FlagName"], - "/ws", - utils.PlatformServerFlagConstants["FabricEndpoint"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["TopicPrefix"]["FlagName"], - "/topic", - utils.PlatformServerFlagConstants["TopicPrefix"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["QueuePrefix"]["FlagName"], - "/queue", - utils.PlatformServerFlagConstants["QueuePrefix"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["RequestPrefix"]["FlagName"], - "/pub", - utils.PlatformServerFlagConstants["RequestPrefix"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["RequestQueuePrefix"]["FlagName"], - "/pub/queue", - utils.PlatformServerFlagConstants["RequestQueuePrefix"]["Description"]) - fs.String( - utils.PlatformServerFlagConstants["ConfigFile"]["FlagName"], - "", - utils.PlatformServerFlagConstants["ConfigFile"]["Description"]) - fs.Int64( - utils.PlatformServerFlagConstants["ShutdownTimeout"]["FlagName"], - 5, - utils.PlatformServerFlagConstants["ShutdownTimeout"]["Description"]) - fs.StringP( - utils.PlatformServerFlagConstants["OutputLog"]["FlagName"], - utils.PlatformServerFlagConstants["OutputLog"]["ShortFlag"], - "stdout", - utils.PlatformServerFlagConstants["OutputLog"]["Description"]) - fs.StringP( - utils.PlatformServerFlagConstants["AccessLog"]["FlagName"], - utils.PlatformServerFlagConstants["AccessLog"]["ShortFlag"], - "stdout", - utils.PlatformServerFlagConstants["AccessLog"]["Description"]) - fs.StringP( - utils.PlatformServerFlagConstants["ErrorLog"]["FlagName"], - utils.PlatformServerFlagConstants["ErrorLog"]["ShortFlag"], - "stderr", - utils.PlatformServerFlagConstants["ErrorLog"]["Description"]) - fs.BoolP( - utils.PlatformServerFlagConstants["Debug"]["FlagName"], - utils.PlatformServerFlagConstants["Debug"]["ShortFlag"], - false, - utils.PlatformServerFlagConstants["Debug"]["Description"]) - fs.BoolP( - utils.PlatformServerFlagConstants["NoBanner"]["FlagName"], - utils.PlatformServerFlagConstants["NoBanner"]["ShortFlag"], - false, - utils.PlatformServerFlagConstants["NoBanner"]["Description"]) - fs.Bool( - utils.PlatformServerFlagConstants["Prometheus"]["FlagName"], - false, - utils.PlatformServerFlagConstants["Prometheus"]["Description"]) - fs.Int64( - utils.PlatformServerFlagConstants["RestBridgeTimeout"]["FlagName"], - 1, - utils.PlatformServerFlagConstants["RestBridgeTimeout"]["Description"]) + wd, _ := os.Getwd() + fs.StringP( + utils.PlatformServerFlagConstants["Hostname"]["FlagName"], + utils.PlatformServerFlagConstants["Hostname"]["ShortFlag"], + "localhost", + utils.PlatformServerFlagConstants["Hostname"]["Description"]) + fs.IntP( + utils.PlatformServerFlagConstants["Port"]["FlagName"], + utils.PlatformServerFlagConstants["Port"]["ShortFlag"], + 30080, + utils.PlatformServerFlagConstants["Port"]["Description"]) + fs.StringP( + utils.PlatformServerFlagConstants["RootDir"]["FlagName"], + utils.PlatformServerFlagConstants["RootDir"]["ShortFlag"], + wd, + utils.PlatformServerFlagConstants["RootDir"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["Cert"]["FlagName"], + "", + utils.PlatformServerFlagConstants["Cert"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["CertKey"]["FlagName"], + "", + utils.PlatformServerFlagConstants["CertKey"]["Description"]) + + fs.StringSliceVarP( + &f.statics, + utils.PlatformServerFlagConstants["Static"]["FlagName"], + utils.PlatformServerFlagConstants["Static"]["ShortFlag"], + []string{}, + utils.PlatformServerFlagConstants["Static"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["SpaPath"]["FlagName"], + "", + utils.PlatformServerFlagConstants["SpaPath"]["Description"]) + fs.Bool( + utils.PlatformServerFlagConstants["NoFabricBroker"]["FlagName"], + false, + utils.PlatformServerFlagConstants["NoFabricBroker"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["FabricEndpoint"]["FlagName"], + "/ws", + utils.PlatformServerFlagConstants["FabricEndpoint"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["TopicPrefix"]["FlagName"], + "/topic", + utils.PlatformServerFlagConstants["TopicPrefix"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["QueuePrefix"]["FlagName"], + "/queue", + utils.PlatformServerFlagConstants["QueuePrefix"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["RequestPrefix"]["FlagName"], + "/pub", + utils.PlatformServerFlagConstants["RequestPrefix"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["RequestQueuePrefix"]["FlagName"], + "/pub/queue", + utils.PlatformServerFlagConstants["RequestQueuePrefix"]["Description"]) + fs.String( + utils.PlatformServerFlagConstants["ConfigFile"]["FlagName"], + "", + utils.PlatformServerFlagConstants["ConfigFile"]["Description"]) + fs.Int64( + utils.PlatformServerFlagConstants["ShutdownTimeout"]["FlagName"], + 5, + utils.PlatformServerFlagConstants["ShutdownTimeout"]["Description"]) + fs.StringP( + utils.PlatformServerFlagConstants["OutputLog"]["FlagName"], + utils.PlatformServerFlagConstants["OutputLog"]["ShortFlag"], + "stdout", + utils.PlatformServerFlagConstants["OutputLog"]["Description"]) + fs.StringP( + utils.PlatformServerFlagConstants["AccessLog"]["FlagName"], + utils.PlatformServerFlagConstants["AccessLog"]["ShortFlag"], + "stdout", + utils.PlatformServerFlagConstants["AccessLog"]["Description"]) + fs.StringP( + utils.PlatformServerFlagConstants["ErrorLog"]["FlagName"], + utils.PlatformServerFlagConstants["ErrorLog"]["ShortFlag"], + "stderr", + utils.PlatformServerFlagConstants["ErrorLog"]["Description"]) + fs.BoolP( + utils.PlatformServerFlagConstants["Debug"]["FlagName"], + utils.PlatformServerFlagConstants["Debug"]["ShortFlag"], + false, + utils.PlatformServerFlagConstants["Debug"]["Description"]) + fs.BoolP( + utils.PlatformServerFlagConstants["NoBanner"]["FlagName"], + utils.PlatformServerFlagConstants["NoBanner"]["ShortFlag"], + false, + utils.PlatformServerFlagConstants["NoBanner"]["Description"]) + fs.Bool( + utils.PlatformServerFlagConstants["Prometheus"]["FlagName"], + false, + utils.PlatformServerFlagConstants["Prometheus"]["Description"]) + fs.Int64( + utils.PlatformServerFlagConstants["RestBridgeTimeout"]["FlagName"], + 1, + utils.PlatformServerFlagConstants["RestBridgeTimeout"]["Description"]) } diff --git a/plank/pkg/server/flag_helper_test.go b/plank/pkg/server/flag_helper_test.go index 80d318a..d2e8b2f 100644 --- a/plank/pkg/server/flag_helper_test.go +++ b/plank/pkg/server/flag_helper_test.go @@ -1,63 +1,63 @@ package server import ( - "fmt" - "github.com/spf13/pflag" - "github.com/stretchr/testify/assert" - "os" - "testing" + "fmt" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" + "os" + "testing" ) func TestFlagHelper_ParseDefaultFlags(t *testing.T) { - // arrange - f := &serverConfigFactory{} - pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) + // arrange + f := &serverConfigFactory{} + pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) - // act - testArgs := []string{""} - f.configureFlags(pflag.CommandLine) - f.parseFlags(testArgs) + // act + testArgs := []string{""} + f.configureFlags(pflag.CommandLine) + f.parseFlags(testArgs) - // assert - wd, _ := os.Getwd() - assert.True(t, f.flagsParsed) - assert.EqualValues(t, "localhost", f.Hostname()) - assert.EqualValues(t, 30080, f.Port()) - assert.EqualValues(t, wd, f.RootDir()) - assert.Empty(t, f.Cert()) - assert.Empty(t, f.CertKey()) - assert.Empty(t, f.Static()) - assert.Empty(t, f.SpaPath()) - assert.False(t, f.NoFabricBroker()) - assert.EqualValues(t, "/ws", f.FabricEndpoint()) - assert.EqualValues(t, "/topic", f.TopicPrefix()) - assert.EqualValues(t, "/queue", f.QueuePrefix()) - assert.EqualValues(t, "/pub", f.RequestPrefix()) - assert.EqualValues(t, "/pub/queue", f.RequestQueuePrefix()) - assert.Empty(t, f.ConfigFile()) - assert.EqualValues(t, 5, f.ShutdownTimeout()) - assert.EqualValues(t, "stdout", f.OutputLog()) - assert.EqualValues(t, "stdout", f.AccessLog()) - assert.EqualValues(t, "stderr", f.ErrorLog()) - assert.False(t, f.Debug()) - assert.False(t, f.NoBanner()) - assert.EqualValues(t, 1, f.RestBridgeTimeout()) + // assert + wd, _ := os.Getwd() + assert.True(t, f.flagsParsed) + assert.EqualValues(t, "localhost", f.Hostname()) + assert.EqualValues(t, 30080, f.Port()) + assert.EqualValues(t, wd, f.RootDir()) + assert.Empty(t, f.Cert()) + assert.Empty(t, f.CertKey()) + assert.Empty(t, f.Static()) + assert.Empty(t, f.SpaPath()) + assert.False(t, f.NoFabricBroker()) + assert.EqualValues(t, "/ws", f.FabricEndpoint()) + assert.EqualValues(t, "/topic", f.TopicPrefix()) + assert.EqualValues(t, "/queue", f.QueuePrefix()) + assert.EqualValues(t, "/pub", f.RequestPrefix()) + assert.EqualValues(t, "/pub/queue", f.RequestQueuePrefix()) + assert.Empty(t, f.ConfigFile()) + assert.EqualValues(t, 5, f.ShutdownTimeout()) + assert.EqualValues(t, "stdout", f.OutputLog()) + assert.EqualValues(t, "stdout", f.AccessLog()) + assert.EqualValues(t, "stderr", f.ErrorLog()) + assert.False(t, f.Debug()) + assert.False(t, f.NoBanner()) + assert.EqualValues(t, 1, f.RestBridgeTimeout()) } func TestFlagHelper_ParseFlags(t *testing.T) { - // arrange - f := &serverConfigFactory{} - pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) + // arrange + f := &serverConfigFactory{} + pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) - // act - testArgs := []string{"", "--hostname", "cloud.local", "--port", "443", "--static", "test", "--static", "test2"} - f.configureFlags(pflag.CommandLine) - f.parseFlags(testArgs) + // act + testArgs := []string{"", "--hostname", "cloud.local", "--port", "443", "--static", "test", "--static", "test2"} + f.configureFlags(pflag.CommandLine) + f.parseFlags(testArgs) - // assert - assert.True(t, f.flagsParsed) - assert.EqualValues(t, "cloud.local", f.Hostname()) - assert.EqualValues(t, 443, f.Port()) - assert.Len(t, f.Static(), 2) - assert.EqualValues(t, "[test test2]", fmt.Sprint(f.Static())) + // assert + assert.True(t, f.flagsParsed) + assert.EqualValues(t, "cloud.local", f.Hostname()) + assert.EqualValues(t, 443, f.Port()) + assert.Len(t, f.Static(), 2) + assert.EqualValues(t, "[test test2]", fmt.Sprint(f.Static())) } diff --git a/plank/pkg/server/helpers.go b/plank/pkg/server/helpers.go index b1757fa..39f34f7 100644 --- a/plank/pkg/server/helpers.go +++ b/plank/pkg/server/helpers.go @@ -1,176 +1,176 @@ package server import ( - "encoding/json" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/plank/utils" - "io/ioutil" - "os" - "path/filepath" - "strings" - "time" + "encoding/json" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/plank/utils" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" ) // generatePlatformServerConfig is a generic internal method that returns the pointer of a new // instance of PlatformServerConfig. for an argument it can be passed either *serverConfigFactory // or *cli.Context which the method will analyze and determine the best way to extract user provided values from it. func generatePlatformServerConfig(f *serverConfigFactory) (*PlatformServerConfig, error) { - configFile := f.ConfigFile() - host := f.Hostname() - port := f.Port() - rootDir := f.RootDir() - static := f.Static() - shutdownTimeoutInMinutes := f.ShutdownTimeout() - accessLog := f.AccessLog() - outputLog := f.OutputLog() - errorLog := f.ErrorLog() - debug := f.Debug() - noBanner := f.NoBanner() - cert := f.Cert() - certKey := f.CertKey() - spaPath := f.SpaPath() - noFabricBroker := f.NoFabricBroker() - fabricEndpoint := f.FabricEndpoint() - topicPrefix := f.TopicPrefix() - queuePrefix := f.QueuePrefix() - requestPrefix := f.RequestPrefix() - requestQueuePrefix := f.RequestQueuePrefix() - restBridgeTimeout := f.RestBridgeTimeout() - - // if config file flag is provided, read directly from the file - if len(configFile) > 0 { - var serverConfig PlatformServerConfig - b, err := ioutil.ReadFile(configFile) - if err != nil { - return nil, err - } - if err = json.Unmarshal(b, &serverConfig); err != nil { - return nil, err - } - - // handle invalid duration by setting it to the default value of 5 minutes - if serverConfig.ShutdownTimeout <= 0 { - serverConfig.ShutdownTimeout = 5 - } - - // handle invalid duration by setting it to the default value of 1 minute - if serverConfig.RestBridgeTimeout <= 0 { - serverConfig.RestBridgeTimeout = 1 - } - - // the raw value from the config.json needs to be multiplied by time.Minute otherwise it's interpreted as nanosecond - serverConfig.ShutdownTimeout = serverConfig.ShutdownTimeout * time.Minute - - // the raw value from the config.json needs to be multiplied by time.Minute otherwise it's interpreted as nanosecond - serverConfig.RestBridgeTimeout = serverConfig.RestBridgeTimeout * time.Minute - - // convert map of cache control rules of SpaConfig into an array - if serverConfig.SpaConfig != nil { - serverConfig.SpaConfig.CollateCacheControlRules() - } - - return &serverConfig, nil - } - - // handle invalid duration by setting it to the default value of 1 minute - if restBridgeTimeout <= 0 { - restBridgeTimeout = 1 - } - - // handle invalid duration by setting it to the default value of 5 minutes - if shutdownTimeoutInMinutes <= 0 { - shutdownTimeoutInMinutes = 5 - } - - // instantiate a server config - serverConfig := &PlatformServerConfig{ - Host: host, - Port: port, - RootDir: rootDir, - StaticDir: static, - ShutdownTimeout: time.Duration(shutdownTimeoutInMinutes) * time.Minute, - LogConfig: &utils.LogConfig{ - AccessLog: accessLog, - ErrorLog: errorLog, - OutputLog: outputLog, - Root: rootDir, - FormatOptions: &utils.LogFormatOption{}, - }, - Debug: debug, - NoBanner: noBanner, - RestBridgeTimeout: time.Duration(restBridgeTimeout) * time.Minute, - } - - if len(certKey) > 0 && len(certKey) > 0 { - var err error - certKey, err = filepath.Abs(certKey) - if err != nil { - return nil, err - } - cert, err = filepath.Abs(cert) - if err != nil { - return nil, err - } - - serverConfig.TLSCertConfig = &TLSCertConfig{CertFile: cert, KeyFile: certKey} - } - - if len(strings.TrimSpace(spaPath)) > 0 { - var err error - serverConfig.SpaConfig, err = NewSpaConfig(spaPath) - if err != nil { - return nil, err - } - } - - // unless --no-Fabric-broker flag is provided, set up a broker config - if !noFabricBroker { - serverConfig.FabricConfig = &FabricBrokerConfig{ - FabricEndpoint: fabricEndpoint, - EndpointConfig: &bus.EndpointConfig{ - TopicPrefix: topicPrefix, - UserQueuePrefix: queuePrefix, - AppRequestPrefix: requestPrefix, - AppRequestQueuePrefix: requestQueuePrefix, - Heartbeat: 60000}, - } - } - - return serverConfig, nil + configFile := f.ConfigFile() + host := f.Hostname() + port := f.Port() + rootDir := f.RootDir() + static := f.Static() + shutdownTimeoutInMinutes := f.ShutdownTimeout() + accessLog := f.AccessLog() + outputLog := f.OutputLog() + errorLog := f.ErrorLog() + debug := f.Debug() + noBanner := f.NoBanner() + cert := f.Cert() + certKey := f.CertKey() + spaPath := f.SpaPath() + noFabricBroker := f.NoFabricBroker() + fabricEndpoint := f.FabricEndpoint() + topicPrefix := f.TopicPrefix() + queuePrefix := f.QueuePrefix() + requestPrefix := f.RequestPrefix() + requestQueuePrefix := f.RequestQueuePrefix() + restBridgeTimeout := f.RestBridgeTimeout() + + // if config file flag is provided, read directly from the file + if len(configFile) > 0 { + var serverConfig PlatformServerConfig + b, err := ioutil.ReadFile(configFile) + if err != nil { + return nil, err + } + if err = json.Unmarshal(b, &serverConfig); err != nil { + return nil, err + } + + // handle invalid duration by setting it to the default value of 5 minutes + if serverConfig.ShutdownTimeout <= 0 { + serverConfig.ShutdownTimeout = 5 + } + + // handle invalid duration by setting it to the default value of 1 minute + if serverConfig.RestBridgeTimeout <= 0 { + serverConfig.RestBridgeTimeout = 1 + } + + // the raw value from the config.json needs to be multiplied by time.Minute otherwise it's interpreted as nanosecond + serverConfig.ShutdownTimeout = serverConfig.ShutdownTimeout * time.Minute + + // the raw value from the config.json needs to be multiplied by time.Minute otherwise it's interpreted as nanosecond + serverConfig.RestBridgeTimeout = serverConfig.RestBridgeTimeout * time.Minute + + // convert map of cache control rules of SpaConfig into an array + if serverConfig.SpaConfig != nil { + serverConfig.SpaConfig.CollateCacheControlRules() + } + + return &serverConfig, nil + } + + // handle invalid duration by setting it to the default value of 1 minute + if restBridgeTimeout <= 0 { + restBridgeTimeout = 1 + } + + // handle invalid duration by setting it to the default value of 5 minutes + if shutdownTimeoutInMinutes <= 0 { + shutdownTimeoutInMinutes = 5 + } + + // instantiate a server config + serverConfig := &PlatformServerConfig{ + Host: host, + Port: port, + RootDir: rootDir, + StaticDir: static, + ShutdownTimeout: time.Duration(shutdownTimeoutInMinutes) * time.Minute, + LogConfig: &utils.LogConfig{ + AccessLog: accessLog, + ErrorLog: errorLog, + OutputLog: outputLog, + Root: rootDir, + FormatOptions: &utils.LogFormatOption{}, + }, + Debug: debug, + NoBanner: noBanner, + RestBridgeTimeout: time.Duration(restBridgeTimeout) * time.Minute, + } + + if len(certKey) > 0 && len(certKey) > 0 { + var err error + certKey, err = filepath.Abs(certKey) + if err != nil { + return nil, err + } + cert, err = filepath.Abs(cert) + if err != nil { + return nil, err + } + + serverConfig.TLSCertConfig = &TLSCertConfig{CertFile: cert, KeyFile: certKey} + } + + if len(strings.TrimSpace(spaPath)) > 0 { + var err error + serverConfig.SpaConfig, err = NewSpaConfig(spaPath) + if err != nil { + return nil, err + } + } + + // unless --no-Fabric-broker flag is provided, set up a broker config + if !noFabricBroker { + serverConfig.FabricConfig = &FabricBrokerConfig{ + FabricEndpoint: fabricEndpoint, + EndpointConfig: &bus.EndpointConfig{ + TopicPrefix: topicPrefix, + UserQueuePrefix: queuePrefix, + AppRequestPrefix: requestPrefix, + AppRequestQueuePrefix: requestQueuePrefix, + Heartbeat: 60000}, + } + } + + return serverConfig, nil } // ensureResponseInByteSlice takes body as an interface not knowing whether it is already converted to []byte or not. // if it is not of []byte type it marshals the payload using json.Marshal. otherwise, the input // byte slice is returned as-is. func ensureResponseInByteSlice(body interface{}) (bytes []byte, err error) { - switch body.(type) { - case []byte: - bytes, err = body.([]byte), nil - default: - bytes, err = json.Marshal(body) - } - return + switch body.(type) { + case []byte: + bytes, err = body.([]byte), nil + default: + bytes, err = json.Marshal(body) + } + return } // sanitizeConfigRootPath takes *PlatformServerConfig, ensures the path specified by RootDir field exists. // if RootDir is empty then the current working directory will be populated. if for some reason the path // cannot be accessed it'll cause a panic. func sanitizeConfigRootPath(config *PlatformServerConfig) { - if len(config.RootDir) == 0 { - wd, _ := os.Getwd() - config.RootDir = wd - } - - absRootPath, err := filepath.Abs(config.RootDir) - if err != nil { - panic(err) - } - - _, err = os.Stat(absRootPath) - if err != nil { - panic(err) - } - - // once it has been confirmed that the path exists, set config.RootDir to the absolute path - config.RootDir = absRootPath + if len(config.RootDir) == 0 { + wd, _ := os.Getwd() + config.RootDir = wd + } + + absRootPath, err := filepath.Abs(config.RootDir) + if err != nil { + panic(err) + } + + _, err = os.Stat(absRootPath) + if err != nil { + panic(err) + } + + // once it has been confirmed that the path exists, set config.RootDir to the absolute path + config.RootDir = absRootPath } diff --git a/plank/pkg/server/helpers_test.go b/plank/pkg/server/helpers_test.go index 761536b..6266163 100644 --- a/plank/pkg/server/helpers_test.go +++ b/plank/pkg/server/helpers_test.go @@ -1,142 +1,142 @@ package server import ( - "encoding/json" - "github.com/spf13/pflag" - "github.com/stretchr/testify/assert" - "os" - "path/filepath" - "testing" - "time" + "encoding/json" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "testing" + "time" ) func TestGeneratePlatformServerConfig_Default(t *testing.T) { - // arrange - f := &serverConfigFactory{} - pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) - - // act - testArgs := []string{""} - f.configureFlags(pflag.CommandLine) - f.parseFlags(testArgs) - config, err := generatePlatformServerConfig(f) - - // assert - wd, _ := os.Getwd() - assert.Nil(t, err) - assert.EqualValues(t, "localhost", config.Host) - assert.EqualValues(t, 30080, config.Port) - assert.EqualValues(t, wd, config.RootDir) - assert.Empty(t, config.StaticDir) - assert.EqualValues(t, 5*time.Minute, config.ShutdownTimeout) - assert.EqualValues(t, "stdout", config.LogConfig.OutputLog) - assert.EqualValues(t, "stdout", config.LogConfig.AccessLog) - assert.EqualValues(t, "stderr", config.LogConfig.ErrorLog) - assert.EqualValues(t, wd, config.LogConfig.Root) - assert.False(t, config.Debug) - assert.False(t, config.NoBanner) - assert.EqualValues(t, time.Minute, config.RestBridgeTimeout) - assert.EqualValues(t, "/ws", config.FabricConfig.FabricEndpoint) - assert.EqualValues(t, "/topic", config.FabricConfig.EndpointConfig.TopicPrefix) - assert.EqualValues(t, "/queue", config.FabricConfig.EndpointConfig.UserQueuePrefix) - assert.EqualValues(t, "/pub", config.FabricConfig.EndpointConfig.AppRequestPrefix) - assert.EqualValues(t, "/pub/queue", config.FabricConfig.EndpointConfig.AppRequestQueuePrefix) - assert.EqualValues(t, 60000, config.FabricConfig.EndpointConfig.Heartbeat) + // arrange + f := &serverConfigFactory{} + pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) + + // act + testArgs := []string{""} + f.configureFlags(pflag.CommandLine) + f.parseFlags(testArgs) + config, err := generatePlatformServerConfig(f) + + // assert + wd, _ := os.Getwd() + assert.Nil(t, err) + assert.EqualValues(t, "localhost", config.Host) + assert.EqualValues(t, 30080, config.Port) + assert.EqualValues(t, wd, config.RootDir) + assert.Empty(t, config.StaticDir) + assert.EqualValues(t, 5*time.Minute, config.ShutdownTimeout) + assert.EqualValues(t, "stdout", config.LogConfig.OutputLog) + assert.EqualValues(t, "stdout", config.LogConfig.AccessLog) + assert.EqualValues(t, "stderr", config.LogConfig.ErrorLog) + assert.EqualValues(t, wd, config.LogConfig.Root) + assert.False(t, config.Debug) + assert.False(t, config.NoBanner) + assert.EqualValues(t, time.Minute, config.RestBridgeTimeout) + assert.EqualValues(t, "/ws", config.FabricConfig.FabricEndpoint) + assert.EqualValues(t, "/topic", config.FabricConfig.EndpointConfig.TopicPrefix) + assert.EqualValues(t, "/queue", config.FabricConfig.EndpointConfig.UserQueuePrefix) + assert.EqualValues(t, "/pub", config.FabricConfig.EndpointConfig.AppRequestPrefix) + assert.EqualValues(t, "/pub/queue", config.FabricConfig.EndpointConfig.AppRequestQueuePrefix) + assert.EqualValues(t, 60000, config.FabricConfig.EndpointConfig.Heartbeat) } func TestGeneratePlatformServerConfig_CertConfig(t *testing.T) { - // arrange - f := &serverConfigFactory{} - pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) - dummyCert := filepath.Join(os.TempDir(), "plank-tests", "cert.pem") - dummyKey := filepath.Join(os.TempDir(), "plank-tests", "key.pem") - - // act - testArgs := []string{"", "--cert", dummyCert, "--cert-key", dummyKey} - f.configureFlags(pflag.CommandLine) - f.parseFlags(testArgs) - config, err := generatePlatformServerConfig(f) - - // assert - assert.Nil(t, err) - assert.EqualValues(t, dummyCert, config.TLSCertConfig.CertFile) - assert.EqualValues(t, dummyKey, config.TLSCertConfig.KeyFile) + // arrange + f := &serverConfigFactory{} + pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) + dummyCert := filepath.Join(os.TempDir(), "plank-tests", "cert.pem") + dummyKey := filepath.Join(os.TempDir(), "plank-tests", "key.pem") + + // act + testArgs := []string{"", "--cert", dummyCert, "--cert-key", dummyKey} + f.configureFlags(pflag.CommandLine) + f.parseFlags(testArgs) + config, err := generatePlatformServerConfig(f) + + // assert + assert.Nil(t, err) + assert.EqualValues(t, dummyCert, config.TLSCertConfig.CertFile) + assert.EqualValues(t, dummyKey, config.TLSCertConfig.KeyFile) } func TestGeneratePlatformServerConfig_SpaConfig(t *testing.T) { - // arrange - f := &serverConfigFactory{} - pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) - spaRoot := filepath.Join(os.TempDir(), "plank-tests", "spaRoot") - - // act - testArgs := []string{"", "--spa-path", spaRoot} - f.configureFlags(pflag.CommandLine) - f.parseFlags(testArgs) - config, err := generatePlatformServerConfig(f) - - // assert - assert.Nil(t, err) - assert.EqualValues(t, spaRoot, config.SpaConfig.RootFolder) - assert.EqualValues(t, "spaRoot", config.SpaConfig.BaseUri) + // arrange + f := &serverConfigFactory{} + pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) + spaRoot := filepath.Join(os.TempDir(), "plank-tests", "spaRoot") + + // act + testArgs := []string{"", "--spa-path", spaRoot} + f.configureFlags(pflag.CommandLine) + f.parseFlags(testArgs) + config, err := generatePlatformServerConfig(f) + + // assert + assert.Nil(t, err) + assert.EqualValues(t, spaRoot, config.SpaConfig.RootFolder) + assert.EqualValues(t, "spaRoot", config.SpaConfig.BaseUri) } func TestGeneratePlatformServerConfig_ConfigFile(t *testing.T) { - // arrange - f := &serverConfigFactory{} - pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) - configFile, err := CreateConfigJsonForTest() - if err != nil { - assert.FailNow(t, err.Error()) - } - defer os.RemoveAll(filepath.Dir(configFile)) - - // act - testArgs := []string{"", "--config-file", configFile} - f.configureFlags(pflag.CommandLine) - f.parseFlags(testArgs) - config, err := generatePlatformServerConfig(f) - - // assert - assert.Nil(t, err) - assert.EqualValues(t, "localhost", config.Host) - assert.EqualValues(t, 31234, config.Port) - assert.EqualValues(t, "./", config.RootDir) - assert.Empty(t, config.StaticDir) - assert.EqualValues(t, 5*time.Minute, config.ShutdownTimeout) - assert.EqualValues(t, "stdout", config.LogConfig.OutputLog) - assert.EqualValues(t, "access.log", config.LogConfig.AccessLog) - assert.EqualValues(t, "errors.log", config.LogConfig.ErrorLog) - assert.EqualValues(t, ".", config.LogConfig.Root) - assert.True(t, config.Debug) - assert.True(t, config.NoBanner) - assert.EqualValues(t, time.Minute, config.RestBridgeTimeout) - assert.EqualValues(t, "cert/server.key", config.TLSCertConfig.KeyFile) - assert.EqualValues(t, "cert/fullchain.pem", config.TLSCertConfig.CertFile) - assert.EqualValues(t, "/ws", config.FabricConfig.FabricEndpoint) - assert.EqualValues(t, "/topic", config.FabricConfig.EndpointConfig.TopicPrefix) - assert.EqualValues(t, "/queue", config.FabricConfig.EndpointConfig.UserQueuePrefix) - assert.EqualValues(t, "/pub", config.FabricConfig.EndpointConfig.AppRequestPrefix) - assert.EqualValues(t, "/pub/queue", config.FabricConfig.EndpointConfig.AppRequestQueuePrefix) - assert.EqualValues(t, 60000, config.FabricConfig.EndpointConfig.Heartbeat) - assert.EqualValues(t, "public/", config.SpaConfig.RootFolder) - assert.EqualValues(t, "/", config.SpaConfig.BaseUri) - assert.EqualValues(t, "public/assets:/assets", config.SpaConfig.StaticAssets[0]) + // arrange + f := &serverConfigFactory{} + pflag.CommandLine = pflag.NewFlagSet("", pflag.ExitOnError) + configFile, err := CreateConfigJsonForTest() + if err != nil { + assert.FailNow(t, err.Error()) + } + defer os.RemoveAll(filepath.Dir(configFile)) + + // act + testArgs := []string{"", "--config-file", configFile} + f.configureFlags(pflag.CommandLine) + f.parseFlags(testArgs) + config, err := generatePlatformServerConfig(f) + + // assert + assert.Nil(t, err) + assert.EqualValues(t, "localhost", config.Host) + assert.EqualValues(t, 31234, config.Port) + assert.EqualValues(t, "./", config.RootDir) + assert.Empty(t, config.StaticDir) + assert.EqualValues(t, 5*time.Minute, config.ShutdownTimeout) + assert.EqualValues(t, "stdout", config.LogConfig.OutputLog) + assert.EqualValues(t, "access.log", config.LogConfig.AccessLog) + assert.EqualValues(t, "errors.log", config.LogConfig.ErrorLog) + assert.EqualValues(t, ".", config.LogConfig.Root) + assert.True(t, config.Debug) + assert.True(t, config.NoBanner) + assert.EqualValues(t, time.Minute, config.RestBridgeTimeout) + assert.EqualValues(t, "cert/server.key", config.TLSCertConfig.KeyFile) + assert.EqualValues(t, "cert/fullchain.pem", config.TLSCertConfig.CertFile) + assert.EqualValues(t, "/ws", config.FabricConfig.FabricEndpoint) + assert.EqualValues(t, "/topic", config.FabricConfig.EndpointConfig.TopicPrefix) + assert.EqualValues(t, "/queue", config.FabricConfig.EndpointConfig.UserQueuePrefix) + assert.EqualValues(t, "/pub", config.FabricConfig.EndpointConfig.AppRequestPrefix) + assert.EqualValues(t, "/pub/queue", config.FabricConfig.EndpointConfig.AppRequestQueuePrefix) + assert.EqualValues(t, 60000, config.FabricConfig.EndpointConfig.Heartbeat) + assert.EqualValues(t, "public/", config.SpaConfig.RootFolder) + assert.EqualValues(t, "/", config.SpaConfig.BaseUri) + assert.EqualValues(t, "public/assets:/assets", config.SpaConfig.StaticAssets[0]) } func TestMarshalResponseBody_byteSlice(t *testing.T) { - payload := []byte("hello") - results, err := ensureResponseInByteSlice(payload) + payload := []byte("hello") + results, err := ensureResponseInByteSlice(payload) - assert.Nil(t, err) - assert.EqualValues(t, payload, results) + assert.Nil(t, err) + assert.EqualValues(t, payload, results) } func TestMarshalResponseBody_nonByteSlice(t *testing.T) { - payload := PlatformServerConfig{} - jsonMarshalled, _ := json.Marshal(payload) - results, err := ensureResponseInByteSlice(payload) + payload := PlatformServerConfig{} + jsonMarshalled, _ := json.Marshal(payload) + results, err := ensureResponseInByteSlice(payload) - assert.Nil(t, err) - assert.EqualValues(t, jsonMarshalled, results) + assert.Nil(t, err) + assert.EqualValues(t, jsonMarshalled, results) } diff --git a/plank/pkg/server/initialize_rest_bridge_override_test.go b/plank/pkg/server/initialize_rest_bridge_override_test.go index 31195ef..56dd379 100644 --- a/plank/pkg/server/initialize_rest_bridge_override_test.go +++ b/plank/pkg/server/initialize_rest_bridge_override_test.go @@ -1,109 +1,109 @@ package server import ( - "fmt" - "github.com/google/uuid" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/model" - "github.com/pb33f/ranch/plank/services" - "github.com/pb33f/ranch/plank/utils" - "github.com/pb33f/ranch/service" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "net/http" - "os" - "path/filepath" - "sync" - "syscall" - "testing" - "time" + "fmt" + "github.com/google/uuid" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/model" + "github.com/pb33f/ranch/plank/services" + "github.com/pb33f/ranch/plank/utils" + "github.com/pb33f/ranch/service" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "net/http" + "os" + "path/filepath" + "sync" + "syscall" + "testing" + "time" ) func TestInitialize_DebugLogging(t *testing.T) { - // arrange - testRoot := filepath.Join(os.TempDir(), "plank-tests") - _ = os.MkdirAll(testRoot, 0755) - defer os.RemoveAll(testRoot) + // arrange + testRoot := filepath.Join(os.TempDir(), "plank-tests") + _ = os.MkdirAll(testRoot, 0755) + defer os.RemoveAll(testRoot) - cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", GetTestPort(), true) - cfg.Debug = true + cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", GetTestPort(), true) + cfg.Debug = true - // act - _, _, _ = CreateTestServer(cfg) + // act + _, _, _ = CreateTestServer(cfg) - // assert - assert.EqualValues(t, logrus.DebugLevel, utils.Log.GetLevel()) + // assert + assert.EqualValues(t, logrus.DebugLevel, utils.Log.GetLevel()) } func TestInitialize_RestBridgeOverride(t *testing.T) { - // arrange - newBus := bus.ResetBus() - service.ResetServiceRegistry() - testRoot := filepath.Join(os.TempDir(), "plank-tests") - _ = os.MkdirAll(testRoot, 0755) - defer os.RemoveAll(testRoot) - defer service.GetServiceRegistry().UnregisterService(services.PingPongServiceChan) - - cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", GetTestPort(), true) - baseUrl, _, testServerInterface := CreateTestServer(cfg) - testServer := testServerInterface.(*platformServer) - testServer.eventbus = newBus - - // register ping pong service with default bridge points of /rest/ping-pong, /rest/ping-pong2 and /rest/ping-pong/{from}/{to}/{message} - testServerInterface.RegisterService(services.NewPingPongService(), services.PingPongServiceChan) - - // start server - syschan := make(chan os.Signal) - wg := sync.WaitGroup{} - wg.Add(1) - go testServerInterface.StartServer(syschan) - - // act - // replace existing rest bridges with a new config - oldRouter := testServer.router - - // assert - RunWhenServerReady(t, newBus, func(t2 *testing.T) { - _ = newBus.SendResponseMessage(service.LifecycleManagerChannelName, &service.SetupRESTBridgeRequest{ - ServiceChannel: services.PingPongServiceChan, - Override: true, - Config: []*service.RESTBridgeConfig{ - { - ServiceChannel: services.PingPongServiceChan, - Uri: "/ping-new", - Method: "GET", - FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { - return model.Request{Id: &uuid.UUID{}, Request: "ping-get", Payload: r.URL.Query().Get("message")} - }, - }, - }, - }, newBus.GetId()) - - // router instance should have been swapped - time.Sleep(1 * time.Second) - assert.NotEqual(t, testServer.router, oldRouter) - - // old endpoints should 404 - rsp, err := http.Get(fmt.Sprintf("%s/rest/ping-pong", baseUrl)) - assert.Nil(t, err) - assert.EqualValues(t, 404, rsp.StatusCode) - - rsp, err = http.Get(fmt.Sprintf("%s/rest/ping-pong2", baseUrl)) - assert.Nil(t, err) - assert.EqualValues(t, 404, rsp.StatusCode) - - rsp, err = http.Get(fmt.Sprintf("%s/rest/ping-pong/a/b/c", baseUrl)) - assert.Nil(t, err) - assert.EqualValues(t, 404, rsp.StatusCode) - - // new endpoints should respond successfully - rsp, err = http.Get(fmt.Sprintf("%s/ping-new", baseUrl)) - assert.Nil(t, err) - assert.EqualValues(t, 200, rsp.StatusCode) - - syschan <- syscall.SIGINT - wg.Done() - }) - - wg.Wait() + // arrange + newBus := bus.ResetBus() + service.ResetServiceRegistry() + testRoot := filepath.Join(os.TempDir(), "plank-tests") + _ = os.MkdirAll(testRoot, 0755) + defer os.RemoveAll(testRoot) + defer service.GetServiceRegistry().UnregisterService(services.PingPongServiceChan) + + cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", GetTestPort(), true) + baseUrl, _, testServerInterface := CreateTestServer(cfg) + testServer := testServerInterface.(*platformServer) + testServer.eventbus = newBus + + // register ping pong service with default bridge points of /rest/ping-pong, /rest/ping-pong2 and /rest/ping-pong/{from}/{to}/{message} + testServerInterface.RegisterService(services.NewPingPongService(), services.PingPongServiceChan) + + // start server + syschan := make(chan os.Signal) + wg := sync.WaitGroup{} + wg.Add(1) + go testServerInterface.StartServer(syschan) + + // act + // replace existing rest bridges with a new config + oldRouter := testServer.router + + // assert + RunWhenServerReady(t, newBus, func(t2 *testing.T) { + _ = newBus.SendResponseMessage(service.LifecycleManagerChannelName, &service.SetupRESTBridgeRequest{ + ServiceChannel: services.PingPongServiceChan, + Override: true, + Config: []*service.RESTBridgeConfig{ + { + ServiceChannel: services.PingPongServiceChan, + Uri: "/ping-new", + Method: "GET", + FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { + return model.Request{Id: &uuid.UUID{}, Request: "ping-get", Payload: r.URL.Query().Get("message")} + }, + }, + }, + }, newBus.GetId()) + + // router instance should have been swapped + time.Sleep(1 * time.Second) + assert.NotEqual(t, testServer.router, oldRouter) + + // old endpoints should 404 + rsp, err := http.Get(fmt.Sprintf("%s/rest/ping-pong", baseUrl)) + assert.Nil(t, err) + assert.EqualValues(t, 404, rsp.StatusCode) + + rsp, err = http.Get(fmt.Sprintf("%s/rest/ping-pong2", baseUrl)) + assert.Nil(t, err) + assert.EqualValues(t, 404, rsp.StatusCode) + + rsp, err = http.Get(fmt.Sprintf("%s/rest/ping-pong/a/b/c", baseUrl)) + assert.Nil(t, err) + assert.EqualValues(t, 404, rsp.StatusCode) + + // new endpoints should respond successfully + rsp, err = http.Get(fmt.Sprintf("%s/ping-new", baseUrl)) + assert.Nil(t, err) + assert.EqualValues(t, 200, rsp.StatusCode) + + syschan <- syscall.SIGINT + wg.Done() + }) + + wg.Wait() } diff --git a/plank/pkg/server/server_smoke_test.go b/plank/pkg/server/server_smoke_test.go index 100fdc8..640e1d9 100644 --- a/plank/pkg/server/server_smoke_test.go +++ b/plank/pkg/server/server_smoke_test.go @@ -1,181 +1,181 @@ package server import ( - "crypto/tls" - "fmt" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/service" - "github.com/stretchr/testify/assert" - "io/ioutil" - "net/http" - "os" - "path/filepath" - "sync" - "testing" + "crypto/tls" + "fmt" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/service" + "github.com/stretchr/testify/assert" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "sync" + "testing" ) // TestSmokeTests_TLS tests if Plank starts with TLS enabled func TestSmokeTests_TLS(t *testing.T) { - // pre-arrange - newBus := bus.ResetBus() - service.ResetServiceRegistry() - testRoot := filepath.Join(os.TempDir(), "plank-tests") - _ = os.MkdirAll(testRoot, 0755) - defer os.RemoveAll(testRoot) - - // arrange - port := GetTestPort() - cfg := GetBasicTestServerConfig(testRoot, "stdout", "null", "stderr", port, true) - cfg.FabricConfig = GetTestFabricBrokerConfig() - cfg.TLSCertConfig = GetTestTLSCertConfig(testRoot) - - // act - var wg sync.WaitGroup - sigChan := make(chan os.Signal) - baseUrl, _, testServer := CreateTestServer(cfg) - testServerInternal := testServer.(*platformServer) - testServerInternal.setEventBusRef(newBus) - - // assert to make sure the server was created with the correct test arguments - assert.EqualValues(t, fmt.Sprintf("https://localhost:%d", port), baseUrl) - - wg.Add(1) - go testServer.StartServer(sigChan) - - originalTransport := http.DefaultTransport - originalTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - - RunWhenServerReady(t, newBus, func(t *testing.T) { - resp, err := http.Get(baseUrl) - if err != nil { - defer func() { - testServer.StopServer() - wg.Done() - }() - t.Fatal(err) - } - assert.EqualValues(t, http.StatusNotFound, resp.StatusCode) - testServer.StopServer() - wg.Done() - }) - wg.Wait() + // pre-arrange + newBus := bus.ResetBus() + service.ResetServiceRegistry() + testRoot := filepath.Join(os.TempDir(), "plank-tests") + _ = os.MkdirAll(testRoot, 0755) + defer os.RemoveAll(testRoot) + + // arrange + port := GetTestPort() + cfg := GetBasicTestServerConfig(testRoot, "stdout", "null", "stderr", port, true) + cfg.FabricConfig = GetTestFabricBrokerConfig() + cfg.TLSCertConfig = GetTestTLSCertConfig(testRoot) + + // act + var wg sync.WaitGroup + sigChan := make(chan os.Signal) + baseUrl, _, testServer := CreateTestServer(cfg) + testServerInternal := testServer.(*platformServer) + testServerInternal.setEventBusRef(newBus) + + // assert to make sure the server was created with the correct test arguments + assert.EqualValues(t, fmt.Sprintf("https://localhost:%d", port), baseUrl) + + wg.Add(1) + go testServer.StartServer(sigChan) + + originalTransport := http.DefaultTransport + originalTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + + RunWhenServerReady(t, newBus, func(t *testing.T) { + resp, err := http.Get(baseUrl) + if err != nil { + defer func() { + testServer.StopServer() + wg.Done() + }() + t.Fatal(err) + } + assert.EqualValues(t, http.StatusNotFound, resp.StatusCode) + testServer.StopServer() + wg.Done() + }) + wg.Wait() } // TestSmokeTests_TLS_InvalidCert tests if Plank fails to start because of an invalid cert func TestSmokeTests_TLS_InvalidCert(t *testing.T) { - // TODO: make StartServer return an error object so it's easier to test + // TODO: make StartServer return an error object so it's easier to test } func TestSmokeTests(t *testing.T) { - newBus := bus.ResetBus() - service.ResetServiceRegistry() - testRoot := filepath.Join(os.TempDir(), "plank-tests") - //testOutFile := filepath.Join(testRoot, "plank-server-tests.log") - _ = os.MkdirAll(testRoot, 0755) - defer os.RemoveAll(testRoot) - - port := GetTestPort() - cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", port, true) - cfg.NoBanner = true - cfg.FabricConfig = GetTestFabricBrokerConfig() - - baseUrl, _, testServer := CreateTestServer(cfg) - testServer.(*platformServer).eventbus = newBus - - assert.EqualValues(t, fmt.Sprintf("http://localhost:%d", port), baseUrl) - - syschan := make(chan os.Signal, 1) - wg := sync.WaitGroup{} - wg.Add(1) - go testServer.StartServer(syschan) - RunWhenServerReady(t, newBus, func(t *testing.T) { - // root url - 404 - t.Run("404 on root", func(t2 *testing.T) { - cl := http.DefaultClient - rsp, err := cl.Get(baseUrl) - assert.Nil(t2, err) - assert.EqualValues(t2, 404, rsp.StatusCode) - }) - - // connection to fabric endpoint - t.Run("fabric endpoint should exist", func(t2 *testing.T) { - cl := http.DefaultClient - rsp, err := cl.Get(fmt.Sprintf("%s/ws", baseUrl)) - assert.Nil(t2, err) - assert.EqualValues(t2, 400, rsp.StatusCode) - }) - - testServer.StopServer() - wg.Done() - }) - wg.Wait() + newBus := bus.ResetBus() + service.ResetServiceRegistry() + testRoot := filepath.Join(os.TempDir(), "plank-tests") + //testOutFile := filepath.Join(testRoot, "plank-server-tests.log") + _ = os.MkdirAll(testRoot, 0755) + defer os.RemoveAll(testRoot) + + port := GetTestPort() + cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", port, true) + cfg.NoBanner = true + cfg.FabricConfig = GetTestFabricBrokerConfig() + + baseUrl, _, testServer := CreateTestServer(cfg) + testServer.(*platformServer).eventbus = newBus + + assert.EqualValues(t, fmt.Sprintf("http://localhost:%d", port), baseUrl) + + syschan := make(chan os.Signal, 1) + wg := sync.WaitGroup{} + wg.Add(1) + go testServer.StartServer(syschan) + RunWhenServerReady(t, newBus, func(t *testing.T) { + // root url - 404 + t.Run("404 on root", func(t2 *testing.T) { + cl := http.DefaultClient + rsp, err := cl.Get(baseUrl) + assert.Nil(t2, err) + assert.EqualValues(t2, 404, rsp.StatusCode) + }) + + // connection to fabric endpoint + t.Run("fabric endpoint should exist", func(t2 *testing.T) { + cl := http.DefaultClient + rsp, err := cl.Get(fmt.Sprintf("%s/ws", baseUrl)) + assert.Nil(t2, err) + assert.EqualValues(t2, 400, rsp.StatusCode) + }) + + testServer.StopServer() + wg.Done() + }) + wg.Wait() } func TestSmokeTests_NoFabric(t *testing.T) { - newBus := bus.ResetBus() - service.ResetServiceRegistry() - testRoot := filepath.Join(os.TempDir(), "plank-tests") - _ = os.MkdirAll(testRoot, 0755) - defer os.RemoveAll(testRoot) - - port := GetTestPort() - cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", port, true) - cfg.FabricConfig = nil - baseUrl, _, testServer := CreateTestServer(cfg) - testServer.(*platformServer).eventbus = newBus - - assert.EqualValues(t, fmt.Sprintf("http://localhost:%d", port), baseUrl) - - syschan := make(chan os.Signal, 1) - wg := sync.WaitGroup{} - wg.Add(1) - go testServer.StartServer(syschan) - RunWhenServerReady(t, newBus, func(t *testing.T) { - // fabric - 404 - t.Run("404 on fabric endpoint", func(t2 *testing.T) { - cl := http.DefaultClient - rsp, err := cl.Get(fmt.Sprintf("%s/ws", baseUrl)) - assert.Nil(t2, err) - assert.EqualValues(t2, 404, rsp.StatusCode) - }) - - testServer.StopServer() - wg.Done() - }) - wg.Wait() + newBus := bus.ResetBus() + service.ResetServiceRegistry() + testRoot := filepath.Join(os.TempDir(), "plank-tests") + _ = os.MkdirAll(testRoot, 0755) + defer os.RemoveAll(testRoot) + + port := GetTestPort() + cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", port, true) + cfg.FabricConfig = nil + baseUrl, _, testServer := CreateTestServer(cfg) + testServer.(*platformServer).eventbus = newBus + + assert.EqualValues(t, fmt.Sprintf("http://localhost:%d", port), baseUrl) + + syschan := make(chan os.Signal, 1) + wg := sync.WaitGroup{} + wg.Add(1) + go testServer.StartServer(syschan) + RunWhenServerReady(t, newBus, func(t *testing.T) { + // fabric - 404 + t.Run("404 on fabric endpoint", func(t2 *testing.T) { + cl := http.DefaultClient + rsp, err := cl.Get(fmt.Sprintf("%s/ws", baseUrl)) + assert.Nil(t2, err) + assert.EqualValues(t2, 404, rsp.StatusCode) + }) + + testServer.StopServer() + wg.Done() + }) + wg.Wait() } func TestSmokeTests_HealthEndpoint(t *testing.T) { - newBus := bus.ResetBus() - service.ResetServiceRegistry() - testRoot := filepath.Join(os.TempDir(), "plank-tests") - _ = os.MkdirAll(testRoot, 0755) - defer os.RemoveAll(testRoot) - - port := GetTestPort() - cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", port, true) - cfg.FabricConfig = nil - baseUrl, _, testServer := CreateTestServer(cfg) - testServer.(*platformServer).eventbus = newBus - - assert.EqualValues(t, fmt.Sprintf("http://localhost:%d", port), baseUrl) - - syschan := make(chan os.Signal, 1) - wg := sync.WaitGroup{} - wg.Add(1) - go testServer.StartServer(syschan) - RunWhenServerReady(t, newBus, func(*testing.T) { - t.Run("/health returns OK", func(t2 *testing.T) { - cl := http.DefaultClient - rsp, err := cl.Get(fmt.Sprintf("%s/health", baseUrl)) - assert.Nil(t2, err) - defer rsp.Body.Close() - bodyBytes, _ := ioutil.ReadAll(rsp.Body) - assert.Contains(t, string(bodyBytes), "OK") - }) - - testServer.StopServer() - wg.Done() - }) - wg.Wait() + newBus := bus.ResetBus() + service.ResetServiceRegistry() + testRoot := filepath.Join(os.TempDir(), "plank-tests") + _ = os.MkdirAll(testRoot, 0755) + defer os.RemoveAll(testRoot) + + port := GetTestPort() + cfg := GetBasicTestServerConfig(testRoot, "stdout", "stdout", "stderr", port, true) + cfg.FabricConfig = nil + baseUrl, _, testServer := CreateTestServer(cfg) + testServer.(*platformServer).eventbus = newBus + + assert.EqualValues(t, fmt.Sprintf("http://localhost:%d", port), baseUrl) + + syschan := make(chan os.Signal, 1) + wg := sync.WaitGroup{} + wg.Add(1) + go testServer.StartServer(syschan) + RunWhenServerReady(t, newBus, func(*testing.T) { + t.Run("/health returns OK", func(t2 *testing.T) { + cl := http.DefaultClient + rsp, err := cl.Get(fmt.Sprintf("%s/health", baseUrl)) + assert.Nil(t2, err) + defer rsp.Body.Close() + bodyBytes, _ := ioutil.ReadAll(rsp.Body) + assert.Contains(t, string(bodyBytes), "OK") + }) + + testServer.StopServer() + wg.Done() + }) + wg.Wait() } diff --git a/plank/pkg/server/spa_config.go b/plank/pkg/server/spa_config.go index 11978a8..602cbe6 100644 --- a/plank/pkg/server/spa_config.go +++ b/plank/pkg/server/spa_config.go @@ -4,11 +4,11 @@ package server import ( - "github.com/gorilla/mux" - "github.com/pb33f/ranch/plank/pkg/middleware" - "github.com/pb33f/ranch/plank/utils" - "net/http" - "regexp" + "github.com/gorilla/mux" + "github.com/pb33f/ranch/plank/pkg/middleware" + "github.com/pb33f/ranch/plank/utils" + "net/http" + "regexp" ) // SpaConfig shorthand for SinglePageApplication Config is used to configure routes for your SPAs like @@ -16,59 +16,59 @@ import ( // are served from /app/static, BaseUri can be set to /app and StaticAssets to "/app/assets". see config.json // for details. type SpaConfig struct { - RootFolder string `json:"root_folder"` // location where Plank will serve SPA - BaseUri string `json:"base_uri"` // base URI for the SPA - StaticAssets []string `json:"static_assets"` // locations for static assets used by the SPA - CacheControlRules map[string]string `json:"cache_control_rules"` // map holding glob pattern - cache-control header value + RootFolder string `json:"root_folder"` // location where Plank will serve SPA + BaseUri string `json:"base_uri"` // base URI for the SPA + StaticAssets []string `json:"static_assets"` // locations for static assets used by the SPA + CacheControlRules map[string]string `json:"cache_control_rules"` // map holding glob pattern - cache-control header value - cacheControlRulePairs []middleware.CacheControlRulePair + cacheControlRulePairs []middleware.CacheControlRulePair } type regexCacheControlRulePair struct { - regex *regexp.Regexp - cacheControlRule string + regex *regexp.Regexp + cacheControlRule string } // NewSpaConfig takes location to where the SPA content is as an input and returns a sanitized // instance of *SpaConfig. func NewSpaConfig(input string) (spaConfig *SpaConfig, err error) { - p, uri := utils.DeriveStaticURIFromPath(input) - spaConfig = &SpaConfig{ - RootFolder: p, - BaseUri: uri, - CacheControlRules: make(map[string]string), - cacheControlRulePairs: make([]middleware.CacheControlRulePair, 0), - } + p, uri := utils.DeriveStaticURIFromPath(input) + spaConfig = &SpaConfig{ + RootFolder: p, + BaseUri: uri, + CacheControlRules: make(map[string]string), + cacheControlRulePairs: make([]middleware.CacheControlRulePair, 0), + } - spaConfig.CollateCacheControlRules() - return spaConfig, err + spaConfig.CollateCacheControlRules() + return spaConfig, err } // CollateCacheControlRules compiles glob patterns and stores them as an array. func (s *SpaConfig) CollateCacheControlRules() { - for globP, rule := range s.CacheControlRules { - pair, err := middleware.NewCacheControlRulePair(globP, rule) - if err != nil { - utils.Log.Errorln("Ignoring invalid glob pattern provided as cache control matcher rule", err) - continue - } + for globP, rule := range s.CacheControlRules { + pair, err := middleware.NewCacheControlRulePair(globP, rule) + if err != nil { + utils.Log.Errorln("Ignoring invalid glob pattern provided as cache control matcher rule", err) + continue + } - s.cacheControlRulePairs = append(s.cacheControlRulePairs, pair) - } + s.cacheControlRulePairs = append(s.cacheControlRulePairs, pair) + } } // CacheControlMiddleware returns the middleware func to be used in route configuration func (s *SpaConfig) CacheControlMiddleware() mux.MiddlewareFunc { - return func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // apply cache control rule that matches first - for _, pair := range s.cacheControlRulePairs { - if pair.CompiledGlobPattern.Match(r.RequestURI) { - w.Header().Set("Cache-Control", pair.CacheControlRule) - break - } - } - handler.ServeHTTP(w, r) - }) - } + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // apply cache control rule that matches first + for _, pair := range s.cacheControlRulePairs { + if pair.CompiledGlobPattern.Match(r.RequestURI) { + w.Header().Set("Cache-Control", pair.CacheControlRule) + break + } + } + handler.ServeHTTP(w, r) + }) + } } diff --git a/plank/pkg/server/test_suite_harness.go b/plank/pkg/server/test_suite_harness.go index 0db39bf..81fac0e 100644 --- a/plank/pkg/server/test_suite_harness.go +++ b/plank/pkg/server/test_suite_harness.go @@ -4,21 +4,21 @@ package server import ( - "errors" - "fmt" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/model" - "github.com/pb33f/ranch/plank/utils" - svc "github.com/pb33f/ranch/service" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "io/ioutil" - "net" - "os" - "path/filepath" - "testing" - "time" + "errors" + "fmt" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/model" + "github.com/pb33f/ranch/plank/utils" + svc "github.com/pb33f/ranch/service" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "io/ioutil" + "net" + "os" + "path/filepath" + "testing" + "time" ) var testSuitePortMap = make(map[string]int) @@ -27,185 +27,185 @@ var testSuitePortMap = make(map[string]int) // In a realistic manner. This is a convenience mechanism to avoid having to rig this harnessing up yourself. type PlankIntegrationTestSuite struct { - // suite.Suite is a reference to a testify test suite. - suite.Suite + // suite.Suite is a reference to a testify test suite. + suite.Suite - // server.PlatformServer is a reference to plank - PlatformServer + // server.PlatformServer is a reference to plank + PlatformServer - // os.Signal is the signal channel passed into plan for OS level notifications. - Syschan chan os.Signal + // os.Signal is the signal channel passed into plan for OS level notifications. + Syschan chan os.Signal - // bus.ChannelManager is a reference to the Transport's Channel Manager. - bus.ChannelManager + // bus.ChannelManager is a reference to the Transport's Channel Manager. + bus.ChannelManager - // bus.EventBus is a reference to Transport. - bus.EventBus + // bus.EventBus is a reference to Transport. + bus.EventBus } // PlankIntegrationTest allows test suites that use PlankIntegrationTestSuite as an embedded struct to set everything // we need on our test. This is the only contract required to use this harness. type PlankIntegrationTest interface { - SetPlatformServer(PlatformServer) - SetSysChan(chan os.Signal) - SetChannelManager(bus.ChannelManager) - SetBus(eventBus bus.EventBus) + SetPlatformServer(PlatformServer) + SetSysChan(chan os.Signal) + SetChannelManager(bus.ChannelManager) + SetBus(eventBus bus.EventBus) } // SetupPlankTestSuiteForTest will copy over everything from the newly set up test suite, to a suite being run. func SetupPlankTestSuiteForTest(suite *PlankIntegrationTestSuite, test PlankIntegrationTest) { - test.SetPlatformServer(suite.PlatformServer) - test.SetSysChan(suite.Syschan) - test.SetChannelManager(suite.ChannelManager) - test.SetBus(suite.EventBus) + test.SetPlatformServer(suite.PlatformServer) + test.SetSysChan(suite.Syschan) + test.SetChannelManager(suite.ChannelManager) + test.SetBus(suite.EventBus) } // GetBasicTestServerConfig will generate a simple platform server config, ready to use in a test. func GetBasicTestServerConfig(rootDir, outLog, accessLog, errLog string, port int, noBanner bool) *PlatformServerConfig { - cfg := &PlatformServerConfig{ - RootDir: rootDir, - Host: "localhost", - Port: port, - RestBridgeTimeout: time.Minute, - LogConfig: &utils.LogConfig{ - OutputLog: outLog, - AccessLog: accessLog, - ErrorLog: errLog, - FormatOptions: &utils.LogFormatOption{}, - }, - NoBanner: noBanner, - ShutdownTimeout: time.Minute, - } - return cfg + cfg := &PlatformServerConfig{ + RootDir: rootDir, + Host: "localhost", + Port: port, + RestBridgeTimeout: time.Minute, + LogConfig: &utils.LogConfig{ + OutputLog: outLog, + AccessLog: accessLog, + ErrorLog: errLog, + FormatOptions: &utils.LogFormatOption{}, + }, + NoBanner: noBanner, + ShutdownTimeout: time.Minute, + } + return cfg } // SetupPlankTestSuite will boot a new instance of plank on your chosen port and will also fire up your service // Ready to be tested. This always runs on localhost. func SetupPlankTestSuite(service svc.FabricService, serviceChannel string, port int, - config *PlatformServerConfig) (*PlankIntegrationTestSuite, error) { + config *PlatformServerConfig) (*PlankIntegrationTestSuite, error) { - s := &PlankIntegrationTestSuite{} + s := &PlankIntegrationTestSuite{} - customFormatter := new(logrus.TextFormatter) - utils.Log.SetFormatter(customFormatter) - customFormatter.DisableTimestamp = true + customFormatter := new(logrus.TextFormatter) + utils.Log.SetFormatter(customFormatter) + customFormatter.DisableTimestamp = true - // check if config has been supplied, if not, generate default one. - if config == nil { - config = GetBasicTestServerConfig("/", "stdout", "stdout", "stderr", port, true) - } + // check if config has been supplied, if not, generate default one. + if config == nil { + config = GetBasicTestServerConfig("/", "stdout", "stdout", "stderr", port, true) + } - s.PlatformServer = NewPlatformServer(config) - if err := s.PlatformServer.RegisterService(service, serviceChannel); err != nil { - return nil, errors.New("cannot create printing press service, test failed") - } + s.PlatformServer = NewPlatformServer(config) + if err := s.PlatformServer.RegisterService(service, serviceChannel); err != nil { + return nil, errors.New("cannot create printing press service, test failed") + } - s.Syschan = make(chan os.Signal, 1) - go s.PlatformServer.StartServer(s.Syschan) + s.Syschan = make(chan os.Signal, 1) + go s.PlatformServer.StartServer(s.Syschan) - s.EventBus = bus.ResetBus() - svc.ResetServiceRegistry() + s.EventBus = bus.ResetBus() + svc.ResetServiceRegistry() - // get a pointer to the channel manager - s.ChannelManager = s.EventBus.GetChannelManager() + // get a pointer to the channel manager + s.ChannelManager = s.EventBus.GetChannelManager() - // wait, service may be slow loading, rest mapping happens last. - wait := make(chan bool) - go func() { - time.Sleep(10 * time.Millisecond) - wait <- true // Sending something to the channel to let the main thread continue + // wait, service may be slow loading, rest mapping happens last. + wait := make(chan bool) + go func() { + time.Sleep(10 * time.Millisecond) + wait <- true // Sending something to the channel to let the main thread continue - }() + }() - <-wait - return s, nil + <-wait + return s, nil } // CreateTestServer consumes *server.PlatformServerConfig and returns a new instance of PlatformServer along with // its base URL and log file name func CreateTestServer(config *PlatformServerConfig) (baseUrl, logFile string, s PlatformServer) { - testServer := NewPlatformServer(config) + testServer := NewPlatformServer(config) - protocol := "http" - if config.TLSCertConfig != nil { - protocol += "s" - } + protocol := "http" + if config.TLSCertConfig != nil { + protocol += "s" + } - baseUrl = fmt.Sprintf("%s://%s:%d", protocol, config.Host, config.Port) - return baseUrl, config.LogConfig.OutputLog, testServer + baseUrl = fmt.Sprintf("%s://%s:%d", protocol, config.Host, config.Port) + return baseUrl, config.LogConfig.OutputLog, testServer } // RunWhenServerReady runs test function fn after Plank has booted up func RunWhenServerReady(t *testing.T, eventBus bus.EventBus, fn func(*testing.T)) { - handler, _ := eventBus.ListenOnce(RANCH_SERVER_ONLINE_CHANNEL) - handler.Handle(func(message *model.Message) { - fn(t) - }, func(err error) { - assert.FailNow(t, err.Error()) - }) + handler, _ := eventBus.ListenOnce(RANCH_SERVER_ONLINE_CHANNEL) + handler.Handle(func(message *model.Message) { + fn(t) + }, func(err error) { + assert.FailNow(t, err.Error()) + }) } // GetTestPort returns an available port for use by Plank tests func GetTestPort() int { - minPort := 9980 - fr := utils.GetCallerStackFrame() - port, exists := testSuitePortMap[fr.File] - if exists { - return port - } - - // try 5 more times with every failure advancing port by one - tryPort := minPort - for i := 0; i <= 4; i++ { - tryPort = minPort + len(testSuitePortMap) + i - _, err := net.Dial("tcp", fmt.Sprintf(":%d", tryPort)) - if err == nil { // port in use, try next one - continue - } - - testSuitePortMap[fr.File] = tryPort - break - } - - if testSuitePortMap[fr.File] == 0 { // port could not be assigned - panic(fmt.Errorf("could not assign a port for tests. last tried port is %d", tryPort)) - } - - return testSuitePortMap[fr.File] + minPort := 9980 + fr := utils.GetCallerStackFrame() + port, exists := testSuitePortMap[fr.File] + if exists { + return port + } + + // try 5 more times with every failure advancing port by one + tryPort := minPort + for i := 0; i <= 4; i++ { + tryPort = minPort + len(testSuitePortMap) + i + _, err := net.Dial("tcp", fmt.Sprintf(":%d", tryPort)) + if err == nil { // port in use, try next one + continue + } + + testSuitePortMap[fr.File] = tryPort + break + } + + if testSuitePortMap[fr.File] == 0 { // port could not be assigned + panic(fmt.Errorf("could not assign a port for tests. last tried port is %d", tryPort)) + } + + return testSuitePortMap[fr.File] } // GetTestTLSCertConfig returns a new &TLSCertConfig for testing. func GetTestTLSCertConfig(testRootPath string) *TLSCertConfig { - crtFile := filepath.Join(testRootPath, "test_server.crt") - keyFile := filepath.Join(testRootPath, "test_server.key") - _ = ioutil.WriteFile(crtFile, []byte(testServerCertTmpl), 0700) - _ = ioutil.WriteFile(keyFile, []byte(testServerKeyTmpl), 0700) - return &TLSCertConfig{ - CertFile: crtFile, - KeyFile: keyFile, - SkipCertificateValidation: true, - } + crtFile := filepath.Join(testRootPath, "test_server.crt") + keyFile := filepath.Join(testRootPath, "test_server.key") + _ = ioutil.WriteFile(crtFile, []byte(testServerCertTmpl), 0700) + _ = ioutil.WriteFile(keyFile, []byte(testServerKeyTmpl), 0700) + return &TLSCertConfig{ + CertFile: crtFile, + KeyFile: keyFile, + SkipCertificateValidation: true, + } } // GetTestFabricBrokerConfig returns a basic fabric broker config. func GetTestFabricBrokerConfig() *FabricBrokerConfig { - return &FabricBrokerConfig{ - FabricEndpoint: "/ws", - UseTCP: false, - TCPPort: 61613, - EndpointConfig: &bus.EndpointConfig{ - TopicPrefix: "/topic", - UserQueuePrefix: "/queue", - AppRequestPrefix: "/pub", - AppRequestQueuePrefix: "/pub/queue", - Heartbeat: 30000, - }, - } + return &FabricBrokerConfig{ + FabricEndpoint: "/ws", + UseTCP: false, + TCPPort: 61613, + EndpointConfig: &bus.EndpointConfig{ + TopicPrefix: "/topic", + UserQueuePrefix: "/queue", + AppRequestPrefix: "/pub", + AppRequestQueuePrefix: "/pub/queue", + Heartbeat: 30000, + }, + } } // CreateConfigJsonForTest creates and returns the path to a file containing the plank configuration in JSON format func CreateConfigJsonForTest() (string, error) { - configJsonContent := `{ + configJsonContent := `{ "debug": true, "no_banner": true, "root_dir": "./", @@ -260,15 +260,15 @@ func CreateConfigJsonForTest() (string, error) { "key_file": "cert/server.key" } }` - testDir := filepath.Join(os.TempDir(), "plank-tests") - testFile := filepath.Join(testDir, "test-config.json") - err := os.MkdirAll(testDir, 0744) - if err != nil { - return "", err - } - err = ioutil.WriteFile(testFile, []byte(configJsonContent), 0744) - if err != nil { - return "", err - } - return testFile, nil + testDir := filepath.Join(os.TempDir(), "plank-tests") + testFile := filepath.Join(testDir, "test-config.json") + err := os.MkdirAll(testDir, 0744) + if err != nil { + return "", err + } + err = ioutil.WriteFile(testFile, []byte(configJsonContent), 0744) + if err != nil { + return "", err + } + return testFile, nil } diff --git a/plank/pkg/server/test_suite_harness_test.go b/plank/pkg/server/test_suite_harness_test.go index 2ed5e23..e209eae 100644 --- a/plank/pkg/server/test_suite_harness_test.go +++ b/plank/pkg/server/test_suite_harness_test.go @@ -1,56 +1,56 @@ package server import ( - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/model" - "github.com/pb33f/ranch/service" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "os" - "testing" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/model" + "github.com/pb33f/ranch/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "os" + "testing" ) func TestGetBasicTestServerConfig(t *testing.T) { - config := GetBasicTestServerConfig("/", "stdout", "stdout", "stderr", 999, true) - assert.Equal(t, "/", config.RootDir) - assert.Equal(t, 999, config.Port) + config := GetBasicTestServerConfig("/", "stdout", "stdout", "stderr", 999, true) + assert.Equal(t, "/", config.RootDir) + assert.Equal(t, 999, config.Port) } // define mock integration suite type testPlankTestIntegration struct { - PlankIntegrationTestSuite + PlankIntegrationTestSuite } func (m *testPlankTestIntegration) SetPlatformServer(s PlatformServer) { - m.PlatformServer = s + m.PlatformServer = s } func (m *testPlankTestIntegration) SetSysChan(c chan os.Signal) { - m.Syschan = c + m.Syschan = c } func (m *testPlankTestIntegration) SetChannelManager(cm bus.ChannelManager) { - m.ChannelManager = cm + m.ChannelManager = cm } func (m *testPlankTestIntegration) SetBus(eventBus bus.EventBus) { - m.EventBus = eventBus + m.EventBus = eventBus } func TestSetupPlankTestSuiteForTest(t *testing.T) { - b := bus.ResetBus() - service.ResetServiceRegistry() - cm := b.GetChannelManager() - pit := &PlankIntegrationTestSuite{ - Suite: suite.Suite{}, - PlatformServer: nil, - Syschan: make(chan os.Signal), - ChannelManager: cm, - EventBus: b, - } + b := bus.ResetBus() + service.ResetServiceRegistry() + cm := b.GetChannelManager() + pit := &PlankIntegrationTestSuite{ + Suite: suite.Suite{}, + PlatformServer: nil, + Syschan: make(chan os.Signal), + ChannelManager: cm, + EventBus: b, + } - test := &testPlankTestIntegration{} - SetupPlankTestSuiteForTest(pit, test) - assert.Equal(t, cm, test.ChannelManager) - assert.Equal(t, b, test.EventBus) - assert.Nil(t, nil, test.PlatformServer) + test := &testPlankTestIntegration{} + SetupPlankTestSuiteForTest(pit, test) + assert.Equal(t, cm, test.ChannelManager) + assert.Equal(t, b, test.EventBus) + assert.Nil(t, nil, test.PlatformServer) } type testService struct { @@ -60,7 +60,7 @@ func (t *testService) HandleServiceRequest(rt *model.Request, c service.FabricSe } func TestSetupPlankTestSuite(t *testing.T) { - suite, err := SetupPlankTestSuite(&testService{}, "nowhere", 62986, nil) - assert.NoError(t, err) - assert.NotNil(t, suite) + suite, err := SetupPlankTestSuite(&testService{}, "nowhere", 62986, nil) + assert.NoError(t, err) + assert.NotNil(t, suite) } diff --git a/plank/services/ping-pong-service.go b/plank/services/ping-pong-service.go index 6deab0f..8b0cd4a 100644 --- a/plank/services/ping-pong-service.go +++ b/plank/services/ping-pong-service.go @@ -4,19 +4,19 @@ package services import ( - "encoding/json" - "fmt" - "github.com/google/uuid" - "github.com/gorilla/mux" - "github.com/pb33f/ranch/model" - "github.com/pb33f/ranch/service" - "io/ioutil" - "net/http" - "time" + "encoding/json" + "fmt" + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/pb33f/ranch/model" + "github.com/pb33f/ranch/service" + "io/ioutil" + "net/http" + "time" ) const ( - PingPongServiceChan = "ping-pong-service" + PingPongServiceChan = "ping-pong-service" ) // PingPongService is a very simple service to demonstrate how request-response cycles are handled in Transport & Plank. @@ -27,45 +27,45 @@ const ( type PingPongService struct{} func NewPingPongService() *PingPongService { - return &PingPongService{} + return &PingPongService{} } // Init will fire when the service is being registered by the fabric, it passes a reference of the same core // Passed through when implementing HandleServiceRequest func (ps *PingPongService) Init(core service.FabricServiceCore) error { - // set default headers for this service. - core.SetHeaders(map[string]string{ - "Content-Type": "application/json", - }) + // set default headers for this service. + core.SetHeaders(map[string]string{ + "Content-Type": "application/json", + }) - return nil + return nil } // HandleServiceRequest routes the incoming request and based on the Request property of request, it invokes the // appropriate handler logic defined and separated by a switch statement like the one shown below. func (ps *PingPongService) HandleServiceRequest(request *model.Request, core service.FabricServiceCore) { - switch request.RequestCommand { - // ping-post request type accepts the payload as a POJO - case "ping-post": - m := make(map[string]interface{}) - m["timestamp"] = time.Now().Unix() - err := json.Unmarshal(request.Payload.([]byte), &m) - if err != nil { - core.SendErrorResponse(request, 400, err.Error()) - } else { - core.SendResponse(request, m) - } - // ping-get request type accepts the payload as a string - case "ping-get": - rsp := make(map[string]interface{}) - val := request.Payload.(string) - rsp["payload"] = val + "-response" - rsp["timestamp"] = time.Now().Unix() - core.SendResponse(request, rsp) - default: - core.HandleUnknownRequest(request) - } + switch request.RequestCommand { + // ping-post request type accepts the payload as a POJO + case "ping-post": + m := make(map[string]interface{}) + m["timestamp"] = time.Now().Unix() + err := json.Unmarshal(request.Payload.([]byte), &m) + if err != nil { + core.SendErrorResponse(request, 400, err.Error()) + } else { + core.SendResponse(request, m) + } + // ping-get request type accepts the payload as a string + case "ping-get": + rsp := make(map[string]interface{}) + val := request.Payload.(string) + rsp["payload"] = val + "-response" + rsp["timestamp"] = time.Now().Unix() + core.SendResponse(request, rsp) + default: + core.HandleUnknownRequest(request) + } } // OnServiceReady contains logic that handles the service initialization that needs to be carried out @@ -74,10 +74,10 @@ func (ps *PingPongService) HandleServiceRequest(request *model.Request, core ser // you need to perform any and every init logic here and return a channel that would receive a payload // once your service truly becomes ready to accept requests. func (ps *PingPongService) OnServiceReady() chan bool { - // for sample purposes this service initializes instantly - readyChan := make(chan bool, 1) - readyChan <- true - return readyChan + // for sample purposes this service initializes instantly + readyChan := make(chan bool, 1) + readyChan <- true + return readyChan } // OnServerShutdown is the opposite of OnServiceReady. it is called when the server enters graceful shutdown @@ -85,8 +85,8 @@ func (ps *PingPongService) OnServiceReady() chan bool { // to return anything because the main server thread is going to shut down soon, but if there's any important teardown // or cleanup that needs to be done, this is the right place to perform that. func (ps *PingPongService) OnServerShutdown() { - // for sample purposes emulate a 1 second teardown process - time.Sleep(1 * time.Second) + // for sample purposes emulate a 1 second teardown process + time.Sleep(1 * time.Second) } // GetRESTBridgeConfig returns a list of REST bridge configurations that Plank will use to automatically register @@ -95,43 +95,43 @@ func (ps *PingPongService) OnServerShutdown() { // as the service author you have full control over every aspect of the translation process which basically turns // an incoming *http.Request into model.Request. See FabricRequestBuilder below to see it in action. func (ps *PingPongService) GetRESTBridgeConfig() []*service.RESTBridgeConfig { - return []*service.RESTBridgeConfig{ - { - ServiceChannel: PingPongServiceChan, - Uri: "/rest/ping-pong", - Method: http.MethodPost, - AllowHead: true, - AllowOptions: true, - FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { - body, _ := ioutil.ReadAll(r.Body) - return model.CreateServiceRequest("ping-post", body) - }, - }, - { - ServiceChannel: PingPongServiceChan, - Uri: "/rest/ping-pong2", - Method: http.MethodGet, - AllowHead: true, - AllowOptions: true, - FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { - return model.Request{Id: &uuid.UUID{}, RequestCommand: "ping-get", Payload: r.URL.Query().Get("message")} - }, - }, - { - ServiceChannel: PingPongServiceChan, - Uri: "/rest/ping-pong/{from}/{to}/{message}", - Method: http.MethodGet, - FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { - pathParams := mux.Vars(r) - return model.Request{ - Id: &uuid.UUID{}, - RequestCommand: "ping-get", - Payload: fmt.Sprintf( - "From %s to %s: %s", - pathParams["from"], - pathParams["to"], - pathParams["message"])} - }, - }, - } + return []*service.RESTBridgeConfig{ + { + ServiceChannel: PingPongServiceChan, + Uri: "/rest/ping-pong", + Method: http.MethodPost, + AllowHead: true, + AllowOptions: true, + FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { + body, _ := ioutil.ReadAll(r.Body) + return model.CreateServiceRequest("ping-post", body) + }, + }, + { + ServiceChannel: PingPongServiceChan, + Uri: "/rest/ping-pong2", + Method: http.MethodGet, + AllowHead: true, + AllowOptions: true, + FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { + return model.Request{Id: &uuid.UUID{}, RequestCommand: "ping-get", Payload: r.URL.Query().Get("message")} + }, + }, + { + ServiceChannel: PingPongServiceChan, + Uri: "/rest/ping-pong/{from}/{to}/{message}", + Method: http.MethodGet, + FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { + pathParams := mux.Vars(r) + return model.Request{ + Id: &uuid.UUID{}, + RequestCommand: "ping-get", + Payload: fmt.Sprintf( + "From %s to %s: %s", + pathParams["from"], + pathParams["to"], + pathParams["message"])} + }, + }, + } } diff --git a/plank/utils/cli.go b/plank/utils/cli.go index 979804f..bfb0a21 100644 --- a/plank/utils/cli.go +++ b/plank/utils/cli.go @@ -4,97 +4,97 @@ package utils var PlatformServerFlagConstants = map[string]map[string]string{ - "Hostname": { - "FlagName": "hostname", - "ShortFlag": "n", - "Description": "Hostname where Plank accepts connections", - }, - "Port": { - "FlagName": "port", - "ShortFlag": "p", - "Description": "Port where Plank is to be served", - }, - "RootDir": { - "FlagName": "rootdir", - "ShortFlag": "r", - "Description": "Root directory for the server (default: Current directory)", - }, - "Cert": { - "FlagName": "cert", - "Description": "X509 Certificate file for TLS", - }, - "CertKey": { - "FlagName": "cert-key", - "Description": "X509 Certificate private Key file for TLS", - }, - "Static": { - "FlagName": "static", - "ShortFlag": "s", - "Description": "Path(s) where static files will be served", - }, - "SpaPath": { - "FlagName": "spa-path", - "Description": "Path to serve Single Page Application (SPA) from. The URI is derived from the leaf directory. A different URI can be specified by providing it following a colon (e.g. --spa-path ./path/to/spa-app:my-spa", - }, - "NoFabricBroker": { - "FlagName": "no-fabric-broker", - "Description": "Disable Fabric (STOMP) broker", - }, - "FabricEndpoint": { - "FlagName": "fabric-endpoint", - "Description": "Fabric broker endpoint", - }, - "TopicPrefix": { - "FlagName": "topic-prefix", - "Description": "Topic prefix for Fabric broker", - }, - "QueuePrefix": { - "FlagName": "query-prefix", - "Description": "Queue prefix for Fabric broker", - }, - "RequestPrefix": { - "FlagName": "request-prefix", - "Description": "Application request prefix for Fabric broker", - }, - "RequestQueuePrefix": { - "FlagName": "request-queue-prefix", - "Description": "Application request queue prefix for Fabric broker", - }, - "ConfigFile": { - "FlagName": "config-file", - "Description": "Path to the server config JSON file", - }, - "ShutdownTimeout": { - "FlagName": "shutdown-timeout", - "Description": "Graceful server shutdown timeout in minutes", - }, - "OutputLog": { - "FlagName": "output-log", - "ShortFlag": "l", - "Description": "Platform log output. Possible values: stdout, stderr, null, or path to a file", - }, - "AccessLog": { - "FlagName": "access-log", - "ShortFlag": "a", - "Description": "HTTP server access log output. Possible values: stdout, stderr, null, or path to a file", - }, - "ErrorLog": { - "FlagName": "error-log", - "ShortFlag": "e", - "Description": "HTTP server error log output. Possible values: stdout, stderr, null, or path to a file", - }, - "Debug": { - "FlagName": "debug", - "ShortFlag": "d", - "Description": "Enable debug logging", - }, - "NoBanner": { - "FlagName": "no-banner", - "ShortFlag": "b", - "Description": "Do not print Plank banner at startup", - }, - "RestBridgeTimeout": { - "FlagName": "rest-bridge-timeout", - "Description": "Time in minutes before a REST endpoint for a service request to timeout", - }, + "Hostname": { + "FlagName": "hostname", + "ShortFlag": "n", + "Description": "Hostname where Plank accepts connections", + }, + "Port": { + "FlagName": "port", + "ShortFlag": "p", + "Description": "Port where Plank is to be served", + }, + "RootDir": { + "FlagName": "rootdir", + "ShortFlag": "r", + "Description": "Root directory for the server (default: Current directory)", + }, + "Cert": { + "FlagName": "cert", + "Description": "X509 Certificate file for TLS", + }, + "CertKey": { + "FlagName": "cert-key", + "Description": "X509 Certificate private Key file for TLS", + }, + "Static": { + "FlagName": "static", + "ShortFlag": "s", + "Description": "Path(s) where static files will be served", + }, + "SpaPath": { + "FlagName": "spa-path", + "Description": "Path to serve Single Page Application (SPA) from. The URI is derived from the leaf directory. A different URI can be specified by providing it following a colon (e.g. --spa-path ./path/to/spa-app:my-spa", + }, + "NoFabricBroker": { + "FlagName": "no-fabric-broker", + "Description": "Disable Fabric (STOMP) broker", + }, + "FabricEndpoint": { + "FlagName": "fabric-endpoint", + "Description": "Fabric broker endpoint", + }, + "TopicPrefix": { + "FlagName": "topic-prefix", + "Description": "Topic prefix for Fabric broker", + }, + "QueuePrefix": { + "FlagName": "query-prefix", + "Description": "Queue prefix for Fabric broker", + }, + "RequestPrefix": { + "FlagName": "request-prefix", + "Description": "Application request prefix for Fabric broker", + }, + "RequestQueuePrefix": { + "FlagName": "request-queue-prefix", + "Description": "Application request queue prefix for Fabric broker", + }, + "ConfigFile": { + "FlagName": "config-file", + "Description": "Path to the server config JSON file", + }, + "ShutdownTimeout": { + "FlagName": "shutdown-timeout", + "Description": "Graceful server shutdown timeout in minutes", + }, + "OutputLog": { + "FlagName": "output-log", + "ShortFlag": "l", + "Description": "Platform log output. Possible values: stdout, stderr, null, or path to a file", + }, + "AccessLog": { + "FlagName": "access-log", + "ShortFlag": "a", + "Description": "HTTP server access log output. Possible values: stdout, stderr, null, or path to a file", + }, + "ErrorLog": { + "FlagName": "error-log", + "ShortFlag": "e", + "Description": "HTTP server error log output. Possible values: stdout, stderr, null, or path to a file", + }, + "Debug": { + "FlagName": "debug", + "ShortFlag": "d", + "Description": "Enable debug logging", + }, + "NoBanner": { + "FlagName": "no-banner", + "ShortFlag": "b", + "Description": "Do not print Plank banner at startup", + }, + "RestBridgeTimeout": { + "FlagName": "rest-bridge-timeout", + "Description": "Time in minutes before a REST endpoint for a service request to timeout", + }, } diff --git a/service/fabric_core_test.go b/service/fabric_core_test.go index c775c20..285f2da 100644 --- a/service/fabric_core_test.go +++ b/service/fabric_core_test.go @@ -4,321 +4,321 @@ package service import ( - "errors" - "github.com/google/uuid" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "sync" - "testing" + "errors" + "github.com/google/uuid" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "sync" + "testing" ) func newTestFabricCore(channelName string) FabricServiceCore { - eventBus := bus.NewEventBusInstance() - eventBus.GetChannelManager().CreateChannel(channelName) - return &fabricCore{ - channelName: channelName, - bus: eventBus, - } + eventBus := bus.NewEventBusInstance() + eventBus.GetChannelManager().CreateChannel(channelName) + return &fabricCore{ + channelName: channelName, + bus: eventBus, + } } func TestFabricCore_Bus(t *testing.T) { - core := newTestFabricCore("test-channel") - assert.NotNil(t, core.Bus()) + core := newTestFabricCore("test-channel") + assert.NotNil(t, core.Bus()) } func TestFabricCore_SendMethods(t *testing.T) { - core := newTestFabricCore("test-channel") - - mh, _ := core.Bus().ListenStream("test-channel") - - wg := sync.WaitGroup{} - - var count = 0 - var lastMessage *model.Message - - mh.Handle(func(message *model.Message) { - count++ - lastMessage = message - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) - - id := uuid.New() - req := model.Request{ - Id: &id, - Request: "test-request", - BrokerDestination: &model.BrokerDestinationConfig{ - Destination: "test", - }, - } - - wg.Add(1) - core.SendResponse(&req, "test-response") - wg.Wait() - - assert.Equal(t, count, 1) - - response, ok := lastMessage.Payload.(*model.Response) - assert.True(t, ok) - assert.Equal(t, response.Id, req.Id) - assert.Equal(t, response.Payload, "test-response") - assert.False(t, response.Error) - assert.Equal(t, response.BrokerDestination.Destination, "test") - - wg.Add(1) - h := make(map[string]string) - h["hello"] = "there" - core.SendResponseWithHeaders(&req, "test-response-with-headers", h) - wg.Wait() - - assert.Equal(t, count, 2) - - response, ok = lastMessage.Payload.(*model.Response) - assert.True(t, ok) - assert.Equal(t, response.Id, req.Id) - assert.Equal(t, response.Payload, "test-response-with-headers") - assert.False(t, response.Error) - assert.Equal(t, response.BrokerDestination.Destination, "test") - assert.Equal(t, response.Headers["hello"], "there") - - wg.Add(1) - core.SendErrorResponse(&req, 404, "test-error") - wg.Wait() - - assert.Equal(t, count, 3) - response = lastMessage.Payload.(*model.Response) - - assert.Equal(t, response.Id, req.Id) - assert.Nil(t, response.Payload) - assert.True(t, response.Error) - assert.Equal(t, response.ErrorCode, 404) - assert.Equal(t, response.ErrorMessage, "test-error") - - wg.Add(1) - - h = make(map[string]string) - h["chicken"] = "nugget" - core.SendErrorResponseWithHeaders(&req, 422, "test-header-error", h) - wg.Wait() - - assert.Equal(t, count, 4) - response = lastMessage.Payload.(*model.Response) - - assert.Equal(t, response.Id, req.Id) - assert.Equal(t, response.Headers["chicken"], "nugget") - assert.Nil(t, response.Payload) - assert.True(t, response.Error) - assert.Equal(t, response.ErrorCode, 422) - assert.Equal(t, response.ErrorMessage, "test-header-error") - - wg.Add(1) - - h = make(map[string]string) - h["potato"] = "dog" - core.SendErrorResponseWithHeadersAndPayload(&req, 500, "test-header-payload-error", "oh my!", h) - wg.Wait() - - assert.Equal(t, count, 5) - response = lastMessage.Payload.(*model.Response) - - assert.Equal(t, response.Id, req.Id) - assert.Equal(t, "dog", response.Headers["potato"]) - assert.Equal(t, "oh my!", response.Payload.(string)) - assert.True(t, response.Error) - assert.Equal(t, response.ErrorCode, 500) - assert.Equal(t, response.ErrorMessage, "test-header-payload-error") - - wg.Add(1) - core.HandleUnknownRequest(&req) - wg.Wait() - - assert.Equal(t, count, 6) - response = lastMessage.Payload.(*model.Response) - - assert.Equal(t, response.Id, req.Id) - assert.True(t, response.Error) - assert.Equal(t, 403, response.ErrorCode) - assert.Equal(t, nil, response.Payload) + core := newTestFabricCore("test-channel") + + mh, _ := core.Bus().ListenStream("test-channel") + + wg := sync.WaitGroup{} + + var count = 0 + var lastMessage *model.Message + + mh.Handle(func(message *model.Message) { + count++ + lastMessage = message + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) + + id := uuid.New() + req := model.Request{ + Id: &id, + Request: "test-request", + BrokerDestination: &model.BrokerDestinationConfig{ + Destination: "test", + }, + } + + wg.Add(1) + core.SendResponse(&req, "test-response") + wg.Wait() + + assert.Equal(t, count, 1) + + response, ok := lastMessage.Payload.(*model.Response) + assert.True(t, ok) + assert.Equal(t, response.Id, req.Id) + assert.Equal(t, response.Payload, "test-response") + assert.False(t, response.Error) + assert.Equal(t, response.BrokerDestination.Destination, "test") + + wg.Add(1) + h := make(map[string]string) + h["hello"] = "there" + core.SendResponseWithHeaders(&req, "test-response-with-headers", h) + wg.Wait() + + assert.Equal(t, count, 2) + + response, ok = lastMessage.Payload.(*model.Response) + assert.True(t, ok) + assert.Equal(t, response.Id, req.Id) + assert.Equal(t, response.Payload, "test-response-with-headers") + assert.False(t, response.Error) + assert.Equal(t, response.BrokerDestination.Destination, "test") + assert.Equal(t, response.Headers["hello"], "there") + + wg.Add(1) + core.SendErrorResponse(&req, 404, "test-error") + wg.Wait() + + assert.Equal(t, count, 3) + response = lastMessage.Payload.(*model.Response) + + assert.Equal(t, response.Id, req.Id) + assert.Nil(t, response.Payload) + assert.True(t, response.Error) + assert.Equal(t, response.ErrorCode, 404) + assert.Equal(t, response.ErrorMessage, "test-error") + + wg.Add(1) + + h = make(map[string]string) + h["chicken"] = "nugget" + core.SendErrorResponseWithHeaders(&req, 422, "test-header-error", h) + wg.Wait() + + assert.Equal(t, count, 4) + response = lastMessage.Payload.(*model.Response) + + assert.Equal(t, response.Id, req.Id) + assert.Equal(t, response.Headers["chicken"], "nugget") + assert.Nil(t, response.Payload) + assert.True(t, response.Error) + assert.Equal(t, response.ErrorCode, 422) + assert.Equal(t, response.ErrorMessage, "test-header-error") + + wg.Add(1) + + h = make(map[string]string) + h["potato"] = "dog" + core.SendErrorResponseWithHeadersAndPayload(&req, 500, "test-header-payload-error", "oh my!", h) + wg.Wait() + + assert.Equal(t, count, 5) + response = lastMessage.Payload.(*model.Response) + + assert.Equal(t, response.Id, req.Id) + assert.Equal(t, "dog", response.Headers["potato"]) + assert.Equal(t, "oh my!", response.Payload.(string)) + assert.True(t, response.Error) + assert.Equal(t, response.ErrorCode, 500) + assert.Equal(t, response.ErrorMessage, "test-header-payload-error") + + wg.Add(1) + core.HandleUnknownRequest(&req) + wg.Wait() + + assert.Equal(t, count, 6) + response = lastMessage.Payload.(*model.Response) + + assert.Equal(t, response.Id, req.Id) + assert.True(t, response.Error) + assert.Equal(t, 403, response.ErrorCode) + assert.Equal(t, nil, response.Payload) } func TestFabricCore_RestServiceRequest(t *testing.T) { - core := newTestFabricCore("test-channel") - - core.Bus().GetChannelManager().CreateChannel(restServiceChannel) - - var lastRequest *model.Request - - wg := sync.WaitGroup{} - - mh, _ := core.Bus().ListenRequestStream(restServiceChannel) - mh.Handle( - func(message *model.Message) { - lastRequest = message.Payload.(*model.Request) - wg.Done() - }, - func(e error) {}) - - var lastSuccess, lastError *model.Response - - restRequest := &RestServiceRequest{ - Uri: "test", - Headers: map[string]string{"h1": "value1"}, - } - - wg.Add(1) - core.RestServiceRequest(restRequest, func(response *model.Response) { - lastSuccess = response - wg.Done() - }, func(response *model.Response) { - lastError = response - wg.Done() - }) - - wg.Wait() - - wg.Add(1) - core.Bus().SendResponseMessage(restServiceChannel, &model.Response{ - Payload: "test", - }, lastRequest.Id) - wg.Wait() - - assert.NotNil(t, lastSuccess) - assert.Nil(t, lastError) - - assert.Equal(t, lastRequest.Payload, restRequest) - assert.Equal(t, len(lastRequest.Payload.(*RestServiceRequest).Headers), 1) - assert.Equal(t, lastRequest.Payload.(*RestServiceRequest).Headers["h1"], "value1") - assert.Equal(t, lastSuccess.Payload, "test") - - lastSuccess, lastError = nil, nil - - core.SetHeaders(map[string]string{"h2": "value2", "h1": "new-value"}) - - wg.Add(1) - core.RestServiceRequest(restRequest, func(response *model.Response) { - lastSuccess = response - wg.Done() - }, func(response *model.Response) { - lastError = response - wg.Done() - }) - - wg.Wait() - - wg.Add(1) - core.Bus().SendResponseMessage(restServiceChannel, &model.Response{ - ErrorMessage: "error", - Error: true, - ErrorCode: 1, - }, lastRequest.Id) - wg.Wait() - - assert.Nil(t, lastSuccess) - assert.NotNil(t, lastError) - assert.Equal(t, lastError.ErrorMessage, "error") - assert.Equal(t, lastError.ErrorCode, 1) - - assert.Equal(t, len(lastRequest.Payload.(*RestServiceRequest).Headers), 2) - assert.Equal(t, lastRequest.Payload.(*RestServiceRequest).Headers["h1"], "value1") - assert.Equal(t, lastRequest.Payload.(*RestServiceRequest).Headers["h2"], "value2") - - lastSuccess, lastError = nil, nil - wg.Add(1) - core.RestServiceRequest(restRequest, func(response *model.Response) { - lastSuccess = response - wg.Done() - }, func(response *model.Response) { - lastError = response - wg.Done() - }) - - wg.Wait() - - wg.Add(1) - core.Bus().SendErrorMessage(restServiceChannel, errors.New("test-error"), lastRequest.Id) - wg.Wait() - - assert.Nil(t, lastSuccess) - assert.NotNil(t, lastError) - assert.Equal(t, lastError.ErrorMessage, "test-error") - assert.Equal(t, lastError.ErrorCode, 500) + core := newTestFabricCore("test-channel") + + core.Bus().GetChannelManager().CreateChannel(restServiceChannel) + + var lastRequest *model.Request + + wg := sync.WaitGroup{} + + mh, _ := core.Bus().ListenRequestStream(restServiceChannel) + mh.Handle( + func(message *model.Message) { + lastRequest = message.Payload.(*model.Request) + wg.Done() + }, + func(e error) {}) + + var lastSuccess, lastError *model.Response + + restRequest := &RestServiceRequest{ + Uri: "test", + Headers: map[string]string{"h1": "value1"}, + } + + wg.Add(1) + core.RestServiceRequest(restRequest, func(response *model.Response) { + lastSuccess = response + wg.Done() + }, func(response *model.Response) { + lastError = response + wg.Done() + }) + + wg.Wait() + + wg.Add(1) + core.Bus().SendResponseMessage(restServiceChannel, &model.Response{ + Payload: "test", + }, lastRequest.Id) + wg.Wait() + + assert.NotNil(t, lastSuccess) + assert.Nil(t, lastError) + + assert.Equal(t, lastRequest.Payload, restRequest) + assert.Equal(t, len(lastRequest.Payload.(*RestServiceRequest).Headers), 1) + assert.Equal(t, lastRequest.Payload.(*RestServiceRequest).Headers["h1"], "value1") + assert.Equal(t, lastSuccess.Payload, "test") + + lastSuccess, lastError = nil, nil + + core.SetHeaders(map[string]string{"h2": "value2", "h1": "new-value"}) + + wg.Add(1) + core.RestServiceRequest(restRequest, func(response *model.Response) { + lastSuccess = response + wg.Done() + }, func(response *model.Response) { + lastError = response + wg.Done() + }) + + wg.Wait() + + wg.Add(1) + core.Bus().SendResponseMessage(restServiceChannel, &model.Response{ + ErrorMessage: "error", + Error: true, + ErrorCode: 1, + }, lastRequest.Id) + wg.Wait() + + assert.Nil(t, lastSuccess) + assert.NotNil(t, lastError) + assert.Equal(t, lastError.ErrorMessage, "error") + assert.Equal(t, lastError.ErrorCode, 1) + + assert.Equal(t, len(lastRequest.Payload.(*RestServiceRequest).Headers), 2) + assert.Equal(t, lastRequest.Payload.(*RestServiceRequest).Headers["h1"], "value1") + assert.Equal(t, lastRequest.Payload.(*RestServiceRequest).Headers["h2"], "value2") + + lastSuccess, lastError = nil, nil + wg.Add(1) + core.RestServiceRequest(restRequest, func(response *model.Response) { + lastSuccess = response + wg.Done() + }, func(response *model.Response) { + lastError = response + wg.Done() + }) + + wg.Wait() + + wg.Add(1) + core.Bus().SendErrorMessage(restServiceChannel, errors.New("test-error"), lastRequest.Id) + wg.Wait() + + assert.Nil(t, lastSuccess) + assert.NotNil(t, lastError) + assert.Equal(t, lastError.ErrorMessage, "test-error") + assert.Equal(t, lastError.ErrorCode, 500) } func TestFabricCore_GenerateJSONHeaders(t *testing.T) { - core := newTestFabricCore("test-channel") - h := core.GenerateJSONHeaders() - assert.EqualValues(t, "application/json", h["Content-Type"]) + core := newTestFabricCore("test-channel") + h := core.GenerateJSONHeaders() + assert.EqualValues(t, "application/json", h["Content-Type"]) } func TestFabricCore_SetDefaultJSONHeaders(t *testing.T) { - core := newTestFabricCore("test-channel") - core.SetDefaultJSONHeaders() + core := newTestFabricCore("test-channel") + core.SetDefaultJSONHeaders() - mh, _ := core.Bus().ListenStream("test-channel") + mh, _ := core.Bus().ListenStream("test-channel") - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - var lastMessage *model.Message + var lastMessage *model.Message - mh.Handle(func(message *model.Message) { - lastMessage = message - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) + mh.Handle(func(message *model.Message) { + lastMessage = message + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) - id := uuid.New() - req := model.Request{ - Id: &id, - Payload: "test-headers", - } + id := uuid.New() + req := model.Request{ + Id: &id, + Payload: "test-headers", + } - wg.Add(1) - core.SendResponse(&req, "test-response") - wg.Wait() + wg.Add(1) + core.SendResponse(&req, "test-response") + wg.Wait() - response := lastMessage.Payload.(*model.Response) + response := lastMessage.Payload.(*model.Response) - // content-type and accept should have been set. - assert.Len(t, response.Headers, 1) - assert.EqualValues(t, "application/json", response.Headers["Content-Type"]) + // content-type and accept should have been set. + assert.Len(t, response.Headers, 1) + assert.EqualValues(t, "application/json", response.Headers["Content-Type"]) } func TestFabricCore_SetDefaultJSONHeadersEmpty(t *testing.T) { - core := newTestFabricCore("test-channel") + core := newTestFabricCore("test-channel") - // set empty headers - core.SetHeaders(nil) + // set empty headers + core.SetHeaders(nil) - mh, _ := core.Bus().ListenStream("test-channel") + mh, _ := core.Bus().ListenStream("test-channel") - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - var lastMessage *model.Message + var lastMessage *model.Message - mh.Handle(func(message *model.Message) { - lastMessage = message - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) + mh.Handle(func(message *model.Message) { + lastMessage = message + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) - id := uuid.New() - req := model.Request{ - Id: &id, - Payload: "test-headers", - } + id := uuid.New() + req := model.Request{ + Id: &id, + Payload: "test-headers", + } - wg.Add(1) - core.SendResponseWithHeaders(&req, "test-response", map[string]string{"Content-Type": "pizza/cake"}) - wg.Wait() + wg.Add(1) + core.SendResponseWithHeaders(&req, "test-response", map[string]string{"Content-Type": "pizza/cake"}) + wg.Wait() - response := lastMessage.Payload.(*model.Response) + response := lastMessage.Payload.(*model.Response) - // content-type and accept should have been set. - assert.Len(t, response.Headers, 1) - assert.EqualValues(t, "pizza/cake", response.Headers["Content-Type"]) + // content-type and accept should have been set. + assert.Len(t, response.Headers, 1) + assert.EqualValues(t, "pizza/cake", response.Headers["Content-Type"]) } diff --git a/service/fabric_error.go b/service/fabric_error.go index 6fef4f5..1889835 100644 --- a/service/fabric_error.go +++ b/service/fabric_error.go @@ -2,20 +2,20 @@ package service // FabricError is a RFC7807 standard error properties (https://tools.ietf.org/html/rfc7807) type FabricError struct { - Type string `json:"type,omitempty"` - Title string `json:"title"` - Status int `json:"status"` - Detail string `json:"detail"` - Instance string `json:"instance,omitempty"` + Type string `json:"type,omitempty"` + Title string `json:"title"` + Status int `json:"status"` + Detail string `json:"detail"` + Instance string `json:"instance,omitempty"` } // GetFabricError will return a structured, standardized Error object that is compliant // with RFC7807 standard error properties (https://tools.ietf.org/html/rfc7807) func GetFabricError(message string, code int, detail string) FabricError { - return FabricError{ - Title: message, - Status: code, - Detail: detail, - Type: "https://github.com/pb33f/ranch/blob/main/plank/services/fabric_error.md", - } + return FabricError{ + Title: message, + Status: code, + Detail: detail, + Type: "https://github.com/pb33f/ranch/blob/main/plank/services/fabric_error.md", + } } diff --git a/service/fabric_service.go b/service/fabric_service.go index 4ec3db2..fe07741 100644 --- a/service/fabric_service.go +++ b/service/fabric_service.go @@ -4,17 +4,17 @@ package service import ( - "github.com/pb33f/ranch/model" + "github.com/pb33f/ranch/model" ) // FabricService Interface containing all APIs which should be implemented by Fabric Services. type FabricService interface { - // Handles a single Fabric Request - HandleServiceRequest(request *model.Request, core FabricServiceCore) + // Handles a single Fabric Request + HandleServiceRequest(request *model.Request, core FabricServiceCore) } // FabricInitializableService Optional interface, if implemented by a fabric service, its Init method // will be invoked when the service is registered in the ServiceRegistry. type FabricInitializableService interface { - Init(core FabricServiceCore) error + Init(core FabricServiceCore) error } diff --git a/service/rest_service.go b/service/rest_service.go index afb7ff3..e63bd60 100644 --- a/service/rest_service.go +++ b/service/rest_service.go @@ -4,211 +4,211 @@ package service import ( - "bytes" - "encoding/json" - "github.com/pb33f/ranch/model" - "io" - "net/http" - "net/url" - "reflect" - "strings" + "bytes" + "encoding/json" + "github.com/pb33f/ranch/model" + "io" + "net/http" + "net/url" + "reflect" + "strings" ) const ( - restServiceChannel = "fabric-rest" + restServiceChannel = "fabric-rest" ) type RestServiceRequest struct { - // The destination URL of the request. - Uri string `json:"uri"` - // HTTP Method to use, e.g. GET, POST, PATCH etc. - Method string `json:"method"` - // The body of the request. String and []byte payloads will be sent as is, - // all other payloads will be serialized as json. - Body interface{} `json:"body"` - // HTTP headers of the request. - Headers map[string]string `json:"headers"` - // Optional type of the response body. If provided the service will try to deserialize - // the response to this type. - // If omitted the response body will be deserialized as map[string]interface{} - // Note that if the response body is not a valid json you should set - // the ResponseType to string or []byte otherwise you might get deserialization error - // or empty result. - ResponseType reflect.Type - // Shouldn't be populated directly, the field is used to deserialize - // com.vmware.bifrost.core.model.RestServiceRequest Java/Typescript requests - ApiClass string `json:"apiClass"` + // The destination URL of the request. + Uri string `json:"uri"` + // HTTP Method to use, e.g. GET, POST, PATCH etc. + Method string `json:"method"` + // The body of the request. String and []byte payloads will be sent as is, + // all other payloads will be serialized as json. + Body interface{} `json:"body"` + // HTTP headers of the request. + Headers map[string]string `json:"headers"` + // Optional type of the response body. If provided the service will try to deserialize + // the response to this type. + // If omitted the response body will be deserialized as map[string]interface{} + // Note that if the response body is not a valid json you should set + // the ResponseType to string or []byte otherwise you might get deserialization error + // or empty result. + ResponseType reflect.Type + // Shouldn't be populated directly, the field is used to deserialize + // com.vmware.bifrost.core.model.RestServiceRequest Java/Typescript requests + ApiClass string `json:"apiClass"` } func (request *RestServiceRequest) marshalBody() ([]byte, error) { - // don't marshal string and []byte payloads as json - stringPayload, ok := request.Body.(string) - if ok { - return []byte(stringPayload), nil - } - bytePayload, ok := request.Body.([]byte) - if ok { - return bytePayload, nil - } - // encode the message payload as JSON - return json.Marshal(request.Body) + // don't marshal string and []byte payloads as json + stringPayload, ok := request.Body.(string) + if ok { + return []byte(stringPayload), nil + } + bytePayload, ok := request.Body.([]byte) + if ok { + return bytePayload, nil + } + // encode the message payload as JSON + return json.Marshal(request.Body) } type restService struct { - httpClient http.Client - baseHost string + httpClient http.Client + baseHost string } func (rs *restService) setBaseHost(host string) { - rs.baseHost = host + rs.baseHost = host } func (rs *restService) HandleServiceRequest(request *model.Request, core FabricServiceCore) { - restReq, ok := rs.getRestServiceRequest(request) - if !ok { - core.SendErrorResponse(request, 500, "invalid RestServiceRequest payload") - return - } - - body, err := restReq.marshalBody() - if err != nil { - core.SendErrorResponse(request, 500, "cannot marshal request body: "+err.Error()) - return - } - - httpReq, err := http.NewRequest(restReq.Method, - rs.getRequestUrl(restReq.Uri, core), bytes.NewBuffer(body)) - - if err != nil { - core.SendErrorResponse(request, 500, err.Error()) - return - } - - // update headers - for k, v := range restReq.Headers { - httpReq.Header.Add(k, v) - } - - // add default Content-Type header if such is not provided in the request - if httpReq.Header.Get("Content-Type") == "" { - httpReq.Header.Add("Content-Type", "application/merge-patch+json") - } - - contentType := httpReq.Header.Get("Content-Type") - if strings.Contains(contentType, "json") { - // leaving restReq.ResponseType empty is equivalent to treating the response as JSON. see deserializeResponse(). - } else { - // otherwise default to byte slice. note that we have an arm for the string type, but defaulting to the byte - // slice makes the payload more flexible to handle in downstream consumers - restReq.ResponseType = reflect.TypeOf([]byte{}) - } - - httpResp, err := rs.httpClient.Do(httpReq) - if err != nil { - core.SendErrorResponse(request, 500, err.Error()) - return - } - defer httpResp.Body.Close() - - if httpResp.StatusCode >= 300 { - core.SendErrorResponseWithPayload(request, httpResp.StatusCode, - "rest-service error, unable to complete request: "+httpResp.Status, - map[string]interface{}{ - "errorCode": httpResp.StatusCode, - "message": "rest-service error, unable to complete request: " + httpResp.Status, - }) - return - } - - result, err := rs.deserializeResponse(httpResp.Body, restReq.ResponseType) - if err != nil { - core.SendErrorResponse(request, 500, "failed to deserialize response:"+err.Error()) - } else { - core.SendResponse(request, result) - } + restReq, ok := rs.getRestServiceRequest(request) + if !ok { + core.SendErrorResponse(request, 500, "invalid RestServiceRequest payload") + return + } + + body, err := restReq.marshalBody() + if err != nil { + core.SendErrorResponse(request, 500, "cannot marshal request body: "+err.Error()) + return + } + + httpReq, err := http.NewRequest(restReq.Method, + rs.getRequestUrl(restReq.Uri, core), bytes.NewBuffer(body)) + + if err != nil { + core.SendErrorResponse(request, 500, err.Error()) + return + } + + // update headers + for k, v := range restReq.Headers { + httpReq.Header.Add(k, v) + } + + // add default Content-Type header if such is not provided in the request + if httpReq.Header.Get("Content-Type") == "" { + httpReq.Header.Add("Content-Type", "application/merge-patch+json") + } + + contentType := httpReq.Header.Get("Content-Type") + if strings.Contains(contentType, "json") { + // leaving restReq.ResponseType empty is equivalent to treating the response as JSON. see deserializeResponse(). + } else { + // otherwise default to byte slice. note that we have an arm for the string type, but defaulting to the byte + // slice makes the payload more flexible to handle in downstream consumers + restReq.ResponseType = reflect.TypeOf([]byte{}) + } + + httpResp, err := rs.httpClient.Do(httpReq) + if err != nil { + core.SendErrorResponse(request, 500, err.Error()) + return + } + defer httpResp.Body.Close() + + if httpResp.StatusCode >= 300 { + core.SendErrorResponseWithPayload(request, httpResp.StatusCode, + "rest-service error, unable to complete request: "+httpResp.Status, + map[string]interface{}{ + "errorCode": httpResp.StatusCode, + "message": "rest-service error, unable to complete request: " + httpResp.Status, + }) + return + } + + result, err := rs.deserializeResponse(httpResp.Body, restReq.ResponseType) + if err != nil { + core.SendErrorResponse(request, 500, "failed to deserialize response:"+err.Error()) + } else { + core.SendResponse(request, result) + } } func (rs *restService) getRestServiceRequest(request *model.Request) (*RestServiceRequest, bool) { - restReq, ok := request.Payload.(*RestServiceRequest) - if ok { - return restReq, true - } - - // check if the request.Payload is map[string]interface{} and convert it to RestServiceRequest - // This is needed to handle requests coming from Java/Typescript Transport clients. - reqAsMap, ok := request.Payload.(map[string]interface{}) - if ok { - restServReqInt, err := model.ConvertValueToType(reqAsMap, reflect.TypeOf(&RestServiceRequest{})) - if err == nil && restServReqInt != nil { - restServReq := restServReqInt.(*RestServiceRequest) - if restServReq.ApiClass == "java.lang.String" { - restServReq.ResponseType = reflect.TypeOf("") - } - return restServReq, true - } - } - - return nil, false + restReq, ok := request.Payload.(*RestServiceRequest) + if ok { + return restReq, true + } + + // check if the request.Payload is map[string]interface{} and convert it to RestServiceRequest + // This is needed to handle requests coming from Java/Typescript Transport clients. + reqAsMap, ok := request.Payload.(map[string]interface{}) + if ok { + restServReqInt, err := model.ConvertValueToType(reqAsMap, reflect.TypeOf(&RestServiceRequest{})) + if err == nil && restServReqInt != nil { + restServReq := restServReqInt.(*RestServiceRequest) + if restServReq.ApiClass == "java.lang.String" { + restServReq.ResponseType = reflect.TypeOf("") + } + return restServReq, true + } + } + + return nil, false } func (rs *restService) getRequestUrl(address string, core FabricServiceCore) string { - if rs.baseHost == "" { - return address - } - - result, err := url.Parse(address) - if err != nil { - return address - } - result.Host = rs.baseHost - return result.String() + if rs.baseHost == "" { + return address + } + + result, err := url.Parse(address) + if err != nil { + return address + } + result.Host = rs.baseHost + return result.String() } func (rs *restService) deserializeResponse( - body io.ReadCloser, responseType reflect.Type) (interface{}, error) { - - if responseType != nil { - - // check for string responseType - if responseType.Kind() == reflect.String { - buf := new(bytes.Buffer) - _, err := buf.ReadFrom(body) - if err != nil { - return nil, err - } - return buf.String(), nil - } - - // check for []byte responseType - if responseType.Kind() == reflect.Slice && - responseType == reflect.TypeOf([]byte{}) { - buf := new(bytes.Buffer) - _, err := buf.ReadFrom(body) - if err != nil { - return nil, err - } - return buf.Bytes(), nil - } - - var returnResultAsPointer bool - if responseType.Kind() == reflect.Ptr { - returnResultAsPointer = true - responseType = responseType.Elem() - } - decodedValuePtr := reflect.New(responseType).Interface() - err := json.NewDecoder(body).Decode(&decodedValuePtr) - if err != nil { - return nil, err - } - if returnResultAsPointer { - return decodedValuePtr, nil - } else { - return reflect.ValueOf(decodedValuePtr).Elem().Interface(), nil - } - } else { - var result map[string]interface{} - err := json.NewDecoder(body).Decode(&result) - return result, err - } + body io.ReadCloser, responseType reflect.Type) (interface{}, error) { + + if responseType != nil { + + // check for string responseType + if responseType.Kind() == reflect.String { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(body) + if err != nil { + return nil, err + } + return buf.String(), nil + } + + // check for []byte responseType + if responseType.Kind() == reflect.Slice && + responseType == reflect.TypeOf([]byte{}) { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(body) + if err != nil { + return nil, err + } + return buf.Bytes(), nil + } + + var returnResultAsPointer bool + if responseType.Kind() == reflect.Ptr { + returnResultAsPointer = true + responseType = responseType.Elem() + } + decodedValuePtr := reflect.New(responseType).Interface() + err := json.NewDecoder(body).Decode(&decodedValuePtr) + if err != nil { + return nil, err + } + if returnResultAsPointer { + return decodedValuePtr, nil + } else { + return reflect.ValueOf(decodedValuePtr).Elem().Interface(), nil + } + } else { + var result map[string]interface{} + err := json.NewDecoder(body).Decode(&result) + return result, err + } } diff --git a/service/rest_service_test.go b/service/rest_service_test.go index 2afbc3e..f189e3c 100644 --- a/service/rest_service_test.go +++ b/service/rest_service_test.go @@ -4,45 +4,45 @@ package service import ( - "bytes" - "encoding/json" - "errors" - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "io/ioutil" - "net/http" - "reflect" - "strings" - "sync" - "testing" + "bytes" + "encoding/json" + "errors" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "io/ioutil" + "net/http" + "reflect" + "strings" + "sync" + "testing" ) type testItem struct { - Name string `json:"name"` - Count int `json:"count"` + Name string `json:"name"` + Count int `json:"count"` } func TestRestServiceRequest_marshalBody(t *testing.T) { - reqWithStringBody := &RestServiceRequest{Body: "test-body"} - body, err := reqWithStringBody.marshalBody() - assert.Nil(t, err) - assert.Equal(t, []byte("test-body"), body) - - reqWithBytesBody := &RestServiceRequest{Body: []byte{1, 2, 3, 4}} - body, err = reqWithBytesBody.marshalBody() - assert.Nil(t, err) - assert.Equal(t, reqWithBytesBody.Body, body) - - item := testItem{Name: "test-name", Count: 5} - reqWithTestItem := &RestServiceRequest{Body: item} - body, err = reqWithTestItem.marshalBody() - assert.Nil(t, err) - expectedValue, _ := json.Marshal(item) - assert.Equal(t, expectedValue, body) + reqWithStringBody := &RestServiceRequest{Body: "test-body"} + body, err := reqWithStringBody.marshalBody() + assert.Nil(t, err) + assert.Equal(t, []byte("test-body"), body) + + reqWithBytesBody := &RestServiceRequest{Body: []byte{1, 2, 3, 4}} + body, err = reqWithBytesBody.marshalBody() + assert.Nil(t, err) + assert.Equal(t, reqWithBytesBody.Body, body) + + item := testItem{Name: "test-name", Count: 5} + reqWithTestItem := &RestServiceRequest{Body: item} + body, err = reqWithTestItem.marshalBody() + assert.Nil(t, err) + expectedValue, _ := json.Marshal(item) + assert.Equal(t, expectedValue, body) } func TestRestService_AutoRegistration(t *testing.T) { - assert.NotNil(t, GetServiceRegistry().(*serviceRegistry).services[restServiceChannel]) + assert.NotNil(t, GetServiceRegistry().(*serviceRegistry).services[restServiceChannel]) } // RoundTripFunc . @@ -50,314 +50,314 @@ type RoundTripFunc func(req *http.Request) (*http.Response, error) // RoundTrip . func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req) + return f(req) } func TestRestService_HandleServiceRequest(t *testing.T) { - core := newTestFabricCore(restServiceChannel) - - restService := &restService{} - var lastHttpRequest *http.Request - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), - Header: make(http.Header), - }, nil - }) - - var lastResponse *model.Response - - wg := sync.WaitGroup{} - wg.Add(1) - - mh, _ := core.Bus().ListenStream(restServiceChannel) - mh.Handle( - func(message *model.Message) { - lastResponse = message.Payload.(*model.Response) - wg.Done() - }, - func(e error) { - assert.Fail(t, "unexpected error") - }) - - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - Headers: map[string]string{"header1": "value1", "header2": "value2"}, - Method: "UPDATE", - Body: "test-body", - ResponseType: reflect.TypeOf(""), - }, - }, core) - - wg.Wait() - - assert.NotNil(t, lastHttpRequest) - assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") - assert.Equal(t, lastHttpRequest.Method, "UPDATE") - assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") - assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") - assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") - sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) - assert.Equal(t, sentBody, []byte("test-body")) - - assert.NotNil(t, lastResponse) - assert.Equal(t, lastResponse.Payload, "test-response-body") - assert.False(t, lastResponse.Error) - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString(`{"name": "test-name", "count": 2}`)), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - Headers: map[string]string{"Content-Type": "json"}, - ResponseType: reflect.TypeOf(testItem{}), - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "json") - assert.Equal(t, lastResponse.Payload, testItem{Name: "test-name", Count: 2}) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - ResponseType: reflect.TypeOf(&testItem{}), - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastResponse.Payload, &testItem{Name: "test-name", Count: 2}) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastResponse.Payload, map[string]interface{}{"name": "test-name", "count": float64(2)}) - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte{1, 2, 3, 4, 5})), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - ResponseType: reflect.TypeOf([]byte{}), - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastResponse.Payload, []byte{1, 2, 3, 4, 5}) + core := newTestFabricCore(restServiceChannel) + + restService := &restService{} + var lastHttpRequest *http.Request + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), + Header: make(http.Header), + }, nil + }) + + var lastResponse *model.Response + + wg := sync.WaitGroup{} + wg.Add(1) + + mh, _ := core.Bus().ListenStream(restServiceChannel) + mh.Handle( + func(message *model.Message) { + lastResponse = message.Payload.(*model.Response) + wg.Done() + }, + func(e error) { + assert.Fail(t, "unexpected error") + }) + + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + Headers: map[string]string{"header1": "value1", "header2": "value2"}, + Method: "UPDATE", + Body: "test-body", + ResponseType: reflect.TypeOf(""), + }, + }, core) + + wg.Wait() + + assert.NotNil(t, lastHttpRequest) + assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") + assert.Equal(t, lastHttpRequest.Method, "UPDATE") + assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") + assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") + assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") + sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) + assert.Equal(t, sentBody, []byte("test-body")) + + assert.NotNil(t, lastResponse) + assert.Equal(t, lastResponse.Payload, "test-response-body") + assert.False(t, lastResponse.Error) + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`{"name": "test-name", "count": 2}`)), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + Headers: map[string]string{"Content-Type": "json"}, + ResponseType: reflect.TypeOf(testItem{}), + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "json") + assert.Equal(t, lastResponse.Payload, testItem{Name: "test-name", Count: 2}) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + ResponseType: reflect.TypeOf(&testItem{}), + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastResponse.Payload, &testItem{Name: "test-name", Count: 2}) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastResponse.Payload, map[string]interface{}{"name": "test-name", "count": float64(2)}) + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte{1, 2, 3, 4, 5})), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + ResponseType: reflect.TypeOf([]byte{}), + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastResponse.Payload, []byte{1, 2, 3, 4, 5}) } func TestRestService_HandleJavaServiceRequest(t *testing.T) { - core := newTestFabricCore(restServiceChannel) - - wg := sync.WaitGroup{} - - restService := &restService{} - var lastHttpRequest *http.Request - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - defer wg.Done() - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: map[string]interface{}{ - "uri": "http://localhost:4444/test-url", - "headers": map[string]string{"header1": "value1", "header2": "value2"}, - "method": "UPDATE", - "Body": "test-body", - "apiClass": "java.lang.String", - }, - }, core) - - wg.Wait() - - assert.NotNil(t, lastHttpRequest) - assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") - assert.Equal(t, lastHttpRequest.Method, "UPDATE") - assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") - assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") - assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") - sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) - assert.Equal(t, sentBody, []byte("test-body")) + core := newTestFabricCore(restServiceChannel) + + wg := sync.WaitGroup{} + + restService := &restService{} + var lastHttpRequest *http.Request + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + defer wg.Done() + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: map[string]interface{}{ + "uri": "http://localhost:4444/test-url", + "headers": map[string]string{"header1": "value1", "header2": "value2"}, + "method": "UPDATE", + "Body": "test-body", + "apiClass": "java.lang.String", + }, + }, core) + + wg.Wait() + + assert.NotNil(t, lastHttpRequest) + assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") + assert.Equal(t, lastHttpRequest.Method, "UPDATE") + assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") + assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") + assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") + sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) + assert.Equal(t, sentBody, []byte("test-body")) } func TestRestService_HandleServiceRequest_InvalidInput(t *testing.T) { - core := newTestFabricCore(restServiceChannel) - - restService := &restService{} - var lastResponse *model.Response - - wg := sync.WaitGroup{} - wg.Add(1) - mh, _ := core.Bus().ListenStream(restServiceChannel) - mh.Handle( - func(message *model.Message) { - lastResponse = message.Payload.(*model.Response) - wg.Done() - }, - func(e error) { - assert.Fail(t, "unexpected error") - }) - - restService.HandleServiceRequest(&model.Request{ - Payload: RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - Method: "UPDATE", - }, - }, core) - - wg.Wait() - - assert.NotNil(t, lastResponse) - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 500) - assert.Equal(t, lastResponse.ErrorMessage, "invalid RestServiceRequest payload") - - wg.Add(1) - - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - Method: "@!#$%^&**()", - }, - }, core) - - wg.Wait() - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 500) - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - return nil, errors.New("custom-rest-error") - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - wg.Wait() - - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 500) - assert.True(t, strings.Contains(lastResponse.ErrorMessage, "custom-rest-error")) - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 404, - Status: "404 Not Found", - Body: ioutil.NopCloser(bytes.NewBufferString("error-response")), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - wg.Wait() - - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 404) - assert.Equal(t, lastResponse.ErrorMessage, "rest-service error, unable to complete request: 404 Not Found") - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString("}")), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - ResponseType: reflect.TypeOf(&testItem{}), - }, - }, core) - wg.Wait() - - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 500) - assert.True(t, strings.HasPrefix(lastResponse.ErrorMessage, "failed to deserialize response:")) + core := newTestFabricCore(restServiceChannel) + + restService := &restService{} + var lastResponse *model.Response + + wg := sync.WaitGroup{} + wg.Add(1) + mh, _ := core.Bus().ListenStream(restServiceChannel) + mh.Handle( + func(message *model.Message) { + lastResponse = message.Payload.(*model.Response) + wg.Done() + }, + func(e error) { + assert.Fail(t, "unexpected error") + }) + + restService.HandleServiceRequest(&model.Request{ + Payload: RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + Method: "UPDATE", + }, + }, core) + + wg.Wait() + + assert.NotNil(t, lastResponse) + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 500) + assert.Equal(t, lastResponse.ErrorMessage, "invalid RestServiceRequest payload") + + wg.Add(1) + + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + Method: "@!#$%^&**()", + }, + }, core) + + wg.Wait() + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 500) + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("custom-rest-error") + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + wg.Wait() + + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 500) + assert.True(t, strings.Contains(lastResponse.ErrorMessage, "custom-rest-error")) + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 404, + Status: "404 Not Found", + Body: ioutil.NopCloser(bytes.NewBufferString("error-response")), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + wg.Wait() + + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 404) + assert.Equal(t, lastResponse.ErrorMessage, "rest-service error, unable to complete request: 404 Not Found") + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("}")), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + ResponseType: reflect.TypeOf(&testItem{}), + }, + }, core) + wg.Wait() + + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 500) + assert.True(t, strings.HasPrefix(lastResponse.ErrorMessage, "failed to deserialize response:")) } func TestRestService_setBaseHost(t *testing.T) { - core := newTestFabricCore(restServiceChannel) - restService := &restService{} - - wg := sync.WaitGroup{} - - var lastHttpRequest *http.Request - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - wg.Done() - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), - Header: make(http.Header), - }, nil - }) - - restService.setBaseHost("appfabric.vmware.com:9999") - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastHttpRequest.Host, "appfabric.vmware.com:9999") - - restService.setBaseHost("") - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - wg.Wait() - - assert.Equal(t, lastHttpRequest.Host, "localhost:4444") + core := newTestFabricCore(restServiceChannel) + restService := &restService{} + + wg := sync.WaitGroup{} + + var lastHttpRequest *http.Request + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + wg.Done() + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), + Header: make(http.Header), + }, nil + }) + + restService.setBaseHost("appfabric.vmware.com:9999") + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastHttpRequest.Host, "appfabric.vmware.com:9999") + + restService.setBaseHost("") + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + wg.Wait() + + assert.Equal(t, lastHttpRequest.Host, "localhost:4444") } diff --git a/service/service_lifecycle_manager.go b/service/service_lifecycle_manager.go index 948d594..357d11e 100644 --- a/service/service_lifecycle_manager.go +++ b/service/service_lifecycle_manager.go @@ -1,8 +1,8 @@ package service import ( - "github.com/pb33f/ranch/model" - "net/http" + "github.com/pb33f/ranch/model" + "net/http" ) var svcLifecycleManagerInstance ServiceLifecycleManager @@ -10,116 +10,116 @@ var svcLifecycleManagerInstance ServiceLifecycleManager type RequestBuilder func(w http.ResponseWriter, r *http.Request) model.Request type ServiceLifecycleManager interface { - //GetServiceHooks(serviceChannelName string) ServiceLifecycleHookEnabled - GetOnReadyCapableService(serviceChannelName string) OnServiceReadyEnabled - GetOnServerShutdownService(serviceChannelName string) OnServerShutdownEnabled - GetRESTBridgeEnabledService(serviceChannelName string) RESTBridgeEnabled - OverrideRESTBridgeConfig(serviceChannelName string, config []*RESTBridgeConfig) error + //GetServiceHooks(serviceChannelName string) ServiceLifecycleHookEnabled + GetOnReadyCapableService(serviceChannelName string) OnServiceReadyEnabled + GetOnServerShutdownService(serviceChannelName string) OnServerShutdownEnabled + GetRESTBridgeEnabledService(serviceChannelName string) RESTBridgeEnabled + OverrideRESTBridgeConfig(serviceChannelName string, config []*RESTBridgeConfig) error } type ServiceLifecycleHookEnabled interface { - OnServiceReady() chan bool // service initialization logic should be implemented here - OnServerShutdown() // teardown logic goes here and will be automatically invoked on graceful server shutdown - GetRESTBridgeConfig() []*RESTBridgeConfig // service-to-REST endpoint mappings go here + OnServiceReady() chan bool // service initialization logic should be implemented here + OnServerShutdown() // teardown logic goes here and will be automatically invoked on graceful server shutdown + GetRESTBridgeConfig() []*RESTBridgeConfig // service-to-REST endpoint mappings go here } type RESTBridgeEnabled interface { - GetRESTBridgeConfig() []*RESTBridgeConfig // service-to-REST endpoint mappings go here + GetRESTBridgeConfig() []*RESTBridgeConfig // service-to-REST endpoint mappings go here } type OnServiceReadyEnabled interface { - OnServiceReady() chan bool // service initialization logic should be implemented here + OnServiceReady() chan bool // service initialization logic should be implemented here } type OnServerShutdownEnabled interface { - OnServerShutdown() // teardown logic goes here and will be automatically invoked on graceful server shutdown + OnServerShutdown() // teardown logic goes here and will be automatically invoked on graceful server shutdown } type SetupRESTBridgeRequest struct { - ServiceChannel string - Override bool - Config []*RESTBridgeConfig + ServiceChannel string + Override bool + Config []*RESTBridgeConfig } type RESTBridgeConfig struct { - ServiceChannel string // transport service channel - Uri string // URI to map the transport service to - Method string // HTTP verb to map the transport service request to URI with - AllowHead bool // whether HEAD calls are allowed for this bridge point - AllowOptions bool // whether OPTIONS calls are allowed for this bridge point - FabricRequestBuilder RequestBuilder // function to transform HTTP request into a transport request + ServiceChannel string // transport service channel + Uri string // URI to map the transport service to + Method string // HTTP verb to map the transport service request to URI with + AllowHead bool // whether HEAD calls are allowed for this bridge point + AllowOptions bool // whether OPTIONS calls are allowed for this bridge point + FabricRequestBuilder RequestBuilder // function to transform HTTP request into a transport request } type serviceLifecycleManager struct { - serviceRegistryRef ServiceRegistry // service registry reference + serviceRegistryRef ServiceRegistry // service registry reference } // GetOnReadyCapableService returns a service that implements OnServiceReadyEnabled func (lm *serviceLifecycleManager) GetOnReadyCapableService(serviceChannelName string) OnServiceReadyEnabled { - service, err := lm.serviceRegistryRef.GetService(serviceChannelName) - if err != nil { - return nil - } - - if lifecycleHookEnabled, ok := service.(OnServiceReadyEnabled); ok { - return lifecycleHookEnabled - } - return nil + service, err := lm.serviceRegistryRef.GetService(serviceChannelName) + if err != nil { + return nil + } + + if lifecycleHookEnabled, ok := service.(OnServiceReadyEnabled); ok { + return lifecycleHookEnabled + } + return nil } // GetOnServerShutdownService returns a service that implements OnServerShutdownEnabled func (lm *serviceLifecycleManager) GetOnServerShutdownService(serviceChannelName string) OnServerShutdownEnabled { - service, err := lm.serviceRegistryRef.GetService(serviceChannelName) - if err != nil { - return nil - } - - if lifecycleHookEnabled, ok := service.(OnServerShutdownEnabled); ok { - return lifecycleHookEnabled - } - return nil + service, err := lm.serviceRegistryRef.GetService(serviceChannelName) + if err != nil { + return nil + } + + if lifecycleHookEnabled, ok := service.(OnServerShutdownEnabled); ok { + return lifecycleHookEnabled + } + return nil } // GetRESTBridgeEnabledService returns a service that implements OnServerShutdownEnabled func (lm *serviceLifecycleManager) GetRESTBridgeEnabledService(serviceChannelName string) RESTBridgeEnabled { - service, err := lm.serviceRegistryRef.GetService(serviceChannelName) - if err != nil { - return nil - } - - if lifecycleHookEnabled, ok := service.(RESTBridgeEnabled); ok { - return lifecycleHookEnabled - } - return nil + service, err := lm.serviceRegistryRef.GetService(serviceChannelName) + if err != nil { + return nil + } + + if lifecycleHookEnabled, ok := service.(RESTBridgeEnabled); ok { + return lifecycleHookEnabled + } + return nil } // OverrideRESTBridgeConfig overrides the REST bridge configuration currently present with the provided new bridge configs func (lm *serviceLifecycleManager) OverrideRESTBridgeConfig(serviceChannelName string, config []*RESTBridgeConfig) error { - _, err := lm.serviceRegistryRef.GetService(serviceChannelName) - if err != nil { - return err - } - reg := lm.serviceRegistryRef.(*serviceRegistry) - if err = reg.bus.SendResponseMessage( - LifecycleManagerChannelName, - &SetupRESTBridgeRequest{ServiceChannel: serviceChannelName, Config: config, Override: true}, - reg.bus.GetId()); err != nil { - return err - } - return nil + _, err := lm.serviceRegistryRef.GetService(serviceChannelName) + if err != nil { + return err + } + reg := lm.serviceRegistryRef.(*serviceRegistry) + if err = reg.bus.SendResponseMessage( + LifecycleManagerChannelName, + &SetupRESTBridgeRequest{ServiceChannel: serviceChannelName, Config: config, Override: true}, + reg.bus.GetId()); err != nil { + return err + } + return nil } // GetServiceLifecycleManager returns a singleton instance of ServiceLifecycleManager func GetServiceLifecycleManager() ServiceLifecycleManager { - if svcLifecycleManagerInstance == nil { - svcLifecycleManagerInstance = &serviceLifecycleManager{ - serviceRegistryRef: registry, - } - } - return svcLifecycleManagerInstance + if svcLifecycleManagerInstance == nil { + svcLifecycleManagerInstance = &serviceLifecycleManager{ + serviceRegistryRef: registry, + } + } + return svcLifecycleManagerInstance } // newServiceLifecycleManager returns a new instance of ServiceLifecycleManager func newServiceLifecycleManager(reg ServiceRegistry) ServiceLifecycleManager { - return &serviceLifecycleManager{serviceRegistryRef: reg} + return &serviceLifecycleManager{serviceRegistryRef: reg} } diff --git a/service/service_lifecycle_manager_test.go b/service/service_lifecycle_manager_test.go index 55ba14d..931a3cb 100644 --- a/service/service_lifecycle_manager_test.go +++ b/service/service_lifecycle_manager_test.go @@ -1,116 +1,116 @@ package service import ( - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "net/http" - "sync" - "testing" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "net/http" + "sync" + "testing" ) func TestServiceLifecycleManager(t *testing.T) { - // arrange - sr := newTestServiceRegistry() + // arrange + sr := newTestServiceRegistry() - // act - lcm := newTestServiceLifecycleManager(sr) + // act + lcm := newTestServiceLifecycleManager(sr) - // assert - assert.NotNil(t, lcm) + // assert + assert.NotNil(t, lcm) } func TestServiceLifecycleManager_GetServiceHooks(t *testing.T) { - // arrange - sr := newTestServiceRegistry() - lcm := newTestServiceLifecycleManager(sr) - svc := &mockLifecycleHookEnabledService{} - sr.RegisterService(svc, "another-test-channel") + // arrange + sr := newTestServiceRegistry() + lcm := newTestServiceLifecycleManager(sr) + svc := &mockLifecycleHookEnabledService{} + sr.RegisterService(svc, "another-test-channel") - // act - hooks := lcm.GetOnReadyCapableService("another-test-channel") + // act + hooks := lcm.GetOnReadyCapableService("another-test-channel") - // assert - assert.NotNil(t, hooks) + // assert + assert.NotNil(t, hooks) } func TestServiceLifecycleManager_GetServiceHooks_NoSuchService(t *testing.T) { - // arrange - sr := newTestServiceRegistry() - lcm := newTestServiceLifecycleManager(sr) + // arrange + sr := newTestServiceRegistry() + lcm := newTestServiceLifecycleManager(sr) - // act - hooks := lcm.GetOnReadyCapableService("i-don-t-exist") + // act + hooks := lcm.GetOnReadyCapableService("i-don-t-exist") - // assert - assert.Nil(t, hooks) + // assert + assert.Nil(t, hooks) } func TestServiceLifecycleManager_GetServiceHooks_LifecycleHooksNotImplemented(t *testing.T) { - // arrange - sr := newTestServiceRegistry() - lcm := newTestServiceLifecycleManager(sr) - svc := &mockInitializableService{} - sr.RegisterService(svc, "test-channel") + // arrange + sr := newTestServiceRegistry() + lcm := newTestServiceLifecycleManager(sr) + svc := &mockInitializableService{} + sr.RegisterService(svc, "test-channel") - // act - hooks := lcm.GetOnReadyCapableService("test-channel") + // act + hooks := lcm.GetOnReadyCapableService("test-channel") - // assert - assert.Nil(t, hooks) + // assert + assert.Nil(t, hooks) } func TestServiceLifecycleManager_OverrideRESTBridgeConfig_NoSuchService(t *testing.T) { - // arrange - sr := newTestServiceRegistry() - lcm := newTestServiceLifecycleManager(sr) + // arrange + sr := newTestServiceRegistry() + lcm := newTestServiceLifecycleManager(sr) - // act - err := lcm.OverrideRESTBridgeConfig("no-such-service", []*RESTBridgeConfig{}) + // act + err := lcm.OverrideRESTBridgeConfig("no-such-service", []*RESTBridgeConfig{}) - // assert - assert.NotNil(t, err) + // assert + assert.NotNil(t, err) } func TestServiceLifecycleManager_OverrideRESTBridgeConfig(t *testing.T) { - // arrange - wg := sync.WaitGroup{} - sr := newTestServiceRegistry() - lcm := newTestServiceLifecycleManager(sr) - sr.lifecycleManager = lcm.(*serviceLifecycleManager) - _ = sr.RegisterService(&mockLifecycleHookEnabledService{}, "another-test-channel") - - // arrange: test payload - payload := &RESTBridgeConfig{ - ServiceChannel: "another-test-channel", - Uri: "/rest/new-uri", - Method: http.MethodGet, - AllowHead: false, - AllowOptions: false, - FabricRequestBuilder: nil, - } - - // arrange: set up a handler to expect to receive a payload that matches the test payload set above - stream, err := sr.bus.ListenStreamForDestination(LifecycleManagerChannelName, sr.bus.GetId()) - assert.Nil(t, err) - defer stream.Close() - - wg.Add(1) - stream.Handle(func(message *model.Message) { - req, parsed := message.Payload.(*SetupRESTBridgeRequest) - if !parsed { - assert.Fail(t, "should have expected *SetupRESTBridgeRequest payload") - } - // assert - assert.True(t, req.Override) - assert.EqualValues(t, "another-test-channel", req.ServiceChannel) - assert.EqualValues(t, payload, req.Config[0]) - wg.Done() - }, func(err error) { - assert.Fail(t, "should not have errored", err) - }) - - // act - err = lcm.OverrideRESTBridgeConfig("another-test-channel", []*RESTBridgeConfig{payload}) - assert.Nil(t, err) - wg.Wait() + // arrange + wg := sync.WaitGroup{} + sr := newTestServiceRegistry() + lcm := newTestServiceLifecycleManager(sr) + sr.lifecycleManager = lcm.(*serviceLifecycleManager) + _ = sr.RegisterService(&mockLifecycleHookEnabledService{}, "another-test-channel") + + // arrange: test payload + payload := &RESTBridgeConfig{ + ServiceChannel: "another-test-channel", + Uri: "/rest/new-uri", + Method: http.MethodGet, + AllowHead: false, + AllowOptions: false, + FabricRequestBuilder: nil, + } + + // arrange: set up a handler to expect to receive a payload that matches the test payload set above + stream, err := sr.bus.ListenStreamForDestination(LifecycleManagerChannelName, sr.bus.GetId()) + assert.Nil(t, err) + defer stream.Close() + + wg.Add(1) + stream.Handle(func(message *model.Message) { + req, parsed := message.Payload.(*SetupRESTBridgeRequest) + if !parsed { + assert.Fail(t, "should have expected *SetupRESTBridgeRequest payload") + } + // assert + assert.True(t, req.Override) + assert.EqualValues(t, "another-test-channel", req.ServiceChannel) + assert.EqualValues(t, payload, req.Config[0]) + wg.Done() + }, func(err error) { + assert.Fail(t, "should not have errored", err) + }) + + // act + err = lcm.OverrideRESTBridgeConfig("another-test-channel", []*RESTBridgeConfig{payload}) + assert.Nil(t, err) + wg.Wait() } diff --git a/service/service_registry_test.go b/service/service_registry_test.go index 6bb6bcd..a58a92c 100644 --- a/service/service_registry_test.go +++ b/service/service_registry_test.go @@ -4,214 +4,214 @@ package service import ( - "errors" - "github.com/google/uuid" - "github.com/pb33f/ranch/bus" - "github.com/pb33f/ranch/model" - "github.com/stretchr/testify/assert" - "net/http" - "sync" - "testing" + "errors" + "github.com/google/uuid" + "github.com/pb33f/ranch/bus" + "github.com/pb33f/ranch/model" + "github.com/stretchr/testify/assert" + "net/http" + "sync" + "testing" ) func newTestServiceRegistry() *serviceRegistry { - eventBus := bus.NewEventBusInstance() - return newServiceRegistry(eventBus).(*serviceRegistry) + eventBus := bus.NewEventBusInstance() + return newServiceRegistry(eventBus).(*serviceRegistry) } func newTestServiceLifecycleManager(sr ServiceRegistry) ServiceLifecycleManager { - return newServiceLifecycleManager(sr) + return newServiceLifecycleManager(sr) } type mockFabricService struct { - processedRequests []*model.Request - core FabricServiceCore - wg sync.WaitGroup + processedRequests []*model.Request + core FabricServiceCore + wg sync.WaitGroup } func (fs *mockFabricService) HandleServiceRequest(request *model.Request, core FabricServiceCore) { - fs.processedRequests = append(fs.processedRequests, request) - fs.core = core - fs.wg.Done() + fs.processedRequests = append(fs.processedRequests, request) + fs.core = core + fs.wg.Done() } type mockLifecycleHookEnabledService struct { - initChan chan bool - core FabricServiceCore - shutdown bool + initChan chan bool + core FabricServiceCore + shutdown bool } func (s *mockLifecycleHookEnabledService) HandleServiceRequest(request *model.Request, core FabricServiceCore) { } func (s *mockLifecycleHookEnabledService) OnServiceReady() chan bool { - s.initChan = make(chan bool, 1) - s.initChan <- true - return s.initChan + s.initChan = make(chan bool, 1) + s.initChan <- true + return s.initChan } func (s *mockLifecycleHookEnabledService) OnServerShutdown() { - s.shutdown = true + s.shutdown = true } func (s *mockLifecycleHookEnabledService) GetRESTBridgeConfig() []*RESTBridgeConfig { - return []*RESTBridgeConfig{ - { - ServiceChannel: "another-test-channel", - Uri: "/rest/test", - Method: http.MethodGet, - AllowHead: true, - AllowOptions: true, - FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { - return model.Request{ - Id: &uuid.UUID{}, - Payload: "test", - } - }, - }, - } + return []*RESTBridgeConfig{ + { + ServiceChannel: "another-test-channel", + Uri: "/rest/test", + Method: http.MethodGet, + AllowHead: true, + AllowOptions: true, + FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { + return model.Request{ + Id: &uuid.UUID{}, + Payload: "test", + } + }, + }, + } } type mockInitializableService struct { - initialized bool - core FabricServiceCore - initError error + initialized bool + core FabricServiceCore + initError error } func (fs *mockInitializableService) Init(core FabricServiceCore) error { - fs.core = core - fs.initialized = true - return fs.initError + fs.core = core + fs.initialized = true + return fs.initError } func (fs *mockInitializableService) HandleServiceRequest(request *model.Request, core FabricServiceCore) { } func TestGetServiceRegistry(t *testing.T) { - sr := GetServiceRegistry() - sr2 := GetServiceRegistry() - assert.NotNil(t, sr) - assert.Equal(t, sr, sr2) + sr := GetServiceRegistry() + sr2 := GetServiceRegistry() + assert.NotNil(t, sr) + assert.Equal(t, sr, sr2) } func TestServiceRegistry_RegisterService(t *testing.T) { - registry := newTestServiceRegistry() - mockService := &mockFabricService{} + registry := newTestServiceRegistry() + mockService := &mockFabricService{} - assert.Nil(t, registry.RegisterService(mockService, "test-channel")) - assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel")) + assert.Nil(t, registry.RegisterService(mockService, "test-channel")) + assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel")) - id := uuid.New() - req := model.Request{ - Id: &id, - RequestCommand: "test-request", - Payload: "request-payload", - } + id := uuid.New() + req := model.Request{ + Id: &id, + RequestCommand: "test-request", + Payload: "request-payload", + } - mockService.wg.Add(1) - registry.bus.SendRequestMessage("test-channel", req, nil) - mockService.wg.Wait() + mockService.wg.Add(1) + registry.bus.SendRequestMessage("test-channel", req, nil) + mockService.wg.Wait() - assert.Equal(t, len(mockService.processedRequests), 1) - assert.Equal(t, *mockService.processedRequests[0], req) - assert.NotNil(t, mockService.core) + assert.Equal(t, len(mockService.processedRequests), 1) + assert.Equal(t, *mockService.processedRequests[0], req) + assert.NotNil(t, mockService.core) - registry.bus.SendRequestMessage("test-channel", "invalid-request", nil) - registry.bus.SendRequestMessage("test-channel", nil, nil) - registry.bus.SendResponseMessage("test-channel", req, nil) - registry.bus.SendErrorMessage("test-channel", errors.New("test-error"), nil) + registry.bus.SendRequestMessage("test-channel", "invalid-request", nil) + registry.bus.SendRequestMessage("test-channel", nil, nil) + registry.bus.SendResponseMessage("test-channel", req, nil) + registry.bus.SendErrorMessage("test-channel", errors.New("test-error"), nil) - mockService.wg.Add(1) - registry.bus.SendRequestMessage("test-channel", &req, nil) - mockService.wg.Wait() + mockService.wg.Add(1) + registry.bus.SendRequestMessage("test-channel", &req, nil) + mockService.wg.Wait() - assert.Equal(t, len(mockService.processedRequests), 2) - assert.Equal(t, mockService.processedRequests[1], &req) - assert.NotNil(t, mockService.core) + assert.Equal(t, len(mockService.processedRequests), 2) + assert.Equal(t, mockService.processedRequests[1], &req) + assert.NotNil(t, mockService.core) - mockService.wg.Add(1) - uuid := uuid.New() - registry.bus.SendRequestMessage("test-channel", model.Request{ - RequestCommand: "test-request-2", - Payload: "request-payload", - }, &uuid) - mockService.wg.Wait() + mockService.wg.Add(1) + uuid := uuid.New() + registry.bus.SendRequestMessage("test-channel", model.Request{ + RequestCommand: "test-request-2", + Payload: "request-payload", + }, &uuid) + mockService.wg.Wait() - assert.Equal(t, len(mockService.processedRequests), 3) - assert.Equal(t, mockService.processedRequests[2].Id, &uuid) + assert.Equal(t, len(mockService.processedRequests), 3) + assert.Equal(t, mockService.processedRequests[2].Id, &uuid) - assert.EqualError(t, registry.RegisterService(&mockFabricService{}, "test-channel"), - "unable to register service: service channel name is already used: test-channel") + assert.EqualError(t, registry.RegisterService(&mockFabricService{}, "test-channel"), + "unable to register service: service channel name is already used: test-channel") - assert.EqualError(t, registry.RegisterService(nil, "test-channel2"), - "unable to register service: nil service") + assert.EqualError(t, registry.RegisterService(nil, "test-channel2"), + "unable to register service: nil service") - assert.False(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel2")) + assert.False(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel2")) } func TestServiceRegistry_RegisterInitializableService(t *testing.T) { - registry := newTestServiceRegistry() - mockService := &mockInitializableService{} - assert.Nil(t, registry.RegisterService(mockService, "test-channel")) + registry := newTestServiceRegistry() + mockService := &mockInitializableService{} + assert.Nil(t, registry.RegisterService(mockService, "test-channel")) - assert.True(t, mockService.initialized) - assert.NotNil(t, mockService.core) + assert.True(t, mockService.initialized) + assert.NotNil(t, mockService.core) - assert.EqualError(t, - registry.RegisterService(&mockInitializableService{initError: errors.New("init-error")}, "test-channel2"), - "init-error") + assert.EqualError(t, + registry.RegisterService(&mockInitializableService{initError: errors.New("init-error")}, "test-channel2"), + "init-error") } func TestServiceRegistry_UnregisterService(t *testing.T) { - registry := newTestServiceRegistry() - mockService := &mockFabricService{} + registry := newTestServiceRegistry() + mockService := &mockFabricService{} - assert.Nil(t, registry.RegisterService(mockService, "test-channel")) - assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel")) + assert.Nil(t, registry.RegisterService(mockService, "test-channel")) + assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel")) - id := uuid.New() - req := model.Request{ - Id: &id, - RequestCommand: "test-request", - Payload: "request-payload", - } + id := uuid.New() + req := model.Request{ + Id: &id, + RequestCommand: "test-request", + Payload: "request-payload", + } - assert.Nil(t, registry.UnregisterService("test-channel")) - registry.bus.SendRequestMessage("test-channel", req, nil) + assert.Nil(t, registry.UnregisterService("test-channel")) + registry.bus.SendRequestMessage("test-channel", req, nil) - assert.Equal(t, len(mockService.processedRequests), 0) - assert.EqualError(t, registry.UnregisterService("test-channel"), - "unable to unregister service: no service is registered for channel \"test-channel\"") + assert.Equal(t, len(mockService.processedRequests), 0) + assert.EqualError(t, registry.UnregisterService("test-channel"), + "unable to unregister service: no service is registered for channel \"test-channel\"") } func TestServiceRegistry_SetGlobalRestServiceBaseHost(t *testing.T) { - registry := newTestServiceRegistry() - registry.SetGlobalRestServiceBaseHost("localhost:9999") - assert.Equal(t, "localhost:9999", - registry.services[restServiceChannel].service.(*restService).baseHost) + registry := newTestServiceRegistry() + registry.SetGlobalRestServiceBaseHost("localhost:9999") + assert.Equal(t, "localhost:9999", + registry.services[restServiceChannel].service.(*restService).baseHost) } func TestServiceRegistry_GetAllServiceChannels(t *testing.T) { - registry := newTestServiceRegistry() - mockService := &mockFabricService{} + registry := newTestServiceRegistry() + mockService := &mockFabricService{} - registry.RegisterService(mockService, "test-channel") - chans := registry.GetAllServiceChannels() + registry.RegisterService(mockService, "test-channel") + chans := registry.GetAllServiceChannels() - assert.Len(t, chans, 1) - assert.EqualValues(t, "test-channel", chans[0]) + assert.Len(t, chans, 1) + assert.EqualValues(t, "test-channel", chans[0]) } func TestServiceRegistry_RegisterService_LifecycleHookEnabled(t *testing.T) { - svc := &mockLifecycleHookEnabledService{} - registry := newTestServiceRegistry() - registry.RegisterService(svc, "another-test-channel") + svc := &mockLifecycleHookEnabledService{} + registry := newTestServiceRegistry() + registry.RegisterService(svc, "another-test-channel") - assert.True(t, <-svc.OnServiceReady()) + assert.True(t, <-svc.OnServiceReady()) - svc.OnServerShutdown() - assert.True(t, svc.shutdown) + svc.OnServerShutdown() + assert.True(t, svc.shutdown) - restBridgeConfig := svc.GetRESTBridgeConfig() - assert.NotNil(t, restBridgeConfig) + restBridgeConfig := svc.GetRESTBridgeConfig() + assert.NotNil(t, restBridgeConfig) } diff --git a/stompserver/stomp_connection.go b/stompserver/stomp_connection.go index 19191af..b81af26 100644 --- a/stompserver/stomp_connection.go +++ b/stompserver/stomp_connection.go @@ -4,458 +4,458 @@ package stompserver import ( - "fmt" - "github.com/go-stomp/stomp/v3" - "github.com/go-stomp/stomp/v3/frame" - "github.com/google/uuid" - "log" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" + "fmt" + "github.com/go-stomp/stomp/v3" + "github.com/go-stomp/stomp/v3/frame" + "github.com/google/uuid" + "log" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" ) type subscription struct { - id string - destination string + id string + destination string } type StompConn interface { - // Return unique connection Id string - GetId() string - SendFrameToSubscription(f *frame.Frame, sub *subscription) - Close() + // Return unique connection Id string + GetId() string + SendFrameToSubscription(f *frame.Frame, sub *subscription) + Close() } const ( - maxHeartBeatDuration = time.Duration(999999999) * time.Millisecond + maxHeartBeatDuration = time.Duration(999999999) * time.Millisecond ) const ( - connecting int32 = iota - connected - closed + connecting int32 = iota + connected + closed ) type stompConn struct { - rawConnection RawConnection - state int32 - version stomp.Version - inFrames chan *frame.Frame - outFrames chan *frame.Frame - readTimeoutMs int64 - writeTimeout time.Duration - id string - events chan *ConnEvent - config StompConfig - subscriptions map[string]*subscription - currentMessageId uint64 - closeOnce sync.Once + rawConnection RawConnection + state int32 + version stomp.Version + inFrames chan *frame.Frame + outFrames chan *frame.Frame + readTimeoutMs int64 + writeTimeout time.Duration + id string + events chan *ConnEvent + config StompConfig + subscriptions map[string]*subscription + currentMessageId uint64 + closeOnce sync.Once } func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *ConnEvent) StompConn { - conn := &stompConn{ - rawConnection: rawConnection, - state: connecting, - inFrames: make(chan *frame.Frame, 32), - outFrames: make(chan *frame.Frame, 32), - config: config, - id: uuid.New().String(), - events: events, - subscriptions: make(map[string]*subscription), - } - - go conn.run() - go conn.readInFrames() - - return conn + conn := &stompConn{ + rawConnection: rawConnection, + state: connecting, + inFrames: make(chan *frame.Frame, 32), + outFrames: make(chan *frame.Frame, 32), + config: config, + id: uuid.New().String(), + events: events, + subscriptions: make(map[string]*subscription), + } + + go conn.run() + go conn.readInFrames() + + return conn } func (conn *stompConn) SendFrameToSubscription(f *frame.Frame, sub *subscription) { - f.Header.Add(frame.Subscription, sub.id) - conn.outFrames <- f + f.Header.Add(frame.Subscription, sub.id) + conn.outFrames <- f } func (conn *stompConn) Close() { - conn.closeOnce.Do(func() { - atomic.StoreInt32(&conn.state, closed) - conn.rawConnection.Close() - - conn.events <- &ConnEvent{ - ConnId: conn.GetId(), - eventType: ConnectionClosed, - conn: conn, - } - }) + conn.closeOnce.Do(func() { + atomic.StoreInt32(&conn.state, closed) + conn.rawConnection.Close() + + conn.events <- &ConnEvent{ + ConnId: conn.GetId(), + eventType: ConnectionClosed, + conn: conn, + } + }) } func (conn *stompConn) GetId() string { - return conn.id + return conn.id } func (conn *stompConn) run() { - defer conn.Close() - - var timerChannel <-chan time.Time - var timer *time.Timer - - for { - - if atomic.LoadInt32(&conn.state) == closed { - return - } - - if timer == nil && conn.writeTimeout > 0 { - timer = time.NewTimer(conn.writeTimeout) - timerChannel = timer.C - } - - select { - case f, ok := <-conn.outFrames: - if !ok { - // close connection - return - } - - // reset heart-beat timer - if timer != nil { - timer.Stop() - timer = nil - } - - conn.populateMessageIdHeader(f) - - // write the frame to the client - err := conn.rawConnection.WriteFrame(f) - if err != nil || f.Command == frame.ERROR { - return - } - - case f, ok := <-conn.inFrames: - if !ok { - return - } - - if err := conn.handleIncomingFrame(f); err != nil { - conn.sendError(err) - return - } - - case _ = <-timerChannel: - // write a heart-beat - err := conn.rawConnection.WriteFrame(nil) - if err != nil { - return - } - if timer != nil { - timer.Stop() - timer = nil - } - } - } + defer conn.Close() + + var timerChannel <-chan time.Time + var timer *time.Timer + + for { + + if atomic.LoadInt32(&conn.state) == closed { + return + } + + if timer == nil && conn.writeTimeout > 0 { + timer = time.NewTimer(conn.writeTimeout) + timerChannel = timer.C + } + + select { + case f, ok := <-conn.outFrames: + if !ok { + // close connection + return + } + + // reset heart-beat timer + if timer != nil { + timer.Stop() + timer = nil + } + + conn.populateMessageIdHeader(f) + + // write the frame to the client + err := conn.rawConnection.WriteFrame(f) + if err != nil || f.Command == frame.ERROR { + return + } + + case f, ok := <-conn.inFrames: + if !ok { + return + } + + if err := conn.handleIncomingFrame(f); err != nil { + conn.sendError(err) + return + } + + case _ = <-timerChannel: + // write a heart-beat + err := conn.rawConnection.WriteFrame(nil) + if err != nil { + return + } + if timer != nil { + timer.Stop() + timer = nil + } + } + } } func (conn *stompConn) handleIncomingFrame(f *frame.Frame) error { - switch f.Command { + switch f.Command { - case frame.CONNECT, frame.STOMP: - return conn.handleConnect(f) + case frame.CONNECT, frame.STOMP: + return conn.handleConnect(f) - case frame.DISCONNECT: - return conn.handleDisconnect(f) + case frame.DISCONNECT: + return conn.handleDisconnect(f) - case frame.SEND: - return conn.handleSend(f) + case frame.SEND: + return conn.handleSend(f) - case frame.SUBSCRIBE: - return conn.handleSubscribe(f) + case frame.SUBSCRIBE: + return conn.handleSubscribe(f) - case frame.UNSUBSCRIBE: - return conn.handleUnsubscribe(f) - } + case frame.UNSUBSCRIBE: + return conn.handleUnsubscribe(f) + } - return unsupportedStompCommandError + return unsupportedStompCommandError } // Returns true if the frame contains ANY of the specified // headers func containsHeader(f *frame.Frame, headers ...string) bool { - for _, h := range headers { - if _, ok := f.Header.Contains(h); ok { - return true - } - } - return false + for _, h := range headers { + if _, ok := f.Header.Contains(h); ok { + return true + } + } + return false } func (conn *stompConn) handleConnect(f *frame.Frame) error { - if atomic.LoadInt32(&conn.state) == connected { - return unexpectedStompCommandError - } - - if containsHeader(f, frame.Receipt) { - return invalidHeaderError - } - - var err error - conn.version, err = determineVersion(f) - if err != nil { - log.Println("cannot determine version") - return err - } - - if conn.version == stomp.V10 { - return unsupportedStompVersionError - } - - cxDuration, cyDuration, err := getHeartBeat(f) - if err != nil { - log.Println("invalid heart-beat") - return err - } - - min := time.Duration(conn.config.HeartBeat()) * time.Millisecond - if min > maxHeartBeatDuration { - min = maxHeartBeatDuration - } - - // apply a minimum heartbeat - if cxDuration > 0 { - if min == 0 || cxDuration < min { - cxDuration = min - } - } - if cyDuration > 0 { - if min == 0 || cyDuration < min { - cyDuration = min - } - } - - conn.writeTimeout = cyDuration - - cx, cy := int64(cxDuration/time.Millisecond), int64(cyDuration/time.Millisecond) - atomic.StoreInt64(&conn.readTimeoutMs, cx) - - response := frame.New(frame.CONNECTED, - frame.Version, string(conn.version), - frame.Server, "pb33f-ranch/0.0.1", - frame.HeartBeat, fmt.Sprintf("%d,%d", cy, cx)) - - err = conn.rawConnection.WriteFrame(response) - if err != nil { - return err - } - - atomic.StoreInt32(&conn.state, connected) - - conn.events <- &ConnEvent{ - ConnId: conn.GetId(), - eventType: ConnectionEstablished, - conn: conn, - } - - return nil + if atomic.LoadInt32(&conn.state) == connected { + return unexpectedStompCommandError + } + + if containsHeader(f, frame.Receipt) { + return invalidHeaderError + } + + var err error + conn.version, err = determineVersion(f) + if err != nil { + log.Println("cannot determine version") + return err + } + + if conn.version == stomp.V10 { + return unsupportedStompVersionError + } + + cxDuration, cyDuration, err := getHeartBeat(f) + if err != nil { + log.Println("invalid heart-beat") + return err + } + + min := time.Duration(conn.config.HeartBeat()) * time.Millisecond + if min > maxHeartBeatDuration { + min = maxHeartBeatDuration + } + + // apply a minimum heartbeat + if cxDuration > 0 { + if min == 0 || cxDuration < min { + cxDuration = min + } + } + if cyDuration > 0 { + if min == 0 || cyDuration < min { + cyDuration = min + } + } + + conn.writeTimeout = cyDuration + + cx, cy := int64(cxDuration/time.Millisecond), int64(cyDuration/time.Millisecond) + atomic.StoreInt64(&conn.readTimeoutMs, cx) + + response := frame.New(frame.CONNECTED, + frame.Version, string(conn.version), + frame.Server, "pb33f-ranch/0.0.1", + frame.HeartBeat, fmt.Sprintf("%d,%d", cy, cx)) + + err = conn.rawConnection.WriteFrame(response) + if err != nil { + return err + } + + atomic.StoreInt32(&conn.state, connected) + + conn.events <- &ConnEvent{ + ConnId: conn.GetId(), + eventType: ConnectionEstablished, + conn: conn, + } + + return nil } func (conn *stompConn) handleDisconnect(f *frame.Frame) error { - if atomic.LoadInt32(&conn.state) == connecting { - return notConnectedStompError - } + if atomic.LoadInt32(&conn.state) == connecting { + return notConnectedStompError + } - conn.sendReceiptResponse(f) - conn.Close() + conn.sendReceiptResponse(f) + conn.Close() - return nil + return nil } func (conn *stompConn) handleSubscribe(f *frame.Frame) error { - switch atomic.LoadInt32(&conn.state) { - case connecting: - return notConnectedStompError - case closed: - return nil - } - - subId, ok := f.Header.Contains(frame.Id) - if !ok { - return invalidSubscriptionError - } - - dest, ok := f.Header.Contains(frame.Destination) - if !ok { - return invalidFrameError - } - - if _, exists := conn.subscriptions[subId]; exists { - // subscription already exists - return nil - } - - conn.subscriptions[subId] = &subscription{ - id: subId, - destination: dest, - } - - conn.events <- &ConnEvent{ - ConnId: conn.GetId(), - eventType: SubscribeToTopic, - destination: dest, - conn: conn, - sub: conn.subscriptions[subId], - frame: f, - } - - return nil + switch atomic.LoadInt32(&conn.state) { + case connecting: + return notConnectedStompError + case closed: + return nil + } + + subId, ok := f.Header.Contains(frame.Id) + if !ok { + return invalidSubscriptionError + } + + dest, ok := f.Header.Contains(frame.Destination) + if !ok { + return invalidFrameError + } + + if _, exists := conn.subscriptions[subId]; exists { + // subscription already exists + return nil + } + + conn.subscriptions[subId] = &subscription{ + id: subId, + destination: dest, + } + + conn.events <- &ConnEvent{ + ConnId: conn.GetId(), + eventType: SubscribeToTopic, + destination: dest, + conn: conn, + sub: conn.subscriptions[subId], + frame: f, + } + + return nil } func (conn *stompConn) handleUnsubscribe(f *frame.Frame) error { - switch atomic.LoadInt32(&conn.state) { - case connecting: - return notConnectedStompError - case closed: - return nil - } - - id, ok := f.Header.Contains(frame.Id) - if !ok { - return invalidSubscriptionError - } - - conn.sendReceiptResponse(f) - - sub, ok := conn.subscriptions[id] - if !ok { - // subscription already removed - return nil - } - - // remove the subscription - delete(conn.subscriptions, id) - - conn.events <- &ConnEvent{ - ConnId: conn.GetId(), - eventType: UnsubscribeFromTopic, - conn: conn, - sub: sub, - destination: sub.destination, - } - - return nil + switch atomic.LoadInt32(&conn.state) { + case connecting: + return notConnectedStompError + case closed: + return nil + } + + id, ok := f.Header.Contains(frame.Id) + if !ok { + return invalidSubscriptionError + } + + conn.sendReceiptResponse(f) + + sub, ok := conn.subscriptions[id] + if !ok { + // subscription already removed + return nil + } + + // remove the subscription + delete(conn.subscriptions, id) + + conn.events <- &ConnEvent{ + ConnId: conn.GetId(), + eventType: UnsubscribeFromTopic, + conn: conn, + sub: sub, + destination: sub.destination, + } + + return nil } func (conn *stompConn) handleSend(f *frame.Frame) error { - switch atomic.LoadInt32(&conn.state) { - case connecting: - return notConnectedStompError - case closed: - return nil - } - - // TODO: Remove if we start supporting transactions - if containsHeader(f, frame.Transaction) { - return unsupportedStompCommandError - } - - // no destination triggers an error - dest, ok := f.Header.Contains(frame.Destination) - if !ok { - return invalidFrameError - } - - // reject SENDing directly to non-request channels by clients - if !conn.config.IsAppRequestDestination(f.Header.Get(frame.Destination)) { - return invalidSendDestinationError - } - - err := conn.sendReceiptResponse(f) - if err != nil { - return err - } - - f.Command = frame.MESSAGE - conn.events <- &ConnEvent{ - ConnId: conn.GetId(), - eventType: IncomingMessage, - destination: dest, - frame: f, - conn: conn, - } - - return nil + switch atomic.LoadInt32(&conn.state) { + case connecting: + return notConnectedStompError + case closed: + return nil + } + + // TODO: Remove if we start supporting transactions + if containsHeader(f, frame.Transaction) { + return unsupportedStompCommandError + } + + // no destination triggers an error + dest, ok := f.Header.Contains(frame.Destination) + if !ok { + return invalidFrameError + } + + // reject SENDing directly to non-request channels by clients + if !conn.config.IsAppRequestDestination(f.Header.Get(frame.Destination)) { + return invalidSendDestinationError + } + + err := conn.sendReceiptResponse(f) + if err != nil { + return err + } + + f.Command = frame.MESSAGE + conn.events <- &ConnEvent{ + ConnId: conn.GetId(), + eventType: IncomingMessage, + destination: dest, + frame: f, + conn: conn, + } + + return nil } func (conn *stompConn) sendReceiptResponse(f *frame.Frame) error { - if receipt, ok := f.Header.Contains(frame.Receipt); ok { - f.Header.Del(frame.Receipt) - return conn.rawConnection.WriteFrame(frame.New(frame.RECEIPT, frame.ReceiptId, receipt)) - } - return nil + if receipt, ok := f.Header.Contains(frame.Receipt); ok { + f.Header.Del(frame.Receipt) + return conn.rawConnection.WriteFrame(frame.New(frame.RECEIPT, frame.ReceiptId, receipt)) + } + return nil } func (conn *stompConn) readInFrames() { - defer func() { - close(conn.inFrames) - }() - - // we never close the connection, even if the heartbeating is inaccurate. - infiniteTimeout := time.Time{} - for { - conn.rawConnection.SetReadDeadline(infiniteTimeout) - f, err := conn.rawConnection.ReadFrame() - if err != nil { - return - } - - if f == nil { - // heartbeat frame - continue - } - - conn.inFrames <- f - } + defer func() { + close(conn.inFrames) + }() + + // we never close the connection, even if the heartbeating is inaccurate. + infiniteTimeout := time.Time{} + for { + conn.rawConnection.SetReadDeadline(infiniteTimeout) + f, err := conn.rawConnection.ReadFrame() + if err != nil { + return + } + + if f == nil { + // heartbeat frame + continue + } + + conn.inFrames <- f + } } func determineVersion(f *frame.Frame) (stomp.Version, error) { - if acceptVersion, ok := f.Header.Contains(frame.AcceptVersion); ok { - versions := strings.Split(acceptVersion, ",") - for _, supportedVersion := range []stomp.Version{stomp.V12, stomp.V11, stomp.V10} { - for _, v := range versions { - if v == supportedVersion.String() { - // return the highest supported version - return supportedVersion, nil - } - } - } - } else { - return stomp.V10, nil - } - - var emptyVersion stomp.Version - return emptyVersion, unsupportedStompVersionError + if acceptVersion, ok := f.Header.Contains(frame.AcceptVersion); ok { + versions := strings.Split(acceptVersion, ",") + for _, supportedVersion := range []stomp.Version{stomp.V12, stomp.V11, stomp.V10} { + for _, v := range versions { + if v == supportedVersion.String() { + // return the highest supported version + return supportedVersion, nil + } + } + } + } else { + return stomp.V10, nil + } + + var emptyVersion stomp.Version + return emptyVersion, unsupportedStompVersionError } func getHeartBeat(f *frame.Frame) (cx, cy time.Duration, err error) { - if heartBeat, ok := f.Header.Contains(frame.HeartBeat); ok { - return frame.ParseHeartBeat(heartBeat) - } - return 0, 0, nil + if heartBeat, ok := f.Header.Contains(frame.HeartBeat); ok { + return frame.ParseHeartBeat(heartBeat) + } + return 0, 0, nil } func (conn *stompConn) sendError(err error) { - errorFrame := frame.New(frame.ERROR, - frame.Message, err.Error()) + errorFrame := frame.New(frame.ERROR, + frame.Message, err.Error()) - conn.rawConnection.WriteFrame(errorFrame) + conn.rawConnection.WriteFrame(errorFrame) } func (conn *stompConn) populateMessageIdHeader(f *frame.Frame) { - if f.Command == frame.MESSAGE { - // allocate the value of message-id for this frame - conn.currentMessageId++ - messageId := strconv.FormatUint(conn.currentMessageId, 10) - f.Header.Set(frame.MessageId, messageId) - // remove the Ack header (if any) as we don't support those - f.Header.Del(frame.Ack) - } + if f.Command == frame.MESSAGE { + // allocate the value of message-id for this frame + conn.currentMessageId++ + messageId := strconv.FormatUint(conn.currentMessageId, 10) + f.Header.Set(frame.MessageId, messageId) + // remove the Ack header (if any) as we don't support those + f.Header.Del(frame.Ack) + } } diff --git a/stompserver/stomp_connection_test.go b/stompserver/stomp_connection_test.go index 9fed1f4..d488b46 100644 --- a/stompserver/stomp_connection_test.go +++ b/stompserver/stomp_connection_test.go @@ -4,830 +4,830 @@ package stompserver import ( - "errors" - "fmt" - "github.com/go-stomp/stomp/v3/frame" - "github.com/stretchr/testify/assert" - "sync" - "testing" - "time" + "errors" + "fmt" + "github.com/go-stomp/stomp/v3/frame" + "github.com/stretchr/testify/assert" + "sync" + "testing" + "time" ) type MockRawConnection struct { - connected bool - incomingFrames chan interface{} - lock sync.Mutex - currentDeadline time.Time - sentFrames []*frame.Frame - nextWriteErr error - nextReadErr error - writeWg *sync.WaitGroup + connected bool + incomingFrames chan interface{} + lock sync.Mutex + currentDeadline time.Time + sentFrames []*frame.Frame + nextWriteErr error + nextReadErr error + writeWg *sync.WaitGroup } func NewMockRawConnection() *MockRawConnection { - return &MockRawConnection{ - connected: true, - incomingFrames: make(chan interface{}), - sentFrames: []*frame.Frame{}, - nextWriteErr: nil, - } + return &MockRawConnection{ + connected: true, + incomingFrames: make(chan interface{}), + sentFrames: []*frame.Frame{}, + nextWriteErr: nil, + } } func (con *MockRawConnection) ReadFrame() (*frame.Frame, error) { - obj := <-con.incomingFrames + obj := <-con.incomingFrames - if obj == nil { - // heart-beat - return nil, nil - } + if obj == nil { + // heart-beat + return nil, nil + } - f, ok := obj.(*frame.Frame) - if ok { - return f, nil - } + f, ok := obj.(*frame.Frame) + if ok { + return f, nil + } - return nil, obj.(error) + return nil, obj.(error) } func (con *MockRawConnection) WriteFrame(frame *frame.Frame) error { - defer func() { con.nextWriteErr = nil }() - if con.nextWriteErr != nil { - return con.nextWriteErr - } + defer func() { con.nextWriteErr = nil }() + if con.nextWriteErr != nil { + return con.nextWriteErr + } - con.lock.Lock() - con.sentFrames = append(con.sentFrames, frame) - if con.writeWg != nil { - con.writeWg.Done() - } - con.lock.Unlock() - return nil + con.lock.Lock() + con.sentFrames = append(con.sentFrames, frame) + if con.writeWg != nil { + con.writeWg.Done() + } + con.lock.Unlock() + return nil } func (con *MockRawConnection) LastSentFrame() *frame.Frame { - return con.sentFrames[len(con.sentFrames)-1] + return con.sentFrames[len(con.sentFrames)-1] } func (con *MockRawConnection) SetReadDeadline(t time.Time) { - con.lock.Lock() - con.currentDeadline = t - con.lock.Unlock() + con.lock.Lock() + con.currentDeadline = t + con.lock.Unlock() } func (con *MockRawConnection) getCurrentReadDeadline() time.Time { - con.lock.Lock() - defer con.lock.Unlock() - return con.currentDeadline + con.lock.Lock() + defer con.lock.Unlock() + return con.currentDeadline } func (con *MockRawConnection) Close() error { - con.connected = false - return nil + con.connected = false + return nil } func (con *MockRawConnection) SendConnectFrame() { - con.incomingFrames <- frame.New( - frame.CONNECT, - frame.AcceptVersion, "1.2") + con.incomingFrames <- frame.New( + frame.CONNECT, + frame.AcceptVersion, "1.2") } func getTestStompConn(conf StompConfig, events chan *ConnEvent) (*stompConn, *MockRawConnection, chan *ConnEvent) { - if events == nil { - events = make(chan *ConnEvent, 1000) - } + if events == nil { + events = make(chan *ConnEvent, 1000) + } - rawConn := NewMockRawConnection() - return NewStompConn(rawConn, conf, events).(*stompConn), rawConn, events + rawConn := NewMockRawConnection() + return NewStompConn(rawConn, conf, events).(*stompConn), rawConn, events } func TestStompConn_Connect(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - assert.NotNil(t, stompConn.GetId()) + assert.NotNil(t, stompConn.GetId()) - assert.Equal(t, stompConn.state, connecting) + assert.Equal(t, stompConn.state, connecting) - rawConn.incomingFrames <- frame.New(frame.CONNECT, frame.AcceptVersion, "1.0,1.2,1.1,1.3") + rawConn.incomingFrames <- frame.New(frame.CONNECT, frame.AcceptVersion, "1.0,1.2,1.1,1.3") - e := <-events + e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) - assert.Equal(t, e.conn, stompConn) + assert.Equal(t, e.eventType, ConnectionEstablished) + assert.Equal(t, e.conn, stompConn) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED, - frame.Version, "1.2", - frame.HeartBeat, "0,0", - frame.Server, "pb33f-ranch/0.0.1"), true) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED, + frame.Version, "1.2", + frame.HeartBeat, "0,0", + frame.Server, "pb33f-ranch/0.0.1"), true) - assert.Equal(t, stompConn.state, connected) + assert.Equal(t, stompConn.state, connected) } func TestStompConn_ConnectStomp10(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - assert.Equal(t, stompConn.state, connecting) + assert.Equal(t, stompConn.state, connecting) - rawConn.incomingFrames <- frame.New(frame.CONNECT, frame.AcceptVersion, "1.0") + rawConn.incomingFrames <- frame.New(frame.CONNECT, frame.AcceptVersion, "1.0") - e := <-events + e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, e.conn, stompConn) + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, e.conn, stompConn) - assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, unsupportedStompVersionError.Error()), true) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, unsupportedStompVersionError.Error()), true) - assert.Equal(t, stompConn.state, closed) - assert.Equal(t, rawConn.connected, false) + assert.Equal(t, stompConn.state, closed) + assert.Equal(t, rawConn.connected, false) } func TestStompConn_ConnectInvalidStompVersion(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - assert.Equal(t, stompConn.state, connecting) + assert.Equal(t, stompConn.state, connecting) - rawConn.incomingFrames <- frame.New(frame.CONNECT, frame.AcceptVersion, "5.0") + rawConn.incomingFrames <- frame.New(frame.CONNECT, frame.AcceptVersion, "5.0") - e := <-events + e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, e.conn, stompConn) + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, e.conn, stompConn) - assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, unsupportedStompVersionError.Error()), true) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, unsupportedStompVersionError.Error()), true) - assert.Equal(t, stompConn.state, closed) - assert.Equal(t, rawConn.connected, false) + assert.Equal(t, stompConn.state, closed) + assert.Equal(t, rawConn.connected, false) } func TestStompConn_ConnectWithReceiptHeader(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - assert.Equal(t, stompConn.state, connecting) + assert.Equal(t, stompConn.state, connecting) - rawConn.incomingFrames <- frame.New(frame.CONNECT, - frame.AcceptVersion, "1.2", - frame.Receipt, "receipt-id") + rawConn.incomingFrames <- frame.New(frame.CONNECT, + frame.AcceptVersion, "1.2", + frame.Receipt, "receipt-id") - e := <-events + e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, e.conn, stompConn) + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, e.conn, stompConn) - assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, invalidHeaderError.Error()), true) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, invalidHeaderError.Error()), true) - assert.Equal(t, stompConn.state, closed) - assert.Equal(t, rawConn.connected, false) + assert.Equal(t, stompConn.state, closed) + assert.Equal(t, rawConn.connected, false) } func TestStompConn_ConnectMissingStompVersionHeader(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - assert.Equal(t, stompConn.state, connecting) + assert.Equal(t, stompConn.state, connecting) - rawConn.incomingFrames <- frame.New(frame.CONNECT) + rawConn.incomingFrames <- frame.New(frame.CONNECT) - e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, e.conn, stompConn) + e := <-events + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, e.conn, stompConn) - assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, unsupportedStompVersionError.Error()), true) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, unsupportedStompVersionError.Error()), true) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_ConnectInvalidHeartbeatHeader(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - assert.Equal(t, stompConn.state, connecting) + assert.Equal(t, stompConn.state, connecting) - rawConn.incomingFrames <- frame.New(frame.CONNECT, - frame.AcceptVersion, "1.2", - frame.HeartBeat, "12,asd") + rawConn.incomingFrames <- frame.New(frame.CONNECT, + frame.AcceptVersion, "1.2", + frame.HeartBeat, "12,asd") - e := <-events + e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, e.conn, stompConn) + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, e.conn, stompConn) - assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, frame.ErrInvalidHeartBeat.Error()), true) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, frame.ErrInvalidHeartBeat.Error()), true) - assert.Equal(t, stompConn.state, closed) - assert.Equal(t, rawConn.connected, false) + assert.Equal(t, stompConn.state, closed) + assert.Equal(t, rawConn.connected, false) } func TestStompConn_InvalidStompCommand(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - assert.Equal(t, stompConn.state, connecting) + assert.Equal(t, stompConn.state, connecting) - rawConn.incomingFrames <- frame.New("invalid-stomp-command", - frame.AcceptVersion, "1.2") + rawConn.incomingFrames <- frame.New("invalid-stomp-command", + frame.AcceptVersion, "1.2") - e := <-events + e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, e.conn, stompConn) + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, e.conn, stompConn) - assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, unsupportedStompCommandError.Error()), true) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, unsupportedStompCommandError.Error()), true) - assert.Equal(t, stompConn.state, closed) - assert.Equal(t, rawConn.connected, false) + assert.Equal(t, stompConn.state, closed) + assert.Equal(t, rawConn.connected, false) } func TestStompConn_ConnectNoServerHeartbeat(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.incomingFrames <- frame.New( - frame.CONNECT, - frame.AcceptVersion, "1.1,1.0", - frame.HeartBeat, "4000,4000") + rawConn.incomingFrames <- frame.New( + frame.CONNECT, + frame.AcceptVersion, "1.1,1.0", + frame.HeartBeat, "4000,4000") - <-events + <-events - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED, - frame.Version, "1.1", - frame.HeartBeat, "0,0"), false) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED, + frame.Version, "1.1", + frame.HeartBeat, "0,0"), false) } func TestStompConn_ConnectServerHeartbeat(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(9999999991, []string{}), nil) - rawConn.incomingFrames <- frame.New( - frame.CONNECT, - frame.AcceptVersion, "1.1,1.0", - frame.HeartBeat, "4000,4000") + _, rawConn, events := getTestStompConn(NewStompConfig(9999999991, []string{}), nil) + rawConn.incomingFrames <- frame.New( + frame.CONNECT, + frame.AcceptVersion, "1.1,1.0", + frame.HeartBeat, "4000,4000") - <-events + <-events - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED, - frame.Version, "1.1", - frame.HeartBeat, "999999999,999999999"), false) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED, + frame.Version, "1.1", + frame.HeartBeat, "999999999,999999999"), false) } func TestStompConn_ConnectClientHeartbeat(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(7000, []string{}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(7000, []string{}), nil) - rawConn.incomingFrames <- frame.New( - frame.CONNECT, - frame.AcceptVersion, "1.2", - frame.HeartBeat, "8000,9000") + rawConn.incomingFrames <- frame.New( + frame.CONNECT, + frame.AcceptVersion, "1.2", + frame.HeartBeat, "8000,9000") - <-events + <-events - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED, - frame.HeartBeat, "9000,8000"), false) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED, + frame.HeartBeat, "9000,8000"), false) } func TestStompConn_ConnectWhenConnected(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e = <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, - frame.Message, unexpectedStompCommandError.Error()), true) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, + frame.Message, unexpectedStompCommandError.Error()), true) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_SubscribeNotConnected(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.incomingFrames <- frame.New( - frame.SUBSCRIBE, - frame.Destination, "/topic/test") + rawConn.incomingFrames <- frame.New( + frame.SUBSCRIBE, + frame.Destination, "/topic/test") - e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, e.conn, stompConn) + e := <-events + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, e.conn, stompConn) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, notConnectedStompError.Error()), true) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, notConnectedStompError.Error()), true) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_SubscribeMissingIdHeader(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - rawConn.incomingFrames <- frame.New( - frame.SUBSCRIBE, - frame.Destination, "/topic/test") + rawConn.incomingFrames <- frame.New( + frame.SUBSCRIBE, + frame.Destination, "/topic/test") - e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e = <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, - frame.Message, invalidSubscriptionError.Error()), true) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, + frame.Message, invalidSubscriptionError.Error()), true) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_SubscribeMissingDestinationHeader(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - rawConn.incomingFrames <- frame.New( - frame.SUBSCRIBE, - frame.Id, "sub-id") + rawConn.incomingFrames <- frame.New( + frame.SUBSCRIBE, + frame.Id, "sub-id") - e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e = <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, - frame.Message, invalidFrameError.Error()), true) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, + frame.Message, invalidFrameError.Error()), true) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_Subscribe(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) - rawConn.incomingFrames <- frame.New( - frame.SUBSCRIBE, - frame.Id, "sub-id", - frame.Destination, "/topic/test") + rawConn.incomingFrames <- frame.New( + frame.SUBSCRIBE, + frame.Id, "sub-id", + frame.Destination, "/topic/test") - e = <-events - assert.Equal(t, e.eventType, SubscribeToTopic) - assert.Equal(t, e.conn, stompConn) - assert.Equal(t, e.destination, "/topic/test") - assert.Equal(t, e.sub.destination, "/topic/test") - assert.Equal(t, e.sub.id, "sub-id") - assert.Equal(t, e.frame.Command, frame.SUBSCRIBE) + e = <-events + assert.Equal(t, e.eventType, SubscribeToTopic) + assert.Equal(t, e.conn, stompConn) + assert.Equal(t, e.destination, "/topic/test") + assert.Equal(t, e.sub.destination, "/topic/test") + assert.Equal(t, e.sub.id, "sub-id") + assert.Equal(t, e.frame.Command, frame.SUBSCRIBE) - assert.Equal(t, len(rawConn.sentFrames), 1) - assert.Equal(t, stompConn.state, connected) + assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, stompConn.state, connected) - assert.Equal(t, stompConn.subscriptions["sub-id"].destination, "/topic/test") + assert.Equal(t, stompConn.subscriptions["sub-id"].destination, "/topic/test") - // trigger send subscribe request with the same id - rawConn.incomingFrames <- frame.New( - frame.SUBSCRIBE, - frame.Id, "sub-id", - frame.Destination, "/topic/test") + // trigger send subscribe request with the same id + rawConn.incomingFrames <- frame.New( + frame.SUBSCRIBE, + frame.Id, "sub-id", + frame.Destination, "/topic/test") - // verify that there was no second subscription created for the same subscription id - assert.Equal(t, e.sub, stompConn.subscriptions["sub-id"]) + // verify that there was no second subscription created for the same subscription id + assert.Equal(t, e.sub, stompConn.subscriptions["sub-id"]) } func TestStompConn_SendNotConnected(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) - rawConn.incomingFrames <- frame.New( - frame.SEND, - frame.Destination, "/pub/test") + rawConn.incomingFrames <- frame.New( + frame.SEND, + frame.Destination, "/pub/test") - e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e := <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, notConnectedStompError.Error()), true) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, notConnectedStompError.Error()), true) } func TestStompConn_SendMissingDestinationHeader(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - rawConn.incomingFrames <- frame.New( - frame.SEND) + rawConn.incomingFrames <- frame.New( + frame.SEND) - e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e = <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, - frame.Message, invalidFrameError.Error()), true) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, + frame.Message, invalidFrameError.Error()), true) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_Send_InvalidSend(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) - // try sending a frame to a topic channel directly not request channel - rawConn.incomingFrames <- frame.New(frame.SEND, - frame.Destination, "/topic/test") - e = <-events + // try sending a frame to a topic channel directly not request channel + rawConn.incomingFrames <- frame.New(frame.SEND, + frame.Destination, "/topic/test") + e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, - frame.Message, invalidSendDestinationError.Error()), true) + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, + frame.Message, invalidSendDestinationError.Error()), true) } func TestStompConn_Send(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - msgF := frame.New(frame.SEND, frame.Destination, "/pub/test") + msgF := frame.New(frame.SEND, frame.Destination, "/pub/test") - rawConn.incomingFrames <- msgF + rawConn.incomingFrames <- msgF - e = <-events - assert.Equal(t, e.eventType, IncomingMessage) - assert.Equal(t, e.frame, msgF) - assert.Equal(t, e.frame.Command, frame.MESSAGE) + e = <-events + assert.Equal(t, e.eventType, IncomingMessage) + assert.Equal(t, e.frame, msgF) + assert.Equal(t, e.frame.Command, frame.MESSAGE) - rawConn.incomingFrames <- frame.New(frame.SEND, - frame.Destination, "/pub/test", frame.Receipt, "receipt-id") + rawConn.incomingFrames <- frame.New(frame.SEND, + frame.Destination, "/pub/test", frame.Receipt, "receipt-id") - e = <-events - assert.Equal(t, e.eventType, IncomingMessage) + e = <-events + assert.Equal(t, e.eventType, IncomingMessage) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.RECEIPT, - frame.ReceiptId, "receipt-id"), true) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.RECEIPT, + frame.ReceiptId, "receipt-id"), true) } func TestStompConn_UnsubscribeNotConnected(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.incomingFrames <- frame.New( - frame.UNSUBSCRIBE, - frame.Destination, "/topic/test") + rawConn.incomingFrames <- frame.New( + frame.UNSUBSCRIBE, + frame.Destination, "/topic/test") - e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e := <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, notConnectedStompError.Error()), true) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, notConnectedStompError.Error()), true) } func TestStompConn_UnsubscribeMissingIdHeader(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - rawConn.incomingFrames <- frame.New( - frame.UNSUBSCRIBE) + rawConn.incomingFrames <- frame.New( + frame.UNSUBSCRIBE) - e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e = <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, - frame.Message, invalidSubscriptionError.Error()), true) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.ERROR, + frame.Message, invalidSubscriptionError.Error()), true) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_Unsubscribe(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - rawConn.incomingFrames <- frame.New( - frame.UNSUBSCRIBE, - frame.Id, "invalid-sub-id", - frame.Receipt, "receipt-id") + rawConn.incomingFrames <- frame.New( + frame.UNSUBSCRIBE, + frame.Id, "invalid-sub-id", + frame.Receipt, "receipt-id") - rawConn.incomingFrames <- frame.New( - frame.SUBSCRIBE, - frame.Id, "sub-id", - frame.Destination, "/topic/test") + rawConn.incomingFrames <- frame.New( + frame.SUBSCRIBE, + frame.Id, "sub-id", + frame.Destination, "/topic/test") - e = <-events - assert.Equal(t, e.eventType, SubscribeToTopic) + e = <-events + assert.Equal(t, e.eventType, SubscribeToTopic) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.RECEIPT, - frame.ReceiptId, "receipt-id"), true) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], frame.New(frame.RECEIPT, + frame.ReceiptId, "receipt-id"), true) - rawConn.incomingFrames <- frame.New( - frame.UNSUBSCRIBE, - frame.Id, "sub-id") + rawConn.incomingFrames <- frame.New( + frame.UNSUBSCRIBE, + frame.Id, "sub-id") - e = <-events - assert.Equal(t, e.eventType, UnsubscribeFromTopic) - assert.Equal(t, e.conn, stompConn) - assert.Equal(t, e.destination, "/topic/test") - assert.Equal(t, e.sub.destination, "/topic/test") - assert.Equal(t, e.sub.id, "sub-id") + e = <-events + assert.Equal(t, e.eventType, UnsubscribeFromTopic) + assert.Equal(t, e.conn, stompConn) + assert.Equal(t, e.destination, "/topic/test") + assert.Equal(t, e.sub.destination, "/topic/test") + assert.Equal(t, e.sub.id, "sub-id") - assert.Equal(t, len(stompConn.subscriptions), 0) - assert.Equal(t, stompConn.state, connected) + assert.Equal(t, len(stompConn.subscriptions), 0) + assert.Equal(t, stompConn.state, connected) } func TestStompConn_DisconnectNotConnected(t *testing.T) { - _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + _, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.incomingFrames <- frame.New( - frame.DISCONNECT) + rawConn.incomingFrames <- frame.New( + frame.DISCONNECT) - e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e := <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, - frame.Message, notConnectedStompError.Error()), true) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.ERROR, + frame.Message, notConnectedStompError.Error()), true) } func TestStompConn_Disconnect(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) - rawConn.incomingFrames <- frame.New( - frame.DISCONNECT) + rawConn.incomingFrames <- frame.New( + frame.DISCONNECT) - e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e = <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 1) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_DisconnectWithReceipt(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - assert.Equal(t, len(rawConn.sentFrames), 1) - verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) + assert.Equal(t, len(rawConn.sentFrames), 1) + verifyFrame(t, rawConn.sentFrames[0], frame.New(frame.CONNECTED), false) - rawConn.incomingFrames <- frame.New( - frame.DISCONNECT, - frame.Receipt, "test-receipt") + rawConn.incomingFrames <- frame.New( + frame.DISCONNECT, + frame.Receipt, "test-receipt") - e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e = <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 2) - verifyFrame(t, rawConn.sentFrames[1], - frame.New(frame.RECEIPT, frame.ReceiptId, "test-receipt"), true) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, len(rawConn.sentFrames), 2) + verifyFrame(t, rawConn.sentFrames[1], + frame.New(frame.RECEIPT, frame.ReceiptId, "test-receipt"), true) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_Close(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - stompConn.Close() + stompConn.Close() - e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e := <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 0) - assert.Equal(t, stompConn.state, closed) - assert.Equal(t, rawConn.connected, false) + assert.Equal(t, len(rawConn.sentFrames), 0) + assert.Equal(t, stompConn.state, closed) + assert.Equal(t, rawConn.connected, false) } func TestStompConn_SendFrameToSubscription(t *testing.T) { - stompConn, rawConn, _ := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, _ := getTestStompConn(NewStompConfig(0, []string{}), nil) - sub := &subscription{ - id: "sub-id", - destination: "/topic/test", - } + sub := &subscription{ + id: "sub-id", + destination: "/topic/test", + } - f := frame.New(frame.MESSAGE, frame.Destination, "/topic/test") + f := frame.New(frame.MESSAGE, frame.Destination, "/topic/test") - rawConn.writeWg = &sync.WaitGroup{} - rawConn.writeWg.Add(1) + rawConn.writeWg = &sync.WaitGroup{} + rawConn.writeWg.Add(1) - stompConn.SendFrameToSubscription(f, sub) + stompConn.SendFrameToSubscription(f, sub) - rawConn.writeWg.Wait() - assert.Equal(t, len(rawConn.sentFrames), 1) + rawConn.writeWg.Wait() + assert.Equal(t, len(rawConn.sentFrames), 1) - assert.Equal(t, rawConn.sentFrames[0], f) - assert.Equal(t, rawConn.sentFrames[0].Header.Get(frame.MessageId), "1") + assert.Equal(t, rawConn.sentFrames[0], f) + assert.Equal(t, rawConn.sentFrames[0].Header.Get(frame.MessageId), "1") - rawConn.writeWg.Add(1) - stompConn.SendFrameToSubscription(f, sub) - rawConn.writeWg.Wait() - assert.Equal(t, len(rawConn.sentFrames), 2) - assert.Equal(t, rawConn.sentFrames[1].Header.Get(frame.MessageId), "2") + rawConn.writeWg.Add(1) + stompConn.SendFrameToSubscription(f, sub) + rawConn.writeWg.Wait() + assert.Equal(t, len(rawConn.sentFrames), 2) + assert.Equal(t, rawConn.sentFrames[1].Header.Get(frame.MessageId), "2") - rawConn.writeWg.Add(50) - for i := 0; i < 50; i++ { - go stompConn.SendFrameToSubscription(f.Clone(), sub) - } + rawConn.writeWg.Add(50) + for i := 0; i < 50; i++ { + go stompConn.SendFrameToSubscription(f.Clone(), sub) + } - rawConn.writeWg.Wait() - assert.Equal(t, len(rawConn.sentFrames), 52) + rawConn.writeWg.Wait() + assert.Equal(t, len(rawConn.sentFrames), 52) } func TestStompConn_SendErrorFrameToSubscription(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - sub := &subscription{ - id: "sub-id", - destination: "/topic/test", - } + sub := &subscription{ + id: "sub-id", + destination: "/topic/test", + } - f := frame.New(frame.ERROR, frame.Destination, "/topic/test") - stompConn.SendFrameToSubscription(f, sub) + f := frame.New(frame.ERROR, frame.Destination, "/topic/test") + stompConn.SendFrameToSubscription(f, sub) - e := <-events + e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 1) + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, len(rawConn.sentFrames), 1) } func TestStompConn_ReadError(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.incomingFrames <- errors.New("read error") + rawConn.incomingFrames <- errors.New("read error") - e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) + e := <-events + assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, len(rawConn.sentFrames), 0) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, len(rawConn.sentFrames), 0) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_WriteErrorDuringConnect(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{}), nil) - rawConn.nextWriteErr = errors.New("write error") - rawConn.SendConnectFrame() + rawConn.nextWriteErr = errors.New("write error") + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, stompConn.state, closed) + e := <-events + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_WriteErrorDuringSend(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(0, []string{"/pub/"}), nil) - rawConn.SendConnectFrame() + rawConn.SendConnectFrame() - e := <-events - assert.Equal(t, e.eventType, ConnectionEstablished) + e := <-events + assert.Equal(t, e.eventType, ConnectionEstablished) - rawConn.nextWriteErr = errors.New("write error") - rawConn.incomingFrames <- frame.New( - frame.SEND, - frame.Destination, "/pub/", - frame.Receipt, "receipt-id") + rawConn.nextWriteErr = errors.New("write error") + rawConn.incomingFrames <- frame.New( + frame.SEND, + frame.Destination, "/pub/", + frame.Receipt, "receipt-id") - e = <-events + e = <-events - assert.Equal(t, e.eventType, ConnectionClosed) - assert.Equal(t, stompConn.state, closed) + assert.Equal(t, e.eventType, ConnectionClosed) + assert.Equal(t, stompConn.state, closed) } func TestStompConn_SetReadDeadline(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(20000, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(20000, []string{}), nil) - infiniteTimeout := time.Time{} + infiniteTimeout := time.Time{} - assert.Equal(t, rawConn.getCurrentReadDeadline(), infiniteTimeout) + assert.Equal(t, rawConn.getCurrentReadDeadline(), infiniteTimeout) - rawConn.incomingFrames <- frame.New( - frame.CONNECT, - frame.AcceptVersion, "1.2", - frame.HeartBeat, "200,200") + rawConn.incomingFrames <- frame.New( + frame.CONNECT, + frame.AcceptVersion, "1.2", + frame.HeartBeat, "200,200") - <-events + <-events - // verify timeout is set to 20 seconds - assert.Equal(t, stompConn.readTimeoutMs, int64(20000)) + // verify timeout is set to 20 seconds + assert.Equal(t, stompConn.readTimeoutMs, int64(20000)) - rawConn.incomingFrames <- nil - rawConn.incomingFrames <- nil + rawConn.incomingFrames <- nil + rawConn.incomingFrames <- nil - diff := rawConn.getCurrentReadDeadline().Sub(time.Now()) + diff := rawConn.getCurrentReadDeadline().Sub(time.Now()) - // verify the read deadline for the connection is - // between 15 and 21 seconds - // assert.Greater(t, diff.Seconds(), float64(15)) <-- we don't care anymore. - assert.Greater(t, float64(21), diff.Seconds()) + // verify the read deadline for the connection is + // between 15 and 21 seconds + // assert.Greater(t, diff.Seconds(), float64(15)) <-- we don't care anymore. + assert.Greater(t, float64(21), diff.Seconds()) } func TestStompConn_WriteHeartbeat(t *testing.T) { - stompConn, rawConn, events := getTestStompConn(NewStompConfig(100, []string{}), nil) + stompConn, rawConn, events := getTestStompConn(NewStompConfig(100, []string{}), nil) - rawConn.incomingFrames <- frame.New( - frame.CONNECT, - frame.AcceptVersion, "1.2", - frame.HeartBeat, "50,50") + rawConn.incomingFrames <- frame.New( + frame.CONNECT, + frame.AcceptVersion, "1.2", + frame.HeartBeat, "50,50") - <-events + <-events - rawConn.lock.Lock() - rawConn.writeWg = new(sync.WaitGroup) - rawConn.writeWg.Add(2) - rawConn.lock.Unlock() + rawConn.lock.Lock() + rawConn.writeWg = new(sync.WaitGroup) + rawConn.writeWg.Add(2) + rawConn.lock.Unlock() - rawConn.writeWg.Wait() + rawConn.writeWg.Wait() - // verify the last frame is heartbeat (nil) - rawConn.lock.Lock() - assert.Nil(t, rawConn.LastSentFrame()) - rawConn.lock.Unlock() + // verify the last frame is heartbeat (nil) + rawConn.lock.Lock() + assert.Nil(t, rawConn.LastSentFrame()) + rawConn.lock.Unlock() - rawConn.writeWg.Add(3) - stompConn.SendFrameToSubscription(frame.New(frame.MESSAGE), &subscription{id: "sub-1"}) - rawConn.writeWg.Wait() + rawConn.writeWg.Add(3) + stompConn.SendFrameToSubscription(frame.New(frame.MESSAGE), &subscription{id: "sub-1"}) + rawConn.writeWg.Wait() - // verify the last frame is heartbeat (nil) - rawConn.lock.Lock() - rawConn.writeWg = nil - assert.Nil(t, rawConn.LastSentFrame()) - rawConn.lock.Unlock() + // verify the last frame is heartbeat (nil) + rawConn.lock.Lock() + rawConn.writeWg = nil + assert.Nil(t, rawConn.LastSentFrame()) + rawConn.lock.Unlock() } func verifyFrame(t *testing.T, actualFrame *frame.Frame, expectedFrame *frame.Frame, exactHeaderMatch bool) { - assert.Equal(t, expectedFrame.Command, actualFrame.Command) - if exactHeaderMatch { - assert.Equal(t, expectedFrame.Header.Len(), actualFrame.Header.Len()) - } + assert.Equal(t, expectedFrame.Command, actualFrame.Command) + if exactHeaderMatch { + assert.Equal(t, expectedFrame.Header.Len(), actualFrame.Header.Len()) + } - for i := 0; i < expectedFrame.Header.Len(); i++ { - key, value := expectedFrame.Header.GetAt(i) - assert.Equal(t, actualFrame.Header.Get(key), value) - } + for i := 0; i < expectedFrame.Header.Len(); i++ { + key, value := expectedFrame.Header.GetAt(i) + assert.Equal(t, actualFrame.Header.Get(key), value) + } } func printFrame(f *frame.Frame) { - if f == nil { - fmt.Println("HEARTBEAT FRAME") - } else { - fmt.Println("FRAME:", f.Command) - for i := 0; i < f.Header.Len(); i++ { - key, value := f.Header.GetAt(i) - fmt.Println(key+":", value) - } - fmt.Println("BODY:", string(f.Body)) - } + if f == nil { + fmt.Println("HEARTBEAT FRAME") + } else { + fmt.Println("FRAME:", f.Command) + for i := 0; i < f.Header.Len(); i++ { + key, value := f.Header.GetAt(i) + fmt.Println(key+":", value) + } + fmt.Println("BODY:", string(f.Body)) + } } diff --git a/stompserver/websocket_connection_listener.go b/stompserver/websocket_connection_listener.go index 09caece..9600ee1 100644 --- a/stompserver/websocket_connection_listener.go +++ b/stompserver/websocket_connection_listener.go @@ -4,174 +4,174 @@ package stompserver import ( - "github.com/go-stomp/stomp/v3/frame" - "github.com/gorilla/mux" - "github.com/gorilla/websocket" - "net" - "net/http" - "net/url" - "strings" - - "time" + "github.com/go-stomp/stomp/v3/frame" + "github.com/gorilla/mux" + "github.com/gorilla/websocket" + "net" + "net/http" + "net/url" + "strings" + + "time" ) type webSocketStompConnection struct { - wsCon *websocket.Conn + wsCon *websocket.Conn } func (c *webSocketStompConnection) ReadFrame() (*frame.Frame, error) { - _, r, err := c.wsCon.NextReader() - if err != nil { - return nil, err - } - frameR := frame.NewReader(r) - f, e := frameR.Read() - return f, e + _, r, err := c.wsCon.NextReader() + if err != nil { + return nil, err + } + frameR := frame.NewReader(r) + f, e := frameR.Read() + return f, e } func (c *webSocketStompConnection) WriteFrame(f *frame.Frame) error { - wr, err := c.wsCon.NextWriter(websocket.TextMessage) - if err != nil { - return err - } - frameWr := frame.NewWriter(wr) - err = frameWr.Write(f) - if err != nil { - return err - } - err = wr.Close() - return err + wr, err := c.wsCon.NextWriter(websocket.TextMessage) + if err != nil { + return err + } + frameWr := frame.NewWriter(wr) + err = frameWr.Write(f) + if err != nil { + return err + } + err = wr.Close() + return err } func (c *webSocketStompConnection) SetReadDeadline(t time.Time) { - c.wsCon.SetReadDeadline(t) + c.wsCon.SetReadDeadline(t) } func (c *webSocketStompConnection) Close() error { - return c.wsCon.Close() + return c.wsCon.Close() } type webSocketConnectionListener struct { - httpServer *http.Server - requestHandler *http.ServeMux - tcpConnectionListener net.Listener - connectionsChannel chan rawConnResult - allowedOrigins []string + httpServer *http.Server + requestHandler *http.ServeMux + tcpConnectionListener net.Listener + connectionsChannel chan rawConnResult + allowedOrigins []string } type rawConnResult struct { - conn RawConnection - err error + conn RawConnection + err error } func NewWebSocketConnectionFromExistingHttpServer(httpServer *http.Server, handler *mux.Router, - endpoint string, allowedOrigins []string) (RawConnectionListener, error) { - l := &webSocketConnectionListener{ - httpServer: httpServer, - connectionsChannel: make(chan rawConnResult), - allowedOrigins: allowedOrigins, - } - - var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - - upgrader.CheckOrigin = l.checkOrigin - - handler.HandleFunc(endpoint, func(writer http.ResponseWriter, request *http.Request) { - upgrader.Subprotocols = websocket.Subprotocols(request) - conn, err := upgrader.Upgrade(writer, request, nil) - if err != nil { - l.connectionsChannel <- rawConnResult{err: err} - - } else { - l.connectionsChannel <- rawConnResult{ - conn: &webSocketStompConnection{ - wsCon: conn, - }, - } - } - }) - - return l, nil + endpoint string, allowedOrigins []string) (RawConnectionListener, error) { + l := &webSocketConnectionListener{ + httpServer: httpServer, + connectionsChannel: make(chan rawConnResult), + allowedOrigins: allowedOrigins, + } + + var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + upgrader.CheckOrigin = l.checkOrigin + + handler.HandleFunc(endpoint, func(writer http.ResponseWriter, request *http.Request) { + upgrader.Subprotocols = websocket.Subprotocols(request) + conn, err := upgrader.Upgrade(writer, request, nil) + if err != nil { + l.connectionsChannel <- rawConnResult{err: err} + + } else { + l.connectionsChannel <- rawConnResult{ + conn: &webSocketStompConnection{ + wsCon: conn, + }, + } + } + }) + + return l, nil } func NewWebSocketConnectionListener(addr string, endpoint string, allowedOrigins []string) (RawConnectionListener, error) { - rh := http.NewServeMux() - l := &webSocketConnectionListener{ - requestHandler: rh, - httpServer: &http.Server{ - Addr: addr, - Handler: rh, - }, - connectionsChannel: make(chan rawConnResult), - allowedOrigins: allowedOrigins, - } - - var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - - upgrader.CheckOrigin = l.checkOrigin - - rh.HandleFunc(endpoint, func(writer http.ResponseWriter, request *http.Request) { - upgrader.Subprotocols = websocket.Subprotocols(request) - conn, err := upgrader.Upgrade(writer, request, nil) - if err != nil { - l.connectionsChannel <- rawConnResult{err: err} - - } else { - l.connectionsChannel <- rawConnResult{ - conn: &webSocketStompConnection{ - wsCon: conn, - }, - } - } - }) - - var err error - l.tcpConnectionListener, err = net.Listen("tcp", addr) - if err != nil { - return nil, err - } - - go l.httpServer.Serve(l.tcpConnectionListener) - return l, nil + rh := http.NewServeMux() + l := &webSocketConnectionListener{ + requestHandler: rh, + httpServer: &http.Server{ + Addr: addr, + Handler: rh, + }, + connectionsChannel: make(chan rawConnResult), + allowedOrigins: allowedOrigins, + } + + var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + upgrader.CheckOrigin = l.checkOrigin + + rh.HandleFunc(endpoint, func(writer http.ResponseWriter, request *http.Request) { + upgrader.Subprotocols = websocket.Subprotocols(request) + conn, err := upgrader.Upgrade(writer, request, nil) + if err != nil { + l.connectionsChannel <- rawConnResult{err: err} + + } else { + l.connectionsChannel <- rawConnResult{ + conn: &webSocketStompConnection{ + wsCon: conn, + }, + } + } + }) + + var err error + l.tcpConnectionListener, err = net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + go l.httpServer.Serve(l.tcpConnectionListener) + return l, nil } func (l *webSocketConnectionListener) checkOrigin(r *http.Request) bool { - if len(l.allowedOrigins) == 0 { - return true - } - - origin := r.Header["Origin"] - if len(origin) == 0 { - return true - } - u, err := url.Parse(origin[0]) - if err != nil { - return false - } - if strings.ToLower(u.Host) == strings.ToLower(r.Host) { - return true - } - - for _, allowedOrigin := range l.allowedOrigins { - if strings.ToLower(u.Host) == strings.ToLower(allowedOrigin) { - return true - } - } - - return false + if len(l.allowedOrigins) == 0 { + return true + } + + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + if strings.ToLower(u.Host) == strings.ToLower(r.Host) { + return true + } + + for _, allowedOrigin := range l.allowedOrigins { + if strings.ToLower(u.Host) == strings.ToLower(allowedOrigin) { + return true + } + } + + return false } func (l *webSocketConnectionListener) Accept() (RawConnection, error) { - cr := <-l.connectionsChannel - return cr.conn, cr.err + cr := <-l.connectionsChannel + return cr.conn, cr.err } func (l *webSocketConnectionListener) Close() error { - return l.httpServer.Close() + return l.httpServer.Close() }