diff --git a/client.go b/client.go index 11d04de..edd8c0c 100644 --- a/client.go +++ b/client.go @@ -393,7 +393,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) { tlsCfg = c.options.OnConnectAttempt(broker, c.options.TLSConfig) } // Start by opening the network connection (tcp, tls, ws) etc - conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions) + conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions, c.options.TcpOptions) if err != nil { ERROR.Println(CLI, err.Error()) WARN.Println(CLI, "failed to connect to broker, trying next") diff --git a/netconn.go b/netconn.go index c123849..fc2b6f4 100644 --- a/netconn.go +++ b/netconn.go @@ -37,7 +37,7 @@ import ( // openConnection opens a network connection using the protocol indicated in the URL. // Does not carry out any MQTT specific handshakes. -func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions) (net.Conn, error) { +func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions, tcpOptions *TcpOptions) (net.Conn, error) { switch uri.Scheme { case "ws": conn, err := NewWebsocket(uri.String(), nil, timeout, headers, websocketOptions) @@ -48,7 +48,8 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade case "mqtt", "tcp": allProxy := os.Getenv("all_proxy") if len(allProxy) == 0 { - conn, err := net.DialTimeout("tcp", uri.Host, timeout) + dialer := net.Dialer{Timeout: timeout, KeepAlive: tcpOptions.KeepAlive} + conn, err := dialer.Dial("tcp", uri.Host) if err != nil { return nil, err } @@ -67,10 +68,11 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade // this check is preserved for compatibility with older versions // which used uri.Host only (it works for local paths, e.g. unix://socket.sock in current dir) + dialer := net.Dialer{Timeout: timeout, KeepAlive: tcpOptions.KeepAlive} if len(uri.Host) > 0 { - conn, err = net.DialTimeout("unix", uri.Host, timeout) + conn, err = dialer.Dial("unix", uri.Host) } else { - conn, err = net.DialTimeout("unix", uri.Path, timeout) + conn, err = dialer.Dial("unix", uri.Path) } if err != nil { @@ -80,14 +82,13 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade case "ssl", "tls", "mqtts", "mqtt+ssl", "tcps": allProxy := os.Getenv("all_proxy") if len(allProxy) == 0 { - conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc) + conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout, KeepAlive: tcpOptions.KeepAlive}, "tcp", uri.Host, tlsc) if err != nil { return nil, err } return conn, nil } proxyDialer := proxy.FromEnvironment() - conn, err := proxyDialer.Dial("tcp", uri.Host) if err != nil { return nil, err diff --git a/options.go b/options.go index d7a24e4..c7fb218 100644 --- a/options.go +++ b/options.go @@ -96,6 +96,7 @@ type ClientOptions struct { HTTPHeaders http.Header WebsocketOptions *WebsocketOptions MaxResumePubInFlight int // // 0 = no limit; otherwise this is the maximum simultaneous messages sent while resuming + TcpOptions *TcpOptions } // NewClientOptions will create a new ClientClientOptions type with some @@ -137,6 +138,7 @@ func NewClientOptions() *ClientOptions { ResumeSubs: false, HTTPHeaders: make(map[string][]string), WebsocketOptions: &WebsocketOptions{}, + TcpOptions: &TcpOptions{}, } return o } @@ -419,3 +421,9 @@ func (o *ClientOptions) SetMaxResumePubInFlight(MaxResumePubInFlight int) *Clien o.MaxResumePubInFlight = MaxResumePubInFlight return o } + +// SetTcpOptions sets the additional tcp options used in a tcp connection +func (o *ClientOptions) SetTcpOptions(w *TcpOptions) *ClientOptions { + o.TcpOptions = w + return o +} diff --git a/tcp.go b/tcp.go new file mode 100644 index 0000000..5ec12ff --- /dev/null +++ b/tcp.go @@ -0,0 +1,7 @@ +package mqtt + +import "time" + +type TcpOptions struct { + KeepAlive time.Duration +}