Skip to content

Commit

Permalink
Add optional method ProxyTLSConnection (closes gorilla#779)
Browse files Browse the repository at this point in the history
Removed the call to NetDialTLSContext from the HTTP proxy CONNECT step and replaced it with a regular net.Dial in order to prevent connection issues. Custom TLS connections can now be made via the new optional ProxyTLSConnection method, after the proxy connection has been successfully established.
  • Loading branch information
sleeyax committed May 12, 2022
1 parent 78cf1bc commit 502bd65
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
45 changes: 28 additions & 17 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ type Dialer struct {
// TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)

// ProxyTLSConnection specifies the dial function for creating TLS connections through a Proxy. If
// ProxyTLSConnection is nil, NetDialTLSContext is used.
// If ProxyTLSConnection is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
ProxyTLSConnection func(ctx context.Context, proxyConn net.Conn) (net.Conn, error)

// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
Expand Down Expand Up @@ -333,26 +339,31 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()

if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
if u.Scheme == "https" {
if d.ProxyTLSConnection != nil && d.Proxy != nil {
// If we are connected to a proxy, perform the TLS handshake through the existing tunnel
netConn, err = d.ProxyTLSConnection(ctx, netConn)
} else if d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done

cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn

if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(ctx, tlsConn, cfg)
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(ctx, tlsConn, cfg)
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}

if err != nil {
return nil, nil, err
if err != nil {
return nil, nil, err
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type httpProxyDialer struct {

func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.forwardDial(network, hostPort)
conn, err := net.Dial(network, hostPort)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 502bd65

Please sign in to comment.