Skip to content

Commit

Permalink
simplified gapless decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
vladmikhalin committed Nov 13, 2024
1 parent 4d4d9d5 commit 6dcf249
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 91 deletions.
45 changes: 26 additions & 19 deletions src/core/libraries/ajm/ajm_at9.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ void AjmAt9Decoder::GetInfo(void* out_info) const {
}

std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples_per_channel) {
AjmInstanceGapless& gapless) {
int ret = 0;
int bytes_used = 0;
switch (m_format) {
Expand All @@ -79,32 +78,37 @@ std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOut

m_superframe_bytes_remain -= bytes_used;

u32 skipped_samples = 0;
if (gapless.skipped_samples < gapless.skip_samples) {
skipped_samples = std::min(u32(m_codec_info.frameSamples),
u32(gapless.skip_samples - gapless.skipped_samples));
gapless.skipped_samples += skipped_samples;
u32 skip_samples = 0;
if (gapless.current.skip_samples > 0) {
skip_samples = std::min(u16(m_codec_info.frameSamples), gapless.current.skip_samples);
gapless.current.skip_samples -= skip_samples;
}

const auto max_samples = max_samples_per_channel.has_value()
? max_samples_per_channel.value() * m_codec_info.channels
: std::numeric_limits<u32>::max();
const auto max_pcm = gapless.init.total_samples != 0
? gapless.current.total_samples * m_codec_info.channels
: std::numeric_limits<u32>::max();

size_t samples_written = 0;
size_t pcm_written = 0;
switch (m_format) {
case AjmFormatEncoding::S16:
samples_written = WriteOutputSamples<s16>(output, skipped_samples, max_samples);
pcm_written = WriteOutputSamples<s16>(output, skip_samples, max_pcm);
break;
case AjmFormatEncoding::S32:
samples_written = WriteOutputSamples<s32>(output, skipped_samples, max_samples);
pcm_written = WriteOutputSamples<s32>(output, skip_samples, max_pcm);
break;
case AjmFormatEncoding::Float:
samples_written = WriteOutputSamples<float>(output, skipped_samples, max_samples);
pcm_written = WriteOutputSamples<float>(output, skip_samples, max_pcm);
break;
default:
UNREACHABLE();
}

const auto samples_written = pcm_written / m_codec_info.channels;
gapless.current.skipped_samples += m_codec_info.frameSamples - samples_written;
if (gapless.init.total_samples != 0) {
gapless.current.total_samples -= samples_written;
}

m_num_frames += 1;
if ((m_num_frames % m_codec_info.framesInSuperframe) == 0) {
if (m_superframe_bytes_remain) {
Expand All @@ -114,7 +118,7 @@ std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
m_num_frames = 0;
}

return {1, samples_written / m_codec_info.channels};
return {1, samples_written};
}

AjmSidebandFormat AjmAt9Decoder::GetFormat() const {
Expand All @@ -129,10 +133,13 @@ AjmSidebandFormat AjmAt9Decoder::GetFormat() const {
};
}

u32 AjmAt9Decoder::GetNextFrameSize(u32 skip_samples, u32 max_samples) const {
skip_samples = std::min({skip_samples, u32(m_codec_info.frameSamples), max_samples});
return (std::min(u32(m_codec_info.frameSamples), max_samples) - skip_samples) *
m_codec_info.channels * GetPCMSize(m_format);
u32 AjmAt9Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const {
const auto max_samples =
gapless.init.total_samples != 0
? std::min(gapless.current.total_samples, u32(m_codec_info.frameSamples))
: m_codec_info.frameSamples;
const auto skip_samples = std::min(u32(gapless.current.skip_samples), max_samples);
return (max_samples - skip_samples) * m_codec_info.channels * GetPCMSize(m_format);
}

} // namespace Libraries::Ajm
5 changes: 2 additions & 3 deletions src/core/libraries/ajm/ajm_at9.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ struct AjmAt9Decoder final : AjmCodec {
void Initialize(const void* buffer, u32 buffer_size) override;
void GetInfo(void* out_info) const override;
AjmSidebandFormat GetFormat() const override;
u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const override;
u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override;
std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples) override;
AjmInstanceGapless& gapless) override;

private:
template <class T>
Expand Down
49 changes: 20 additions & 29 deletions src/core/libraries/ajm/ajm_instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
m_format = {};
m_gapless = {};
m_resample_parameters = {};
m_gapless_samples = 0;
m_total_samples = 0;
m_codec->Reset();
}
Expand All @@ -79,10 +78,14 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
if (job.input.gapless_decode.has_value()) {
auto& params = job.input.gapless_decode.value();
if (params.total_samples != 0) {
m_gapless.total_samples = std::max(params.total_samples, m_gapless.total_samples);
const auto max = std::max(params.total_samples, m_gapless.init.total_samples);
m_gapless.current.total_samples += max - m_gapless.init.total_samples;
m_gapless.init.total_samples = max;
}
if (params.skip_samples != 0) {
m_gapless.skip_samples = std::max(params.skip_samples, m_gapless.skip_samples);
const auto max = std::max(params.skip_samples, m_gapless.init.skip_samples);
m_gapless.current.skip_samples += max - m_gapless.init.skip_samples;
m_gapless.init.skip_samples = max;
}
}

Expand All @@ -93,22 +96,29 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
u32 frames_decoded = 0;
auto in_size = in_buf.size();
auto out_size = out_buf.Size();
while (!in_buf.empty() && !out_buf.IsEmpty() && !IsGaplessEnd()) {
while (!in_buf.empty() && !out_buf.IsEmpty() && !m_gapless.IsEnd()) {
if (!HasEnoughSpace(out_buf)) {
if (job.output.p_mframe == nullptr || frames_decoded == 0) {
job.output.p_result->result = ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM;
break;
}
}
const auto [nframes, nsamples] =
m_codec->ProcessData(in_buf, out_buf, m_gapless, GetNumRemainingSamples());

const auto [nframes, nsamples] = m_codec->ProcessData(in_buf, out_buf, m_gapless);
frames_decoded += nframes;
m_total_samples += nsamples;
m_gapless_samples += nsamples;
if (job.output.p_mframe == nullptr) {

if (False(job.flags.run_flags & AjmJobRunFlags::MultipleFrames)) {
break;
}
}

if (m_gapless.IsEnd()) {
in_buf = in_buf.subspan(in_buf.size());
m_gapless.current.total_samples = m_gapless.init.total_samples;
m_gapless.current.skip_samples = m_gapless.init.skip_samples;
m_codec->Reset();
}
if (job.output.p_mframe) {
job.output.p_mframe->num_frames = frames_decoded;
}
Expand All @@ -119,38 +129,19 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
}
}

if (m_flags.gapless_loop && m_gapless.total_samples != 0 &&
m_gapless_samples >= m_gapless.total_samples) {
m_gapless_samples = 0;
m_gapless.skipped_samples = 0;
m_codec->Reset();
}
if (job.output.p_format != nullptr) {
*job.output.p_format = m_codec->GetFormat();
}
if (job.output.p_gapless_decode != nullptr) {
*job.output.p_gapless_decode = m_gapless;
*job.output.p_gapless_decode = m_gapless.current;
}
if (job.output.p_codec_info != nullptr) {
m_codec->GetInfo(job.output.p_codec_info);
}
}

