diff --git a/src/msac.rs b/src/msac.rs index ffd758726..88f6f1a65 100644 --- a/src/msac.rs +++ b/src/msac.rs @@ -14,14 +14,16 @@ extern "C" { fn dav1d_msac_decode_bool_equi_sse2(s: *mut MsacContext) -> libc::c_uint; fn dav1d_msac_decode_bool_adapt_sse2(s: *mut MsacContext, cdf: *mut uint16_t) -> libc::c_uint; fn dav1d_msac_decode_symbol_adapt16_avx2( - s: *mut MsacContext, + s: &mut MsacContext, cdf: *mut uint16_t, n_symbols: size_t, + _cdf_len: usize, ) -> libc::c_uint; fn dav1d_msac_decode_symbol_adapt16_sse2( - s: *mut MsacContext, + s: &mut MsacContext, cdf: *mut uint16_t, n_symbols: size_t, + _cdf_len: usize, ) -> libc::c_uint; fn dav1d_msac_decode_symbol_adapt8_sse2( s: *mut MsacContext, @@ -60,14 +62,6 @@ extern "C" { pub type ec_win = size_t; -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum FnSymbolAdapt16 { - Rust, - Sse2, - Avx2, -} - #[derive(Copy, Clone)] #[repr(C)] pub struct MsacContext { @@ -78,7 +72,8 @@ pub struct MsacContext { pub cnt: libc::c_int, allow_update_cdf: libc::c_int, #[cfg(all(feature = "asm", target_arch = "x86_64"))] - pub symbol_adapt16: FnSymbolAdapt16, + symbol_adapt16: + unsafe extern "C" fn(&mut MsacContext, *mut uint16_t, size_t, usize) -> libc::c_uint, } impl MsacContext { @@ -141,10 +136,10 @@ fn msac_init_x86(s: &mut MsacContext) { let flags = dav1d_get_cpu_flags(); if flags & DAV1D_X86_CPU_FLAG_SSE2 != 0 { - s.symbol_adapt16 = FnSymbolAdapt16::Sse2; + s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2; } if flags & DAV1D_X86_CPU_FLAG_AVX2 != 0 { - s.symbol_adapt16 = FnSymbolAdapt16::Avx2; + s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2; } } @@ -272,6 +267,22 @@ fn dav1d_msac_decode_symbol_adapt_rust( val } +#[deny(unsafe_op_in_unsafe_fn)] +unsafe extern "C" fn dav1d_msac_decode_symbol_adapt_c( + s: &mut MsacContext, + cdf: *mut u16, + n_symbols: size_t, + cdf_len: usize, +) -> libc::c_uint { + // # Safety + // + // This is only called from [`dav1d_msac_decode_symbol_adapt16`], + // where it comes from `cdf.len()`. + let cdf = unsafe { std::slice::from_raw_parts_mut(cdf, cdf_len) }; + + dav1d_msac_decode_symbol_adapt_rust(s, cdf, n_symbols) +} + fn dav1d_msac_decode_bool_adapt_rust(s: &mut MsacContext, cdf: &mut [u16; 2]) -> bool { let bit = dav1d_msac_decode_bool(s, cdf[0] as libc::c_uint); if s.allow_update_cdf() { @@ -323,7 +334,7 @@ pub unsafe fn dav1d_msac_init( #[cfg(all(feature = "asm", target_arch = "x86_64"))] { - s.symbol_adapt16 = FnSymbolAdapt16::Rust; + s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c; msac_init_x86(s); } } @@ -379,12 +390,9 @@ pub fn dav1d_msac_decode_symbol_adapt16( ) -> libc::c_uint { cfg_if! { if #[cfg(all(feature = "asm", target_arch = "x86_64"))] { - match s.symbol_adapt16 { - FnSymbolAdapt16::Rust => dav1d_msac_decode_symbol_adapt_rust(s, cdf, n_symbols), - // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. - FnSymbolAdapt16::Sse2 => unsafe { dav1d_msac_decode_symbol_adapt16_sse2(s, cdf.as_mut_ptr(), n_symbols) }, - // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. - FnSymbolAdapt16::Avx2 => unsafe { dav1d_msac_decode_symbol_adapt16_avx2(s, cdf.as_mut_ptr(), n_symbols) } , + // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. + unsafe { + (s.symbol_adapt16)(s, cdf.as_mut_ptr(), n_symbols, cdf.len()) } } else if #[cfg(all(feature = "asm", target_arch = "aarch64"))] { // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`].