diff --git a/client.go b/client.go index edd8c0c..ef8e4b7 100644 --- a/client.go +++ b/client.go @@ -393,7 +393,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) { tlsCfg = c.options.OnConnectAttempt(broker, c.options.TLSConfig) } // Start by opening the network connection (tcp, tls, ws) etc - conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions, c.options.TcpOptions) + conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions, c.options.Dialer) if err != nil { ERROR.Println(CLI, err.Error()) WARN.Println(CLI, "failed to connect to broker, trying next") diff --git a/netconn.go b/netconn.go index fc2b6f4..7e3899e 100644 --- a/netconn.go +++ b/netconn.go @@ -37,7 +37,7 @@ import ( // openConnection opens a network connection using the protocol indicated in the URL. // Does not carry out any MQTT specific handshakes. -func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions, tcpOptions *TcpOptions) (net.Conn, error) { +func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions, dialer *net.Dialer) (net.Conn, error) { switch uri.Scheme { case "ws": conn, err := NewWebsocket(uri.String(), nil, timeout, headers, websocketOptions) @@ -48,7 +48,6 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade case "mqtt", "tcp": allProxy := os.Getenv("all_proxy") if len(allProxy) == 0 { - dialer := net.Dialer{Timeout: timeout, KeepAlive: tcpOptions.KeepAlive} conn, err := dialer.Dial("tcp", uri.Host) if err != nil { return nil, err @@ -68,7 +67,6 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade // this check is preserved for compatibility with older versions // which used uri.Host only (it works for local paths, e.g. unix://socket.sock in current dir) - dialer := net.Dialer{Timeout: timeout, KeepAlive: tcpOptions.KeepAlive} if len(uri.Host) > 0 { conn, err = dialer.Dial("unix", uri.Host) } else { @@ -82,7 +80,7 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade case "ssl", "tls", "mqtts", "mqtt+ssl", "tcps": allProxy := os.Getenv("all_proxy") if len(allProxy) == 0 { - conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout, KeepAlive: tcpOptions.KeepAlive}, "tcp", uri.Host, tlsc) + conn, err := tls.DialWithDialer(dialer, "tcp", uri.Host, tlsc) if err != nil { return nil, err } diff --git a/options.go b/options.go index c7fb218..e745258 100644 --- a/options.go +++ b/options.go @@ -23,6 +23,7 @@ package mqtt import ( "crypto/tls" + "net" "net/http" "net/url" "strings" @@ -96,7 +97,7 @@ type ClientOptions struct { HTTPHeaders http.Header WebsocketOptions *WebsocketOptions MaxResumePubInFlight int // // 0 = no limit; otherwise this is the maximum simultaneous messages sent while resuming - TcpOptions *TcpOptions + Dialer *net.Dialer } // NewClientOptions will create a new ClientClientOptions type with some @@ -138,7 +139,7 @@ func NewClientOptions() *ClientOptions { ResumeSubs: false, HTTPHeaders: make(map[string][]string), WebsocketOptions: &WebsocketOptions{}, - TcpOptions: &TcpOptions{}, + Dialer: &net.Dialer{Timeout: 30 * time.Second}, } return o } @@ -357,6 +358,7 @@ func (o *ClientOptions) SetWriteTimeout(t time.Duration) *ClientOptions { // Default 30 seconds. Currently only operational on TCP/TLS connections. func (o *ClientOptions) SetConnectTimeout(t time.Duration) *ClientOptions { o.ConnectTimeout = t + o.Dialer.Timeout = t return o } @@ -422,8 +424,8 @@ func (o *ClientOptions) SetMaxResumePubInFlight(MaxResumePubInFlight int) *Clien return o } -// SetTcpOptions sets the additional tcp options used in a tcp connection -func (o *ClientOptions) SetTcpOptions(w *TcpOptions) *ClientOptions { - o.TcpOptions = w +// SetDialer sets the tcp dialer options used in a tcp connection +func (o *ClientOptions) SetDialer(dialer *net.Dialer) *ClientOptions { + o.Dialer = dialer return o } diff --git a/tcp.go b/tcp.go deleted file mode 100644 index 5ec12ff..0000000 --- a/tcp.go +++ /dev/null @@ -1,7 +0,0 @@ -package mqtt - -import "time" - -type TcpOptions struct { - KeepAlive time.Duration -}