Skip to content

Commit

Permalink
enum {{Comp,}Inter{Pred,}Mode,DRL_PROXIMITY,InterIntraType}: Narrow…
Browse files Browse the repository at this point in the history
… to `u8` to avoid redundant casts (#393)
  • Loading branch information
kkysen authored Aug 25, 2023
2 parents 8981e44 + 396e9bd commit 08e776f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 74 deletions.
102 changes: 47 additions & 55 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,6 @@ use crate::src::levels::PARTITION_T_TOP_SPLIT;
use crate::src::levels::PARTITION_V;
use crate::src::levels::PARTITION_V4;

use crate::src::levels::InterPredMode;
use crate::src::levels::MV_JOINT_H;
use crate::src::levels::MV_JOINT_HV;
use crate::src::levels::MV_JOINT_V;
Expand All @@ -864,7 +863,6 @@ use crate::src::levels::NEARESTMV;
use crate::src::levels::NEARMV;
use crate::src::levels::NEWMV;

use crate::src::levels::CompInterPredMode;
use crate::src::levels::GLOBALMV_GLOBALMV;
use crate::src::levels::NEARER_DRL;
use crate::src::levels::NEAREST_DRL;
Expand Down Expand Up @@ -1885,7 +1883,7 @@ unsafe fn splat_oneref_mv(
bw4: usize,
bh4: usize,
) {
let mode = b.inter_mode() as InterPredMode;
let mode = b.inter_mode();
let tmpl = Align16(refmvs_block(refmvs_block_unaligned {
mv: refmvs_mvpair {
mv: [b.mv()[0], mv::ZERO],
Expand Down Expand Up @@ -1944,7 +1942,7 @@ unsafe fn splat_tworef_mv(
bh4: usize,
) {
assert!(bw4 >= 2 && bh4 >= 2);
let mode = b.inter_mode() as CompInterPredMode;
let mode = b.inter_mode();
let tmpl = Align16(refmvs_block(refmvs_block_unaligned {
mv: refmvs_mvpair { mv: *b.mv() },
r#ref: refmvs_refpair {
Expand Down Expand Up @@ -3149,8 +3147,8 @@ unsafe fn decode_b(
frame_hdr.skip_mode_refs[1] as i8,
];
*b.comp_type_mut() = COMP_INTER_AVG;
*b.inter_mode_mut() = NEARESTMV_NEARESTMV as u8;
*b.drl_idx_mut() = NEAREST_DRL as u8;
*b.inter_mode_mut() = NEARESTMV_NEARESTMV;
*b.drl_idx_mut() = NEAREST_DRL;
has_subpel_filter = false;

let mut mvstack = [Default::default(); 8];
Expand Down Expand Up @@ -3287,16 +3285,16 @@ unsafe fn decode_b(
}

let im = &dav1d_comp_inter_pred_modes[b.inter_mode() as usize];
*b.drl_idx_mut() = NEAREST_DRL as u8;
if b.inter_mode() == NEWMV_NEWMV as u8 {
*b.drl_idx_mut() = NEAREST_DRL;
if b.inter_mode() == NEWMV_NEWMV {
if n_mvs > 1 {
// NEARER, NEAR or NEARISH
let drl_ctx_v1 = get_drl_context(&mvstack, 0);
*b.drl_idx_mut() += dav1d_msac_decode_bool_adapt(
&mut ts.msac,
&mut ts.cdf.m.drl_bit[drl_ctx_v1 as usize],
) as u8;
if b.drl_idx() == NEARER_DRL as u8 && n_mvs > 2 {
if b.drl_idx() == NEARER_DRL && n_mvs > 2 {
let drl_ctx_v2 = get_drl_context(&mvstack, 1);
*b.drl_idx_mut() += dav1d_msac_decode_bool_adapt(
&mut ts.msac,
Expand All @@ -3312,16 +3310,16 @@ unsafe fn decode_b(
);
}
}
} else if im[0] == NEARMV as u8 || im[1] == NEARMV as u8 {
*b.drl_idx_mut() = NEARER_DRL as u8;
} else if im[0] == NEARMV || im[1] == NEARMV {
*b.drl_idx_mut() = NEARER_DRL;
if n_mvs > 2 {
// NEAR or NEARISH
let drl_ctx_v2 = get_drl_context(&mvstack, 1);
*b.drl_idx_mut() += dav1d_msac_decode_bool_adapt(
&mut ts.msac,
&mut ts.cdf.m.drl_bit[drl_ctx_v2 as usize],
) as u8;
if b.drl_idx() == NEAR_DRL as u8 && n_mvs > 3 {
if b.drl_idx() == NEAR_DRL && n_mvs > 3 {
let drl_ctx_v3 = get_drl_context(&mvstack, 2);
*b.drl_idx_mut() += dav1d_msac_decode_bool_adapt(
&mut ts.msac,
Expand All @@ -3338,11 +3336,10 @@ unsafe fn decode_b(
}
}
}
assert!(b.drl_idx() >= NEAREST_DRL as u8 && b.drl_idx() <= NEARISH_DRL as u8);
assert!(b.drl_idx() >= NEAREST_DRL && b.drl_idx() <= NEARISH_DRL);

has_subpel_filter =
std::cmp::min(bw4, bh4) == 1 || b.inter_mode() != GLOBALMV_GLOBALMV as u8;
let mut assign_comp_mv = |idx: usize| match im[idx] as InterPredMode {
has_subpel_filter = std::cmp::min(bw4, bh4) == 1 || b.inter_mode() != GLOBALMV_GLOBALMV;
let mut assign_comp_mv = |idx: usize| match im[idx] {
NEARMV | NEARESTMV => {
b.mv_mut()[idx] = mvstack[b.drl_idx() as usize].mv.mv[idx];
fix_mv_precision(frame_hdr, &mut b.mv_mut()[idx]);
Expand Down Expand Up @@ -3411,15 +3408,15 @@ unsafe fn decode_b(
by4,
bx4,
);
*b.comp_type_mut() = COMP_INTER_WEIGHTED_AVG as u8
*b.comp_type_mut() = COMP_INTER_WEIGHTED_AVG
+ dav1d_msac_decode_bool_adapt(
&mut ts.msac,
&mut ts.cdf.m.jnt_comp[jnt_ctx as usize],
) as u8;
if DEBUG_BLOCK_INFO(f, t) {
println!(
"Post-jnt_comp[{},ctx={}[ac:{},ar:{},lc:{},lr:{}]]: r={}",
b.comp_type() == COMP_INTER_AVG as u8,
b.comp_type() == COMP_INTER_AVG,
jnt_ctx,
(*t.a).comp_type[bx4 as usize],
(*t.a).r#ref[0][bx4 as usize],
Expand All @@ -3429,37 +3426,37 @@ unsafe fn decode_b(
);
}
} else {
*b.comp_type_mut() = COMP_INTER_AVG as u8;
*b.comp_type_mut() = COMP_INTER_AVG;
}
} else {
if wedge_allowed_mask & (1 << bs) != 0 {
let ctx = dav1d_wedge_ctx_lut[bs as usize] as usize;
*b.comp_type_mut() = COMP_INTER_WEDGE as u8
*b.comp_type_mut() = COMP_INTER_WEDGE
- dav1d_msac_decode_bool_adapt(&mut ts.msac, &mut ts.cdf.m.wedge_comp[ctx])
as u8;
if b.comp_type() == COMP_INTER_WEDGE as u8 {
if b.comp_type() == COMP_INTER_WEDGE {
*b.wedge_idx_mut() = dav1d_msac_decode_symbol_adapt16(
&mut ts.msac,
&mut ts.cdf.m.wedge_idx[ctx],
15,
) as u8;
}
} else {
*b.comp_type_mut() = COMP_INTER_SEG as u8;
*b.comp_type_mut() = COMP_INTER_SEG;
}
*b.mask_sign_mut() = dav1d_msac_decode_bool_equi(&mut ts.msac) as u8;
if DEBUG_BLOCK_INFO(f, t) {
println!(
"Post-seg/wedge[{},wedge_idx={},sign={}]: r={}",
b.comp_type() == COMP_INTER_WEDGE as u8,
b.comp_type() == COMP_INTER_WEDGE,
b.wedge_idx(),
b.mask_sign(),
ts.msac.rng,
);
}
}
} else {
*b.comp_type_mut() = COMP_INTER_NONE as u8;
*b.comp_type_mut() = COMP_INTER_NONE;

// ref
if let Some(seg) = seg.filter(|seg| seg.r#ref > 0) {
Expand Down Expand Up @@ -3545,7 +3542,7 @@ unsafe fn decode_b(
&mut ts.cdf.m.globalmv_mode[(ctx >> 3 & 1) as usize],
)
{
*b.inter_mode_mut() = GLOBALMV as u8;
*b.inter_mode_mut() = GLOBALMV;
b.mv_mut()[0] = get_gmv_2d(
&frame_hdr.gmv[b.r#ref()[0] as usize],
t.bx,
Expand All @@ -3563,8 +3560,8 @@ unsafe fn decode_b(
&mut ts.cdf.m.refmv_mode[(ctx >> 4 & 15) as usize],
) {
// NEAREST, NEARER, NEAR or NEARISH
*b.inter_mode_mut() = NEARMV as u8;
*b.drl_idx_mut() = NEARER_DRL as u8;
*b.inter_mode_mut() = NEARMV;
*b.drl_idx_mut() = NEARER_DRL;
if n_mvs > 2 {
// NEARER, NEAR or NEARISH
let drl_ctx_v2 = get_drl_context(&mvstack, 1);
Expand All @@ -3573,7 +3570,7 @@ unsafe fn decode_b(
&mut ts.msac,
&mut ts.cdf.m.drl_bit[drl_ctx_v2 as usize],
) as u8;
if b.drl_idx() == NEAR_DRL as u8 && n_mvs > 3 {
if b.drl_idx() == NEAR_DRL && n_mvs > 3 {
// NEAR or NEARISH
let drl_ctx_v3 = get_drl_context(&mvstack, 2);
*b.drl_idx_mut() = b.drl_idx()
Expand All @@ -3585,11 +3582,11 @@ unsafe fn decode_b(
}
} else {
*b.inter_mode_mut() = NEARESTMV as u8;
*b.drl_idx_mut() = NEAREST_DRL as u8;
*b.drl_idx_mut() = NEAREST_DRL;
}
assert!(b.drl_idx() >= NEAREST_DRL as u8 && b.drl_idx() <= NEARISH_DRL as u8);
assert!(b.drl_idx() >= NEAREST_DRL && b.drl_idx() <= NEARISH_DRL);
b.mv_mut()[0] = mvstack[b.drl_idx() as usize].mv.mv[0];
if b.drl_idx() < NEAR_DRL as u8 {
if b.drl_idx() < NEAR_DRL {
fix_mv_precision(frame_hdr, &mut b.mv_mut()[0]);
}
}
Expand All @@ -3607,8 +3604,8 @@ unsafe fn decode_b(
}
} else {
has_subpel_filter = true;
*b.inter_mode_mut() = NEWMV as u8;
*b.drl_idx_mut() = NEAREST_DRL as u8;
*b.inter_mode_mut() = NEWMV;
*b.drl_idx_mut() = NEAREST_DRL;
if n_mvs > 1 {
// NEARER, NEAR or NEARISH
let drl_ctx_v1 = get_drl_context(&mvstack, 0);
Expand All @@ -3617,7 +3614,7 @@ unsafe fn decode_b(
&mut ts.msac,
&mut ts.cdf.m.drl_bit[drl_ctx_v1 as usize],
) as u8;
if b.drl_idx() == NEARER_DRL as u8 && n_mvs > 2 {
if b.drl_idx() == NEARER_DRL && n_mvs > 2 {
// NEAR or NEARISH
let drl_ctx_v2 = get_drl_context(&mvstack, 1);
*b.drl_idx_mut() = b.drl_idx()
Expand All @@ -3627,7 +3624,7 @@ unsafe fn decode_b(
) as u8;
}
}
assert!(b.drl_idx() >= NEAREST_DRL as u8 && b.drl_idx() <= NEARISH_DRL as u8);
assert!(b.drl_idx() >= NEAREST_DRL && b.drl_idx() <= NEARISH_DRL);
if n_mvs > 1 {
b.mv_mut()[0] = mvstack[b.drl_idx() as usize].mv.mv[0];
} else {
Expand Down Expand Up @@ -3674,20 +3671,20 @@ unsafe fn decode_b(
N_INTER_INTRA_PRED_MODES as size_t - 1,
) as u8;
let wedge_ctx = dav1d_wedge_ctx_lut[bs as usize] as libc::c_int;
*b.interintra_type_mut() = INTER_INTRA_BLEND as u8
*b.interintra_type_mut() = INTER_INTRA_BLEND
+ dav1d_msac_decode_bool_adapt(
&mut ts.msac,
&mut ts.cdf.m.interintra_wedge[wedge_ctx as usize],
) as u8;
if b.interintra_type() == INTER_INTRA_WEDGE as u8 {
if b.interintra_type() == INTER_INTRA_WEDGE {
*b.wedge_idx_mut() = dav1d_msac_decode_symbol_adapt16(
&mut ts.msac,
&mut ts.cdf.m.wedge_idx[wedge_ctx as usize],
15,
) as u8;
}
} else {
*b.interintra_type_mut() = INTER_INTRA_NONE as u8;
*b.interintra_type_mut() = INTER_INTRA_NONE;
}
if DEBUG_BLOCK_INFO(f, t)
&& (*f.seq_hdr).inter_intra != 0
Expand All @@ -3704,11 +3701,11 @@ unsafe fn decode_b(

// motion variation
if frame_hdr.switchable_motion_mode != 0
&& b.interintra_type() == INTER_INTRA_NONE as u8
&& b.interintra_type() == INTER_INTRA_NONE
&& std::cmp::min(bw4, bh4) >= 2
// is not warped global motion
&& !(frame_hdr.force_integer_mv == 0
&& b.inter_mode() == GLOBALMV as u8
&& b.inter_mode() == GLOBALMV
&& frame_hdr.gmv[b.r#ref()[0] as usize].type_0 > DAV1D_WM_TYPE_TRANSLATION)
// has overlappable neighbours
&& (have_left && findoddzero(&t.l.intra.0[by4 as usize..][..h4 as usize])
Expand Down Expand Up @@ -3794,7 +3791,7 @@ unsafe fn decode_b(
// subpel filter
let filter = if frame_hdr.subpel_filter_mode == DAV1D_FILTER_SWITCHABLE {
if has_subpel_filter {
let comp = b.comp_type() != COMP_INTER_NONE as u8;
let comp = b.comp_type() != COMP_INTER_NONE;
let ctx1 = get_filter_ctx(&*t.a, &t.l, comp, false, b.r#ref()[0], by4, bx4);
let filter0 = dav1d_msac_decode_symbol_adapt4(
&mut ts.msac,
Expand Down Expand Up @@ -3848,12 +3845,8 @@ unsafe fn decode_b(
}

if frame_hdr.loopfilter.level_y != [0, 0] {
let is_globalmv = (b.inter_mode()
== if is_comp {
GLOBALMV_GLOBALMV as u8
} else {
GLOBALMV as u8
}) as libc::c_int;
let is_globalmv = (b.inter_mode() == if is_comp { GLOBALMV_GLOBALMV } else { GLOBALMV })
as libc::c_int;
let tx_split = [b.tx_split0() as u16, b.tx_split1()];
let mut ytx = b.max_ytx() as RectTxfmSize;
let mut uvtx = b.uvtx as RectTxfmSize;
Expand Down Expand Up @@ -3975,11 +3968,10 @@ unsafe fn decode_b(
let sby = t.by - ts.tiling.row_start >> f.sb_shift;
let lowest_px = &mut *ts.lowest_pixel.offset(sby as isize);
// keep track of motion vectors for each reference
if b.comp_type() == COMP_INTER_NONE as u8 {
if b.comp_type() == COMP_INTER_NONE {
// y
if std::cmp::min(bw4, bh4) > 1
&& (b.inter_mode() == GLOBALMV as u8
&& f.gmv_warp_allowed[b.r#ref()[0] as usize] != 0
&& (b.inter_mode() == GLOBALMV && f.gmv_warp_allowed[b.r#ref()[0] as usize] != 0
|| b.motion_mode() == MM_WARP as u8
&& t.warpmv.type_0 > DAV1D_WM_TYPE_TRANSLATION)
{
Expand Down Expand Up @@ -4076,7 +4068,7 @@ unsafe fn decode_b(
&f.svc[b.r#ref()[0] as usize][1],
);
} else if std::cmp::min(cbw4, cbh4) > 1
&& (b.inter_mode() == GLOBALMV as u8
&& (b.inter_mode() == GLOBALMV
&& f.gmv_warp_allowed[b.r#ref()[0] as usize] != 0
|| b.motion_mode() == MM_WARP as u8
&& t.warpmv.type_0 > DAV1D_WM_TYPE_TRANSLATION)
Expand Down Expand Up @@ -4110,7 +4102,7 @@ unsafe fn decode_b(
let refmvs =
|| std::iter::zip(b.r#ref(), b.mv()).map(|(r#ref, mv)| (r#ref as usize, mv));
for (r#ref, mv) in refmvs() {
if b.inter_mode() == GLOBALMV_GLOBALMV as u8 && f.gmv_warp_allowed[r#ref] != 0 {
if b.inter_mode() == GLOBALMV_GLOBALMV && f.gmv_warp_allowed[r#ref] != 0 {
affine_lowest_px_luma(
t,
&mut lowest_px[r#ref][0],
Expand All @@ -4129,7 +4121,7 @@ unsafe fn decode_b(
}
}
for (r#ref, mv) in refmvs() {
if b.inter_mode() == GLOBALMV_GLOBALMV as u8 && f.gmv_warp_allowed[r#ref] != 0 {
if b.inter_mode() == GLOBALMV_GLOBALMV && f.gmv_warp_allowed[r#ref] != 0 {
affine_lowest_px_luma(
t,
&mut lowest_px[r#ref][0],
Expand All @@ -4151,7 +4143,7 @@ unsafe fn decode_b(
// uv
if has_chroma {
for (r#ref, mv) in refmvs() {
if b.inter_mode() == GLOBALMV_GLOBALMV as u8
if b.inter_mode() == GLOBALMV_GLOBALMV
&& std::cmp::min(cbw4, cbh4) > 1
&& f.gmv_warp_allowed[r#ref] != 0
{
Expand Down Expand Up @@ -4588,7 +4580,7 @@ fn reset_context(ctx: &mut BlockContext, keyframe: bool, pass: libc::c_int) {
r#ref.fill(-1);
}
ctx.comp_type.0.fill(0);
ctx.mode.0.fill(NEARESTMV as u8);
ctx.mode.0.fill(NEARESTMV);
}
ctx.lcoef.0.fill(0x40);
for ccoef in &mut ctx.ccoef.0 {
Expand Down
Loading

0 comments on commit 08e776f

Please sign in to comment.