Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fn load_tmvs_c: Add safe rp_proj arg #975

Merged
merged 5 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use crate::src::error::Rav1dError::EINVAL;
use crate::src::error::Rav1dError::ENOMEM;
use crate::src::error::Rav1dError::ENOPROTOOPT;
use crate::src::error::Rav1dResult;
use crate::src::ffi_safe::FFISafe;
randomPoison marked this conversation as resolved.
Show resolved Hide resolved
use crate::src::filmgrain::Rav1dFilmGrainDSPContext;
use crate::src::internal::Bxy;
use crate::src::internal::Rav1dContext;
Expand Down Expand Up @@ -862,20 +863,11 @@ fn get_prev_frame_segid(
) -> u8 {
assert!(frame_hdr.primary_ref_frame != RAV1D_PRIMARY_REF_NONE);

// Need checked casts here because an overflowing cast
// would give a too large `len` to [`std::slice::from_raw_parts`], which would UB.
let w4 = usize::try_from(w4).unwrap();
let h4 = usize::try_from(h4).unwrap();
let stride = usize::try_from(stride).unwrap();

let mut prev_seg_id = 8;
let offset = b.y as usize * stride as usize + b.x as usize;
let len = h4 as usize * stride;
let ref_seg_map = ref_seg_map.index(offset..offset + len);

assert!(w4 <= stride);
for ref_seg_map in ref_seg_map.chunks_exact(stride) {
prev_seg_id = ref_seg_map[..w4]
for y in 0..h4 as usize {
let offset = (b.y as usize + y) * stride as usize + b.x as usize;
prev_seg_id = ref_seg_map
.index(offset..offset + w4 as usize)
.iter()
.copied()
.fold(prev_seg_id, cmp::min);
Expand Down Expand Up @@ -3929,6 +3921,7 @@ pub(crate) unsafe fn rav1d_decode_tile_sbrow(
ts.tiling.col_end >> 1,
t.b.y >> 1,
t.b.y + sb_step >> 1,
FFISafe::new(&f.rf.rp_proj),
);
}
t.pal_sz_uv[1] = Default::default();
Expand Down Expand Up @@ -4527,7 +4520,15 @@ unsafe fn rav1d_decode_frame_main(c: &Rav1dContext, f: &mut Rav1dFrameData) -> R
let by_end = t.b.y + f.sb_step >> 1;
if frame_hdr.use_ref_frame_mvs != 0 {
let rf = f.rf.as_mut_dav1d();
(c.refmvs_dsp.load_tmvs)(&rf, tile_row as c_int, 0, f.bw >> 1, t.b.y >> 1, by_end);
(c.refmvs_dsp.load_tmvs)(
&rf,
tile_row as c_int,
0,
f.bw >> 1,
t.b.y >> 1,
by_end,
FFISafe::new(&f.rf.rp_proj),
);
}
for col in 0..cols {
t.ts = tile_row * cols + col;
Expand Down
33 changes: 20 additions & 13 deletions src/refmvs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ extern "C" {
col_end8: c_int,
row_start8: c_int,
row_end8: c_int,
_rp_proj: *const FFISafe<DisjointMut<AlignedVec64<refmvs_temporal_block>>>,
);
}

Expand Down Expand Up @@ -312,6 +313,7 @@ pub(crate) type load_tmvs_fn = unsafe extern "C" fn(
col_end8: c_int,
row_start8: c_int,
row_end8: c_int,
rp_proj: *const FFISafe<DisjointMut<AlignedVec64<refmvs_temporal_block>>>,
) -> ();

pub type save_tmvs_fn = unsafe extern "C" fn(
Expand Down Expand Up @@ -1383,7 +1385,9 @@ unsafe extern "C" fn load_tmvs_c(
col_end8: c_int,
row_start8: c_int,
mut row_end8: c_int,
rp_proj: *const FFISafe<DisjointMut<AlignedVec64<refmvs_temporal_block>>>,
) {
let rp_proj = FFISafe::get(rp_proj);
let rf = &*rf;

if rf.n_tile_threads == 1 {
Expand All @@ -1394,17 +1398,16 @@ unsafe extern "C" fn load_tmvs_c(
row_end8 = cmp::min(row_end8, rf.ih8);
let col_start8i = cmp::max(col_start8 - 8, 0);
let col_end8i = cmp::min(col_end8 + 8, rf.iw8);
let stride = rf.rp_stride;
let mut rp_proj = rf
.rp_proj
.offset(16 * stride * tile_row_idx as isize + (row_start8 & 15) as isize * stride);
for _ in row_start8..row_end8 {
for x in col_start8..col_end8 {
(*rp_proj.offset(x as isize)).mv = mv::INVALID;
let stride = rf.rp_stride as usize;
let rp_proj_offset = 16 * stride * tile_row_idx as usize;
for y in row_start8..row_end8 {
let offset = rp_proj_offset + (y & 15) as usize * stride;
for rp_proj in
&mut *rp_proj.index_mut(offset + col_start8 as usize..offset + col_end8 as usize)
{
rp_proj.mv = mv::INVALID;
}
rp_proj = rp_proj.offset(stride as isize);
}
rp_proj = rf.rp_proj.offset(16 * stride * tile_row_idx as isize);
for n in 0..rf.n_mfmvs {
let ref2cur = rf.mfmv_ref2cur[n as usize];
if ref2cur == i32::MIN {
Expand All @@ -1413,7 +1416,7 @@ unsafe extern "C" fn load_tmvs_c(
let r#ref = rf.mfmv_ref[n as usize] as c_int;
let ref_sign = r#ref - 4;
let mut r = (*rf.rp_ref.offset(r#ref as isize))
.offset(row_start8 as isize * stride)
.add(row_start8 as usize * stride)
.cast_const();
for y in row_start8..row_end8 {
let y_sb_align = y & !7;
Expand All @@ -1439,14 +1442,18 @@ unsafe extern "C" fn load_tmvs_c(
let pos_y =
y + apply_sign((offset.y as c_int).abs() >> 6, offset.y as c_int ^ ref_sign);
if pos_y >= y_proj_start && pos_y < y_proj_end {
let pos = (pos_y & 15) as isize * stride;
let pos = (pos_y & 15) as usize * stride;
loop {
let x_sb_align = x & !7;
if pos_x >= cmp::max(x_sb_align - 8, col_start8)
&& pos_x < cmp::min(x_sb_align + 16, col_end8)
{
(*rp_proj.offset(pos + pos_x as isize)).mv = (*rb).mv;
(*rp_proj.offset(pos + pos_x as isize)).r#ref = ref2ref as i8;
*rp_proj.index_mut(
rp_proj_offset + (pos as isize + pos_x as isize) as usize,
) = refmvs_temporal_block {
mv: (*rb).mv,
r#ref: ref2ref as i8,
};
}
x += 1;
if x >= col_end8i {
Expand Down