diff --git a/src/disjoint_mut.rs b/src/disjoint_mut.rs index f15a15652..26b20cb29 100644 --- a/src/disjoint_mut.rs +++ b/src/disjoint_mut.rs @@ -13,6 +13,8 @@ use std::fmt::Debug; use std::fmt::Display; use std::fmt::Formatter; use std::marker::PhantomData; +use std::mem; +use std::mem::ManuallyDrop; use std::ops::Bound; use std::ops::Deref; use std::ops::DerefMut; @@ -27,6 +29,8 @@ use std::ops::RangeToInclusive; use std::ptr; use std::ptr::addr_of_mut; use std::sync::Arc; +use zerocopy::AsBytes; +use zerocopy::FromBytes; /// Wraps an indexable collection to allow unchecked concurrent mutable borrows. /// @@ -70,6 +74,23 @@ pub struct DisjointMutGuard<'a, T: ?Sized + AsMutPtr, V: ?Sized> { bounds: Bounds, } +impl<'a, T: AsMutPtr> DisjointMutGuard<'a, T, [u8]> { + fn cast(self) -> DisjointMutGuard<'a, T, V> { + // We don't want to drop the old guard, because we aren't changing or + // removing the bounds from parent here. + let mut old_guard = ManuallyDrop::new(self); + let bytes = mem::take(&mut old_guard.slice); + DisjointMutGuard { + slice: V::mut_from(bytes).unwrap(), + phantom: old_guard.phantom, + #[cfg(debug_assertions)] + parent: old_guard.parent, + #[cfg(debug_assertions)] + bounds: old_guard.bounds.clone(), + } + } +} + impl<'a, T: ?Sized + AsMutPtr, V: ?Sized> Deref for DisjointMutGuard<'a, T, V> { type Target = V; @@ -96,6 +117,23 @@ pub struct DisjointImmutGuard<'a, T: ?Sized + AsMutPtr, V: ?Sized> { bounds: Bounds, } +impl<'a, T: AsMutPtr> DisjointImmutGuard<'a, T, [u8]> { + fn cast(self) -> DisjointImmutGuard<'a, T, V> { + // We don't want to drop the old guard, because we aren't changing or + // removing the bounds from parent here. + let mut old_guard = ManuallyDrop::new(self); + let bytes = mem::take(&mut old_guard.slice); + DisjointImmutGuard { + slice: V::ref_from(bytes).unwrap(), + phantom: old_guard.phantom, + #[cfg(debug_assertions)] + parent: old_guard.parent, + #[cfg(debug_assertions)] + bounds: old_guard.bounds.clone(), + } + } +} + impl<'a, T: ?Sized + AsMutPtr, V: ?Sized> Deref for DisjointImmutGuard<'a, T, V> { type Target = V; @@ -258,6 +296,70 @@ impl DisjointMut { } } +impl> DisjointMut { + /// Mutably borrow a slice or element of a convertible type. + /// + /// This method accesses an element or slice of elements of a type that + /// implements `zerocopy::FromBytes` from a buffer of `u8`. + /// + /// This mutable borrow may be unchecked and callers must ensure that no + /// other borrows from this collection overlap with the mutably borrowed + /// region for the lifetime of that mutable borrow. + /// + /// # Safety + /// + /// Caller must ensure that no elements of the resulting borrowed slice or + /// element are concurrently borrowed (immutably or mutably) at all during + /// the lifetime of the returned mutable borrow. We require that the + /// referenced data must be plain data and not contain any pointers or + /// references to avoid other potential memory safety issues due to racy + /// access. + #[cfg_attr(debug_assertions, track_caller)] + pub unsafe fn index_mut_as<'a, I, V>(&'a self, index: I) -> DisjointMutGuard<'a, T, V> + where + I: Into + Clone, + V: AsBytes + FromBytes, + { + let bounds = index.into().multiply(mem::size_of::()); + // SAFETY: Same safety requirements as this method. + let byte_guard = unsafe { self.index_mut(bounds.range) }; + byte_guard.cast() + } + + /// Immutably borrow a slice or element of a convertible type. + /// + /// This method accesses an element or slice of elements of a type that + /// implements `zerocopy::FromBytes` from a buffer of `u8`. + /// + /// This immutable borrow may be unchecked and callers must ensure that no + /// other mutable borrows from this collection overlap with the returned + /// immutably borrowed region for the lifetime of that borrow. + /// + /// # Safety + /// + /// This method is not marked as unsafe but its safety requires correct + /// usage alongside [`index_mut`]. It cannot result in a race + /// condition without creating an overlapping mutable range via + /// [`index_mut`]. As an internal helper, we ensure that all calls are + /// safe and document this when mutating rather than marking each immutable + /// reference with virtually identical safety justifications. + /// + /// Caller must take care that no elements of the resulting borrowed slice + /// or element are concurrently mutably borrowed at all by [`index_mut`] + /// during the lifetime of the returned borrow. + /// + /// [`index_mut`]: DisjointMut::index_mut + #[cfg_attr(debug_assertions, track_caller)] + pub fn index_as<'a, I, V>(&'a self, index: I) -> DisjointImmutGuard<'a, T, V> + where + I: Into + Clone, + V: FromBytes, + { + let bounds = index.into().multiply(mem::size_of::()); + self.index(bounds.range).cast() + } +} + /// This trait is a stable implementation of [`std::slice::SliceIndex`] to allow /// for indexing into mutable slice raw pointers. pub trait DisjointMutIndex { @@ -309,6 +411,12 @@ impl Bounds { let b = &other.range; a.start < b.end && b.start < a.end } + + fn multiply(self, multiple: usize) -> Bounds { + let start = self.range.start * multiple; + let end = self.range.end * multiple; + Self { range: start..end } + } } impl From for Bounds { diff --git a/src/internal.rs b/src/internal.rs index e2efad0d1..7a35f82ba 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -29,8 +29,10 @@ use crate::src::cdf::CdfContext; use crate::src::cdf::CdfThreadContext; use crate::src::cpu::rav1d_get_cpu_flags; use crate::src::cpu::CpuFlags; +use crate::src::disjoint_mut::DisjointImmutGuard; use crate::src::disjoint_mut::DisjointMut; use crate::src::disjoint_mut::DisjointMutArcSlice; +use crate::src::disjoint_mut::DisjointMutGuard; use crate::src::env::BlockContext; use crate::src::error::Rav1dResult; use crate::src::filmgrain::Rav1dFilmGrainDSPContext; @@ -82,8 +84,6 @@ use crate::src::refmvs::refmvs_temporal_block; use crate::src::refmvs::refmvs_tile; use crate::src::refmvs::Rav1dRefmvsDSPContext; use crate::src::refmvs::RefMvsFrame; -use crate::src::unstable_extensions::as_chunks; -use crate::src::unstable_extensions::as_chunks_mut; use atomig::Atom; use atomig::Atomic; use libc::ptrdiff_t; @@ -497,9 +497,11 @@ impl CodedBlockInfo { #[derive(Default)] #[repr(C)] pub struct Pal { - data: AlignedVec64, + data: DisjointMut>, } +type PalArray = [[::Pixel; 8]; 3]; + impl Pal { pub fn resize(&mut self, n: usize) { self.data.resize(n * 8 * 3, Default::default()); @@ -509,18 +511,33 @@ impl Pal { self.data.is_empty() } - pub fn as_slice(&self) -> &[[[BD::Pixel; 8]; 3]] { - as_chunks::<3, [BD::Pixel; 8]>( - as_chunks::<8, BD::Pixel>(BD::cast_pixel_slice(&self.data)).0, - ) - .0 + pub fn index<'a: 'b, 'b, BD: BitDepth>( + &'a self, + index: usize, + ) -> DisjointImmutGuard<'b, AlignedVec64, PalArray> { + self.data.index_as(index) } - pub fn as_slice_mut(&mut self) -> &mut [[[BD::Pixel; 8]; 3]] { - as_chunks_mut::<3, [BD::Pixel; 8]>( - as_chunks_mut::<8, BD::Pixel>(BD::cast_pixel_slice_mut(&mut self.data)).0, - ) - .0 + /// Mutably borrow a pal array. + /// + /// This mutable borrow is unchecked and callers must ensure that no other + /// borrows of a pal overlaps with the mutably borrowed region for the + /// lifetime of that mutable borrow. + /// + /// # Safety + /// + /// Caller must ensure that no elements of the resulting borrowed element is + /// concurrently borrowed (immutably or mutably) at all during the lifetime + /// of the returned mutable borrow. + pub unsafe fn index_mut<'a: 'b, 'b, BD: BitDepth>( + &'a self, + index: usize, + ) -> DisjointMutGuard<'b, AlignedVec64, PalArray> { + // SAFETY: The preconditions of our `index_mut` safety imply that the + // indexed region we are mutably borrowing is not concurrently borrowed + // and will not be borrowed during the lifetime of the returned + // reference. + unsafe { self.data.index_mut_as(index) } } } diff --git a/src/lf_apply.rs b/src/lf_apply.rs index f0d8239bf..3963ead96 100644 --- a/src/lf_apply.rs +++ b/src/lf_apply.rs @@ -13,7 +13,6 @@ use libc::ptrdiff_t; use std::cmp; use std::ffi::c_int; use std::ffi::c_uint; -use std::slice; use std::sync::atomic::AtomicU16; use std::sync::atomic::Ordering; @@ -156,6 +155,8 @@ pub(crate) unsafe fn rav1d_copy_lpf( let seq_hdr = &***f.seq_hdr.as_ref().unwrap(); let tt_off = have_tt * sby * ((4 as c_int) << seq_hdr.sb128); + let src_y_stride = BD::pxstride(src_stride[0]); + let src_uv_stride = BD::pxstride(src_stride[1]); let y_stride = BD::pxstride(lr_stride[0]); let uv_stride = BD::pxstride(lr_stride[1]); @@ -185,7 +186,7 @@ pub(crate) unsafe fn rav1d_copy_lpf( dst_offset[0], lr_stride[0], src[0], - (src_offset[0] as isize - offset as isize * BD::pxstride(src_stride[0])) as usize, + (src_offset[0] as isize - offset as isize * src_y_stride) as usize, src_stride[0], 0, seq_hdr.sb128, @@ -203,24 +204,18 @@ pub(crate) unsafe fn rav1d_copy_lpf( ); } if have_tt != 0 && resize != 0 { - let cdef_off_y: ptrdiff_t = (sby * 4) as isize * BD::pxstride(src_stride[0]); - let cdef_plane_y_sz = 4 * f.sbh as isize * y_stride; - let y_span = cdef_plane_y_sz - y_stride; + let cdef_off_y: ptrdiff_t = (sby * 4) as isize * src_y_stride; + let cdef_plane_y_sz = 4 * f.sbh as isize * src_y_stride; + let y_span = cdef_plane_y_sz - src_y_stride; + let cdef_line_start = (f.lf.cdef_lpf_line[0] as isize + cmp::min(y_span, 0)) as usize; backup_lpf::( c, - slice::from_raw_parts_mut( - cdef_line_buf - .as_mut_ptr() - .add(f.lf.cdef_lpf_line[0]) - .offset(cmp::min(y_span, 0)), - cdef_plane_y_sz.unsigned_abs(), - ), + &mut cdef_line_buf + [cdef_line_start..cdef_line_start + cdef_plane_y_sz.unsigned_abs()], (cdef_off_y - cmp::min(y_span, 0)) as usize, src_stride[0], src[0], - (src_offset[0] as isize - - offset as isize * BD::pxstride(src_stride[0] as usize) as isize) - as usize, + (src_offset[0] as isize - offset as isize * src_y_stride as isize) as usize, src_stride[0], 0, seq_hdr.sb128, @@ -248,7 +243,7 @@ pub(crate) unsafe fn rav1d_copy_lpf( let row_h_0 = cmp::min((sby + 1) << 6 - ss_ver + seq_hdr.sb128, h_0 - 1); let offset_uv = offset >> ss_ver; let y_stripe_0 = (sby << 6 - ss_ver + seq_hdr.sb128) - offset_uv; - let cdef_off_uv: ptrdiff_t = sby as isize * 4 * BD::pxstride(src_stride[1]); + let cdef_off_uv: ptrdiff_t = sby as isize * 4 * src_uv_stride; if seq_hdr.cdef != 0 || restore_planes & LR_RESTORE_U as c_int != 0 { if restore_planes & LR_RESTORE_U as c_int != 0 || resize == 0 { backup_lpf::( @@ -257,8 +252,7 @@ pub(crate) unsafe fn rav1d_copy_lpf( dst_offset[1], lr_stride[1], src[1], - (src_offset[1] as isize - offset_uv as isize * BD::pxstride(src_stride[1])) - as usize, + (src_offset[1] as isize - offset_uv as isize * src_uv_stride) as usize, src_stride[1], ss_ver, seq_hdr.sb128, @@ -276,22 +270,18 @@ pub(crate) unsafe fn rav1d_copy_lpf( ); } if have_tt != 0 && resize != 0 { - let cdef_plane_uv_sz = 4 * f.sbh as isize * uv_stride; - let uv_span = cdef_plane_uv_sz - uv_stride; + let cdef_plane_uv_sz = 4 * f.sbh as isize * src_uv_stride; + let uv_span = cdef_plane_uv_sz - src_uv_stride; + let cdef_line_start = + (f.lf.cdef_lpf_line[1] as isize + cmp::min(uv_span, 0)) as usize; backup_lpf::( c, - slice::from_raw_parts_mut( - cdef_line_buf - .as_mut_ptr() - .add(f.lf.cdef_lpf_line[1]) - .offset(cmp::min(uv_span, 0)), - cdef_plane_uv_sz.unsigned_abs(), - ), + &mut cdef_line_buf + [cdef_line_start..cdef_line_start + cdef_plane_uv_sz.unsigned_abs()], (cdef_off_uv - cmp::min(uv_span, 0)) as usize, src_stride[1], src[1], - (src_offset[1] as isize - offset_uv as isize * BD::pxstride(src_stride[1])) - as usize, + (src_offset[1] as isize - offset_uv as isize * src_uv_stride) as usize, src_stride[1], ss_ver, seq_hdr.sb128, @@ -317,8 +307,7 @@ pub(crate) unsafe fn rav1d_copy_lpf( dst_offset[2], lr_stride[1], src[2], - (src_offset[1] as isize - offset_uv as isize * BD::pxstride(src_stride[1])) - as usize, + (src_offset[1] as isize - offset_uv as isize * src_uv_stride) as usize, src_stride[1], ss_ver, seq_hdr.sb128, @@ -336,22 +325,18 @@ pub(crate) unsafe fn rav1d_copy_lpf( ); } if have_tt != 0 && resize != 0 { - let cdef_plane_uv_sz = 4 * f.sbh as isize * uv_stride; - let uv_span = cdef_plane_uv_sz - uv_stride; + let cdef_plane_uv_sz = 4 * f.sbh as isize * src_uv_stride; + let uv_span = cdef_plane_uv_sz - src_uv_stride; + let cdef_line_start = + (f.lf.cdef_lpf_line[2] as isize + cmp::min(uv_span, 0)) as usize; backup_lpf::( c, - slice::from_raw_parts_mut( - cdef_line_buf - .as_mut_ptr() - .add(f.lf.cdef_lpf_line[2]) - .offset(cmp::min(uv_span, 0)), - cdef_plane_uv_sz.unsigned_abs(), - ), + &mut cdef_line_buf + [cdef_line_start..cdef_line_start + cdef_plane_uv_sz.unsigned_abs()], (cdef_off_uv - cmp::min(uv_span, 0)) as usize, src_stride[1], src[2], - (src_offset[1] as isize - offset_uv as isize * BD::pxstride(src_stride[1])) - as usize, + (src_offset[1] as isize - offset_uv as isize * src_uv_stride) as usize, src_stride[1], ss_ver, seq_hdr.sb128, diff --git a/src/recon.rs b/src/recon.rs index 848751c4c..6dcabe43f 100644 --- a/src/recon.rs +++ b/src/recon.rs @@ -2455,12 +2455,14 @@ pub(crate) unsafe fn rav1d_recon_b_intra( } else { &t.scratch.c2rust_unnamed_0.pal_idx }; + let pal_guard; let pal: *const BD::Pixel = if t.frame_thread.pass != 0 { let index = (((t.b.y as isize >> 1) + (t.b.x as isize & 1)) * (f.b4_stride >> 1) + ((t.b.x >> 1) + (t.b.y & 1)) as isize) as isize; - f.frame_thread.pal.as_slice::()[index as usize][0].as_ptr() + pal_guard = f.frame_thread.pal.index::(index as usize); + pal_guard[0].as_ptr() } else { let interintra_edge_pal = BD::select(&t.scratch.c2rust_unnamed_0.interintra_edge_pal); @@ -2826,6 +2828,7 @@ pub(crate) unsafe fn rav1d_recon_b_intra( let uv_dstoff: ptrdiff_t = 4 * ((t.b.x >> ss_hor) as isize + (t.b.y >> ss_ver) as isize * BD::pxstride(f.cur.stride[1])); + let pal_guard; let (pal, pal_idx) = if t.frame_thread.pass != 0 { let p = t.frame_thread.pass & 1; let index = (((t.b.y >> 1) + (t.b.x & 1)) as isize * (f.b4_stride >> 1) @@ -2835,10 +2838,8 @@ pub(crate) unsafe fn rav1d_recon_b_intra( let len = (cbw4 * cbh4 * 16) as usize; let pal_idx = &f.frame_thread.pal_idx[*pal_idx_offset..][..len]; *pal_idx_offset += len; - ( - &f.frame_thread.pal.as_slice::()[index as usize], - pal_idx, - ) + pal_guard = f.frame_thread.pal.index::(index as usize); + (&*pal_guard, pal_idx) } else { let interintra_edge_pal = BD::select_mut(&mut t.scratch.c2rust_unnamed_0.interintra_edge_pal); @@ -4401,10 +4402,12 @@ pub(crate) unsafe fn rav1d_copy_pal_block_y( bw4: usize, bh4: usize, ) { + let pal_guard; let pal = if t.frame_thread.pass != 0 { let index = ((t.b.y >> 1) + (t.b.x & 1)) as isize * (f.b4_stride >> 1) + ((t.b.x >> 1) + (t.b.y & 1)) as isize; - &f.frame_thread.pal.as_slice::()[index as usize][0] + pal_guard = f.frame_thread.pal.index::(index as usize); + &pal_guard[0] } else { let interintra_edge_pal = BD::select(&t.scratch.c2rust_unnamed_0.interintra_edge_pal); &interintra_edge_pal.pal[0] @@ -4425,10 +4428,12 @@ pub(crate) unsafe fn rav1d_copy_pal_block_uv( bw4: usize, bh4: usize, ) { + let pal_guard; let pal = if t.frame_thread.pass != 0 { let index = ((t.b.y >> 1) + (t.b.x & 1)) as isize * (f.b4_stride >> 1) + ((t.b.x >> 1) + (t.b.y & 1)) as isize; - &f.frame_thread.pal.as_slice_mut::()[index as usize] + pal_guard = f.frame_thread.pal.index::(index as usize); + &pal_guard } else { let interintra_edge_pal = BD::select(&t.scratch.c2rust_unnamed_0.interintra_edge_pal); &interintra_edge_pal.pal @@ -4553,11 +4558,12 @@ pub(crate) unsafe fn rav1d_read_pal_plane( let used_cache = &used_cache[..i]; // parse new entries + let mut pal_guard; let pal = if t.frame_thread.pass != 0 { - &mut f.frame_thread.pal.as_slice_mut::()[(((t.b.y >> 1) + (t.b.x & 1)) as isize - * (f.b4_stride >> 1) - + ((t.b.x >> 1) + (t.b.y & 1)) as isize) - as usize][pli] + let pal_start = (((t.b.y >> 1) + (t.b.x & 1)) as isize * (f.b4_stride >> 1) + + ((t.b.x >> 1) + (t.b.y & 1)) as isize) as usize; + pal_guard = f.frame_thread.pal.index_mut::(pal_start); + &mut pal_guard[pli] } else { let interintra_edge_pal = BD::select_mut(&mut t.scratch.c2rust_unnamed_0.interintra_edge_pal); @@ -4647,11 +4653,13 @@ pub(crate) unsafe fn rav1d_read_pal_uv( // V pal coding let ts = &mut *f.ts.offset(t.ts as isize); + let mut pal_guard; let pal = if t.frame_thread.pass != 0 { - &mut f.frame_thread.pal.as_slice_mut::()[(((t.b.y >> 1) + (t.b.x & 1)) as isize - * (f.b4_stride >> 1) - + ((t.b.x >> 1) + (t.b.y & 1)) as isize) - as usize][2] + pal_guard = f.frame_thread.pal.index_mut::( + (((t.b.y >> 1) + (t.b.x & 1)) as isize * (f.b4_stride >> 1) + + ((t.b.x >> 1) + (t.b.y & 1)) as isize) as usize, + ); + &mut pal_guard[2] } else { let interintra_edge_pal = BD::select_mut(&mut t.scratch.c2rust_unnamed_0.interintra_edge_pal); diff --git a/src/unstable_extensions.rs b/src/unstable_extensions.rs index 74c95b9df..d29a74c29 100644 --- a/src/unstable_extensions.rs +++ b/src/unstable_extensions.rs @@ -8,7 +8,6 @@ use std::mem; use std::slice::from_raw_parts; -use std::slice::from_raw_parts_mut; /// From `1.75.0`. pub const fn flatten(this: &[[T; N]]) -> &[T] { @@ -54,29 +53,3 @@ pub const fn as_chunks(this: &[T]) -> (&[[T; N]], &[T]) { let array_slice = unsafe { as_chunks_unchecked(multiple_of_n) }; (array_slice, remainder) } - -#[inline] -#[must_use] -pub unsafe fn as_chunks_unchecked_mut(this: &mut [T]) -> &mut [[T; N]] { - // SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length - let new_len = /* unsafe */ { - assert!(N != 0 && this.len() % N == 0); - this.len() / N - }; - // SAFETY: We cast a slice of `new_len * N` elements into - // a slice of `new_len` many `N` elements chunks. - unsafe { from_raw_parts_mut(this.as_mut_ptr().cast(), new_len) } -} - -#[inline] -#[track_caller] -#[must_use] -pub fn as_chunks_mut(this: &mut [T]) -> (&mut [[T; N]], &mut [T]) { - assert!(N != 0, "chunk size must be non-zero"); - let len = this.len() / N; - let (multiple_of_n, remainder) = this.split_at_mut(len * N); - // SAFETY: We already panicked for zero, and ensured by construction - // that the length of the subslice is a multiple of N. - let array_slice = unsafe { as_chunks_unchecked_mut(multiple_of_n) }; - (array_slice, remainder) -}