diff --git a/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java b/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java index dbe56f8a0e0..5e58bf03207 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java @@ -15,16 +15,14 @@ */ package io.netty.handler.codec.compression; -import com.github.luben.zstd.BaseZstdBufferDecompressingStreamNoFinalizer; -import com.github.luben.zstd.ZstdBufferDecompressingStreamNoFinalizer; -import com.github.luben.zstd.ZstdDirectBufferDecompressingStreamNoFinalizer; +import com.github.luben.zstd.ZstdInputStreamNoFinalizer; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; +import java.io.Closeable; import java.io.IOException; -import java.nio.ByteBuffer; +import java.io.InputStream; import java.util.List; /** @@ -36,15 +34,15 @@ public final class ZstdDecoder extends ByteToMessageDecoder { { try { Zstd.ensureAvailability(); - outCapacity = ZstdBufferDecompressingStreamNoFinalizer.recommendedTargetBufferSize(); } catch (Throwable throwable) { throw new ExceptionInInitializerError(throwable); } } - private final int outCapacity; + + private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream(); + private ZstdInputStreamNoFinalizer zstdIs; private State currentState = State.DECOMPRESS_DATA; - private ZstdStream stream; /** * Current state of stream. @@ -62,140 +60,91 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t return; } final int compressedLength = in.readableBytes(); - if (compressedLength == 0) { - // Nothing to decompress, try again later. - return; - } - if (stream == null) { - // We assume that if the first buffer is direct the next buffer will also most likely be direct. - stream = new ZstdStream(in.isDirect(), outCapacity); - } - do { - ByteBuf decompressed = stream.decompress(ctx.alloc(), in); - if (decompressed == null) { - return; + inputStream.current = in; + + ByteBuf outBuffer = null; + try { + int w; + do { + if (outBuffer == null) { + // Let's start with the compressedLength * 2 as often we will not have everything + // we need in the in buffer and don't want to reserve too much memory. + outBuffer = ctx.alloc().heapBuffer(compressedLength * 2); + } + do { + w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes()); + } while (w != -1 && outBuffer.isWritable()); + if (outBuffer.isReadable()) { + out.add(outBuffer); + outBuffer = null; + } + } while (w != -1); + } finally { + if (outBuffer != null) { + outBuffer.release(); } - out.add(decompressed); - } while (in.isReadable()); - } catch (DecompressionException e) { + } + } catch (Exception e) { currentState = State.CORRUPTED; - throw e; + throw new DecompressionException(e); + } finally { + inputStream.current = null; } } + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + zstdIs = new ZstdInputStreamNoFinalizer(inputStream); + zstdIs.setContinuous(true); + } + @Override protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { try { - if (stream != null) { - stream.close(); - stream = null; - } + closeSilently(zstdIs); } finally { super.handlerRemoved0(ctx); } } - private static final class ZstdStream { - private static final ByteBuffer EMPTY_HEAP_BUFFER = ByteBuffer.allocate(0); - private static final ByteBuffer EMPTY_DIRECT_BUFFER = ByteBuffer.allocateDirect(0); - - private final boolean direct; - private final int outCapacity; - private final BaseZstdBufferDecompressingStreamNoFinalizer decompressingStream; - private ByteBuffer current; - - ZstdStream(boolean direct, int outCapacity) { - this.direct = direct; - this.outCapacity = outCapacity; - if (direct) { - decompressingStream = new ZstdDirectBufferDecompressingStreamNoFinalizer(EMPTY_DIRECT_BUFFER) { - @Override - protected ByteBuffer refill(ByteBuffer toRefill) { - return ZstdStream.this.refill(toRefill); - } - }; - } else { - decompressingStream = new ZstdBufferDecompressingStreamNoFinalizer(EMPTY_HEAP_BUFFER) { - @Override - protected ByteBuffer refill(ByteBuffer toRefill) { - return ZstdStream.this.refill(toRefill); - } - }; + private static void closeSilently(Closeable closeable) { + if (closeable != null) { + try { + closeable.close(); + } catch (IOException ignore) { + // ignore } } + } - ByteBuf decompress(ByteBufAllocator alloc, ByteBuf in) throws DecompressionException { - final ByteBuf source; - // Ensure we use the correct input buffer type. - if (direct && !in.isDirect()) { - source = alloc.directBuffer(in.readableBytes()); - source.writeBytes(in, in.readerIndex(), in.readableBytes()); - } else if (!direct && !in.hasArray()) { - source = alloc.heapBuffer(in.readableBytes()); - source.writeBytes(in, in.readerIndex(), in.readableBytes()); - } else { - source = in; - } - int inPosition = -1; - ByteBuf outBuffer = null; - try { - ByteBuffer inNioBuffer = CompressionUtil.safeNioBuffer( - source, source.readerIndex(), source.readableBytes()); - inPosition = inNioBuffer.position(); - assert inNioBuffer.hasRemaining(); - current = inNioBuffer; - - // allocate the outBuffer based on what we expect from the decompressingStream. - if (direct) { - outBuffer = alloc.directBuffer(outCapacity); - } else { - outBuffer = alloc.heapBuffer(outCapacity); - } - ByteBuffer target = outBuffer.internalNioBuffer(outBuffer.writerIndex(), outBuffer.writableBytes()); - int position = target.position(); - do { - do { - if (decompressingStream.read(target) == 0) { - break; - } - } while (decompressingStream.hasRemaining() && target.hasRemaining() && current.hasRemaining()); - int written = target.position() - position; - if (written > 0) { - outBuffer.writerIndex(outBuffer.writerIndex() + written); - ByteBuf out = outBuffer; - outBuffer = null; - return out; - } - } while (decompressingStream.hasRemaining() && current.hasRemaining()); - } catch (IOException e) { - throw new DecompressionException(e); - } finally { - if (outBuffer != null) { - outBuffer.release(); - } - // Release in case of copy - if (source != in) { - source.release(); - } - ByteBuffer buffer = current; - current = null; - if (inPosition != -1) { - int read = buffer.position() - inPosition; - if (read > 0) { - in.skipBytes(read); - } - } + private static final class MutableByteBufInputStream extends InputStream { + ByteBuf current; + + @Override + public int read() { + if (current == null || !current.isReadable()) { + return -1; } - return null; + return current.readByte() & 0xff; } - private ByteBuffer refill(@SuppressWarnings("unused") ByteBuffer toRefill) { - return current; + @Override + public int read(byte[] b, int off, int len) { + int available = available(); + if (available == 0) { + return -1; + } + + len = Math.min(available, len); + current.readBytes(b, off, len); + return len; } - void close() { - decompressingStream.close(); + @Override + public int available() { + return current == null ? 0 : current.readableBytes(); } } }