diff --git a/cmd/trace.go b/cmd/trace.go index 613c1ad269..0c67775f52 100644 --- a/cmd/trace.go +++ b/cmd/trace.go @@ -24,6 +24,7 @@ var ( L4Proto string Port int OutputFile string + DropOnly bool ) func init() { @@ -56,7 +57,7 @@ func init() { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() - if err := trace.StartTrace(ctx, IPVersion, L4ProtoNo, Port, OutputFile); err != nil { + if err := trace.StartTrace(ctx, IPVersion, L4ProtoNo, Port, DropOnly, OutputFile); err != nil { logrus.Fatalln(err) } }, @@ -66,6 +67,7 @@ func init() { traceCmd.PersistentFlags().BoolVarP(&IPv6, "ipv6", "6", false, "Capture IPv6 traffic") traceCmd.PersistentFlags().StringVarP(&L4Proto, "l4-proto", "p", "tcp", "Layer 4 protocol") traceCmd.PersistentFlags().IntVarP(&Port, "port", "P", 80, "Port") + traceCmd.PersistentFlags().BoolVarP(&DropOnly, "drop-only", "", false, "only trace the dropped package") traceCmd.PersistentFlags().StringVarP(&OutputFile, "output", "o", "/dev/stdout", "Output file") rootCmd.AddCommand(traceCmd) diff --git a/trace/trace.go b/trace/trace.go index a1d6eaf377..b801ecbf37 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -11,6 +11,7 @@ import ( "encoding/binary" "errors" "fmt" + "slices" "net" "os" "syscall" @@ -43,7 +44,7 @@ func init() { } } -func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int, outputFile string) (err error) { +func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int, dropOnly bool, outputFile string) (err error) { kernelVersion, err := internal.KernelVersion() if err != nil { return fmt.Errorf("failed to get kernel version: %w", err) @@ -80,7 +81,7 @@ func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int, }() fmt.Printf("\nstart tracing\n") - if err = handleEvents(ctx, objs, outputFile, kfreeSkbReasons); err != nil { + if err = handleEvents(ctx, objs, outputFile, kfreeSkbReasons, dropOnly); err != nil { return } return @@ -221,7 +222,7 @@ func attachBpfToTargets(objs *bpfObjects, targets map[string]int) (links []link. return links, nil } -func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfreeSkbReasons map[uint64]string) (err error) { +func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfreeSkbReasons map[uint64]string, dropOnly bool) (err error) { writer, err := os.Create(outputFile) if err != nil { return @@ -258,6 +259,10 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre PayloadLen uint16 } + skb2events := make(map[uint64][]bpfEvent) + // a map to save slices of bpfEvent of the Skb + skb2symNames := make(map[uint64][]string) + // a map to save slices of function name called with the Skb for { rec, err := eventsReader.Read() if err != nil { @@ -273,22 +278,43 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre logrus.Debugf("failed to parse ringbuf event: %+v", err) continue } - - fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", event.Skb, event.Mark, event.Netns, event.Ifindex, TrimNull(string(event.Ifname[:])), event.Pid, TrimNull(string(event.Pname[:]))) - if event.L3Proto == syscall.ETH_P_IP { - fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(event.Saddr[:4]).String(), Ntohs(event.Sport), net.IP(event.Daddr[:4]).String(), Ntohs(event.Dport)) - } else { - fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(event.Saddr[:]).String(), Ntohs(event.Sport), net.IP(event.Daddr[:]).String(), Ntohs(event.Dport)) + if skb2events[event.Skb]==nil { + skb2events[event.Skb] = []bpfEvent{} } - if event.L4Proto == syscall.IPPROTO_TCP { - fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(event.TcpFlags)) + skb2events[event.Skb] = append(skb2events[event.Skb],event) + + + sym := NearestSymbol(event.Pc); + if skb2symNames[event.Skb]==nil { + skb2symNames[event.Skb] = []string{} } - fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen) - sym := NearestSymbol(event.Pc) - fmt.Fprintf(writer, "%s", sym.Name) - if sym.Name == "kfree_skb_reason" { - fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[event.SecondParam]) + skb2symNames[event.Skb] = append(skb2symNames[event.Skb],sym.Name) + switch sym.Name { + case "__kfree_skb","kfree_skbmem": + // most skb end in the call of kfree_skbmem + if !dropOnly || slices.Contains(skb2symNames[event.Skb],"kfree_skb_reason") { + // trace dropOnly with drop reason or all skb + for _,skb_ev := range skb2events[event.Skb] { + fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:]))) + if event.L3Proto == syscall.ETH_P_IP { + fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport)) + } else { + fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport)) + } + if event.L4Proto == syscall.IPPROTO_TCP { + fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags)) + } + fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen) + sym := NearestSymbol(skb_ev.Pc) + fmt.Fprintf(writer, "%s", sym.Name) + if sym.Name == "kfree_skb_reason" { + fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam]) + } + fmt.Fprintf(writer, "\n") + } + delete(skb2events, event.Skb) + delete(skb2symNames, event.Skb) + } } - fmt.Fprintf(writer, "\n") } }