Skip to content

Commit

Permalink
struct Rav1dRefmvsDSPContext: Add wrapper methods (#976)
Browse files Browse the repository at this point in the history
`fn splat_mv` is already a method, and `fn rav1d_refmvs_save_tmvs` was
already like a method, but as a free `fn`. So this just makes them all
methods on `Rav1dRefmvsDspContext`, which helps things for `fn
load_tmvs` as I make it and the `mvs` fields safe.
  • Loading branch information
kkysen authored Apr 19, 2024
2 parents 4501b38 + 6b4ef97 commit d862eea
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 77 deletions.
34 changes: 7 additions & 27 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ use crate::src::error::Rav1dError::EINVAL;
use crate::src::error::Rav1dError::ENOMEM;
use crate::src::error::Rav1dError::ENOPROTOOPT;
use crate::src::error::Rav1dResult;
use crate::src::ffi_safe::FFISafe;
use crate::src::filmgrain::Rav1dFilmGrainDSPContext;
use crate::src::internal::Bxy;
use crate::src::internal::Rav1dContext;
Expand Down Expand Up @@ -147,7 +146,6 @@ use crate::src::r#ref::rav1d_ref_inc;
use crate::src::recon::debug_block_info;
use crate::src::refmvs::rav1d_refmvs_find;
use crate::src::refmvs::rav1d_refmvs_init_frame;
use crate::src::refmvs::rav1d_refmvs_save_tmvs;
use crate::src::refmvs::rav1d_refmvs_tile_sbrow_init;
use crate::src::refmvs::refmvs_block;
use crate::src::refmvs::refmvs_mvpair;
Expand Down Expand Up @@ -3913,15 +3911,13 @@ pub(crate) unsafe fn rav1d_decode_tile_sbrow(
}

if c.tc.len() > 1 && frame_hdr.use_ref_frame_mvs != 0 {
let rf = f.rf.as_mut_dav1d();
(c.refmvs_dsp.load_tmvs)(
&rf,
c.refmvs_dsp.load_tmvs(
&f.rf,
ts.tiling.row,
ts.tiling.col_start >> 1,
ts.tiling.col_end >> 1,
t.b.y >> 1,
t.b.y + sb_step >> 1,
FFISafe::new(&f.rf.rp_proj),
);
}
t.pal_sz_uv[1] = Default::default();
Expand Down Expand Up @@ -4024,8 +4020,7 @@ pub(crate) unsafe fn rav1d_decode_tile_sbrow(
&& c.tc.len() > 1
&& f.frame_hdr().frame_type.is_inter_or_switch()
{
rav1d_refmvs_save_tmvs(
&c.refmvs_dsp,
c.refmvs_dsp.save_tmvs(
&t.rt,
&f.rf,
ts.tiling.col_start >> 1,
Expand Down Expand Up @@ -4519,31 +4514,16 @@ unsafe fn rav1d_decode_frame_main(c: &Rav1dContext, f: &mut Rav1dFrameData) -> R
t.b.y = sby << 4 + seq_hdr.sb128;
let by_end = t.b.y + f.sb_step >> 1;
if frame_hdr.use_ref_frame_mvs != 0 {
let rf = f.rf.as_mut_dav1d();
(c.refmvs_dsp.load_tmvs)(
&rf,
tile_row as c_int,
0,
f.bw >> 1,
t.b.y >> 1,
by_end,
FFISafe::new(&f.rf.rp_proj),
);
c.refmvs_dsp
.load_tmvs(&f.rf, tile_row as c_int, 0, f.bw >> 1, t.b.y >> 1, by_end);
}
for col in 0..cols {
t.ts = tile_row * cols + col;
rav1d_decode_tile_sbrow(c, &mut t, f).map_err(|()| EINVAL)?;
}
if f.frame_hdr().frame_type.is_inter_or_switch() {
rav1d_refmvs_save_tmvs(
&c.refmvs_dsp,
&t.rt,
&f.rf,
0,
f.bw >> 1,
t.b.y >> 1,
by_end,
);
c.refmvs_dsp
.save_tmvs(&t.rt, &f.rf, 0, f.bw >> 1, t.b.y >> 1, by_end);
}

// loopfilter + cdef + restoration
Expand Down
121 changes: 71 additions & 50 deletions src/refmvs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,80 @@ pub type splat_mv_fn = unsafe extern "C" fn(

#[repr(C)]
pub(crate) struct Rav1dRefmvsDSPContext {
pub load_tmvs: load_tmvs_fn,
pub save_tmvs: save_tmvs_fn,
pub splat_mv: splat_mv_fn,
load_tmvs: load_tmvs_fn,
save_tmvs: save_tmvs_fn,
splat_mv: splat_mv_fn,
}

impl Rav1dRefmvsDSPContext {
pub unsafe fn load_tmvs(
&self,
rf: &RefMvsFrame,
tile_row_idx: c_int,
col_start8: c_int,
col_end8: c_int,
row_start8: c_int,
row_end8: c_int,
) {
let rf_dav1d = rf.as_mut_dav1d();
(self.load_tmvs)(
&rf_dav1d,
tile_row_idx,
col_start8,
col_end8,
row_start8,
row_end8,
FFISafe::new(&rf.rp_proj),
);
}

// cache the current tile/sbrow (or frame/sbrow)'s projectable motion vectors
// into buffers for use in future frame's temporal MV prediction
pub unsafe fn save_tmvs(
&self,
rt: &refmvs_tile,
rf: &RefMvsFrame,
col_start8: c_int,
col_end8: c_int,
row_start8: c_int,
row_end8: c_int,
) {
assert!(row_start8 >= 0);
assert!((row_end8 - row_start8) as c_uint <= 16);
let row_end8 = cmp::min(row_end8, rf.ih8);
let col_end8 = cmp::min(col_end8, rf.iw8);
let stride = rf.rp_stride as isize;
let ref_sign = &rf.mfmv_sign;
let rp = rf.rp.offset(row_start8 as isize * stride);
let ri = <&[_; 31]>::try_from(&rt.r[6..]).unwrap();

// SAFETY: Note that for asm calls, disjointedness is unchecked here,
// even with `#[cfg(debug_assertions)]`. This is because the disjointedness
// is more fine-grained than the pointers passed to asm.
// For the Rust fallback fn, the extra args `&rf.r` and `ri`
// are passed to do allow for disjointedness checking.
let rr = &ri.map(|ri| {
if ri > rf.r.len() {
return ptr::null();
}
// SAFETY: `.add` is in-bounds; checked above.
unsafe { rf.r.as_mut_ptr().cast_const().add(ri) }
});

(self.save_tmvs)(
rp,
stride,
rr,
ref_sign,
col_end8,
row_end8,
col_start8,
row_start8,
FFISafe::new(&rf.r),
ri,
);
}

pub unsafe fn splat_mv(
&self,
rf: &RefMvsFrame,
Expand Down Expand Up @@ -1271,53 +1339,6 @@ pub(crate) fn rav1d_refmvs_find(
*ctx = refmv_ctx << 4 | globalmv_ctx << 3 | newmv_ctx;
}

// cache the current tile/sbrow (or frame/sbrow)'s projectable motion vectors
// into buffers for use in future frame's temporal MV prediction
pub(crate) unsafe fn rav1d_refmvs_save_tmvs(
dsp: &Rav1dRefmvsDSPContext,
rt: &refmvs_tile,
rf: &RefMvsFrame,
col_start8: c_int,
col_end8: c_int,
row_start8: c_int,
row_end8: c_int,
) {
assert!(row_start8 >= 0);
assert!((row_end8 - row_start8) as c_uint <= 16);
let row_end8 = cmp::min(row_end8, rf.ih8);
let col_end8 = cmp::min(col_end8, rf.iw8);
let stride = rf.rp_stride as isize;
let ref_sign = &rf.mfmv_sign;
let rp = rf.rp.offset(row_start8 as isize * stride);
let ri = <&[_; 31]>::try_from(&rt.r[6..]).unwrap();

// SAFETY: Note that for asm calls, disjointedness is unchecked here,
// even with `#[cfg(debug_assertions)]`. This is because the disjointedness
// is more fine-grained than the pointers passed to asm.
// For the Rust fallback fn, the extra args `&rf.r` and `ri`
// are passed to do allow for disjointedness checking.
let rr = &ri.map(|ri| {
if ri > rf.r.len() {
return ptr::null();
}
// SAFETY: `.add` is in-bounds; checked above.
unsafe { rf.r.as_mut_ptr().cast_const().add(ri) }
});

(dsp.save_tmvs)(
rp,
stride,
rr,
ref_sign,
col_end8,
row_end8,
col_start8,
row_start8,
FFISafe::new(&rf.r),
ri,
);
}

pub(crate) fn rav1d_refmvs_tile_sbrow_init(
rf: &RefMvsFrame,
tile_col_start4: c_int,
Expand Down

0 comments on commit d862eea

Please sign in to comment.