diff --git a/src/huffman_decoder.zig b/src/huffman_decoder.zig index 455da3b..22c0064 100644 --- a/src/huffman_decoder.zig +++ b/src/huffman_decoder.zig @@ -8,10 +8,14 @@ pub const Symbol = packed struct { match, }; - symbol: u8, // symbol from alphabet - code_bits: u4, // code bits count + symbol: u8 = 0, // symbol from alphabet + code_bits: u4 = 0, // code bits count kind: Kind = .literal, + code: u16 = 0, + next: u16 = 0, // pointer to the next symbol in linked list + // it is safe to use 0 as null pointer, when sorted 0 has shortest code and fits into lookup + // Sorting less than function. pub fn asc(_: void, a: Symbol, b: Symbol) bool { if (a.code_bits == b.code_bits) { @@ -26,7 +30,7 @@ pub const Symbol = packed struct { pub const LiteralDecoder = HuffmanDecoder(286, 15, 9); pub const DistanceDecoder = HuffmanDecoder(30, 15, 9); -pub const CodegenDecoder = HuffmanDecoder(19, 7, 0); +pub const CodegenDecoder = HuffmanDecoder(19, 7, 7); /// Creates huffman tree codes from list of code lengths (in `build`). /// @@ -44,22 +48,20 @@ pub const CodegenDecoder = HuffmanDecoder(19, 7, 0); fn HuffmanDecoder( comptime alphabet_size: u16, comptime max_code_bits: u4, - comptime small_lookup_bits: u4, + comptime lookup_bits: u4, ) type { - const small_lookup_shift = max_code_bits - small_lookup_bits; + const lookup_shift = max_code_bits - lookup_bits; return struct { // all symbols in alaphabet, sorted by code_len, symbol symbols: [alphabet_size]Symbol = undefined, // lookup table code -> symbol - lookup: [1 << max_code_bits]Symbol = undefined, - // small lookup table - lookup_s: [1 << small_lookup_bits]Symbol = undefined, + lookup: [1 << lookup_bits]Symbol = undefined, const Self = @This(); /// Builds symbols and lookup tables from list of code lens for each symbol. - pub fn build(self: *Self, lens: []const u4) void { + pub fn generate(self: *Self, lens: []const u4) void { // init alphabet with code_bits for (self.symbols, 0..) |_, i| { const cb: u4 = if (i < lens.len) lens[i] else 0; @@ -72,41 +74,59 @@ fn HuffmanDecoder( } std.sort.heap(Symbol, &self.symbols, {}, Symbol.asc); + // reset lookup table + for (0..self.lookup.len) |i| { + self.lookup[i] = .{}; + } + // assign code to symbols // reference: https://youtu.be/9_YEGLe33NA?list=PLU4IQLU9e_OrY8oASHx0u3IXAL9TOdidm&t=2639 var code: u16 = 0; - var code_s: u16 = 0; - for (self.symbols) |sym| { + var idx: u16 = 0; + for (&self.symbols, 0..) |*sym, pos| { if (sym.code_bits == 0) continue; // skip unused + sym.code = code; - const next = code + (@as(u16, 1) << (max_code_bits - sym.code_bits)); + const next_code = code + (@as(u16, 1) << (max_code_bits - sym.code_bits)); + const next_idx = next_code >> lookup_shift; - if (sym.code_bits <= small_lookup_bits) { + if (sym.code_bits <= lookup_bits) { // fill small lookup table - const next_s = next >> small_lookup_shift; - for (code_s..next_s) |j| - self.lookup_s[j] = sym; - code_s = next_s; + for (idx..next_idx) |j| + self.lookup[j] = sym.*; } else { - // fill lookup table - // assign symbol to all codes between current and next code - for (code..next) |j| - self.lookup[j] = sym; + // insert into linked table starting at root + const root = &self.lookup[idx]; + const root_next = root.next; + root.next = @intCast(pos); + sym.next = root_next; } - code = next; + + idx = next_idx; + code = next_code; } - for (code_s..self.lookup_s.len) |i| - self.lookup_s[i].code_bits = 0; // unused } /// Finds symbol for lookup table code. pub inline fn find(self: *Self, code: u16) Symbol { - if (small_lookup_bits > 0) { - const code_s = code >> small_lookup_shift; - const sym = self.lookup_s[code_s]; - if (sym.code_bits != 0) return sym; + // try to find in lookup table + const idx = code >> lookup_shift; + const sym = self.lookup[idx]; + if (sym.code_bits != 0) return sym; + // if not use linked list of symbols with same prefix + return self.findLinked(code, sym.next); + } + + inline fn findLinked(self: *Self, code: u16, start: u16) Symbol { + var pos = start; + while (pos > 0) { + const sym = self.symbols[pos]; + const shift = max_code_bits - sym.code_bits; + // compare code_bits number of upper bits + if ((code ^ sym.code) >> shift == 0) return sym; + pos = sym.next; } - return self.lookup[code]; + return .{}; } }; } @@ -115,7 +135,7 @@ test "Huffman init/find" { // example data from: https://youtu.be/SJPvNi4HrWQ?t=8423 const code_lens = [_]u4{ 4, 3, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2 }; var h: CodegenDecoder = .{}; - h.build(&code_lens); + h.generate(&code_lens); const expected = [_]struct { sym: Symbol, @@ -186,3 +206,51 @@ test "Huffman init/find" { for (0b1111_000..0b1_0000_000) |c| // 120...128 (8) try testing.expectEqual(16, h.find(@intCast(c)).symbol); } + +const print = std.debug.print; +const assert = std.debug.assert; +const expect = std.testing.expect; + +test "full " { + const LiteralEncoder = @import("huffman_encoder.zig").LiteralEncoder; + var enc: LiteralEncoder = .{}; + // worst case, all freqencies are used + var freq = [_]u16{0} ** 286; + for (&freq, 1..) |*f, i| { + if (i % 2 == 0) + f.* = @intCast(i); + } + // encoder from freqencies + enc.generate(&freq, 15); + + // generate decoder from code lens + var code_lens = [_]u4{0} ** 286; + for (code_lens, 0..) |_, i| { + code_lens[i] = @intCast(enc.codes[i].len); + } + var dec: LiteralDecoder = .{}; + dec.generate(&code_lens); + + // expect decoder code to match original encoder code + for (dec.symbols) |s| { + const c_code: u15 = @bitReverse(@as(u15, @intCast(s.code))); + const symbol: u16 = switch (s.kind) { + .literal => s.symbol, + .end_of_block => 256, + .match => @as(u16, s.symbol) + 257, + }; + + const c = enc.codes[symbol]; + try expect(c.code == c_code); + } + + // find each symbol by code + for (enc.codes) |c| { + if (c.len == 0) continue; + + const s_code: u15 = @bitReverse(@as(u15, @intCast(c.code))); + const s = dec.find(s_code); + try expect(s.code == s_code); + try expect(s.code_bits == c.len); + } +} diff --git a/src/inflate.zig b/src/inflate.zig index 4b94bf9..2e6bbf4 100644 --- a/src/inflate.zig +++ b/src/inflate.zig @@ -36,9 +36,9 @@ pub fn decompressor(comptime container: Container, reader: anytype) Inflate(cont /// `step` function runs decoder until internal `hist` buffer is full. Client /// than needs to read that data in order to proceed with decoding. /// -/// Allocates ~200K of internal buffers, most important are: +/// Allocates 74.5K of internal buffers, most important are: /// * 64K for history (CircularBuffer) -/// * 2 * 64K (2 * 32K of u16) for huffman decoders (Literal and DistanceDecoder) +/// * ~10K huffman decoders (Literal and DistanceDecoder) /// pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { return struct { @@ -151,7 +151,7 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { cl_l[codegen_order[i]] = try self.bits.read(u3); } var cl_h: hfd.CodegenDecoder = .{}; - cl_h.build(&cl_l); + cl_h.generate(&cl_l); // literal code lengths var lit_l = [_]u4{0} ** (286); @@ -171,8 +171,8 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { pos += try self.dynamicCodeLength(sym.symbol, &dst_l, pos); } - self.lit_h.build(&lit_l); - self.dst_h.build(&dst_l); + self.lit_h.generate(&lit_l); + self.dst_h.generate(&dst_l); } // Decode code length symbol to code length. Writes decoded length into @@ -338,7 +338,7 @@ test "Struct sizes" { const ReaderType = @TypeOf(fbs.reader()); const inflate_size = @sizeOf(Inflate(.gzip, ReaderType)); - try testing.expectEqual(199352, inflate_size); + try testing.expectEqual(76320, inflate_size); try testing.expectEqual( @sizeOf(CircularBuffer) + @sizeOf(hfd.LiteralDecoder) + @sizeOf(hfd.DistanceDecoder) + 48, inflate_size, @@ -346,8 +346,8 @@ test "Struct sizes" { try testing.expectEqual(65536 + 8 + 8, @sizeOf(CircularBuffer)); try testing.expectEqual(8, @sizeOf(Container.raw.Hasher())); try testing.expectEqual(24, @sizeOf(BitReader(ReaderType))); - try testing.expectEqual(67132, @sizeOf(hfd.LiteralDecoder)); - try testing.expectEqual(66620, @sizeOf(hfd.DistanceDecoder)); + try testing.expectEqual(6384, @sizeOf(hfd.LiteralDecoder)); + try testing.expectEqual(4336, @sizeOf(hfd.DistanceDecoder)); } test "flate decompress" {