bool AjmInstance::IsGaplessEnd() const {
return m_gapless.total_samples != 0 && m_gapless_samples >= m_gapless.total_samples;
}

bool AjmInstance::HasEnoughSpace(const SparseOutputBuffer& output) const {
const auto skip =
m_gapless.skip_samples - std::min(m_gapless.skip_samples, m_gapless.skipped_samples);
const auto remain = GetNumRemainingSamples().value_or(std::numeric_limits<u32>::max());
return output.Size() >= m_codec->GetNextFrameSize(skip, remain);
}

std::optional<u32> AjmInstance::GetNumRemainingSamples() const {
return m_gapless.total_samples != 0
? std::optional<u32>{m_gapless.total_samples - m_gapless_samples}
: std::optional<u32>{};
return output.Size() >= m_codec->GetNextFrameSize(m_gapless);
}

} // namespace Libraries::Ajm
20 changes: 12 additions & 8 deletions src/core/libraries/ajm/ajm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ class SparseOutputBuffer {
std::span<std::span<u8>>::iterator m_current;
};

struct AjmInstanceGapless {
AjmSidebandGaplessDecode init{};
AjmSidebandGaplessDecode current{};

bool IsEnd() const {
return init.total_samples != 0 && current.total_samples == 0;
}
};

class AjmCodec {
public:
virtual ~AjmCodec() = default;
Expand All @@ -66,10 +75,9 @@ class AjmCodec {
virtual void Reset() = 0;
virtual void GetInfo(void* out_info) const = 0;
virtual AjmSidebandFormat GetFormat() const = 0;
virtual u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const = 0;
virtual u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const = 0;
virtual std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples_per_channel) = 0;
AjmInstanceGapless& gapless) = 0;
};

