From e97ddc0c4d0f72e89a477ff7f4685bba63c3c066 Mon Sep 17 00:00:00 2001 From: Frank Bossen Date: Thu, 18 Apr 2024 05:45:28 -0400 Subject: [PATCH] Port C code changes to Rust --- src/decode.rs | 67 +++++++++++++++++++++++++++++++++---------------- src/internal.rs | 3 ++- src/ipred.rs | 9 ++++--- src/recon.rs | 8 +++--- 4 files changed, 57 insertions(+), 30 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 1e0e81f5f..eff0d47df 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -708,7 +708,8 @@ fn order_palette( unsafe fn read_pal_indices( ts: &mut Rav1dTileState, scratch_pal: &mut Rav1dTaskContext_scratch_pal, - pal_idx: &mut [u8], + pal_tmp: &mut [u8], + pal_idx: Option<&mut [u8]>, // if None, use pal_tmp instead of pal_idx b: &Av1Block, pl: bool, w4: c_int, @@ -721,7 +722,7 @@ unsafe fn read_pal_indices( let pal_sz = b.pal_sz()[pli] as usize; let stride = bw4 * 4; - pal_idx[0] = rav1d_msac_decode_uniform(&mut ts.msac, pal_sz as c_uint) as u8; + pal_tmp[0] = rav1d_msac_decode_uniform(&mut ts.msac, pal_sz as c_uint) as u8; let color_map_cdf = &mut ts.cdf.m.color_map[pli][pal_sz - 2]; let Rav1dTaskContext_scratch_pal { pal_order: order, @@ -731,32 +732,52 @@ unsafe fn read_pal_indices( // top/left-to-bottom/right diagonals ("wave-front") let first = cmp::min(i, w4 * 4 - 1); let last = (i + 1).checked_sub(h4 * 4).unwrap_or(0); - order_palette(pal_idx, stride, i, first, last, order, ctx); + order_palette(pal_tmp, stride, i, first, last, order, ctx); for (m, j) in (last..=first).rev().enumerate() { let color_idx = rav1d_msac_decode_symbol_adapt8( &mut ts.msac, &mut color_map_cdf[ctx[m] as usize], pal_sz - 1, ) as usize; - pal_idx[(i - j) * stride + j] = order[m][color_idx]; + pal_tmp[(i - j) * stride + j] = order[m][color_idx]; } } - // fill invisible edges + // fill invisible edges and pack to 4-bit (2 pixels per byte) if bw4 > w4 { for y in 0..4 * h4 { let offset = y * stride + (4 * w4); let len = 4 * (bw4 - w4); - let filler = pal_idx[offset - 1]; - pal_idx[offset..][..len].fill(filler); + let filler = pal_tmp[offset - 1]; + pal_tmp[offset..][..len].fill(filler); } } - if h4 < bh4 { - let y_start = h4 * 4; - let len = bw4 * 4; - let (src, dests) = pal_idx.split_at_mut(stride * y_start); - let src = &src[stride * (y_start - 1)..][..len]; - for y in 0..(bh4 - h4) * 4 { - dests[y * stride..][..len].copy_from_slice(src); + if let Some(pal_idx) = pal_idx { + for i in 0..bw4 * h4 * 8 { + pal_idx[i] = pal_tmp[2 * i + 0] | (pal_tmp[2 * i + 1] << 4); + } + if h4 < bh4 { + let y_start = h4 * 4; + let len = bw4 * 2; + let packed_stride = bw4 * 2; + let (src, dests) = pal_idx.split_at_mut(packed_stride * y_start); + let src = &src[bw4 * h4 * 8 - packed_stride..][..len]; + for y in 0..(bh4 - h4) * 4 { + dests[y * packed_stride..][..len].copy_from_slice(src); + } + } + } else { + for i in 0..bw4 * h4 * 8 { + pal_tmp[i] = pal_tmp[2 * i + 0] | (pal_tmp[2 * i + 1] << 4); + } + if h4 < bh4 { + let y_start = h4 * 4; + let len = bw4 * 2; + let packed_stride = bw4 * 2; + let (src, dests) = pal_tmp.split_at_mut(packed_stride * y_start); + let src = &src[bw4 * h4 * 8 - packed_stride..][..len]; + for y in 0..(bh4 - h4) * 4 { + dests[y * packed_stride..][..len].copy_from_slice(src); + } } } } @@ -1795,17 +1816,18 @@ unsafe fn decode_b_inner( let pal_idx = if t.frame_thread.pass != 0 { let p = t.frame_thread.pass & 1; let frame_thread = &mut ts.frame_thread[p as usize]; - let len = usize::try_from(bw4 * bh4 * 16).unwrap(); + let len = usize::try_from(bw4 * bh4 * 8).unwrap(); let pal_idx = &mut f.frame_thread.pal_idx[frame_thread.pal_idx..][..len]; frame_thread.pal_idx += len; pal_idx } else { - &mut t.scratch.c2rust_unnamed_0.pal_idx + &mut t.scratch.c2rust_unnamed_0.pal_idx_y }; read_pal_indices( ts, &mut t.scratch.c2rust_unnamed_0.c2rust_unnamed.c2rust_unnamed, - pal_idx, + &mut t.scratch.c2rust_unnamed_0.pal_idx_uv, + Some(pal_idx), b, false, w4, @@ -1822,16 +1844,17 @@ unsafe fn decode_b_inner( let pal_idx = if t.frame_thread.pass != 0 { let p = t.frame_thread.pass & 1; let frame_thread = &mut ts.frame_thread[p as usize]; - let len = usize::try_from(cbw4 * cbh4 * 16).unwrap(); + let len = usize::try_from(cbw4 * cbh4 * 8).unwrap(); let pal_idx = &mut f.frame_thread.pal_idx[frame_thread.pal_idx..][..len]; frame_thread.pal_idx += len; - pal_idx + Some(pal_idx) } else { - &mut t.scratch.c2rust_unnamed_0.pal_idx[(bw4 * bh4 * 16) as usize..] + None }; read_pal_indices( ts, &mut t.scratch.c2rust_unnamed_0.c2rust_unnamed.c2rust_unnamed, + &mut t.scratch.c2rust_unnamed_0.pal_idx_uv, pal_idx, b, true, @@ -3649,7 +3672,7 @@ unsafe fn setup_tile( let size_mul = &ss_size_mul[f.cur.p.layout]; for p in 0..2 { ts.frame_thread[p].pal_idx = if !f.frame_thread.pal_idx.is_empty() { - tile_start_off * size_mul[1] as usize / 4 + tile_start_off * size_mul[1] as usize / 8 } else { 0 }; @@ -4143,7 +4166,7 @@ pub(crate) unsafe fn rav1d_decode_frame_init( // TODO: Fallible allocation f.frame_thread .pal_idx - .resize(pal_idx_sz as usize * 128 * 128 / 4, Default::default()); + .resize(pal_idx_sz as usize * 128 * 128 / 8, Default::default()); } else if !f.frame_thread.pal.is_empty() { let _ = mem::take(&mut f.frame_thread.pal); let _ = mem::take(&mut f.frame_thread.pal_idx); diff --git a/src/internal.rs b/src/internal.rs index c8335687e..2f62f930f 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -933,7 +933,8 @@ pub union Rav1dTaskContext_scratch_ac_txtp_map { pub struct Rav1dTaskContext_scratch_levels_pal_ac_interintra_edge { pub c2rust_unnamed: Rav1dTaskContext_scratch_levels_pal, pub ac_txtp_map: Rav1dTaskContext_scratch_ac_txtp_map, - pub pal_idx: [u8; 8192], + pub pal_idx_y: [u8; 32 * 64], + pub pal_idx_uv: [u8; 64 * 64], // also used as pre-pack scratch buffer pub interintra_edge_pal: BitDepthUnion, } diff --git a/src/ipred.rs b/src/ipred.rs index b7f1d2915..1036631c7 100644 --- a/src/ipred.rs +++ b/src/ipred.rs @@ -1442,10 +1442,13 @@ unsafe fn pal_pred_rust( while y < h { let mut x = 0; while x < w { - *dst.offset(x as isize) = *pal.offset(*idx.offset(x as isize) as isize); - x += 1; + let i = *idx; + assert!((i & 0x88) == 0); + *dst.offset(x as isize) = *pal.offset((i & 7) as isize); + *dst.offset(x as isize + 1) = *pal.offset((i >> 4) as isize); + idx = idx.offset(1); + x += 2; } - idx = idx.offset(w as isize); dst = dst.offset(BD::pxstride(stride)); y += 1; } diff --git a/src/recon.rs b/src/recon.rs index 1aaa571ec..cd7321ff1 100644 --- a/src/recon.rs +++ b/src/recon.rs @@ -2449,12 +2449,12 @@ pub(crate) unsafe fn rav1d_recon_b_intra( let pal_idx = if t.frame_thread.pass != 0 { let p = t.frame_thread.pass & 1; let frame_thread = &mut ts.frame_thread[p as usize]; - let len = (bw4 * bh4 * 16) as usize; + let len = (bw4 * bh4 * 8) as usize; let pal_idx = &f.frame_thread.pal_idx[frame_thread.pal_idx..][..len]; frame_thread.pal_idx += len; pal_idx } else { - &t.scratch.c2rust_unnamed_0.pal_idx + &t.scratch.c2rust_unnamed_0.pal_idx_y }; let pal: *const BD::Pixel = if t.frame_thread.pass != 0 { let index = (((t.b.y as isize >> 1) + (t.b.x as isize & 1)) @@ -2833,7 +2833,7 @@ pub(crate) unsafe fn rav1d_recon_b_intra( + ((t.b.x as isize >> 1) as isize + (t.b.y as isize & 1)) as isize) as isize; let pal_idx_offset = &mut ts.frame_thread[p as usize].pal_idx; - let len = (cbw4 * cbh4 * 16) as usize; + let len = (cbw4 * cbh4 * 8) as usize; let pal_idx = &f.frame_thread.pal_idx[*pal_idx_offset..][..len]; *pal_idx_offset += len; ( @@ -2845,7 +2845,7 @@ pub(crate) unsafe fn rav1d_recon_b_intra( BD::select_mut(&mut t.scratch.c2rust_unnamed_0.interintra_edge_pal); ( &interintra_edge_pal.pal, - &t.scratch.c2rust_unnamed_0.pal_idx[(bw4 * bh4 * 16) as usize..], + &t.scratch.c2rust_unnamed_0.pal_idx_uv[..], ) }; (*f.dsp).ipred.pal_pred.call::(