From a0e6e17a6cf05356923c477c6f9ee0b3c41684c6 Mon Sep 17 00:00:00 2001 From: amir gh Date: Thu, 22 Aug 2024 18:25:51 -0700 Subject: [PATCH] add DoH test and add trace to ctx --- x/connectivity/connectivity.go | 188 +++++++++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) diff --git a/x/connectivity/connectivity.go b/x/connectivity/connectivity.go index c6d26124..7579378f 100644 --- a/x/connectivity/connectivity.go +++ b/x/connectivity/connectivity.go @@ -16,12 +16,22 @@ package connectivity import ( "context" + ctls "crypto/tls" "errors" "fmt" + "io" + "log" + "net" + "net/http" + "net/http/httptrace" + "os" + "strings" "syscall" "time" "github.com/Jigsaw-Code/outline-sdk/dns" + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/x/trace" "golang.org/x/net/dns/dnsmessage" ) @@ -83,6 +93,7 @@ func TestConnectivityWithResolver(ctx context.Context, resolver dns.Resolver, te return nil, fmt.Errorf("question creation failed: %w", err) } + // Pass this context to your DNS resolver function _, err = resolver.Query(ctx, *q) if errors.Is(err, dns.ErrBadRequest) { @@ -97,3 +108,180 @@ func TestConnectivityWithResolver(ctx context.Context, resolver dns.Resolver, te } return nil, nil } + +func TestStreamConnectivitywithHTTP(ctx context.Context, baseDialer transport.StreamDialer, domain string, timeout time.Duration, method string) (*ConnectivityError, error) { + dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("invalid address: %w", err) + } + if !strings.HasPrefix(network, "tcp") { + return nil, fmt.Errorf("protocol not supported: %v", network) + } + return baseDialer.DialStream(ctx, net.JoinHostPort(host, port)) + } + httpClient := &http.Client{ + Transport: &http.Transport{DialContext: dialContext}, + Timeout: time.Duration(timeout) * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + req, err := http.NewRequest(method, domain, nil) + if err != nil { + log.Fatalln("Failed to create request:", err) + } + // TODO: Add this as test param + // headerText := strings.Join(headersFlag, "\r\n") + "\r\n\r\n" + // h, err := textproto.NewReader(bufio.NewReader(strings.NewReader(headerText))).ReadMIMEHeader() + // if err != nil { + // log.Fatalf("invalid header line: %v", err) + // } + // for name, values := range h { + // for _, value := range values { + // req.Header.Add(name, value) + // } + // } + + req = req.WithContext(ctx) + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + for k, v := range resp.Header { + fmt.Printf("%v: %v\n", k, v) + } + + fmt.Printf("StatusCode %v\n", resp.StatusCode) + + _, err = io.Copy(os.Stdout, resp.Body) + if err != nil { + log.Fatalf("Read of page body failed: %v\n", err) + } + + return nil, nil +} + +func AddLoggerTrace(ctx context.Context) context.Context { + t := &trace.DNSClientTrace{ + QuestionSent: func(question dnsmessage.Question) { + fmt.Println("DNS query started for", question.Name.String()) + }, + ResponsDone: func(question dnsmessage.Question, msg *dnsmessage.Message, err error) { + if err != nil { + fmt.Printf("DNS query for %s failed: %v\n", question.Name.String(), err) + } else { + // Prepare to collect IP addresses + var ips []string + + // Iterate over the answer section + for _, answer := range msg.Answers { + switch rr := answer.Body.(type) { + case *dnsmessage.AResource: + // Handle IPv4 addresses - convert [4]byte to IP string + ipv4 := net.IP(rr.A[:]) // Convert [4]byte to net.IP + ips = append(ips, ipv4.String()) + case *dnsmessage.AAAAResource: + // Handle IPv6 addresses - convert [16]byte to IP string + ipv6 := net.IP(rr.AAAA[:]) // Convert [16]byte to net.IP + ips = append(ips, ipv6.String()) + } + } + + // Print all resolved IP addresses + if len(ips) > 0 { + fmt.Printf("Resolved IPs for %s: %v\n", question.Name.String(), ips) + } else { + fmt.Printf("No IPs found for %s\n", question.Name.String()) + } + } + }, + ConnectDone: func(network, addr string, err error) { + if err != nil { + fmt.Printf("%v Connection to %s failed: %v\n", network, addr, err) + } else { + fmt.Printf("%v Connection to %s succeeded\n", network, addr) + } + }, + WroteDone: func(err error) { + if err != nil { + fmt.Printf("Write failed: %v\n", err) + } else { + fmt.Println("Write succeeded") + } + }, + ReadDone: func(err error) { + if err != nil { + fmt.Printf("Read failed: %v\n", err) + } else { + fmt.Println("Read succeeded") + } + }, + } + + // Variables to store the timestamps + var startTLS time.Time + + ht := &httptrace.ClientTrace{ + DNSStart: func(info httptrace.DNSStartInfo) { + fmt.Printf("DNS start: %v\n", info) + }, + DNSDone: func(info httptrace.DNSDoneInfo) { + fmt.Printf("DNS done: %v\n", info) + }, + ConnectStart: func(network, addr string) { + fmt.Printf("Connect start: %v %v\n", network, addr) + }, + ConnectDone: func(network, addr string, err error) { + fmt.Printf("Connect done: %v %v %v\n", network, addr, err) + }, + GotFirstResponseByte: func() { + fmt.Println("Got first response byte") + }, + WroteHeaderField: func(key string, value []string) { + fmt.Printf("Wrote header field: %v %v\n", key, value) + }, + WroteHeaders: func() { + fmt.Println("Wrote headers") + }, + WroteRequest: func(info httptrace.WroteRequestInfo) { + fmt.Printf("Wrote request: %v\n", info) + }, + TLSHandshakeStart: func() { + startTLS = time.Now() + }, + TLSHandshakeDone: func(state ctls.ConnectionState, err error) { + if err != nil { + fmt.Printf("TLS handshake failed: %v\n", err) + } + fmt.Printf("SNI: %v\n", state.ServerName) + fmt.Printf("TLS version: %v\n", state.Version) + fmt.Printf("ALPN: %v\n", state.NegotiatedProtocol) + fmt.Printf("TLS handshake took %v seconds.\n", time.Since(startTLS).Seconds()) + }, + } + + tlsTrace := &trace.TLSClientTrace{ + TLSHandshakeStart: func() { + fmt.Println("TLS handshake started") + startTLS = time.Now() + }, + TLSHandshakeDone: func(state ctls.ConnectionState, err error) { + if err != nil { + fmt.Printf("TLS handshake failed: %v\n", err) + } + fmt.Printf("SNI: %v\n", state.ServerName) + fmt.Printf("TLS version: %v\n", state.Version) + fmt.Printf("ALPN: %v\n", state.NegotiatedProtocol) + fmt.Printf("TLS handshake took %v seconds.\n", time.Since(startTLS).Seconds()) + }, + } + + ctx = httptrace.WithClientTrace(ctx, ht) + ctx = trace.WithDNSClientTrace(ctx, t) + ctx = trace.WithTLSClientTrace(ctx, tlsTrace) + return ctx +}