Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rav1dFrameContext_frame_thread::pal: Make into an AlignedVec #732

Merged
merged 4 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions src/align.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@
//! make them easier to use in common cases, e.g. [`From`] and
//! [`Index`]/[`IndexMut`] (since it's usually array fields that require
//! specific aligment for use with SIMD instructions).

use std::marker::PhantomData;
use std::mem;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ops::Index;
use std::ops::IndexMut;
use std::slice;

/// [`Default`] isn't `impl`emented for all arrays `[T; N]`
/// because they were implemented before `const` generics
Expand Down Expand Up @@ -38,6 +45,7 @@ impl_ArrayDefault!(u8);
impl_ArrayDefault!(i8);
impl_ArrayDefault!(i16);
impl_ArrayDefault!(i32);
impl_ArrayDefault!(u16);

macro_rules! def_align {
($align:literal, $name:ident) => {
Expand Down Expand Up @@ -86,3 +94,99 @@ def_align!(8, Align8);
def_align!(16, Align16);
def_align!(32, Align32);
def_align!(64, Align64);

/// A [`Vec`] that uses a 64-byte aligned allocation.
///
/// Only works with [`Copy`] types so that we don't have to handle drop logic.
pub struct AlignedVec64<T: Copy> {
inner: Vec<MaybeUninit<Align64<[u8; 64]>>>,
randomPoison marked this conversation as resolved.
Show resolved Hide resolved

/// The number of `T`s in [`Self::inner`] currently initialized.
len: usize,
randomPoison marked this conversation as resolved.
Show resolved Hide resolved
randomPoison marked this conversation as resolved.
Show resolved Hide resolved
_phantom: PhantomData<T>,
}

impl<T: Copy> AlignedVec64<T> {
pub const fn new() -> Self {
Self {
inner: Vec::new(),
len: 0,
_phantom: PhantomData,
}
}

/// Returns the number of elements in the vector.
pub fn len(&self) -> usize {
self.len
}

pub fn as_ptr(&self) -> *const T {
self.inner.as_ptr().cast()
}

pub fn as_mut_ptr(&mut self) -> *mut T {
self.inner.as_mut_ptr().cast()
}

/// Extracts a slice containing the entire vector.
kkysen marked this conversation as resolved.
Show resolved Hide resolved
pub fn as_slice(&self) -> &[T] {
// Safety: The first `len` elements have been initialized to `T`s in
// `Self::resize_with`.
unsafe { slice::from_raw_parts(self.as_ptr(), self.len) }
randomPoison marked this conversation as resolved.
Show resolved Hide resolved
}

/// Extracts a mutable slice of the entire vector.
pub fn as_mut_slice(&mut self) -> &mut [T] {
// Safety: The first `len` elements have been initialized to `T`s in
// `Self::resize_with`.
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) }
}

pub fn resize(&mut self, new_len: usize, value: T) {
let old_len = self.len();

// Resize the underlying vector to have enough chunks for the new length.
//
// NOTE: We don't need to `drop` any elements if the `Vec` is truncated since
// `T: Copy`.
let new_bytes = mem::size_of::<T>() * new_len;
let new_chunks = if (new_bytes % 64) == 0 {
new_bytes / 64
} else {
(new_bytes / 64) + 1
};
kkysen marked this conversation as resolved.
Show resolved Hide resolved
self.inner.resize_with(new_chunks, MaybeUninit::uninit);
randomPoison marked this conversation as resolved.
Show resolved Hide resolved

// If we grew the vector, initialize the new elements past `len`.
for offset in old_len..new_len {
// SAFETY: We've allocated enough space to write up to `new_len` elements into
// the buffer.
unsafe {
self.as_mut_ptr().add(offset).write(value);
}
}

self.len = new_len;
}
}

impl<T: Copy> Deref for AlignedVec64<T> {
type Target = [T];

fn deref(&self) -> &Self::Target {
self.as_slice()
}
}

impl<T: Copy> DerefMut for AlignedVec64<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_slice()
}
}

