From 55dc06dd207eea02ee4bcad746181861ca815eb2 Mon Sep 17 00:00:00 2001 From: Nicole LeGare Date: Tue, 13 Feb 2024 14:47:23 -0800 Subject: [PATCH] Add `AlignedVec64` to supported aligned allocation --- src/align.rs | 101 +++++++++++++++++++++++++++++++++++++++++++++++- src/decode.rs | 2 +- src/internal.rs | 2 +- 3 files changed, 102 insertions(+), 3 deletions(-) diff --git a/src/align.rs b/src/align.rs index d30fdfe3e..461e93f55 100644 --- a/src/align.rs +++ b/src/align.rs @@ -5,8 +5,12 @@ //! make them easier to use in common cases, e.g. [`From`] and //! [`Index`]/[`IndexMut`] (since it's usually array fields that require //! specific aligment for use with SIMD instructions). -use std::ops::Index; + +use std::marker::PhantomData; +use std::mem::{self, MaybeUninit}; use std::ops::IndexMut; +use std::ops::{Deref, DerefMut, Index}; +use std::slice; /// [`Default`] isn't `impl`emented for all arrays `[T; N]` /// because they were implemented before `const` generics @@ -87,3 +91,98 @@ def_align!(8, Align8); def_align!(16, Align16); def_align!(32, Align32); def_align!(64, Align64); + +/// A [`Vec`] that uses a 64-byte aligned allocation. +/// +/// Only works with `Copy` types so that we don't have to handle drop logic. +pub struct AlignedVec64 { + inner: Vec>>, + len: usize, + _phantom: PhantomData, +} + +impl AlignedVec64 { + pub const fn new() -> Self { + Self { + inner: Vec::new(), + len: 0, + _phantom: PhantomData, + } + } + + /// Returns the number of elements in the vector. + pub fn len(&self) -> usize { + self.len + } + + pub fn as_ptr(&self) -> *const T { + self.inner.as_ptr().cast() + } + + pub fn as_mut_ptr(&mut self) -> *mut T { + self.inner.as_mut_ptr().cast() + } + + /// Extracts a slice containing the entire vector. + pub fn as_slice(&self) -> &[T] { + // Safety: The first `len` elements have been initialized to `T`s in + // `Self::resize_with`. + unsafe { slice::from_raw_parts(self.as_ptr(), self.len) } + } + + /// Extracts a mutable slice of the entire vector. + pub fn as_mut_slice(&mut self) -> &mut [T] { + // Safety: The first `len` elements have been initialized to `T`s in + // `Self::resize_with`. + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) } + } + + pub fn resize_with(&mut self, new_len: usize, mut f: F) + where + F: FnMut() -> T, + { + let len = self.len(); + + // Resize the underlying vector to have enough chunks for the new length. + let new_bytes = mem::size_of::() * new_len; + let new_chunks = if (new_bytes % 64) == 0 { + new_bytes / 64 + } else { + (new_bytes / 64) + 1 + }; + self.inner.resize_with(new_chunks, MaybeUninit::uninit); + + // If we grew the vector, initialize the new elements past `len`. + if new_len > len { + for offset in len..new_len { + // SAFETY: We've allocated enough space to write up to `new_len` elements into the buffer. + unsafe { + self.as_mut_ptr().add(offset).write(f()); + } + } + } + + self.len = new_len; + } +} + +impl Deref for AlignedVec64 { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + self.as_slice() + } +} + +impl DerefMut for AlignedVec64 { + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut_slice() + } +} + +// NOTE: Custom impl so that we don't require `T: Default`. +impl Default for AlignedVec64 { + fn default() -> Self { + Self::new() + } +} diff --git a/src/decode.rs b/src/decode.rs index bc2c7b070..2ecafdfe3 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -2263,7 +2263,7 @@ unsafe fn decode_b_inner( let pal = if t.frame_thread.pass != 0 { let index = ((t.by >> 1) + (t.bx & 1)) as isize * (f.b4_stride >> 1) + ((t.bx >> 1) + (t.by & 1)) as isize; - &f.frame_thread.pal[index as usize].0 + &f.frame_thread.pal[index as usize] } else { &t.scratch.c2rust_unnamed_0.pal }; diff --git a/src/internal.rs b/src/internal.rs index 431287bda..60a6389c8 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -409,7 +409,7 @@ pub struct Rav1dFrameContext_frame_thread { pub b: Vec, pub cbi: Vec, // indexed using (t->by >> 1) * (f->b4_stride >> 1) + (t->bx >> 1) - pub pal: Vec>, /* [3 plane][8 idx] */ + pub pal: AlignedVec64<[[u16; 8]; 3]>, /* [3 plane][8 idx] */ // iterated over inside tile state pub pal_idx: *mut u8, pub cf: *mut DynCoef,