class AjmInstance {
Expand All @@ -79,18 +87,14 @@ class AjmInstance {
void ExecuteJob(AjmJob& job);

private:
bool IsGaplessEnd() const;
bool HasEnoughSpace(const SparseOutputBuffer& output) const;
std::optional<u32> GetNumRemainingSamples() const;

AjmInstanceFlags m_flags{};
AjmSidebandFormat m_format{};
AjmSidebandGaplessDecode m_gapless{};
AjmInstanceGapless m_gapless{};
AjmSidebandResampleParameters m_resample_parameters{};

u32 m_gapless_samples{};
u32 m_total_samples{};

std::unique_ptr<AjmCodec> m_codec;
};

Expand Down
54 changes: 28 additions & 26 deletions src/core/libraries/ajm/ajm_mp3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ void AjmMp3Decoder::GetInfo(void* out_info) const {
}

std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples_per_channel) {
AjmInstanceGapless& gapless) {
AVPacket* pkt = av_packet_alloc();

if ((!m_header.has_value() || m_frame_samples == 0) && in_buf.size() >= 4) {
Expand All @@ -154,12 +153,7 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
in_buf = in_buf.subspan(ret);

u32 frames_decoded = 0;
u32 samples_decoded = 0;

auto max_samples =
max_samples_per_channel.has_value()
? max_samples_per_channel.value() * m_codec_context->ch_layout.nb_channels
: std::numeric_limits<u32>::max();
u32 samples_written = 0;

if (pkt->size) {
// Send the packet with the compressed data to the decoder
Expand All @@ -182,32 +176,37 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
frame = ConvertAudioFrame(frame);

frames_decoded += 1;
u32 skipped_samples = 0;
if (gapless.skipped_samples < gapless.skip_samples) {
skipped_samples = std::min(u32(frame->nb_samples),
u32(gapless.skip_samples - gapless.skipped_samples));
gapless.skipped_samples += skipped_samples;
u32 skip_samples = 0;
if (gapless.current.skip_samples > 0) {
skip_samples = std::min(u16(frame->nb_samples), gapless.current.skip_samples);
gapless.current.skip_samples -= skip_samples;
}

const auto max_pcm =
gapless.init.total_samples != 0
? gapless.current.total_samples * m_codec_context->ch_layout.nb_channels
: std::numeric_limits<u32>::max();

u32 pcm_written = 0;
switch (m_format) {
case AjmFormatEncoding::S16:
samples_decoded +=
WriteOutputSamples<s16>(frame, output, skipped_samples, max_samples);
pcm_written = WriteOutputPCM<s16>(frame, output, skip_samples, max_pcm);
break;
case AjmFormatEncoding::S32:
samples_decoded +=
WriteOutputSamples<s32>(frame, output, skipped_samples, max_samples);
pcm_written = WriteOutputPCM<s32>(frame, output, skip_samples, max_pcm);
break;
case AjmFormatEncoding::Float:
samples_decoded +=
WriteOutputSamples<float>(frame, output, skipped_samples, max_samples);
pcm_written = WriteOutputPCM<float>(frame, output, skip_samples, max_pcm);
break;
default:
UNREACHABLE();
}

if (max_samples_per_channel.has_value()) {
max_samples -= samples_decoded;
const auto samples = pcm_written / m_codec_context->ch_layout.nb_channels;
samples_written += samples;
gapless.current.skipped_samples += frame->nb_samples - samples;
if (gapless.init.total_samples != 0) {
gapless.current.total_samples -= samples;
}

av_frame_free(&frame);
Expand All @@ -216,13 +215,16 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut

av_packet_free(&pkt);

return {frames_decoded, samples_decoded / m_codec_context->ch_layout.nb_channels};
return {frames_decoded, samples_written};
}

u32 AjmMp3Decoder::GetNextFrameSize(u32 skip_samples, u32 max_samples) const {
skip_samples = std::min({skip_samples, m_frame_samples, max_samples});
return (std::min(m_frame_samples, max_samples) - skip_samples) *
m_codec_context->ch_layout.nb_channels * GetPCMSize(m_format);
u32 AjmMp3Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const {
const auto max_samples = gapless.init.total_samples != 0
? std::min(gapless.current.total_samples, m_frame_samples)
: m_frame_samples;
const auto skip_samples = std::min(u32(gapless.current.skip_samples), max_samples);
return (max_samples - skip_samples) * m_codec_context->ch_layout.nb_channels *
GetPCMSize(m_format);
}

class BitReader {
Expand Down
11 changes: 5 additions & 6 deletions src/core/libraries/ajm/ajm_mp3.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,21 @@ class AjmMp3Decoder : public AjmCodec {
void Initialize(const void* buffer, u32 buffer_size) override {}
void GetInfo(void* out_info) const override;
AjmSidebandFormat GetFormat() const override;
u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const override;
u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override;
std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples_per_channel) override;
AjmInstanceGapless& gapless) override;

static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl,
AjmDecMp3ParseFrame* frame);

private:
template <class T>
size_t WriteOutputSamples(AVFrame* frame, SparseOutputBuffer& output, u32 skipped_samples,
u32 max_samples) {
size_t WriteOutputPCM(AVFrame* frame, SparseOutputBuffer& output, u32 skipped_samples,
u32 max_pcm) {
std::span<T> pcm_data(reinterpret_cast<T*>(frame->data[0]),
frame->nb_samples * frame->ch_layout.nb_channels);
pcm_data = pcm_data.subspan(skipped_samples * frame->ch_layout.nb_channels);
return output.Write(pcm_data.subspan(0, std::min(u32(pcm_data.size()), max_samples)));
return output.Write(pcm_data.subspan(0, std::min(u32(pcm_data.size()), max_pcm)));
}

AVFrame* ConvertAudioFrame(AVFrame* frame);
Expand Down

0 comments on commit 6dcf249

Please sign in to comment.