From eca5a4addd48b54b7005e040b82137ab2e1a66d6 Mon Sep 17 00:00:00 2001 From: Khyber Sen Date: Tue, 25 Jul 2023 16:00:32 -0700 Subject: [PATCH 1/5] `MsacContext::symbol_adapt16`: Re-virtualize (revert 4bea6a04b09ffae62530d04c46791f0f1f4384f3, 908703ea1bc70a3de24e2c130fe72be6a64cb006). --- src/msac.rs | 49 +++++++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/src/msac.rs b/src/msac.rs index ffd758726..05c845e88 100644 --- a/src/msac.rs +++ b/src/msac.rs @@ -60,14 +60,6 @@ extern "C" { pub type ec_win = size_t; -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum FnSymbolAdapt16 { - Rust, - Sse2, - Avx2, -} - #[derive(Copy, Clone)] #[repr(C)] pub struct MsacContext { @@ -78,7 +70,8 @@ pub struct MsacContext { pub cnt: libc::c_int, allow_update_cdf: libc::c_int, #[cfg(all(feature = "asm", target_arch = "x86_64"))] - pub symbol_adapt16: FnSymbolAdapt16, + pub symbol_adapt16: + Option libc::c_uint>, } impl MsacContext { @@ -141,10 +134,10 @@ fn msac_init_x86(s: &mut MsacContext) { let flags = dav1d_get_cpu_flags(); if flags & DAV1D_X86_CPU_FLAG_SSE2 != 0 { - s.symbol_adapt16 = FnSymbolAdapt16::Sse2; + s.symbol_adapt16 = Some(dav1d_msac_decode_symbol_adapt16_sse2); } if flags & DAV1D_X86_CPU_FLAG_AVX2 != 0 { - s.symbol_adapt16 = FnSymbolAdapt16::Avx2; + s.symbol_adapt16 = Some(dav1d_msac_decode_symbol_adapt16_avx2); } } @@ -272,6 +265,28 @@ fn dav1d_msac_decode_symbol_adapt_rust( val } +#[deny(unsafe_op_in_unsafe_fn)] +unsafe extern "C" fn dav1d_msac_decode_symbol_adapt_c( + s: *mut MsacContext, + cdf: *mut u16, + n_symbols: size_t, +) -> libc::c_uint { + // # Safety + // + // This is only called from [`dav1d_msac_decode_symbol_adapt16`], + // where it comes from a valid `&mut`. + let s = unsafe { &mut *s }; + + // # Safety + // + // This is only called from [`dav1d_msac_decode_symbol_adapt16`], + // where there is an `assert!(n_symbols < cdf.len());`. + // Thus, `n_symbols + 1` is a valid length for the slice `cdf` came from. + let cdf = unsafe { std::slice::from_raw_parts_mut(cdf, n_symbols + 1) }; + + dav1d_msac_decode_symbol_adapt_rust(s, cdf, n_symbols) +} + fn dav1d_msac_decode_bool_adapt_rust(s: &mut MsacContext, cdf: &mut [u16; 2]) -> bool { let bit = dav1d_msac_decode_bool(s, cdf[0] as libc::c_uint); if s.allow_update_cdf() { @@ -323,7 +338,7 @@ pub unsafe fn dav1d_msac_init( #[cfg(all(feature = "asm", target_arch = "x86_64"))] { - s.symbol_adapt16 = FnSymbolAdapt16::Rust; + s.symbol_adapt16 = Some(dav1d_msac_decode_symbol_adapt_c); msac_init_x86(s); } } @@ -379,12 +394,10 @@ pub fn dav1d_msac_decode_symbol_adapt16( ) -> libc::c_uint { cfg_if! { if #[cfg(all(feature = "asm", target_arch = "x86_64"))] { - match s.symbol_adapt16 { - FnSymbolAdapt16::Rust => dav1d_msac_decode_symbol_adapt_rust(s, cdf, n_symbols), - // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. - FnSymbolAdapt16::Sse2 => unsafe { dav1d_msac_decode_symbol_adapt16_sse2(s, cdf.as_mut_ptr(), n_symbols) }, - // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. - FnSymbolAdapt16::Avx2 => unsafe { dav1d_msac_decode_symbol_adapt16_avx2(s, cdf.as_mut_ptr(), n_symbols) } , + assert!(n_symbols < cdf.len()); + // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. + unsafe { + (s.symbol_adapt16).expect("non-null function pointer")(s, cdf.as_mut_ptr(), n_symbols) } } else if #[cfg(all(feature = "asm", target_arch = "aarch64"))] { // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. From eb04185936117f26b2ebb5e04d0cb54328222661 Mon Sep 17 00:00:00 2001 From: Khyber Sen Date: Tue, 25 Jul 2023 16:01:08 -0700 Subject: [PATCH 2/5] `MsacContext::symbol_adapt16`: Remove the `fn` ptr `Option` that adds no safety. --- src/msac.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/msac.rs b/src/msac.rs index 05c845e88..960edbce7 100644 --- a/src/msac.rs +++ b/src/msac.rs @@ -71,7 +71,7 @@ pub struct MsacContext { allow_update_cdf: libc::c_int, #[cfg(all(feature = "asm", target_arch = "x86_64"))] pub symbol_adapt16: - Option libc::c_uint>, + unsafe extern "C" fn(*mut MsacContext, *mut uint16_t, size_t) -> libc::c_uint, } impl MsacContext { @@ -134,10 +134,10 @@ fn msac_init_x86(s: &mut MsacContext) { let flags = dav1d_get_cpu_flags(); if flags & DAV1D_X86_CPU_FLAG_SSE2 != 0 { - s.symbol_adapt16 = Some(dav1d_msac_decode_symbol_adapt16_sse2); + s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2; } if flags & DAV1D_X86_CPU_FLAG_AVX2 != 0 { - s.symbol_adapt16 = Some(dav1d_msac_decode_symbol_adapt16_avx2); + s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2; } } @@ -338,7 +338,7 @@ pub unsafe fn dav1d_msac_init( #[cfg(all(feature = "asm", target_arch = "x86_64"))] { - s.symbol_adapt16 = Some(dav1d_msac_decode_symbol_adapt_c); + s.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c; msac_init_x86(s); } } @@ -397,7 +397,7 @@ pub fn dav1d_msac_decode_symbol_adapt16( assert!(n_symbols < cdf.len()); // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. unsafe { - (s.symbol_adapt16).expect("non-null function pointer")(s, cdf.as_mut_ptr(), n_symbols) + (s.symbol_adapt16)(s, cdf.as_mut_ptr(), n_symbols) } } else if #[cfg(all(feature = "asm", target_arch = "aarch64"))] { // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. From d10a2084dbf7ed51f0ebb77276c9fd4e7e9a5b30 Mon Sep 17 00:00:00 2001 From: Khyber Sen Date: Tue, 25 Jul 2023 16:05:21 -0700 Subject: [PATCH 3/5] `MsacContext::symbol_adapt16`: Pass `cdf.len()` as an extra argument, removing the need for an `assert!`. --- src/msac.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/msac.rs b/src/msac.rs index 960edbce7..7501c6510 100644 --- a/src/msac.rs +++ b/src/msac.rs @@ -17,11 +17,13 @@ extern "C" { s: *mut MsacContext, cdf: *mut uint16_t, n_symbols: size_t, + _cdf_len: usize, ) -> libc::c_uint; fn dav1d_msac_decode_symbol_adapt16_sse2( s: *mut MsacContext, cdf: *mut uint16_t, n_symbols: size_t, + _cdf_len: usize, ) -> libc::c_uint; fn dav1d_msac_decode_symbol_adapt8_sse2( s: *mut MsacContext, @@ -71,7 +73,7 @@ pub struct MsacContext { allow_update_cdf: libc::c_int, #[cfg(all(feature = "asm", target_arch = "x86_64"))] pub symbol_adapt16: - unsafe extern "C" fn(*mut MsacContext, *mut uint16_t, size_t) -> libc::c_uint, + unsafe extern "C" fn(*mut MsacContext, *mut uint16_t, size_t, usize) -> libc::c_uint, } impl MsacContext { @@ -270,6 +272,7 @@ unsafe extern "C" fn dav1d_msac_decode_symbol_adapt_c( s: *mut MsacContext, cdf: *mut u16, n_symbols: size_t, + cdf_len: usize, ) -> libc::c_uint { // # Safety // @@ -280,9 +283,8 @@ unsafe extern "C" fn dav1d_msac_decode_symbol_adapt_c( // # Safety // // This is only called from [`dav1d_msac_decode_symbol_adapt16`], - // where there is an `assert!(n_symbols < cdf.len());`. - // Thus, `n_symbols + 1` is a valid length for the slice `cdf` came from. - let cdf = unsafe { std::slice::from_raw_parts_mut(cdf, n_symbols + 1) }; + // where it comes from `cdf.len()`. + let cdf = unsafe { std::slice::from_raw_parts_mut(cdf, cdf_len) }; dav1d_msac_decode_symbol_adapt_rust(s, cdf, n_symbols) } @@ -394,10 +396,9 @@ pub fn dav1d_msac_decode_symbol_adapt16( ) -> libc::c_uint { cfg_if! { if #[cfg(all(feature = "asm", target_arch = "x86_64"))] { - assert!(n_symbols < cdf.len()); // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. unsafe { - (s.symbol_adapt16)(s, cdf.as_mut_ptr(), n_symbols) + (s.symbol_adapt16)(s, cdf.as_mut_ptr(), n_symbols, cdf.len()) } } else if #[cfg(all(feature = "asm", target_arch = "aarch64"))] { // Safety: `checkasm` has verified that it is equivalent to [`dav1d_msac_decode_symbol_adapt_rust`]. From e82488c9c179b852da98ee3a93946978c53778a7 Mon Sep 17 00:00:00 2001 From: Khyber Sen Date: Tue, 25 Jul 2023 16:20:47 -0700 Subject: [PATCH 4/5] `MsacContext::symbol_adapt16`: Make the `*mut` ptr a `&mut` ref, as it seems to be allowed. --- src/msac.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/msac.rs b/src/msac.rs index 7501c6510..4675bbf72 100644 --- a/src/msac.rs +++ b/src/msac.rs @@ -14,13 +14,13 @@ extern "C" { fn dav1d_msac_decode_bool_equi_sse2(s: *mut MsacContext) -> libc::c_uint; fn dav1d_msac_decode_bool_adapt_sse2(s: *mut MsacContext, cdf: *mut uint16_t) -> libc::c_uint; fn dav1d_msac_decode_symbol_adapt16_avx2( - s: *mut MsacContext, + s: &mut MsacContext, cdf: *mut uint16_t, n_symbols: size_t, _cdf_len: usize, ) -> libc::c_uint; fn dav1d_msac_decode_symbol_adapt16_sse2( - s: *mut MsacContext, + s: &mut MsacContext, cdf: *mut uint16_t, n_symbols: size_t, _cdf_len: usize, @@ -73,7 +73,7 @@ pub struct MsacContext { allow_update_cdf: libc::c_int, #[cfg(all(feature = "asm", target_arch = "x86_64"))] pub symbol_adapt16: - unsafe extern "C" fn(*mut MsacContext, *mut uint16_t, size_t, usize) -> libc::c_uint, + unsafe extern "C" fn(&mut MsacContext, *mut uint16_t, size_t, usize) -> libc::c_uint, } impl MsacContext { @@ -269,17 +269,11 @@ fn dav1d_msac_decode_symbol_adapt_rust( #[deny(unsafe_op_in_unsafe_fn)] unsafe extern "C" fn dav1d_msac_decode_symbol_adapt_c( - s: *mut MsacContext, + s: &mut MsacContext, cdf: *mut u16, n_symbols: size_t, cdf_len: usize, ) -> libc::c_uint { - // # Safety - // - // This is only called from [`dav1d_msac_decode_symbol_adapt16`], - // where it comes from a valid `&mut`. - let s = unsafe { &mut *s }; - // # Safety // // This is only called from [`dav1d_msac_decode_symbol_adapt16`], From f3908e0076a336ba5915503310900ef930cc89f8 Mon Sep 17 00:00:00 2001 From: Khyber Sen Date: Wed, 26 Jul 2023 00:39:08 -0700 Subject: [PATCH 5/5] `MsacContext::symbol_adapt16`: Make private so safety invariants are guaranteed. --- src/msac.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/msac.rs b/src/msac.rs index 4675bbf72..88f6f1a65 100644 --- a/src/msac.rs +++ b/src/msac.rs @@ -72,7 +72,7 @@ pub struct MsacContext { pub cnt: libc::c_int, allow_update_cdf: libc::c_int, #[cfg(all(feature = "asm", target_arch = "x86_64"))] - pub symbol_adapt16: + symbol_adapt16: unsafe extern "C" fn(&mut MsacContext, *mut uint16_t, size_t, usize) -> libc::c_uint, }