Skip to content

Commit

Permalink
Add AlignedVec64 to supported aligned allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
randomPoison committed Feb 21, 2024
1 parent 24badfd commit 55dc06d
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 3 deletions.
101 changes: 100 additions & 1 deletion src/align.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
//! make them easier to use in common cases, e.g. [`From`] and
//! [`Index`]/[`IndexMut`] (since it's usually array fields that require
//! specific aligment for use with SIMD instructions).
use std::ops::Index;

use std::marker::PhantomData;
use std::mem::{self, MaybeUninit};
use std::ops::IndexMut;
use std::ops::{Deref, DerefMut, Index};
use std::slice;

/// [`Default`] isn't `impl`emented for all arrays `[T; N]`
/// because they were implemented before `const` generics
Expand Down Expand Up @@ -87,3 +91,98 @@ def_align!(8, Align8);
def_align!(16, Align16);
def_align!(32, Align32);
def_align!(64, Align64);

/// A [`Vec`] that uses a 64-byte aligned allocation.
///
/// Only works with `Copy` types so that we don't have to handle drop logic.
pub struct AlignedVec64<T: Copy> {
inner: Vec<MaybeUninit<Align64<[u8; 64]>>>,
len: usize,
_phantom: PhantomData<T>,
}

impl<T: Copy> AlignedVec64<T> {
pub const fn new() -> Self {
Self {
inner: Vec::new(),
len: 0,
_phantom: PhantomData,
}
}

/// Returns the number of elements in the vector.
pub fn len(&self) -> usize {
self.len
}

pub fn as_ptr(&self) -> *const T {
self.inner.as_ptr().cast()
}

pub fn as_mut_ptr(&mut self) -> *mut T {
self.inner.as_mut_ptr().cast()
}

/// Extracts a slice containing the entire vector.
pub fn as_slice(&self) -> &[T] {
// Safety: The first `len` elements have been initialized to `T`s in
// `Self::resize_with`.
unsafe { slice::from_raw_parts(self.as_ptr(), self.len) }
}

/// Extracts a mutable slice of the entire vector.
pub fn as_mut_slice(&mut self) -> &mut [T] {
// Safety: The first `len` elements have been initialized to `T`s in
// `Self::resize_with`.
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) }
}

pub fn resize_with<F>(&mut self, new_len: usize, mut f: F)
where
F: FnMut() -> T,
{
let len = self.len();

// Resize the underlying vector to have enough chunks for the new length.
let new_bytes = mem::size_of::<T>() * new_len;
let new_chunks = if (new_bytes % 64) == 0 {
new_bytes / 64
} else {
(new_bytes / 64) + 1
};
self.inner.resize_with(new_chunks, MaybeUninit::uninit);

// If we grew the vector, initialize the new elements past `len`.
if new_len > len {
for offset in len..new_len {
// SAFETY: We've allocated enough space to write up to `new_len` elements into the buffer.
unsafe {
self.as_mut_ptr().add(offset).write(f());
}
}
}

self.len = new_len;
}
}

impl<T: Copy> Deref for AlignedVec64<T> {
type Target = [T];

fn deref(&self) -> &Self::Target {
self.as_slice()
}
}

impl<T: Copy> DerefMut for AlignedVec64<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_slice()
}
}

// NOTE: Custom impl so that we don't require `T: Default`.
impl<T: Copy> Default for AlignedVec64<T> {
fn default() -> Self {
Self::new()
}
}
2 changes: 1 addition & 1 deletion src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2263,7 +2263,7 @@ unsafe fn decode_b_inner(
let pal = if t.frame_thread.pass != 0 {
let index = ((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1)
+ ((t.bx >> 1) + (t.by & 1)) as isize;
&f.frame_thread.pal[index as usize].0
&f.frame_thread.pal[index as usize]
} else {
&t.scratch.c2rust_unnamed_0.pal
};
Expand Down
2 changes: 1 addition & 1 deletion src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ pub struct Rav1dFrameContext_frame_thread {
pub b: Vec<Av1Block>,
pub cbi: Vec<CodedBlockInfo>,
// indexed using (t->by >> 1) * (f->b4_stride >> 1) + (t->bx >> 1)
pub pal: Vec<Align64<[[u16; 8]; 3]>>, /* [3 plane][8 idx] */
pub pal: AlignedVec64<[[u16; 8]; 3]>, /* [3 plane][8 idx] */
// iterated over inside tile state
pub pal_idx: *mut u8,
pub cf: *mut DynCoef,
Expand Down

0 comments on commit 55dc06d

Please sign in to comment.