Skip to content

Commit

Permalink
Validate response to get rid of malformed packets
Browse files Browse the repository at this point in the history
  • Loading branch information
ba0f3 committed Sep 21, 2022
1 parent 4960de2 commit 478c268
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 53 deletions.
5 changes: 2 additions & 3 deletions dnsclient.nimble
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Package

version = "0.2.0"
version = "0.3.0"
author = "Huy Doan"
description = "Simple DNS Client & Library"
license = "MIT"
srcDir = "src"
installExt = @["nim"]
bin = @["dnsclient"]


skipDirs = @["fuzz", "tests"]

# Dependencies

Expand Down
11 changes: 11 additions & 0 deletions fuzz/config.nims
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
switch("path", "$projectDir/../src")
#switch("cc", "gcc")
let cc = "afl-clang-fast"

switch("gcc.linkerexe", cc)
switch("gcc.exe", cc)
switch("gcc.path", "/usr/local/bin")

--debugger:native
--define:release
--define:danger
16 changes: 16 additions & 0 deletions fuzz/harness.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import streams, dnsclient, dnsclientpkg/protocol

proc fuzz_target(input: string) {.exportc.} =

try:
var resp = parseResponse(input)
echo "OK"
except ValueError:
echo "FAIL: ", getCurrentExceptionMsg()


when isMainModule:
let input = readAll(stdin);
fuzz_target(input)


1 change: 1 addition & 0 deletions fuzz/in/001
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
000
7 changes: 2 additions & 5 deletions src/dnsclient.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright Huy Doan
# Simple DNS client

import strutils, streams, net, nativesockets, endians
import streams, net, nativesockets
import dnsclientpkg/[protocol, records, types]

export records, types, TimeoutError
Expand Down Expand Up @@ -48,10 +48,7 @@ proc sendQuery*(c: DNSClient, query: string, kind: QKind = A, timeout = 500): Re
else:
raise newException(TimeoutError, "Call to 'sendQuery' timed out.")

buf.setPosition(0)
buf.write(resp)
buf.setPosition(0)
result = parseResponse(buf)
result = parseResponse(resp)

proc close*(c: DNSClient) = c.socket.close()

Expand Down
52 changes: 37 additions & 15 deletions src/dnsclientpkg/protocol.nim
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,11 @@ proc toStream*(h: var Header): StringStream =
result.writeShort(h.nscount)
result.writeShort(h.arcount)


proc initQuestion*(name: string, kind: QKind = A): Question =
result.name = name
result.kind = kind
result.class= IN


proc toStream*(q: var Question, data: StringStream) =
var labelLen: uint8
for label in q.name.split('.'):
Expand All @@ -82,7 +80,6 @@ proc toStream*(q: var Question, data: StringStream) =
data.writeShort(q.kind.uint16)
data.writeShort(q.class.uint16)


proc parseHeader(data: StringStream): Header =
result.id = data.readShort()
var flags = data.readUint16()
Expand All @@ -94,27 +91,47 @@ proc parseHeader(data: StringStream): Header =
flags = flags shr 1
result.tc = flags and 1
flags = flags shr 1
result.aa = QAuthority(flags and 1)
var tmp = flags and 1
if tmp < QAuthority.low.uint16 or tmp > QAuthority.high.uint16:
raise newException(ValueError, "Invalid value for Authoritative Answer, expect 0 or 1 got " & $tmp)
result.aa = QAuthority(tmp)
flags = flags shr 1
result.opcode = QOpCode(flags and 15)
tmp = flags and 15
if tmp < QOpCode.low.uint16 or tmp > QOpCode.high.uint16:
raise newException(ValueError, "Invalid value for Op Code, expect value >= 0 and value <= 2, got " & $tmp)
result.opcode = QOpCode(tmp)
flags = flags shr 4
if tmp < QQuery.low.uint16 or tmp > QQuery.high.uint16:
raise newException(ValueError, "Invalid value for QR, expect 0 or 1, got " & $tmp)
result.qr = QQuery(flags)
result.qdcount = data.readShort()
if result.qdcount == 0:
raise newException(ValueError, "Question section must not be empty")
result.ancount = data.readShort()
result.nscount = data.readShort()
result.arcount = data.readShort()


proc parseQuestion(data: StringStream): Question =
result.name = data.getName()
result.kind = QKind(data.readShort())
result.class = QClass(data.readShort())
var tmp = data.readShort()
if tmp < QKind.low.uint16 or tmp > QKind.high.uint16:
raise newException(ValueError, "Invalid question QTYPE, got " & $tmp)
result.kind = QKind(tmp)
tmp = data.readShort()
if tmp < QClass.low.uint16 or tmp > QClass.high.uint16:
raise newException(ValueError, "Invalid question QCLASS, got " & $tmp)
result.class = QClass(tmp)

