diff --git a/dhcpv4/nclient4/conn_unix.go b/dhcpv4/nclient4/conn_unix.go index 7f79f50d..1495dc28 100644 --- a/dhcpv4/nclient4/conn_unix.go +++ b/dhcpv4/nclient4/conn_unix.go @@ -99,14 +99,9 @@ func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { pkt = pkt[:n] buf := uio.NewBigEndianBuffer(pkt) - // To read the header length, access data directly. - if !buf.Has(ipv4MinimumSize) { - continue - } - ipHdr := ipv4(buf.Data()) - if !buf.Has(int(ipHdr.headerLength())) { + if !ipHdr.isValid(n) { continue } diff --git a/dhcpv4/nclient4/ipv4.go b/dhcpv4/nclient4/ipv4.go index c2219650..f2bfb651 100644 --- a/dhcpv4/nclient4/ipv4.go +++ b/dhcpv4/nclient4/ipv4.go @@ -14,6 +14,7 @@ // // This file contains code taken from gVisor. +//go:build go1.12 // +build go1.12 package nclient4 @@ -95,6 +96,23 @@ const ( // ipv4AddressSize is the size, in bytes, of an IPv4 address. ipv4AddressSize = 4 + + // IPv4Version is the version of the IPv4 protocol. + ipv4Version = 4 +) + +// ipVersion returns the version of IP used in the given packet. It returns -1 +// if the packet is not large enough to contain the version field. +func ipVersion(b []byte) int { + // Length must be at least offset+length of version field. + if len(b) < versIHL+1 { + return -1 + } + return int(b[versIHL] >> ipVersionShift) +} + +const ( + ipVersionShift = 4 ) // headerLength returns the value of the "header length" field of the ipv4 @@ -170,6 +188,25 @@ func (b ipv4) encode(i *ipv4Fields) { copy(b[dstAddr:dstAddr+ipv4AddressSize], i.DstAddr) } +// isValid performs basic validation on the packet. +func (b ipv4) isValid(pktSize int) bool { + if len(b) < ipv4MinimumSize { + return false + } + + hlen := int(b.headerLength()) + tlen := int(b.totalLength()) + if hlen < ipv4MinimumSize || hlen > tlen || tlen > pktSize { + return false + } + + if ipVersion(b) != ipv4Version { + return false + } + + return true +} + const ( udpSrcPort = 0 udpDstPort = 2