Skip to content

Commit

Permalink
fn generate_grain_uv_rust: Elide luma bounds check (#573)
Browse files Browse the repository at this point in the history
This elides the luma bounds check when indexing into `buf_y` by using
`unsafe`, but still also statically checking the `unsafe`'s correctness.
Going forward, I'll focus on perf when asm is enabled, as the
C/Rust-only perf is not as important, but I had already worked out how
to elide this bounds check before that.
  • Loading branch information
kkysen authored Nov 21, 2023
2 parents a4a98a1 + 3aeff5e commit a8d6823
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 33 deletions.
1 change: 1 addition & 0 deletions lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod include {
} // mod include
pub mod src {
pub mod align;
mod assume;
mod cdef;
#[cfg_attr(not(feature = "bitdepth_16"), allow(dead_code))]
mod cdef_apply_tmpl_16;
Expand Down
9 changes: 9 additions & 0 deletions src/assume.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use std::hint::unreachable_unchecked;

/// A stable version of [`core::intrinsics::assume`].
#[inline(always)]
pub unsafe fn assume(condition: bool) {
if !condition {
unreachable_unchecked();
}
}
121 changes: 88 additions & 33 deletions src/filmgrain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::include::common::intops::iclip;
use crate::include::dav1d::headers::Dav1dFilmGrainData;
use crate::include::dav1d::headers::Rav1dFilmGrainData;
use crate::include::dav1d::headers::Rav1dPixelLayoutSubSampled;
use crate::src::assume::assume;
use crate::src::enum_map::enum_map;
use crate::src::enum_map::DefaultValue;
use crate::src::enum_map::EnumMap;
Expand Down Expand Up @@ -367,6 +368,8 @@ unsafe extern "C" fn generate_grain_y_c_erased<BD: BitDepth>(
generate_grain_y_rust(buf, data, bd)
}

const AR_PAD: usize = 3;

unsafe fn generate_grain_y_rust<BD: BitDepth>(
buf: &mut GrainLut<BD::Entry>,
data: &Rav1dFilmGrainData,
Expand All @@ -386,17 +389,16 @@ unsafe fn generate_grain_y_rust<BD: BitDepth>(
});
}

let ar_pad = 3;
// `ar_lag` is 2 bits; this tells the compiler it definitely is.
// That also means `ar_lag <= ar_pad`.
let ar_lag = data.ar_coeff_lag as usize & ((1 << 2) - 1);

for y in 0..GRAIN_HEIGHT - ar_pad {
for x in 0..GRAIN_WIDTH - 2 * ar_pad {
for y in 0..GRAIN_HEIGHT - AR_PAD {
for x in 0..GRAIN_WIDTH - 2 * AR_PAD {
let mut coeff = (data.ar_coeffs_y).as_ptr();
let mut sum = 0;
for (dy, buf_row) in buf[y..][ar_pad - ar_lag..=ar_pad].iter().enumerate() {
for (dx, &buf_val) in buf_row[x..][ar_pad - ar_lag..=ar_pad + ar_lag]
for (dy, buf_row) in buf[y..][AR_PAD - ar_lag..=AR_PAD].iter().enumerate() {
for (dx, &buf_val) in buf_row[x..][AR_PAD - ar_lag..=AR_PAD + ar_lag]
.iter()
.enumerate()
{
Expand All @@ -408,7 +410,7 @@ unsafe fn generate_grain_y_rust<BD: BitDepth>(
}
}

let buf_yx = &mut buf[y + ar_pad][x + ar_pad];
let buf_yx = &mut buf[y + AR_PAD][x + AR_PAD];
let grain = (*buf_yx).as_::<c_int>() + round2(sum, data.ar_coeff_shift);
(*buf_yx) = iclip(grain, grain_min, grain_max).as_::<BD::Entry>();
}
Expand All @@ -425,7 +427,67 @@ unsafe fn generate_grain_uv_rust<BD: BitDepth>(
bd: BD,
) {
let uv = is_uv as usize;
let [subx, suby] = [is_subx, is_suby].map(|it| it as u8);

struct IsSub {
y: bool,
x: bool,
}

impl IsSub {
const fn chroma(&self) -> (usize, usize) {
let h = if self.y {
SUB_GRAIN_HEIGHT
} else {
GRAIN_HEIGHT
};
let w = if self.x { SUB_GRAIN_WIDTH } else { GRAIN_WIDTH };
(h, w)
}

const fn len(&self) -> (usize, usize) {
let (h, w) = self.chroma();
(h - AR_PAD, w - 2 * AR_PAD)
}

const fn luma(&self, (y, x): (usize, usize)) -> (usize, usize) {
(
(y << self.y as usize) + AR_PAD,
(x << self.x as usize) + AR_PAD,
)
}

const fn buf_index(&self, (y, x): (usize, usize)) -> (usize, usize) {
let (y, x) = self.luma((y, x));
(y + self.y as usize, x + self.x as usize)
}

const fn max_buf_index(&self) -> (usize, usize) {
let (y, x) = self.len();
self.buf_index((y - 1, x - 1))
}

const fn check_buf_index<T, const Y: usize, const X: usize>(
&self,
_: &Option<[[T; X]; Y]>,
) {
let (y, x) = self.max_buf_index();
assert!(y < Y);
assert!(x < X);
}

#[allow(dead_code)] // False positive; used in a `const`.
const fn check_buf_index_all<T, const Y: usize, const X: usize>(buf: &Option<[[T; X]; Y]>) {
Self { y: true, x: true }.check_buf_index(buf);
Self { y: true, x: false }.check_buf_index(buf);
Self { y: false, x: true }.check_buf_index(buf);
Self { y: false, x: false }.check_buf_index(buf);
}
}

let is_sub = IsSub {
y: is_suby,
x: is_subx,
};

let bitdepth_min_8 = bd.bitdepth() - 8;
let mut seed = data.seed ^ if is_uv { 0x49d8 } else { 0xb524 };
Expand All @@ -434,49 +496,42 @@ unsafe fn generate_grain_uv_rust<BD: BitDepth>(
let grain_min = -grain_ctr;
let grain_max = grain_ctr - 1;

let chromaW = if is_subx {
SUB_GRAIN_WIDTH
} else {
GRAIN_WIDTH
};
let chromaH = if is_suby {
SUB_GRAIN_HEIGHT
} else {
GRAIN_HEIGHT
};

for row in &mut buf[..chromaH] {
row[..chromaW].fill_with(|| {
for row in &mut buf[..is_sub.chroma().0] {
row[..is_sub.chroma().1].fill_with(|| {
let value = get_random_number(11, &mut seed);
round2(dav1d_gaussian_sequence[value as usize], shift).as_::<BD::Entry>()
});
}

let ar_pad = 3;
// `ar_lag` is 2 bits; this tells the compiler it definitely is.
// That also means `ar_lag <= ar_pad`.
let ar_lag = data.ar_coeff_lag as usize & ((1 << 2) - 1);

for y in 0..chromaH - ar_pad {
for x in 0..chromaW - 2 * ar_pad {
for y in 0..is_sub.len().0 {
for x in 0..is_sub.len().1 {
let mut coeff = (data.ar_coeffs_uv[uv]).as_ptr();
let mut sum = 0;
for (dy, buf_row) in buf[y..][ar_pad - ar_lag..=ar_pad].iter().enumerate() {
for (dx, &buf_val) in buf_row[x..][ar_pad - ar_lag..=ar_pad + ar_lag]
for (dy, buf_row) in buf[y..][AR_PAD - ar_lag..=AR_PAD].iter().enumerate() {
for (dx, &buf_val) in buf_row[x..][AR_PAD - ar_lag..=AR_PAD + ar_lag]
.iter()
.enumerate()
{
if dx == ar_lag && dy == ar_lag {
let mut luma = 0;
let lumaX = (x << subx) + ar_pad;
let lumaY = (y << suby) + ar_pad;
for i in 0..=suby {
for j in 0..=subx {
luma +=
buf_y[lumaY + i as usize][lumaX + j as usize].as_::<c_int>();
let (luma_y, luma_x) = is_sub.luma((y, x));
const _: () = IsSub::check_buf_index_all(&None::<GrainLut<()>>);
// The optimizer is not smart enough to deduce this on its own.
// Safety: The above static check checks all maximum index possibilities.
unsafe {
assume(luma_y < GRAIN_HEIGHT + 1 - 1);
assume(luma_x < GRAIN_WIDTH - 1);
}
for i in 0..1 + is_sub.y as usize {
for j in 0..1 + is_sub.x as usize {
luma += buf_y[luma_y + i][luma_x + j].as_::<c_int>();
}
}
luma = round2(luma, subx + suby);
luma = round2(luma, is_sub.y as u8 + is_sub.x as u8);

sum += luma * *coeff as c_int;
break;
Expand All @@ -486,7 +541,7 @@ unsafe fn generate_grain_uv_rust<BD: BitDepth>(
}
}

let buf_yx = &mut buf[y + ar_pad][x + ar_pad];
let buf_yx = &mut buf[y + AR_PAD][x + AR_PAD];
let grain = (*buf_yx).as_::<c_int>() + round2(sum, data.ar_coeff_shift);
(*buf_yx) = iclip(grain, grain_min, grain_max).as_::<BD::Entry>();
}
Expand Down

0 comments on commit a8d6823

Please sign in to comment.