Skip to content

Commit

Permalink
feat(trace): add drop-only option (#632)
Browse files Browse the repository at this point in the history
Co-authored-by: mzz <[email protected]>
  • Loading branch information
linglilongyi and mzz2017 authored Sep 27, 2024
1 parent b218ecf commit b814105
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
4 changes: 3 additions & 1 deletion cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var (
L4Proto string
Port int
OutputFile string
DropOnly bool
)

func init() {
Expand Down Expand Up @@ -56,7 +57,7 @@ func init() {

ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
if err := trace.StartTrace(ctx, IPVersion, L4ProtoNo, Port, OutputFile); err != nil {
if err := trace.StartTrace(ctx, IPVersion, L4ProtoNo, Port, DropOnly, OutputFile); err != nil {
logrus.Fatalln(err)
}
},
Expand All @@ -66,6 +67,7 @@ func init() {
traceCmd.PersistentFlags().BoolVarP(&IPv6, "ipv6", "6", false, "Capture IPv6 traffic")
traceCmd.PersistentFlags().StringVarP(&L4Proto, "l4-proto", "p", "tcp", "Layer 4 protocol")
traceCmd.PersistentFlags().IntVarP(&Port, "port", "P", 80, "Port")
traceCmd.PersistentFlags().BoolVarP(&DropOnly, "drop-only", "", false, "only trace the dropped package")
traceCmd.PersistentFlags().StringVarP(&OutputFile, "output", "o", "/dev/stdout", "Output file")

rootCmd.AddCommand(traceCmd)
Expand Down
60 changes: 43 additions & 17 deletions trace/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"slices"
"net"
"os"
"syscall"
Expand Down Expand Up @@ -43,7 +44,7 @@ func init() {
}
}

func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int, outputFile string) (err error) {
func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int, dropOnly bool, outputFile string) (err error) {
kernelVersion, err := internal.KernelVersion()
if err != nil {
return fmt.Errorf("failed to get kernel version: %w", err)
Expand Down Expand Up @@ -80,7 +81,7 @@ func StartTrace(ctx context.Context, ipVersion int, l4ProtoNo uint16, port int,
}()

fmt.Printf("\nstart tracing\n")
if err = handleEvents(ctx, objs, outputFile, kfreeSkbReasons); err != nil {
if err = handleEvents(ctx, objs, outputFile, kfreeSkbReasons, dropOnly); err != nil {
return
}
return
Expand Down Expand Up @@ -221,7 +222,7 @@ func attachBpfToTargets(objs *bpfObjects, targets map[string]int) (links []link.
return links, nil
}

func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfreeSkbReasons map[uint64]string) (err error) {
func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfreeSkbReasons map[uint64]string, dropOnly bool) (err error) {
writer, err := os.Create(outputFile)
if err != nil {
return
Expand Down Expand Up @@ -258,6 +259,10 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre
PayloadLen uint16
}

skb2events := make(map[uint64][]bpfEvent)
// a map to save slices of bpfEvent of the Skb
skb2symNames := make(map[uint64][]string)
// a map to save slices of function name called with the Skb
for {
rec, err := eventsReader.Read()
if err != nil {
Expand All @@ -273,22 +278,43 @@ func handleEvents(ctx context.Context, objs *bpfObjects, outputFile string, kfre
logrus.Debugf("failed to parse ringbuf event: %+v", err)
continue
}

fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", event.Skb, event.Mark, event.Netns, event.Ifindex, TrimNull(string(event.Ifname[:])), event.Pid, TrimNull(string(event.Pname[:])))
if event.L3Proto == syscall.ETH_P_IP {
fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(event.Saddr[:4]).String(), Ntohs(event.Sport), net.IP(event.Daddr[:4]).String(), Ntohs(event.Dport))
} else {
fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(event.Saddr[:]).String(), Ntohs(event.Sport), net.IP(event.Daddr[:]).String(), Ntohs(event.Dport))
if skb2events[event.Skb]==nil {
skb2events[event.Skb] = []bpfEvent{}
}
if event.L4Proto == syscall.IPPROTO_TCP {
fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(event.TcpFlags))
skb2events[event.Skb] = append(skb2events[event.Skb],event)


sym := NearestSymbol(event.Pc);
if skb2symNames[event.Skb]==nil {
skb2symNames[event.Skb] = []string{}
}
fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen)
sym := NearestSymbol(event.Pc)
fmt.Fprintf(writer, "%s", sym.Name)
if sym.Name == "kfree_skb_reason" {
fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[event.SecondParam])
skb2symNames[event.Skb] = append(skb2symNames[event.Skb],sym.Name)
switch sym.Name {
case "__kfree_skb","kfree_skbmem":
// most skb end in the call of kfree_skbmem
if !dropOnly || slices.Contains(skb2symNames[event.Skb],"kfree_skb_reason") {
// trace dropOnly with drop reason or all skb
for _,skb_ev := range skb2events[event.Skb] {
fmt.Fprintf(writer, "%x mark=%x netns=%010d if=%d(%s) proc=%d(%s) ", skb_ev.Skb, skb_ev.Mark, skb_ev.Netns, skb_ev.Ifindex, TrimNull(string(skb_ev.Ifname[:])), skb_ev.Pid, TrimNull(string(skb_ev.Pname[:])))
if event.L3Proto == syscall.ETH_P_IP {
fmt.Fprintf(writer, "%s:%d > %s:%d ", net.IP(skb_ev.Saddr[:4]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:4]).String(), Ntohs(skb_ev.Dport))
} else {
fmt.Fprintf(writer, "[%s]:%d > [%s]:%d ", net.IP(skb_ev.Saddr[:]).String(), Ntohs(skb_ev.Sport), net.IP(skb_ev.Daddr[:]).String(), Ntohs(skb_ev.Dport))
}
if event.L4Proto == syscall.IPPROTO_TCP {
fmt.Fprintf(writer, "tcp_flags=%s ", TcpFlags(skb_ev.TcpFlags))
}
fmt.Fprintf(writer, "payload_len=%d ", event.PayloadLen)
sym := NearestSymbol(skb_ev.Pc)
fmt.Fprintf(writer, "%s", sym.Name)
if sym.Name == "kfree_skb_reason" {
fmt.Fprintf(writer, "(%s)", kfreeSkbReasons[skb_ev.SecondParam])
}
fmt.Fprintf(writer, "\n")
}
delete(skb2events, event.Skb)
delete(skb2symNames, event.Skb)
}
}
fmt.Fprintf(writer, "\n")
}
}

0 comments on commit b814105

Please sign in to comment.