// NOTE: Custom impl so that we don't require `T: Default`.
impl<T: Copy> Default for AlignedVec64<T> {
fn default() -> Self {
Self::new()
}
}
39 changes: 12 additions & 27 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -788,10 +788,8 @@ unsafe fn read_pal_plane(

// parse new entries
let pal = if t.frame_thread.pass != 0 {
&mut (*(f.frame_thread.pal).offset(
((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize,
))[pli]
&mut f.frame_thread.pal[(((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize) as usize][pli]
} else {
&mut t.scratch.c2rust_unnamed_0.pal[pli]
};
Expand Down Expand Up @@ -872,10 +870,8 @@ unsafe fn read_pal_uv(
let ts = &mut *t.ts;

let pal = if t.frame_thread.pass != 0 {
&mut (*(f.frame_thread.pal).offset(
((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize,
))[2]
&mut f.frame_thread.pal[(((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize) as usize][2]
} else {
&mut t.scratch.c2rust_unnamed_0.pal[2]
};
Expand Down Expand Up @@ -2242,7 +2238,7 @@ unsafe fn decode_b_inner(
let pal = if t.frame_thread.pass != 0 {
let index = ((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize;
&(*f.frame_thread.pal.offset(index))[0]
&f.frame_thread.pal[index as usize][0]
} else {
&t.scratch.c2rust_unnamed_0.pal[0]
};
Expand All @@ -2266,7 +2262,7 @@ unsafe fn decode_b_inner(
let pal = if t.frame_thread.pass != 0 {
let index = ((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize;
&*f.frame_thread.pal.offset(index)
&f.frame_thread.pal[index as usize]
} else {
&t.scratch.c2rust_unnamed_0.pal
};
Expand Down Expand Up @@ -4387,20 +4383,10 @@ pub(crate) unsafe fn rav1d_decode_frame_init(
}

if frame_hdr.allow_screen_content_tools != 0 {
if num_sb128 != f.frame_thread.pal_sz {
rav1d_freep_aligned(
&mut f.frame_thread.pal as *mut *mut [[u16; 8]; 3] as *mut c_void,
);
f.frame_thread.pal = rav1d_alloc_aligned(
::core::mem::size_of::<[[u16; 8]; 3]>() * num_sb128 as usize * 16 * 16,
64,
) as *mut [[u16; 8]; 3];
if f.frame_thread.pal.is_null() {
f.frame_thread.pal_sz = 0;
return Err(ENOMEM);
}
f.frame_thread.pal_sz = num_sb128;
}
// TODO: Fallible allocation
f.frame_thread
.pal
.resize(num_sb128 as usize * 16 * 16, Default::default());

let pal_idx_sz = num_sb128 * size_mul[1] as c_int;
if pal_idx_sz != f.frame_thread.pal_idx_sz {
Expand All @@ -4415,11 +4401,10 @@ pub(crate) unsafe fn rav1d_decode_frame_init(
}
f.frame_thread.pal_idx_sz = pal_idx_sz;
}
} else if !f.frame_thread.pal.is_null() {
rav1d_freep_aligned(&mut f.frame_thread.pal as *mut *mut [[u16; 8]; 3] as *mut c_void);
} else if !f.frame_thread.pal.is_empty() {
let _ = mem::take(&mut f.frame_thread.pal);
rav1d_freep_aligned(&mut f.frame_thread.pal_idx as *mut *mut u8 as *mut c_void);
rinon marked this conversation as resolved.
Show resolved Hide resolved
f.frame_thread.pal_idx_sz = 0;
f.frame_thread.pal_sz = f.frame_thread.pal_idx_sz;
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,10 @@ pub struct Rav1dFrameContext_frame_thread {
pub b: Vec<Av1Block>,
pub cbi: Vec<CodedBlockInfo>,
// indexed using (t->by >> 1) * (f->b4_stride >> 1) + (t->bx >> 1)
pub pal: *mut [[u16; 8]; 3], /* [3 plane][8 idx] */
pub pal: AlignedVec64<[[u16; 8]; 3]>, /* [3 plane][8 idx] */
// iterated over inside tile state
pub pal_idx: *mut u8,
pub cf: *mut DynCoef,
pub pal_sz: c_int,
pub pal_idx_sz: c_int,
pub cf_sz: c_int,
// start offsets per tile
Expand Down
4 changes: 1 addition & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -907,9 +907,7 @@ impl Drop for Rav1dContext {
rav1d_freep_aligned(&mut f.frame_thread.pal_idx as *mut *mut u8 as *mut c_void);
rav1d_freep_aligned(&mut f.frame_thread.cf as *mut *mut DynCoef as *mut c_void);
freep(&mut f.frame_thread.tile_start_off as *mut *mut u32 as *mut c_void);
rav1d_freep_aligned(
&mut f.frame_thread.pal as *mut *mut [[u16; 8]; 3] as *mut c_void,
);
let _ = mem::take(&mut f.frame_thread.pal); // TODO: remove when context is owned
let _ = mem::take(&mut f.frame_thread.cbi); // TODO: remove when context is owned
}
if self.tc.len() > 1 {
Expand Down
21 changes: 9 additions & 12 deletions src/recon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2573,13 +2573,12 @@ pub(crate) unsafe fn rav1d_recon_b_intra<BD: BitDepth>(
pal_idx = (t.scratch.c2rust_unnamed_0.pal_idx).as_mut_ptr();
}
let pal: *const u16 = if t.frame_thread.pass != 0 {
((*(f.frame_thread.pal).offset(
(((t.by as isize >> 1) + (t.bx as isize & 1)) * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize) as isize,
))[0])
.as_mut_ptr()
let index = (((t.by as isize >> 1) + (t.bx as isize & 1)) * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize)
as isize;
f.frame_thread.pal[index as usize][0].as_ptr()
} else {
(t.scratch.c2rust_unnamed_0.pal[0]).as_mut_ptr()
(t.scratch.c2rust_unnamed_0.pal[0]).as_ptr()
};
(*f.dsp).ipred.pal_pred.call::<BD>(
dst,
Expand Down Expand Up @@ -2958,12 +2957,10 @@ pub(crate) unsafe fn rav1d_recon_b_intra<BD: BitDepth>(
if ((*ts).frame_thread[p as usize].pal_idx).is_null() {
unreachable!();
}
pal = (*(f.frame_thread.pal).offset(
(((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx as isize >> 1) as isize + (t.by as isize & 1)) as isize)
as isize,
))
.as_mut_ptr() as *const [u16; 8];
let index = (((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx as isize >> 1) as isize + (t.by as isize & 1)) as isize)
as isize;
pal = &f.frame_thread.pal[index as usize][0] as *const [u16; 8];
pal_idx = (*ts).frame_thread[p as usize].pal_idx;
(*ts).frame_thread[p as usize].pal_idx = ((*ts).frame_thread[p as usize]
.pal_idx)
Expand Down
Loading