Skip to content

Commit

Permalink
fn w_mask_rust: Cleanup and make mostly safe (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkysen authored Aug 7, 2023
2 parents a5db00d + e6036f1 commit 71a63ff
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 66 deletions.
100 changes: 46 additions & 54 deletions src/mc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,78 +829,70 @@ pub unsafe fn blend_h_rust<BD: BitDepth>(

// TODO(kkysen) temporarily `pub` until `mc` callers are deduplicated
pub unsafe fn w_mask_rust<BD: BitDepth>(
mut dst: *mut BD::Pixel,
dst_stride: libc::ptrdiff_t,
mut tmp1: *const i16,
mut tmp2: *const i16,
w: libc::c_int,
h: libc::c_int,
mut mask: *mut u8,
sign: libc::c_int,
ss_hor: libc::c_int,
ss_ver: libc::c_int,
dst: *mut BD::Pixel,
dst_stride: usize,
tmp1: *const i16,
tmp2: *const i16,
w: usize,
h: usize,
mask: *mut u8,
sign: bool,
ss_hor: bool,
ss_ver: bool,
bd: BD,
) {
let dst_stride = BD::pxstride(dst_stride);
let dst =
std::slice::from_raw_parts_mut(dst, if h == 0 { 0 } else { (h - 1) * dst_stride + w });
let [tmp1, tmp2] = [tmp1, tmp2].map(|tmp| std::slice::from_raw_parts(tmp, h * w));
let mut mask =
std::slice::from_raw_parts_mut(mask, (w >> ss_hor as usize) * (h >> ss_ver as usize));
let sign = sign as u8;

// store mask at 2x2 resolution, i.e. store 2x1 sum for even rows,
// and then load this intermediate to calculate final value for odd rows
let intermediate_bits = bd.get_intermediate_bits();
let bitdepth = bd.bitdepth();
let sh = intermediate_bits + 6;
let rnd = (32 << intermediate_bits) + i32::from(BD::PREP_BIAS) * 64;
let mask_sh = bitdepth + intermediate_bits - 4;
let mask_rnd = 1 << (mask_sh - 5);
for h in 0..h {
for (h, ((tmp1, tmp2), dst)) in iter::zip(tmp1.chunks_exact(w), tmp2.chunks_exact(w))
.zip(dst.chunks_mut(dst_stride))
.enumerate()
{
let mut x = 0;
while x < w {
let m = std::cmp::min(
38 as libc::c_int
+ ((*tmp1.offset(x as isize) as libc::c_int
- *tmp2.offset(x as isize) as libc::c_int)
.abs()
+ mask_rnd
>> mask_sh),
64 as libc::c_int,
);
*dst.offset(x as isize) = bd.iclip_pixel(
*tmp1.offset(x as isize) as libc::c_int * m
+ *tmp2.offset(x as isize) as libc::c_int * (64 - m)
+ rnd
>> sh,
let m =
std::cmp::min(38 + ((tmp1[x].abs_diff(tmp2[x]) + mask_rnd) >> mask_sh), 64) as u8;
dst[x] = bd.iclip_pixel(
(tmp1[x] as i32 * m as i32 + tmp2[x] as i32 * (64 - m as i32) + rnd) >> sh,
);
if ss_hor != 0 {

if ss_hor {
x += 1;
let n = std::cmp::min(
38 as libc::c_int
+ ((*tmp1.offset(x as isize) as libc::c_int
- *tmp2.offset(x as isize) as libc::c_int)
.abs()
+ mask_rnd
>> mask_sh),
64 as libc::c_int,
);
*dst.offset(x as isize) = bd.iclip_pixel(
*tmp1.offset(x as isize) as libc::c_int * n
+ *tmp2.offset(x as isize) as libc::c_int * (64 - n)
+ rnd
>> sh,

let n = std::cmp::min(38 + ((tmp1[x].abs_diff(tmp2[x]) + mask_rnd) >> mask_sh), 64)
as u8;
dst[x] = bd.iclip_pixel(
(tmp1[x] as i32 * n as i32 + tmp2[x] as i32 * (64 - n as i32) + rnd) >> sh,
);
if h & ss_ver != 0 {
*mask.offset((x >> 1) as isize) =
(m + n + *mask.offset((x >> 1) as isize) as libc::c_int + 2 - sign >> 2)
as u8;
} else if ss_ver != 0 {
*mask.offset((x >> 1) as isize) = (m + n) as u8;

mask[x >> 1] = if h & ss_ver as usize != 0 {
((m + n + mask[x >> 1] + 2 - sign) >> 2) as u8
} else if ss_ver {
m + n
} else {
*mask.offset((x >> 1) as isize) = (m + n + 1 - sign >> 1) as u8;
}
((m + n + 1 - sign) >> 1) as u8
};
} else {
*mask.offset(x as isize) = m as u8;
mask[x] = m;
}
x += 1;
}
tmp1 = tmp1.offset(w as isize);
tmp2 = tmp2.offset(w as isize);
dst = dst.offset(BD::pxstride(dst_stride as usize) as isize);
if ss_ver == 0 || h & 1 != 0 {
mask = mask.offset((w >> ss_hor) as isize);

if !ss_ver || h & 1 != 0 {
mask = &mut mask[w >> ss_hor as usize..];
}
}
}
13 changes: 7 additions & 6 deletions src/mc_tmpl_16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3179,17 +3179,18 @@ unsafe extern "C" fn w_mask_c(
ss_ver: libc::c_int,
bitdepth_max: libc::c_int,
) {
debug_assert!(sign == 0 || sign == 1);
w_mask_rust(
dst,
dst_stride,
dst_stride as usize,
tmp1,
tmp2,
w,
h,
w as usize,
h as usize,
mask,
sign,
ss_hor,
ss_ver,
sign != 0,
ss_hor != 0,
ss_ver != 0,
BitDepth16::new(bitdepth_max as u16),
)
}
Expand Down
13 changes: 7 additions & 6 deletions src/mc_tmpl_8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3080,17 +3080,18 @@ unsafe extern "C" fn w_mask_c(
ss_hor: libc::c_int,
ss_ver: libc::c_int,
) {
debug_assert!(sign == 0 || sign == 1);
w_mask_rust(
dst,
dst_stride,
dst_stride as usize,
tmp1,
tmp2,
w,
h,
w as usize,
h as usize,
mask,
sign,
ss_hor,
ss_ver,
sign != 0,
ss_hor != 0,
ss_ver != 0,
BitDepth8::new(()),
)
}
Expand Down

0 comments on commit 71a63ff

Please sign in to comment.