diff --git a/conn.go b/conn.go index 1f51820..9de9a74 100644 --- a/conn.go +++ b/conn.go @@ -16,6 +16,9 @@ import ( // to avoid premature disconnections due to network latency. const DefaultHeartBeatError = 5 * time.Second +// Default timeout of calling Conn.Send function +const DefaultMsgSendTimeout = 10 * time.Second + // A Conn is a connection to a STOMP server. Create a Conn using either // the Dial or Connect function. type Conn struct { @@ -27,6 +30,7 @@ type Conn struct { server string readTimeout time.Duration writeTimeout time.Duration + msgSendTimeout time.Duration hbGracePeriodMultiplier float64 closed bool closeMutex *sync.Mutex @@ -172,6 +176,8 @@ func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error) } } + c.msgSendTimeout = options.MsgSendTimeout + // TODO(jpj): make any non-standard headers in the CONNECTED // frame available. This could be implemented as: // (a) a callback function supplied as an option; or @@ -439,7 +445,10 @@ func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(* C: make(chan *frame.Frame), } - c.writeCh <- request + err := sendDataToWriteChWithTimeout(c.writeCh, request, c.msgSendTimeout) + if err != nil { + return err + } response := <-request.C if response.Command != frame.RECEIPT { return newError(response) @@ -447,12 +456,32 @@ func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(* } else { // no receipt required request := writeRequest{Frame: f} - c.writeCh <- request + + err := sendDataToWriteChWithTimeout(c.writeCh, request, c.msgSendTimeout) + if err != nil { + return err + } } return nil } +func sendDataToWriteChWithTimeout(ch chan writeRequest, request writeRequest, timeout time.Duration) error { + if timeout <= 0 { + ch <- request + return nil + } + + timer := time.NewTimer(timeout) + select { + case <-timer.C: + return ErrMsgSendTimeout + case ch <- request: + timer.Stop() + return nil + } +} + func createSendFrame(destination, contentType string, body []byte, opts []func(*frame.Frame) error) (*frame.Frame, error) { // Set the content-length before the options, because this provides // an opportunity to remove content-length. diff --git a/conn_options.go b/conn_options.go index 0e0aed0..86020b8 100644 --- a/conn_options.go +++ b/conn_options.go @@ -16,6 +16,7 @@ type connOptions struct { ReadTimeout time.Duration WriteTimeout time.Duration HeartBeatError time.Duration + MsgSendTimeout time.Duration HeartBeatGracePeriodMultiplier float64 Login, Passcode string AcceptVersions []string @@ -30,6 +31,7 @@ func newConnOptions(conn *Conn, opts []func(*Conn) error) (*connOptions, error) WriteTimeout: time.Minute, HeartBeatGracePeriodMultiplier: 1.0, HeartBeatError: DefaultHeartBeatError, + MsgSendTimeout: DefaultMsgSendTimeout, } // This is a slight of hand, attach the options to the Conn long @@ -127,6 +129,14 @@ var ConnOpt struct { // shorter time duration during unit testing. HeartBeatError func(errorTimeout time.Duration) func(*Conn) error + // MsgSendTimeout is a connect option that allows the client to specify + // the timeout for the Conn.Send function. + // The msgSendTimeout parameter specifies maximum blocking time for calling + // the Conn.Send function. + // If not specified, this option defaults to 10 seconds. + // Less than or equal to zero means infinite + MsgSendTimeout func(msgSendTimeout time.Duration) func(*Conn) error + // HeartBeatGracePeriodMultiplier is used to calculate the effective read heart-beat timeout // the broker will enforce for each client’s connection. The multiplier is applied to // the read-timeout interval the client specifies in its CONNECT frame @@ -196,6 +206,13 @@ func init() { } } + ConnOpt.MsgSendTimeout = func(msgSendTimeout time.Duration) func(*Conn) error { + return func(c *Conn) error { + c.options.MsgSendTimeout = msgSendTimeout + return nil + } + } + ConnOpt.HeartBeatGracePeriodMultiplier = func(multiplier float64) func(*Conn) error { return func(c *Conn) error { c.options.HeartBeatGracePeriodMultiplier = multiplier diff --git a/errors.go b/errors.go index 66fb650..7fb3f56 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,7 @@ var ( ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed") ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly") ErrAlreadyClosed = newErrorMessage("connection already closed") + ErrMsgSendTimeout = newErrorMessage("msg send timeout") ErrNilOption = newErrorMessage("nil option") )