diff --git a/build.rs b/build.rs index 705b48546..73f13e00c 100644 --- a/build.rs +++ b/build.rs @@ -145,7 +145,7 @@ mod asm { fs::write(&config_path, &config_contents).unwrap(); // Note that avx* is never (at runtime) supported on x86. - let x86_generic = &["cdef_sse", "itx_sse", "msac", "refmvs"][..]; + let x86_generic = &["cdef_sse", "itx_sse", "msac", "pal", "refmvs"][..]; let x86_64_generic = &[ "cdef_avx2", "itx_avx2", diff --git a/lib.rs b/lib.rs index b449c0320..538af4992 100644 --- a/lib.rs +++ b/lib.rs @@ -68,6 +68,7 @@ pub mod src { mod mem; mod msac; mod obu; + mod pal; mod picture; mod qm; mod recon; diff --git a/src/decode.rs b/src/decode.rs index ee78ffbba..68b489dd3 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -124,6 +124,7 @@ use crate::src::msac::rav1d_msac_decode_symbol_adapt4; use crate::src::msac::rav1d_msac_decode_symbol_adapt8; use crate::src::msac::rav1d_msac_decode_uniform; use crate::src::msac::rav1d_msac_init; +use crate::src::pal::Rav1dPalDSPContext; use crate::src::picture::rav1d_picture_alloc_copy; use crate::src::picture::rav1d_picture_ref; use crate::src::picture::rav1d_picture_unref_internal; @@ -699,6 +700,7 @@ fn order_palette( unsafe fn read_pal_indices( ts: &mut Rav1dTileState, + pal_dsp: &Rav1dPalDSPContext, scratch_pal: &mut Rav1dTaskContext_scratch_pal, pal_tmp: &mut [u8], pal_idx: Option<&mut [u8]>, // if None, use pal_tmp instead of pal_idx @@ -743,35 +745,14 @@ unsafe fn read_pal_indices( pal_tmp[offset..][..len].fill(filler); } } - if let Some(pal_idx) = pal_idx { - for i in 0..bw4 * h4 * 8 { - pal_idx[i] = pal_tmp[2 * i + 0] | (pal_tmp[2 * i + 1] << 4); - } - if h4 < bh4 { - let y_start = h4 * 4; - let len = bw4 * 2; - let packed_stride = bw4 * 2; - let (src, dests) = pal_idx.split_at_mut(packed_stride * y_start); - let src = &src[bw4 * h4 * 8 - packed_stride..][..len]; - for y in 0..(bh4 - h4) * 4 { - dests[y * packed_stride..][..len].copy_from_slice(src); - } - } - } else { - for i in 0..bw4 * h4 * 8 { - pal_tmp[i] = pal_tmp[2 * i + 0] | (pal_tmp[2 * i + 1] << 4); - } - if h4 < bh4 { - let y_start = h4 * 4; - let len = bw4 * 2; - let packed_stride = bw4 * 2; - let (src, dests) = pal_tmp.split_at_mut(packed_stride * y_start); - let src = &src[bw4 * h4 * 8 - packed_stride..][..len]; - for y in 0..(bh4 - h4) * 4 { - dests[y * packed_stride..][..len].copy_from_slice(src); - } - } - } + (pal_dsp.pal_idx_finish)( + pal_idx.unwrap_or(pal_tmp).as_mut_ptr(), + pal_tmp.as_ptr(), + bw4 as c_int * 4, + bh4 as c_int * 4, + w4 as c_int * 4, + h4 as c_int * 4, + ); } unsafe fn read_vartx_tree( @@ -1805,6 +1786,7 @@ unsafe fn decode_b( }; read_pal_indices( ts, + &c.pal_dsp, &mut t.scratch.c2rust_unnamed_0.c2rust_unnamed.c2rust_unnamed, &mut t.scratch.c2rust_unnamed_0.pal_idx_uv, Some(pal_idx), @@ -1837,6 +1819,7 @@ unsafe fn decode_b( }; read_pal_indices( ts, + &c.pal_dsp, &mut t.scratch.c2rust_unnamed_0.c2rust_unnamed.c2rust_unnamed, &mut t.scratch.c2rust_unnamed_0.pal_idx_uv, pal_idx, diff --git a/src/internal.rs b/src/internal.rs index a9dd47990..7a011bc56 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -55,6 +55,7 @@ use crate::src::looprestoration::Rav1dLoopRestorationDSPContext; use crate::src::mc::Rav1dMCDSPContext; use crate::src::mem::Rav1dMemPool; use crate::src::msac::MsacContext; +use crate::src::pal::Rav1dPalDSPContext; use crate::src::picture::PictureFlags; use crate::src::picture::Rav1dThreadPicture; use crate::src::recon::backup_ipred_edge_fn; @@ -339,6 +340,7 @@ pub struct Rav1dContext { pub(crate) refs: [Rav1dContext_refs; 8], pub(crate) cdf: [CdfThreadContext; 8], // Previously pooled + pub(crate) pal_dsp: Rav1dPalDSPContext, pub(crate) refmvs_dsp: Rav1dRefmvsDSPContext, pub(crate) allocator: Rav1dPicAllocator, diff --git a/src/lib.rs b/src/lib.rs index 4253da7fa..295a601e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ use crate::src::mem::rav1d_mem_pool_end; use crate::src::mem::rav1d_mem_pool_init; use crate::src::obu::rav1d_parse_obus; use crate::src::obu::rav1d_parse_sequence_header; +use crate::src::pal::ravid_pal_dsp_init; use crate::src::picture::dav1d_default_picture_alloc; use crate::src::picture::dav1d_default_picture_release; use crate::src::picture::rav1d_picture_alloc_copy; @@ -334,6 +335,7 @@ pub(crate) unsafe fn rav1d_open(c_out: &mut *mut Rav1dContext, s: &Rav1dSettings } }) .collect(); + ravid_pal_dsp_init(&mut (*c).pal_dsp); rav1d_refmvs_dsp_init(&mut (*c).refmvs_dsp); Ok(()) } diff --git a/src/pal.rs b/src/pal.rs new file mode 100644 index 000000000..b9a3dadf4 --- /dev/null +++ b/src/pal.rs @@ -0,0 +1,139 @@ +use crate::src::cpu::rav1d_get_cpu_flags; +use crate::src::cpu::CpuFlags; +use cfg_if::cfg_if; +use libc::c_int; +use std::slice; + +pub type pal_idx_finish_fn = unsafe extern "C" fn( + dst: *mut u8, + src: *const u8, + bw: c_int, + bh: c_int, + w: c_int, + h: c_int, +) -> (); + +#[repr(C)] +pub(crate) struct Rav1dPalDSPContext { + pub pal_idx_finish: pal_idx_finish_fn, +} + +// fill invisible edges and pack to 4-bit (2 pixels per byte) +unsafe extern "C" fn pal_idx_finish_rust( + dst: *mut u8, + src: *const u8, + bw: c_int, + bh: c_int, + w: c_int, + h: c_int, +) -> () { + assert!(bw >= 4 && bw <= 64 && (bw & (bw - 1)) == 0); + assert!(bh >= 4 && bh <= 64 && (bh & (bh - 1)) == 0); + assert!(w >= 4 && w <= bw && (w & 3) == 0); + assert!(h >= 4 && h <= bh && (h & 3) == 0); + + let w = w as usize; + let h = h as usize; + let bw = bw as usize; + let bh = bh as usize; + let dst_w = w / 2; + let dst_bw = bw / 2; + + let mut dst = slice::from_raw_parts_mut(dst, dst_bw * bh); + let mut src = slice::from_raw_parts(src, bw * bh); + + for y in 0..h { + for x in 0..dst_w { + dst[x] = src[2 * x] | (src[2 * x + 1] << 4) + } + if dst_w < dst_bw { + for x in dst_w..dst_bw { + dst[x] = 0x11 * src[w]; + } + } + src = &src[bw..]; + if y < h - 1 { + dst = &mut dst[dst_bw..]; + } + } + + if h < bh { + let (last_row, dst) = dst.split_at_mut(dst_bw); + + for row in dst.chunks_exact_mut(dst_bw) { + row.copy_from_slice(last_row); + } + } +} + +#[cfg(all(feature = "asm", any(target_arch = "x86", target_arch = "x86_64"),))] +extern "C" { + fn dav1d_pal_idx_finish_ssse3( + dst: *mut u8, + src: *const u8, + bw: c_int, + bh: c_int, + w: c_int, + h: c_int, + ); +} + +#[cfg(all(feature = "asm", any(target_arch = "x86_64"),))] +extern "C" { + fn dav1d_pal_idx_finish_avx2( + dst: *mut u8, + src: *const u8, + bw: c_int, + bh: c_int, + w: c_int, + h: c_int, + ); + + fn dav1d_pal_idx_finish_avx512icl( + dst: *mut u8, + src: *const u8, + bw: c_int, + bh: c_int, + w: c_int, + h: c_int, + ); +} + +#[inline(always)] +#[cfg(all(feature = "asm", any(target_arch = "x86", target_arch = "x86_64"),))] +unsafe fn pal_dsp_init_x86(c: *mut Rav1dPalDSPContext) { + let flags = rav1d_get_cpu_flags(); + + if !flags.contains(CpuFlags::SSSE3) { + return; + } + + (*c).pal_idx_finish = dav1d_pal_idx_finish_ssse3; + + cfg_if! { + if #[cfg(any(target_arch = "x86_64"))] { + if !flags.contains(CpuFlags::AVX2) { + return; + } + + (*c).pal_idx_finish = dav1d_pal_idx_finish_avx2; + + if !flags.contains(CpuFlags::AVX512ICL) { + return; + } + + (*c).pal_idx_finish = dav1d_pal_idx_finish_avx512icl; + } + } +} + +pub(crate) unsafe fn ravid_pal_dsp_init(c: *mut Rav1dPalDSPContext) -> () { + (*c).pal_idx_finish = pal_idx_finish_rust; + + #[cfg(feature = "asm")] + cfg_if! { + if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { + pal_dsp_init_x86(c); + } + } +}