proc parseRR(data: StringStream): ResourceRecord =
# name offset
new(result)
let
name = data.getName()
kind = QKind(data.readShort())
tmp = data.readShort()
if tmp < QKind.low.uint16 or tmp > QKind.high.uint16:
raise newException(ValueError, "Invalid resource record QTYPE, got " & $tmp)
let kind = QKind(tmp)
case kind
of A:
result = ARecord(name: name, kind: kind)
Expand Down Expand Up @@ -150,16 +167,21 @@ proc parseRR(data: StringStream): ResourceRecord =
result.rdlength = data.readShort()
result.parse(data)

proc parseResponse*(data: StringStream): Response =
result.header = parseHeader(data)
result.question = parseQuestion(data)
proc parseResponse*(data: string): Response =
if data.len < 12: # header length
raise newException(ValueError, "Invalid response header, got length of " & $data.len & " but expect 12")

let stream = newStringStream(data)

result.header = parseHeader(stream)
result.question = parseQuestion(stream)

for _ in 0..<result.header.ancount.int:
var answer = parseRR(data)
var answer = parseRR(stream)
result.answers.add(answer)

for _ in 0..<result.header.nscount.int:
var answer = parseRR(data)
var answer = parseRR(stream)
result.authorityRRs.add(answer)
assert data.atEnd()
data.close()
assert stream.atEnd()
stream.close()
4 changes: 2 additions & 2 deletions src/dnsclientpkg/types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ type



method parse*(r: ResourceRecord, data: StringStream) {.base, raises: [Defect, OSError, IOError].} =
method parse*(r: ResourceRecord, data: StringStream) {.base.} =
raise newException(LibraryError, "parser for " & $r.kind & " is not implemented yet")

method toString*(r: ResourceRecord): string {.base, raises: [Defect, OSError, IOError].} =
method toString*(r: ResourceRecord): string {.base.} =
raise newException(LibraryError, "to override!")
73 changes: 45 additions & 28 deletions src/dnsclientpkg/utils.nim
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import endians, streams, strutils

const MAX_LABEL_LENGTH = 63
const MAX_NAME_LENGTH = 255

#proc pack*(inp: int16): uint16 {.inline.} =
# var inp = inp.uint16
# bigEndian16(addr result, addr inp)


#proc pack*(inp: uint16): uint16 {.inline.} =
# var inp = inp
# bigEndian16(addr result, addr inp)
const TYPE_MASK = 0xC0'u8
type LabelType = enum
TYPE_LABEL = 0x00'u8
TYPE_EDNS0 = 0x40'u8
TYPE_RESERVED = 0x80'u8
TYPE_INDIR = 0xc0'u8

proc readTTL*(s: StringStream): int32 {.inline.} =
var value = s.readInt32()
Expand All @@ -25,29 +25,46 @@ proc writeShort*[T: int16|uint16](s: StringStream, value: T) {.inline.} =
bigEndian16(addr input, addr value)
s.write(input)

proc getBits(data: auto, offset: int, bits = 1): int =
let mask = ((1 shl bits) - 1) shl offset
result = (data.int and mask) shr offset

proc getName*(data: StringStream): string =
var labels: seq[string]
var
labels: seq[string]
length: uint8
offset: uint16
kind: LabelType
lastPos: int = 0
lenLeft = MAX_NAME_LENGTH

while true:
let
length = data.readUint8()
magic = length.getBits(6, 2)
if magic == 3:
data.setPosition(data.getPosition() - 1)
let offset = int(data.readShort() xor 0xC000)
let currentPosition = data.getPosition()
data.setPosition(offset)
labels.add(data.getName())
data.setPosition(currentPosition)
break
elif length.int > 0:
labels.add(data.readStr(length.int))
length = data.readUint8()
if length == 0: break
kind = LabelType(length and TYPE_MASK)
case kind
of TYPE_INDIR:
# lenght is first octet << 8 + last octet
offset = (length.uint16 shl 8 + data.readUint8()) xor 0xC000'u16
lastPos = data.getPosition()
data.setPosition(offset.int)
if data.atEnd():
raise newException(ValueError, "Invalid compression label offset")
if LabelType(data.peekUint8() and TYPE_MASK) == TYPE_INDIR:
raise newException(ValueError, "Nested compression label is not supported")
# we will get the label in next loop as TYPE_LABEL
of TYPE_LABEL:
if length.int > MAX_LABEL_LENGTH:
raise newException(ValueError, "Label too long, max 63 got " & $length)
dec(lenLeft, length.int + 1)
if lenLeft <= 0:
raise newException(ValueError, "Name too long")
labels.add(data.readStr(length.int))
# if next octet is zero, means and of labels
# go back to last position and stop
if data.peekUint8() == 0 and lastPos > 0:
data.setPosition(lastPos)
break # last label was INDIR, stop the loop
else:
break
result = labels.join(".")
#reversed
discard
result = if labels.len == 1: labels[0] else: labels.join(".")

proc ipv4ToString*(ip: int32): string =
let arr = cast[array[4, uint8]](ip)
Expand Down

0 comments on commit 478c268

Please sign in to comment.