Skip to content

Commit

Permalink
struct Rav1dFrameData::{cur,prev}_segmap: Arcify with an `Option<…
Browse files Browse the repository at this point in the history
…DisjointMutArcSlice<u8>>` (#971)

* Fixes `{cur,prev}_segmap{,_ref}` fields of #713.
* Fixes `segmap{,_pool}` fields of #641.

To do this, I first added a few things to `DisjointMut`:
* e495e00: Allowing `T: ?Sized` so `DisjointMut<[_]>` is possible. This
didn't really require any changes other than expanding the bounds.
* e495e00: `impl AsMutPtr for [_]`
* 631ce49: `impl AsMutPtr for Box<[_]>`
* 704fe93: add `fn DisjointMut::new` to create a non-`default()`
`DisjointMut` from a pre-existing `T`
* 712c759: Add `DisjointMutArcSlice`, which is an
`Arc<DisjointMut<[_]>>` in release mode, and an
`Arc<DisjointMut<Box<[_]>>>` in debug mode. In release mode,
`DisjointMut` is `#[repr(transparent)]`, so we can do this and safe
allocations and indirections. And in debug mode, the overhead is fine.

Then this `Arc`ifies the `segmap`s with the above
`DisjointMutArcSlice<u8>`.
  • Loading branch information
kkysen authored Apr 17, 2024
2 parents 8c8a71e + 683f7c3 commit 85da63c
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 138 deletions.
145 changes: 64 additions & 81 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::src::cdf::CdfMvContext;
use crate::src::ctx::CaseSet;
use crate::src::dequant_tables::dav1d_dq_tbl;
use crate::src::disjoint_mut::DisjointMut;
use crate::src::disjoint_mut::DisjointMutSlice;
use crate::src::enum_map::enum_map;
use crate::src::enum_map::enum_map_ty;
use crate::src::enum_map::DefaultValue;
Expand Down Expand Up @@ -851,16 +852,12 @@ unsafe fn read_vartx_tree(
}

#[inline]
unsafe fn get_prev_frame_segid(
fn get_prev_frame_segid(
frame_hdr: &Rav1dFrameHeader,
b: Bxy,
w4: c_int,
h4: c_int,
// It's very difficult to make this safe (a slice),
// as it comes from [`Dav1dFrameContext::prev_segmap`],
// which is set to [`Dav1dFrameContext::prev_segmap_ref`],
// which is a [`Dav1dRef`], which has no size and is refcounted.
ref_seg_map: *const u8,
ref_seg_map: &DisjointMutSlice<u8>,
stride: ptrdiff_t,
) -> u8 {
assert!(frame_hdr.primary_ref_frame != RAV1D_PRIMARY_REF_NONE);
Expand All @@ -872,10 +869,9 @@ unsafe fn get_prev_frame_segid(
let stride = usize::try_from(stride).unwrap();

let mut prev_seg_id = 8;
let ref_seg_map = std::slice::from_raw_parts(
ref_seg_map.offset(b.y as isize * stride as isize + b.x as isize),
h4 * stride,
);
let offset = b.y as usize * stride as usize + b.x as usize;
let len = h4 as usize * stride;
let ref_seg_map = ref_seg_map.index(offset..offset + len);

assert!(w4 <= stride);
for ref_seg_map in ref_seg_map.chunks_exact(stride) {
Expand Down Expand Up @@ -1343,9 +1339,9 @@ unsafe fn decode_b_inner(
let frame_hdr: &Rav1dFrameHeader = &f.frame_hdr.as_ref().unwrap();
if frame_hdr.segmentation.enabled != 0 {
if frame_hdr.segmentation.update_map == 0 {
if !(f.prev_segmap).is_null() {
if let Some(prev_segmap) = f.prev_segmap.as_ref() {
let seg_id =
get_prev_frame_segid(frame_hdr, t.b, w4, h4, f.prev_segmap, f.b4_stride);
get_prev_frame_segid(frame_hdr, t.b, w4, h4, &prev_segmap.inner, f.b4_stride);
if seg_id >= RAV1D_MAX_SEGMENTS.into() {
return Err(());
}
Expand All @@ -1364,9 +1360,15 @@ unsafe fn decode_b_inner(
seg_pred
} {
// temporal predicted seg_id
if !(f.prev_segmap).is_null() {
let seg_id =
get_prev_frame_segid(frame_hdr, t.b, w4, h4, f.prev_segmap, f.b4_stride);
if let Some(prev_segmap) = f.prev_segmap.as_ref() {
let seg_id = get_prev_frame_segid(
frame_hdr,
t.b,
w4,
h4,
&prev_segmap.inner,
f.b4_stride,
);
if seg_id >= RAV1D_MAX_SEGMENTS.into() {
return Err(());
}
Expand All @@ -1379,7 +1381,7 @@ unsafe fn decode_b_inner(
t.b,
have_top,
have_left,
f.cur_segmap,
&f.cur_segmap.as_ref().unwrap().inner,
f.b4_stride as usize,
);
let diff = rav1d_msac_decode_symbol_adapt8(
Expand Down Expand Up @@ -1453,9 +1455,9 @@ unsafe fn decode_b_inner(
seg_pred
} {
// temporal predicted seg_id
if !(f.prev_segmap).is_null() {
if let Some(prev_segmap) = f.prev_segmap.as_ref() {
let seg_id =
get_prev_frame_segid(frame_hdr, t.b, w4, h4, f.prev_segmap, f.b4_stride);
get_prev_frame_segid(frame_hdr, t.b, w4, h4, &prev_segmap.inner, f.b4_stride);
if seg_id >= RAV1D_MAX_SEGMENTS.into() {
return Err(());
}
Expand All @@ -1464,8 +1466,13 @@ unsafe fn decode_b_inner(
b.seg_id = 0;
}
} else {
let (pred_seg_id, seg_ctx) =
get_cur_frame_segid(t.b, have_top, have_left, f.cur_segmap, f.b4_stride as usize);
let (pred_seg_id, seg_ctx) = get_cur_frame_segid(
t.b,
have_top,
have_left,
&f.cur_segmap.as_ref().unwrap().inner,
f.b4_stride as usize,
);
if b.skip != 0 {
b.seg_id = pred_seg_id as u8;
} else {
Expand Down Expand Up @@ -2993,18 +3000,12 @@ unsafe fn decode_b_inner(
// Need checked casts here because we're using `from_raw_parts_mut` and an overflow would be UB.
let [by, bx, bh4, bw4] = [t.b.y, t.b.x, bh4, bw4].map(|it| usize::try_from(it).unwrap());
let b4_stride = usize::try_from(f.b4_stride).unwrap();
let cur_segmap_len = (by * b4_stride + bx)
+ if bh4 == 0 {
0
} else {
(b4_stride * (bh4 - 1)) + bw4
};
let cur_segmap = std::slice::from_raw_parts_mut(f.cur_segmap, cur_segmap_len);
let seg_ptr = &mut cur_segmap[by * b4_stride + bx..];

let cur_segmap = &f.cur_segmap.as_ref().unwrap().inner;
let offset = by * b4_stride + bx;
CaseSet::<32, false>::one((), bw4, 0, |case, ()| {
for seg_ptr in seg_ptr.chunks_mut(b4_stride).take(bh4) {
case.set(seg_ptr, b.seg_id);
for i in 0..bh4 {
let i = offset + i * b4_stride;
case.set(&mut cur_segmap.index_mut(i..i + bw4), b.seg_id);
}
});
}
Expand Down Expand Up @@ -4585,8 +4586,8 @@ pub(crate) unsafe fn rav1d_decode_frame_exit(
}
}

rav1d_ref_dec(&mut f.cur_segmap_ref);
rav1d_ref_dec(&mut f.prev_segmap_ref);
let _ = mem::take(&mut f.cur_segmap);
let _ = mem::take(&mut f.prev_segmap);
rav1d_ref_dec(&mut f.mvs_ref);
let _ = mem::take(&mut f.seq_hdr);
let _ = mem::take(&mut f.frame_hdr);
Expand Down Expand Up @@ -4955,8 +4956,7 @@ pub unsafe fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
// segmap
if frame_hdr.segmentation.enabled != 0 {
// By default, the previous segmentation map is not initialised.
f.prev_segmap_ref = ptr::null_mut();
f.prev_segmap = ptr::null();
f.prev_segmap = None;

// We might need a previous frame's segmentation map.
// This happens if there is either no update or a temporal update.
Expand All @@ -4966,50 +4966,37 @@ pub unsafe fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
let ref_w = (ref_coded_width[pri_ref] + 7 >> 3) << 1;
let ref_h = (f.refp[pri_ref].p.p.h + 7 >> 3) << 1;
if ref_w == f.bw && ref_h == f.bh {
f.prev_segmap_ref = c.refs[frame_hdr.refidx[pri_ref] as usize].segmap;
if !f.prev_segmap_ref.is_null() {
rav1d_ref_inc(f.prev_segmap_ref);
f.prev_segmap = (*f.prev_segmap_ref).data.cast::<u8>();
}
f.prev_segmap = c.refs[frame_hdr.refidx[pri_ref] as usize].segmap.clone();
}
}

if frame_hdr.segmentation.update_map != 0 {
// We're updating an existing map,
// but need somewhere to put the new values.
// Allocate them here (the data actually gets set elsewhere).
f.cur_segmap_ref = rav1d_ref_create_using_pool(
c.segmap_pool,
::core::mem::size_of::<u8>() * f.b4_stride as usize * 32 * f.sb128h as usize,
);
if f.cur_segmap_ref.is_null() {
rav1d_ref_dec(&mut f.prev_segmap_ref);
on_error(f, c, out);
return Err(ENOMEM);
}
f.cur_segmap = (*f.cur_segmap_ref).data.cast::<u8>();
} else if !f.prev_segmap_ref.is_null() {
// We're not updating an existing map,
// and we have a valid reference. Use that.
f.cur_segmap_ref = f.prev_segmap_ref;
rav1d_ref_inc(f.cur_segmap_ref);
f.cur_segmap = (*f.prev_segmap_ref).data.cast::<u8>();
} else {
// We need to make a new map. Allocate one here and zero it out.
let segmap_size =
::core::mem::size_of::<u8>() * f.b4_stride as usize * 32 * f.sb128h as usize;
f.cur_segmap_ref = rav1d_ref_create_using_pool(c.segmap_pool, segmap_size);
if f.cur_segmap_ref.is_null() {
on_error(f, c, out);
return Err(ENOMEM);
}
f.cur_segmap = (*f.cur_segmap_ref).data.cast::<u8>();
slice::from_raw_parts_mut(f.cur_segmap, segmap_size).fill(0);
}
f.cur_segmap = Some(
match (
frame_hdr.segmentation.update_map != 0,
f.prev_segmap.as_mut(),
) {
(true, _) | (false, None) => {
// If we're updating an existing map,
// we need somewhere to put the new values.
// Allocate them here (the data actually gets set elsewhere).
// Since this is Rust, we have to initialize it anyways.

// Otherwise if there's no previous, we need to make a new map.
// Allocate one here and zero it out.
let segmap_size = f.b4_stride as usize * 32 * f.sb128h as usize;
// TODO fallible allocation
(0..segmap_size).map(|_| 0).collect()
}
(_, Some(prev_segmap)) => {
// We're not updating an existing map,
// and we have a valid reference. Use that.
prev_segmap.clone()
}
},
);
} else {
f.cur_segmap = ptr::null_mut();
f.cur_segmap_ref = ptr::null_mut();
f.prev_segmap_ref = ptr::null_mut();
f.cur_segmap = None;
f.prev_segmap = None;
}

// update references etc.
Expand All @@ -5027,11 +5014,7 @@ pub unsafe fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
c.cdf[i] = f.in_cdf.clone();
}

rav1d_ref_dec(&mut c.refs[i].segmap);
c.refs[i].segmap = f.cur_segmap_ref;
if !f.cur_segmap_ref.is_null() {
rav1d_ref_inc(f.cur_segmap_ref);
}
c.refs[i].segmap = f.cur_segmap.clone();
rav1d_ref_dec(&mut c.refs[i].refmvs);
if !frame_hdr.allow_intrabc {
c.refs[i].refmvs = f.mvs_ref;
Expand All @@ -5053,7 +5036,7 @@ pub unsafe fn rav1d_submit_frame(c: &mut Rav1dContext) -> Rav1dResult {
rav1d_thread_picture_unref(&mut c.refs[i].p);
}
let _ = mem::take(&mut c.cdf[i]);
rav1d_ref_dec(&mut c.refs[i].segmap);
let _ = mem::take(&mut c.refs[i].segmap);
rav1d_ref_dec(&mut c.refs[i].refmvs);
}
}
Expand Down
Loading

0 comments on commit 85da63c

Please sign in to comment.