Skip to content

Commit

Permalink
MsacContext::symbol_adapt16: Re-virtualize and simplify by passing …
Browse files Browse the repository at this point in the history
…extra args (#342)

eca5a4a reverts
4bea6a0 and
908703e from #310. The rest of the
commits simplify the unsafe by passing the extra `cdf.len()` arg so an
`assert!` isn't needed and the unsafety is extremely simpler, as well as
passing `&mut` instead of `*mut`, which seems to work (if not, I can
remove this commit).
  • Loading branch information
kkysen authored Jul 27, 2023
2 parents b156be6 + f3908e0 commit 51017e5
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions src/msac.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ extern "C" {
fn dav1d_msac_decode_bool_equi_sse2(s: *mut MsacContext) -> libc::c_uint;
fn dav1d_msac_decode_bool_adapt_sse2(s: *mut MsacContext, cdf: *mut uint16_t) -> libc::c_uint;
fn dav1d_msac_decode_symbol_adapt16_avx2(
s: *mut MsacContext,
s: &mut MsacContext,
cdf: *mut uint16_t,
n_symbols: size_t,
_cdf_len: usize,
) -> libc::c_uint;
fn dav1d_msac_decode_symbol_adapt16_sse2(
s: *mut MsacContext,
s: &mut MsacContext,
cdf: *mut uint16_t,
n_symbols: size_t,
_cdf_len: usize,
) -> libc::c_uint;
fn dav1d_msac_decode_symbol_adapt8_sse2(
s: *mut MsacContext,
Expand Down Expand Up @@ -60,14 +62,6 @@ extern "C" {

pub type ec_win = size_t;

#[derive(Copy, Clone)]
#[repr(u8)]
pub enum FnSymbolAdapt16 {
Rust,
Sse2,
Avx2,
}

#[derive(Copy, Clone)]
#[repr(C)]
pub struct MsacContext {
Expand All @@ -78,7 +72,8 @@ pub struct MsacContext {
pub cnt: libc::c_int,
allow_update_cdf: libc::c_int,
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
pub symbol_adapt16: FnSymbolAdapt16,
symbol_adapt16:
unsafe extern "C" fn(&mut MsacContext, *mut uint16_t, size_t, usize) -> libc::c_uint,
}

impl MsacContext {
Expand Down Expand Up @@ -141,10 +136,10 @@ fn msac_init_x86(s: &mut MsacContext) {

let flags = dav1d_get_cpu_flags();
if flags & DAV1D_X86_CPU_FLAG_SSE2 != 0 {
s.symbol_adapt16 = FnSymbolAdapt16::Sse2;
s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
}
if flags & DAV1D_X86_CPU_FLAG_AVX2 != 0 {
s.symbol_adapt16 = FnSymbolAdapt16::Avx2;
s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2;
}
}

Expand Down Expand Up @@ -272,6 +267,22 @@ fn dav1d_msac_decode_symbol_adapt_rust(
val
}

#[deny(unsafe_op_in_unsafe_fn)]
unsafe extern "C" fn dav1d_msac_decode_symbol_adapt_c(
s: &mut MsacContext,
cdf: *mut u16,
n_symbols: size_t,
cdf_len: usize,
) -> libc::c_uint {
// # Safety
//
// This is only called from [`dav1d_msac_decode_symbol_adapt16`],
// where it comes from `cdf.len()`.
let cdf = unsafe { std::slice::from_raw_parts_mut(cdf, cdf_len) };

dav1d_msac_decode_symbol_adapt_rust(s, cdf, n_symbols)
}

fn dav1d_msac_decode_bool_adapt_rust(s: &mut MsacContext, cdf: &mut [u16; 2]) -> bool {
let bit = dav1d_msac_decode_bool(s, cdf[0] as libc::c_uint);
if s.allow_update_cdf() {
Expand Down Expand Up @@ -323,7 +334,7 @@ pub unsafe fn dav1d_msac_init(

#[cfg(all(feature = "asm", target_arch = "x86_64"))]
{
s.symbol_adapt16 = FnSymbolAdapt16::Rust;
s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
msac_init_x86(s);
}
}
Expand Down Expand Up @@ -379,12 +390,9 @@ pub fn dav1d_msac_decode_symbol_adapt16(
) -> libc::c_uint {
cfg_if! {
if #[cfg(all(feature = "asm", target_arch = "x86_64"))] {
match s.symbol_adapt16 {
FnSymbolAdapt16::Rust => dav1d_msac_decode_symbol_adapt_rust(s, cdf, n_symbols),
// Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`].
FnSymbolAdapt16::Sse2 => unsafe { dav1d_msac_decode_symbol_adapt16_sse2(s, cdf.as_mut_ptr(), n_symbols) },
// Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`].
FnSymbolAdapt16::Avx2 => unsafe { dav1d_msac_decode_symbol_adapt16_avx2(s, cdf.as_mut_ptr(), n_symbols) } ,
// Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`].
unsafe {
(s.symbol_adapt16)(s, cdf.as_mut_ptr(), n_symbols, cdf.len())
}
} else if #[cfg(all(feature = "asm", target_arch = "aarch64"))] {
// Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`].
Expand Down

0 comments on commit 51017e5

Please sign in to comment.