Skip to content

Commit

Permalink
remove huge lookup table in decoder
Browse files Browse the repository at this point in the history
Use linked list instead.
  • Loading branch information
ianic committed Feb 8, 2024
1 parent 1e7a9b7 commit 71eb86a
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 38 deletions.
128 changes: 98 additions & 30 deletions src/huffman_decoder.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ pub const Symbol = packed struct {
match,
};

symbol: u8, // symbol from alphabet
code_bits: u4, // code bits count
symbol: u8 = 0, // symbol from alphabet
code_bits: u4 = 0, // code bits count
kind: Kind = .literal,

code: u16 = 0,
next: u16 = 0, // pointer to the next symbol in linked list
// it is safe to use 0 as null pointer, when sorted 0 has shortest code and fits into lookup

// Sorting less than function.
pub fn asc(_: void, a: Symbol, b: Symbol) bool {
if (a.code_bits == b.code_bits) {
Expand All @@ -26,7 +30,7 @@ pub const Symbol = packed struct {

pub const LiteralDecoder = HuffmanDecoder(286, 15, 9);
pub const DistanceDecoder = HuffmanDecoder(30, 15, 9);
pub const CodegenDecoder = HuffmanDecoder(19, 7, 0);
pub const CodegenDecoder = HuffmanDecoder(19, 7, 7);

/// Creates huffman tree codes from list of code lengths (in `build`).
///
Expand All @@ -44,22 +48,20 @@ pub const CodegenDecoder = HuffmanDecoder(19, 7, 0);
fn HuffmanDecoder(
comptime alphabet_size: u16,
comptime max_code_bits: u4,
comptime small_lookup_bits: u4,
comptime lookup_bits: u4,
) type {
const small_lookup_shift = max_code_bits - small_lookup_bits;
const lookup_shift = max_code_bits - lookup_bits;

return struct {
// all symbols in alaphabet, sorted by code_len, symbol
symbols: [alphabet_size]Symbol = undefined,
// lookup table code -> symbol
lookup: [1 << max_code_bits]Symbol = undefined,
// small lookup table
lookup_s: [1 << small_lookup_bits]Symbol = undefined,
lookup: [1 << lookup_bits]Symbol = undefined,

const Self = @This();

/// Builds symbols and lookup tables from list of code lens for each symbol.
pub fn build(self: *Self, lens: []const u4) void {
pub fn generate(self: *Self, lens: []const u4) void {
// init alphabet with code_bits
for (self.symbols, 0..) |_, i| {
const cb: u4 = if (i < lens.len) lens[i] else 0;
Expand All @@ -72,41 +74,59 @@ fn HuffmanDecoder(
}
std.sort.heap(Symbol, &self.symbols, {}, Symbol.asc);

// reset lookup table
for (0..self.lookup.len) |i| {
self.lookup[i] = .{};
}

// assign code to symbols
// reference: https://youtu.be/9_YEGLe33NA?list=PLU4IQLU9e_OrY8oASHx0u3IXAL9TOdidm&t=2639
var code: u16 = 0;
var code_s: u16 = 0;
for (self.symbols) |sym| {
var idx: u16 = 0;
for (&self.symbols, 0..) |*sym, pos| {
if (sym.code_bits == 0) continue; // skip unused
sym.code = code;

const next = code + (@as(u16, 1) << (max_code_bits - sym.code_bits));
const next_code = code + (@as(u16, 1) << (max_code_bits - sym.code_bits));
const next_idx = next_code >> lookup_shift;

if (sym.code_bits <= small_lookup_bits) {
if (sym.code_bits <= lookup_bits) {
// fill small lookup table
const next_s = next >> small_lookup_shift;
for (code_s..next_s) |j|
self.lookup_s[j] = sym;
code_s = next_s;
for (idx..next_idx) |j|
self.lookup[j] = sym.*;
} else {
// fill lookup table
// assign symbol to all codes between current and next code
for (code..next) |j|
self.lookup[j] = sym;
// insert into linked table starting at root
const root = &self.lookup[idx];
const root_next = root.next;
root.next = @intCast(pos);
sym.next = root_next;
}
code = next;

idx = next_idx;
code = next_code;
}
for (code_s..self.lookup_s.len) |i|
self.lookup_s[i].code_bits = 0; // unused
}

/// Finds symbol for lookup table code.
pub inline fn find(self: *Self, code: u16) Symbol {
if (small_lookup_bits > 0) {
const code_s = code >> small_lookup_shift;
const sym = self.lookup_s[code_s];
if (sym.code_bits != 0) return sym;
// try to find in lookup table
const idx = code >> lookup_shift;
const sym = self.lookup[idx];
if (sym.code_bits != 0) return sym;
// if not use linked list of symbols with same prefix
return self.findLinked(code, sym.next);
}

inline fn findLinked(self: *Self, code: u16, start: u16) Symbol {
var pos = start;
while (pos > 0) {
const sym = self.symbols[pos];
const shift = max_code_bits - sym.code_bits;
// compare code_bits number of upper bits
if ((code ^ sym.code) >> shift == 0) return sym;
pos = sym.next;
}
return self.lookup[code];
return .{};
}
};
}
Expand All @@ -115,7 +135,7 @@ test "Huffman init/find" {
// example data from: https://youtu.be/SJPvNi4HrWQ?t=8423
const code_lens = [_]u4{ 4, 3, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2 };
var h: CodegenDecoder = .{};
h.build(&code_lens);
h.generate(&code_lens);

const expected = [_]struct {
sym: Symbol,
Expand Down Expand Up @@ -186,3 +206,51 @@ test "Huffman init/find" {
for (0b1111_000..0b1_0000_000) |c| // 120...128 (8)
try testing.expectEqual(16, h.find(@intCast(c)).symbol);
}

const print = std.debug.print;
const assert = std.debug.assert;
const expect = std.testing.expect;

test "full " {
const LiteralEncoder = @import("huffman_encoder.zig").LiteralEncoder;
var enc: LiteralEncoder = .{};
// worst case, all freqencies are used
var freq = [_]u16{0} ** 286;
for (&freq, 1..) |*f, i| {
if (i % 2 == 0)
f.* = @intCast(i);
}
// encoder from freqencies
enc.generate(&freq, 15);

// generate decoder from code lens
var code_lens = [_]u4{0} ** 286;
for (code_lens, 0..) |_, i| {
code_lens[i] = @intCast(enc.codes[i].len);
}
var dec: LiteralDecoder = .{};
dec.generate(&code_lens);

// expect decoder code to match original encoder code
for (dec.symbols) |s| {
const c_code: u15 = @bitReverse(@as(u15, @intCast(s.code)));
const symbol: u16 = switch (s.kind) {
.literal => s.symbol,
.end_of_block => 256,
.match => @as(u16, s.symbol) + 257,
};

const c = enc.codes[symbol];
try expect(c.code == c_code);
}

// find each symbol by code
for (enc.codes) |c| {
if (c.len == 0) continue;

const s_code: u15 = @bitReverse(@as(u15, @intCast(c.code)));
const s = dec.find(s_code);
try expect(s.code == s_code);
try expect(s.code_bits == c.len);
}
}
16 changes: 8 additions & 8 deletions src/inflate.zig
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ pub fn decompressor(comptime container: Container, reader: anytype) Inflate(cont
/// `step` function runs decoder until internal `hist` buffer is full. Client
/// than needs to read that data in order to proceed with decoding.
///
/// Allocates ~200K of internal buffers, most important are:
/// Allocates 74.5K of internal buffers, most important are:
/// * 64K for history (CircularBuffer)
/// * 2 * 64K (2 * 32K of u16) for huffman decoders (Literal and DistanceDecoder)
/// * ~10K huffman decoders (Literal and DistanceDecoder)
///
pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
return struct {
Expand Down Expand Up @@ -151,7 +151,7 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
cl_l[codegen_order[i]] = try self.bits.read(u3);
}
var cl_h: hfd.CodegenDecoder = .{};
cl_h.build(&cl_l);
cl_h.generate(&cl_l);

// literal code lengths
var lit_l = [_]u4{0} ** (286);
Expand All @@ -171,8 +171,8 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
pos += try self.dynamicCodeLength(sym.symbol, &dst_l, pos);
}

self.lit_h.build(&lit_l);
self.dst_h.build(&dst_l);
self.lit_h.generate(&lit_l);
self.dst_h.generate(&dst_l);
}

// Decode code length symbol to code length. Writes decoded length into
Expand Down Expand Up @@ -338,16 +338,16 @@ test "Struct sizes" {
const ReaderType = @TypeOf(fbs.reader());
const inflate_size = @sizeOf(Inflate(.gzip, ReaderType));

try testing.expectEqual(199352, inflate_size);
try testing.expectEqual(76320, inflate_size);
try testing.expectEqual(
@sizeOf(CircularBuffer) + @sizeOf(hfd.LiteralDecoder) + @sizeOf(hfd.DistanceDecoder) + 48,
inflate_size,
);
try testing.expectEqual(65536 + 8 + 8, @sizeOf(CircularBuffer));
try testing.expectEqual(8, @sizeOf(Container.raw.Hasher()));
try testing.expectEqual(24, @sizeOf(BitReader(ReaderType)));
try testing.expectEqual(67132, @sizeOf(hfd.LiteralDecoder));
try testing.expectEqual(66620, @sizeOf(hfd.DistanceDecoder));
try testing.expectEqual(6384, @sizeOf(hfd.LiteralDecoder));
try testing.expectEqual(4336, @sizeOf(hfd.DistanceDecoder));
}

test "flate decompress" {
Expand Down

0 comments on commit 71eb86a

Please sign in to comment.