From 37c9abc92017e762c19037b436f3b70890f11adf Mon Sep 17 00:00:00 2001 From: Murtaza Aliakbar Date: Sat, 16 Nov 2024 06:09:04 +0530 Subject: [PATCH] ipn/wg: amnezia --- intra/ipn/wg/amnezia.go | 256 ++++++++++++++++++++++++++++++++++++++++ intra/ipn/wg/wgconn.go | 32 +++-- intra/ipn/wgproxy.go | 41 ++++++- 3 files changed, 320 insertions(+), 9 deletions(-) create mode 100644 intra/ipn/wg/amnezia.go diff --git a/intra/ipn/wg/amnezia.go b/intra/ipn/wg/amnezia.go new file mode 100644 index 00000000..3b1230d5 --- /dev/null +++ b/intra/ipn/wg/amnezia.go @@ -0,0 +1,256 @@ +// Copyright (c) 2024 RethinkDNS and its authors. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package wg + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + + "github.com/celzero/firestack/intra/log" + "golang.zx2c4.com/wireguard/device" +) + +// https://github.com/amnezia-vpn/amneziawg-go/pull/2/files + +const ( + sNoop = 0 // no-op size +) + +// Jc (Junk packet count) - number of packets with random data that are sent before the start of the session +// Jmin (Junk packet minimum size) - minimum packet size for Junk packet. That is, all randomly generated packets will have a size no smaller than Jmin. +// Jmax (Junk packet maximum size) - maximum size for Junk packets +// S1 (Init packet junk size) - the size of random data that will be added to the init packet, the size of which is initially fixed. +// S2 (Response packet junk size) - the size of random data that will be added to the response packet, the size of which is initially fixed. +// H1 (Init packet magic header) - the header of the first byte of the handshake +// H2 (Response packet magic header) - header of the first byte of the handshake response +// H4 (Transport packet magic header) - header of the packet of the data packet +// H3 (Underload packet magic header) - UnderLoad packet header." +type Amnezia struct { + id string + + Jc, Jmin, Jmax uint16 // unused: junk packet count, min, max + + S1, S2 uint16 // handshake init/resp pkt sizes + H1, H2, H3, H4 uint32 // modified msg types [4]byte +} + +func NewAmnezia(id string) *Amnezia { + return &Amnezia{ + id: id, + } +} + +func (a *Amnezia) String() string { + if a == nil { + return "" + } + return fmt.Sprintf("%s: amnezia: jc(%d), jmin(%d), jmax(%d), s1(%d), s2(%d), h1(%d), h2(%d), h3(%d), h4(%d)", + a.id, a.Jc, a.Jmin, a.Jmax, a.S1, a.S2, a.H1, a.H2, a.H3, a.H4) +} + +func (a *Amnezia) Set() bool { + if a == nil { + return false + } + + return a.S1 > 0 || a.S2 > 0 || a.H1 > 0 || a.H2 > 0 || a.H3 > 0 || a.H4 > 0 +} + +func (a *Amnezia) Same(b *Amnezia) bool { + if a == nil && b == nil { + return false + } else if a == nil || b == nil { + return false + } + + return a.S1 == b.S1 && + a.S2 == b.S2 && + a.H1 == b.H1 && + a.H2 == b.H2 && + a.H3 == b.H3 && + a.H4 == b.H4 +} + +func (a *Amnezia) send(pktptr *[]byte) (ok bool) { + if a == nil || !a.Set() { + return + } + + pkt := *pktptr + + n := len(pkt) + if n < device.MinMessageSize { + return + } + + h := uint16(device.MessageTransportOffsetReceiver) + typ := binary.LittleEndian.Uint32(pkt[:h]) + + defer a.logIfNeeded("send", typ, n) + + *pktptr, _ = a.instate(pkt) + return true +} + +func (a *Amnezia) recv(pktptr *[]byte) (ok bool) { + if a == nil || !a.Set() { + return + } + + var typ uint32 + pkt := *pktptr + + if len(pkt) < device.MinMessageSize { + return + } + h := uint16(device.MessageTransportOffsetReceiver) + + pkt, typ = a.strip(pkt) + + switch typ { + case a.H1: + typ = device.MessageInitiationType + binary.LittleEndian.PutUint32(pkt[:h], device.MessageInitiationType) + case a.H2: + typ = device.MessageResponseType + binary.LittleEndian.PutUint32(pkt[:h], device.MessageResponseType) + case a.H3: + typ = device.MessageCookieReplyType + binary.LittleEndian.PutUint32(pkt[:h], device.MessageCookieReplyType) + case a.H4: + typ = device.MessageTransportType + binary.LittleEndian.PutUint32(pkt[:h], device.MessageTransportType) + } + + defer a.logIfNeeded("recv", typ, len(pkt)) + + *pktptr = pkt + return true +} + +func (a *Amnezia) instate(pkt []byte) ([]byte, uint32) { + n := len(pkt) + + h := uint16(device.MessageTransportOffsetReceiver) + defaultType := binary.LittleEndian.Uint32(pkt[:h]) + + var pad uint16 = 0 + var obsType uint32 = 0 + maybeInstate := false + + switch defaultType { + case device.MessageInitiationType: + if n == device.MessageInitiationSize { + // github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/send.go#L130 + pad = a.S1 + obsType = a.H1 + maybeInstate = true + } + case device.MessageResponseType: + if n == device.MessageResponseSize { + // github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/send.go#L198 + pad = a.S2 + obsType = a.H2 + maybeInstate = true + } + case device.MessageCookieReplyType: + if n == device.MessageCookieReplySize { + pad = sNoop + obsType = a.H3 + maybeInstate = true + } + case device.MessageTransportType: + if n >= device.MinMessageSize { + pad = sNoop + obsType = a.H4 + maybeInstate = true + } + } + + if maybeInstate { + random, err := blob(pad) // pad may be 0 + if err != nil { // unlikely + log.E("wg: %s: amnezia: instate: %v", a.id, err) + return pkt, defaultType + } + binary.LittleEndian.PutUint32(pkt[:h], obsType) + if len(random) <= 0 { + return pkt, obsType + } else { + return append(random, pkt...), obsType + } + } + + return pkt, defaultType +} + +func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) { + size := uint16(len(pkt)) + h := uint16(device.MessageTransportOffsetReceiver) + defaultType := binary.LittleEndian.Uint32(pkt[:h]) + + var discard uint16 = 0 + var possibleType uint32 = 0 + maybeStrip := false + + // ref: https://github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/device.go#L765 + if size == a.S1+device.MessageInitiationSize { + discard = a.S1 + possibleType = a.H1 + maybeStrip = true + } else if size == a.S2+device.MessageResponseSize { + discard = a.S2 + possibleType = a.H2 + maybeStrip = true + } // else: default + + if maybeStrip { + hdr := pkt[discard : discard+h] + strippedType := binary.LittleEndian.Uint32(hdr) + if strippedType == possibleType { + return pkt[discard:], strippedType + } // else: sizes match but msg types do not + } // else: nothing to discard + + return pkt, defaultType +} + +func (a *Amnezia) logIfNeeded(dir string, typ uint32, n int) { + switch typ { + case device.MessageInitiationType: + notok := n != device.MessageInitiationSize + logif(notok)("wg: %s: amnezia: %s: err initiation %d != %d", + a.id, dir, n, device.MessageInitiationSize) + case device.MessageResponseType: + notok := n != device.MessageResponseSize + logif(notok)("wg: %s: amnezia: %s: err response %d != %d", + a.id, dir, n, device.MessageResponseSize) + case device.MessageCookieReplyType: + notok := n != device.MessageCookieReplySize + logif(notok)("wg: %s: amnezia: %s: err cookie %d != %d", + a.id, dir, n, device.MessageCookieReplySize) + case device.MessageTransportType: + notok := n < device.MinMessageSize + logif(notok)("wg: %s: amnezia: %s: err data %d < %d", + a.id, dir, n, device.MinMessageSize) + default: + log.W("wg: %s: amnezia: %s: unexpected type %d; sz(pkt): %d", + a.id, dir, typ, n) + } +} + +// ref: github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/util.go#L1 +func blob(sz uint16) ([]byte, error) { + if sz == 0 { + return nil, nil + } + + junk := make([]byte, sz) + n, err := rand.Read(junk) + return junk[:n], err +} diff --git a/intra/ipn/wg/wgconn.go b/intra/ipn/wg/wgconn.go index e6a539d8..f8de7b22 100644 --- a/intra/ipn/wg/wgconn.go +++ b/intra/ipn/wg/wgconn.go @@ -104,6 +104,7 @@ type StdNetBind struct { reserved []byte // overwrite the 3 wg reserved bytes overwriteReserve bool + amnezia *Amnezia floodBa *core.Barrier[int, netip.AddrPort] mu sync.Mutex // protects following fields @@ -117,17 +118,18 @@ type StdNetBind struct { } // TODO: get d, ep, f, rb through an Opts bag? -func NewEndpoint(id string, d connector, ep *multihost.MH, f rwobserver, rb [3]byte) *StdNetBind { +func NewEndpoint(id string, d connector, ep *multihost.MH, f rwobserver, a *Amnezia, rb [3]byte) *StdNetBind { s := &StdNetBind{ id: id, connect: d, mh: ep, observer: f, + amnezia: a, reserved: rb[:3], // github.com/bepass-org/warp-plus/blob/19ac233cc6/wiresocks/config.go#L184 floodBa: core.NewKeyedBarrier[int, netip.AddrPort](minFloodInterval), sendAddr: core.NewZeroVolatile[netip.AddrPort](), } - s.overwriteReserve = isReservedOverwitten(s.reserved) + s.overwriteReserve = a.Set() || isReservedOverwitten(s.reserved) return s } @@ -324,10 +326,14 @@ func (s *StdNetBind) makeReceiveFn(uc *net.UDPConn) conn.ReceiveFunc { extend(uc, wgtimeout) n, addr, err := uc.ReadFromUDPAddrPort(b) if err == nil { - recvOverwritten = isReservedOverwitten(b) - // github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/receive.go#L138 - if n > 3 && isWgMsgType(b[0]) && recvOverwritten { - copy(b[1:4], reservedZeros) + if isReservedOverwitten(b) { + if s.amnezia.Set() { + recvOverwritten = s.amnezia.recv(&b) + } else if n > 3 && isWgMsgType(b[0]) && recvOverwritten { + // github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/receive.go#L138 + copy(b[1:4], reservedZeros) + recvOverwritten = true + } } numMsgs++ } @@ -397,14 +403,17 @@ func (s *StdNetBind) Send(buf [][]byte, peer conn.Endpoint) (err error) { return syscall.EAFNOSUPPORT } - // overwrite the 3 reserved bytes on non-random packets if s.overwriteReserve { - if len(data) > 3 && isWgMsgType(data[0]) { + if s.amnezia.Set() { + overwritten = s.amnezia.send(&data) + } else if len(data) > 3 && isWgMsgType(data[0]) { + // overwrite the 3 reserved bytes on non-random packets // from: github.com/bepass-org/warp-plus/blob/19ac233cc6/wireguard/device/peer.go#L138 copy(data[1:4], s.reserved) overwritten = true } } + if !flooded && !overwritten && (experimentalWg || s.overwriteReserve) { if len(data) == device.MessageInitiationSize { go s.flood(uc, dst, fkHandshake) // probably a handshake @@ -562,6 +571,13 @@ func loge(err error) log.LogFn { return l } +func logif(warn bool) log.LogFn { + if warn { + return log.W + } + return log.N +} + func extend(c net.Conn, t time.Duration) { if c != nil && core.IsNotNil(c) { _ = c.SetDeadline(time.Now().Add(t)) diff --git a/intra/ipn/wgproxy.go b/intra/ipn/wgproxy.go index e27a6deb..45833f3b 100644 --- a/intra/ipn/wgproxy.go +++ b/intra/ipn/wgproxy.go @@ -78,6 +78,7 @@ type wgifopts struct { dns, ep *multihost.MH mtu int clientid [3]byte + amnezia *wg.Amnezia } type wgtun struct { @@ -89,6 +90,7 @@ type wgtun struct { ep *channel.Endpoint // reads and writes packets to/from stack ingress chan *buffer.View // pipes ep writes to wg events chan tun.Event // wg specific tun (interface) events + amnezia *wg.Amnezia // amnezia config, if any clientid [3]byte // client id; applicable only for warp finalize chan struct{} // close signal for incomingPacket once sync.Once // closer fn; exec exactly once @@ -343,6 +345,9 @@ func (w *wgproxy) update(id, txt string) bool { } if settings.Debug { + if !w.amnezia.Same(opts.amnezia) { + log.D("proxy: wg: !update(%s): amnezia %v != %v", w.id, opts.amnezia, w.amnezia) + } if opts.dns != nil && !opts.dns.EqualAddrs(w.dns.Load()) { log.D("proxy: wg: !update(%s): new/mismatched dns", w.id) } // nb: client code MUST re-add wg DNS, not our responsibility @@ -361,6 +366,7 @@ func (w *wgproxy) update(id, txt string) bool { w.remote.Store(opts.ep) // requires refresh w.dns.Store(opts.dns) // requires refresh w.ep.SetMTU(uint32(maybeNewMtu)) + w.amnezia = opts.amnezia // TODO: core.Volatile? return reuse } @@ -391,6 +397,7 @@ func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) { opts.dns = multihost.New(id + "dns") opts.ep = multihost.New(id + "endpoint") opts.peers = make(map[string]device.NoisePublicKey) + opts.amnezia = wg.NewAmnezia(id) for r.Scan() { line := r.Text() if len(line) <= 0 { @@ -466,11 +473,42 @@ func wgIfConfigOf(id string, txtptr *string) (opts wgifopts, err error) { // peer config: carry over public keys log.D("proxy: wg: %s ifconfig: processing key %q, err? %v", id, k, exx) pcfg.WriteString(line + "\n") + case "jc": + // github.com/amnezia-vpn/amneziawg-go/blob/2e3f7d122c/device/uapi.go#L286 + jc, _ := strconv.Atoi(v) + opts.amnezia.Jc = uint16(jc) + case "jmin": + jmin, _ := strconv.Atoi(v) + opts.amnezia.Jmin = uint16(jmin) + case "jmax": + jmax, _ := strconv.Atoi(v) + opts.amnezia.Jmax = uint16(jmax) + case "s1": + s1, _ := strconv.Atoi(v) + opts.amnezia.S1 = uint16(s1) + case "s2": + s2, _ := strconv.Atoi(v) + opts.amnezia.S2 = uint16(s2) + case "h1": + h1, _ := strconv.Atoi(v) + opts.amnezia.H1 = uint32(h1) + case "h2": + h2, _ := strconv.Atoi(v) + opts.amnezia.H2 = uint32(h2) + case "h3": + h3, _ := strconv.Atoi(v) + opts.amnezia.H3 = uint32(h3) + case "h4": + h4, _ := strconv.Atoi(v) + opts.amnezia.H4 = uint32(h4) default: log.D("proxy: wg: %s ifconfig: skipping key %q", id, k) pcfg.WriteString(line + "\n") } } + if opts.amnezia.Set() { + log.I("proxy: wg: %s amnezia: %s", id, opts.amnezia) + } *txtptr = pcfg.String() if err == nil && len(opts.ifaddrs) <= 0 || opts.dns.Len() <= 0 || opts.mtu <= 0 { err = errProxyConfig @@ -536,7 +574,7 @@ func NewWgProxy(id string, ctl protect.Controller, rev netstack.GConnHandler, cf // todo: use wgtun.serve fn instead of ctl wgep = wg.NewEndpoint2(id, ctl, opts.ep, wgtun.listener) } else { - wgep = wg.NewEndpoint(id, wgtun.serve, opts.ep, wgtun.listener, wgtun.clientid) + wgep = wg.NewEndpoint(id, wgtun.serve, opts.ep, wgtun.listener, wgtun.amnezia, wgtun.clientid) } wgdev := device.NewDevice(wgtun, wgep, wglogger(id)) @@ -609,6 +647,7 @@ func makeWgTun(id, cfg string, ctl protect.Controller, rev netstack.GConnHandler peers: core.NewVolatile(ifopts.peers), // its entries must never be modified rt: x.NewIpTree(), // must be set to allowedaddrs ba: core.NewBarrier[[]netip.Addr](wgbarrierttl), + amnezia: ifopts.amnezia, clientid: ifopts.clientid, status: core.NewVolatile(TUP), preferOffload: preferOffload(id),