Skip to content

Commit

Permalink
struct WedgeMasks: Encapsulate static wedge_masks_*s and `fn fill…
Browse files Browse the repository at this point in the history
…2d_16x2` in a `struct` (#479)

`fn fill2d_16x2` as it was couldn't get rid of the `&mut`s to make it a
`const fn` and remove the `static mut`s for the `static wedge_masks_*`s,
since it took and store (in other statics) references to those statics.
This separates the steps into creating the statics and then referencing
them in separate functions/methods, and it also significantly simplifies
things as well.
  • Loading branch information
kkysen authored Sep 20, 2023
2 parents 872dee0 + ead1fea commit 248e347
Showing 1 changed file with 99 additions and 177 deletions.
276 changes: 99 additions & 177 deletions src/wedge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,43 +104,6 @@ static wedge_codebook_16_heqw: [wedge_code_type; 16] = [
wedge_code_type::new(6, 4, WEDGE_OBLIQUE117),
];

static mut wedge_masks_444_32x32: Align64<[[[u8; 32 * 32]; 16]; 2]> =
Align64([[[0; 32 * 32]; 16]; 2]);
static mut wedge_masks_444_32x16: Align64<[[[u8; 32 * 16]; 16]; 2]> =
Align64([[[0; 32 * 16]; 16]; 2]);
static mut wedge_masks_444_32x8: Align64<[[[u8; 32 * 8]; 16]; 2]> = Align64([[[0; 32 * 8]; 16]; 2]);
static mut wedge_masks_444_16x32: Align64<[[[u8; 16 * 32]; 16]; 2]> =
Align64([[[0; 16 * 32]; 16]; 2]);
static mut wedge_masks_444_16x16: Align64<[[[u8; 16 * 16]; 16]; 2]> =
Align64([[[0; 16 * 16]; 16]; 2]);
static mut wedge_masks_444_16x8: Align64<[[[u8; 16 * 8]; 16]; 2]> = Align64([[[0; 16 * 8]; 16]; 2]);
static mut wedge_masks_444_8x32: Align64<[[[u8; 8 * 32]; 16]; 2]> = Align64([[[0; 8 * 32]; 16]; 2]);
static mut wedge_masks_444_8x16: Align64<[[[u8; 8 * 16]; 16]; 2]> = Align64([[[0; 8 * 16]; 16]; 2]);
static mut wedge_masks_444_8x8: Align64<[[[u8; 8 * 8]; 16]; 2]> = Align64([[[0; 8 * 8]; 16]; 2]);

static mut wedge_masks_422_16x32: Align64<[[[u8; 16 * 32]; 16]; 2]> =
Align64([[[0; 16 * 32]; 16]; 2]);
static mut wedge_masks_422_16x16: Align64<[[[u8; 16 * 16]; 16]; 2]> =
Align64([[[0; 16 * 16]; 16]; 2]);
static mut wedge_masks_422_16x8: Align64<[[[u8; 16 * 8]; 16]; 2]> = Align64([[[0; 16 * 8]; 16]; 2]);
static mut wedge_masks_422_8x32: Align64<[[[u8; 8 * 32]; 16]; 2]> = Align64([[[0; 8 * 32]; 16]; 2]);
static mut wedge_masks_422_8x16: Align64<[[[u8; 8 * 16]; 16]; 2]> = Align64([[[0; 8 * 16]; 16]; 2]);
static mut wedge_masks_422_8x8: Align64<[[[u8; 8 * 8]; 16]; 2]> = Align64([[[0; 8 * 8]; 16]; 2]);
static mut wedge_masks_422_4x32: Align64<[[[u8; 4 * 32]; 16]; 2]> = Align64([[[0; 4 * 32]; 16]; 2]);
static mut wedge_masks_422_4x16: Align64<[[[u8; 4 * 16]; 16]; 2]> = Align64([[[0; 4 * 16]; 16]; 2]);
static mut wedge_masks_422_4x8: Align64<[[[u8; 4 * 8]; 16]; 2]> = Align64([[[0; 4 * 8]; 16]; 2]);

static mut wedge_masks_420_16x16: Align64<[[[u8; 16 * 16]; 16]; 2]> =
Align64([[[0; 16 * 16]; 16]; 2]);
static mut wedge_masks_420_16x8: Align64<[[[u8; 16 * 8]; 16]; 2]> = Align64([[[0; 16 * 8]; 16]; 2]);
static mut wedge_masks_420_16x4: Align64<[[[u8; 16 * 4]; 16]; 2]> = Align64([[[0; 16 * 4]; 16]; 2]);
static mut wedge_masks_420_8x16: Align64<[[[u8; 8 * 16]; 16]; 2]> = Align64([[[0; 8 * 16]; 16]; 2]);
static mut wedge_masks_420_8x8: Align64<[[[u8; 8 * 8]; 16]; 2]> = Align64([[[0; 8 * 8]; 16]; 2]);
static mut wedge_masks_420_8x4: Align64<[[[u8; 8 * 4]; 16]; 2]> = Align64([[[0; 8 * 4]; 16]; 2]);
static mut wedge_masks_420_4x16: Align64<[[[u8; 4 * 16]; 16]; 2]> = Align64([[[0; 4 * 16]; 16]; 2]);
static mut wedge_masks_420_4x8: Align64<[[[u8; 4 * 8]; 16]; 2]> = Align64([[[0; 4 * 8]; 16]; 2]);
static mut wedge_masks_420_4x4: Align64<[[[u8; 4 * 4]; 16]; 2]> = Align64([[[0; 4 * 4]; 16]; 2]);

pub static mut dav1d_wedge_masks: [[[[&'static [u8]; 16]; 2]; 3]; N_BS_SIZES] =
[[[[&[]; 16]; 2]; 3]; N_BS_SIZES];

Expand Down Expand Up @@ -258,61 +221,88 @@ const fn init_chroma<const LEN_LUMA: usize, const LEN_CHROMA: usize>(
chroma
}

#[cold]
fn fill2d_16x2<const LEN_444: usize, const LEN_422: usize, const LEN_420: usize>(
w: usize,
h: usize,
master: &[[[u8; 64]; 64]; N_WEDGE_DIRECTIONS],
cb: &[wedge_code_type; 16],
masks_444: &'static mut [[[u8; LEN_444]; 16]; 2],
masks_422: &'static mut [[[u8; LEN_422]; 16]; 2],
masks_420: &'static mut [[[u8; LEN_420]; 16]; 2],
struct WedgeMasks<const LEN_444: usize, const LEN_422: usize, const LEN_420: usize> {
masks_444: Align64<[[[u8; LEN_444]; 16]; 2]>,
masks_422: Align64<[[[u8; LEN_422]; 16]; 2]>,
masks_420: Align64<[[[u8; LEN_420]; 16]; 2]>,
signs: u16,
) -> [[[&'static [u8]; 16]; 2]; 3] {
assert!(LEN_444 == (w * h) >> 0);
assert!(LEN_422 == (w * h) >> 1);
assert!(LEN_420 == (w * h) >> 2);

const_for!(n in 0..16 => {
masks_444[0][n] = copy2d(
&master[cb[n].direction as usize],
w,
h,
32 - (w * cb[n].x_offset as usize >> 3),
32 - (h * cb[n].y_offset as usize >> 3),
);
});
const_for!(n in 0..16 => {
masks_444[1][n] = invert(&masks_444[0][n], w, h);
});
}

const_for!(n in 0..16 => {
let sign = (signs >> n & 1) != 0;
let luma = &masks_444[sign as usize][n];
impl<const LEN_444: usize, const LEN_422: usize, const LEN_420: usize>
WedgeMasks<LEN_444, LEN_422, LEN_420>
{
const fn fill2d_16x2(
w: usize,
h: usize,
master: &[[[u8; 64]; 64]; N_WEDGE_DIRECTIONS],
cb: &[wedge_code_type; 16],
signs: u16,
) -> Self {
assert!(LEN_444 == (w * h) >> 0);
assert!(LEN_422 == (w * h) >> 1);
assert!(LEN_420 == (w * h) >> 2);

let mut masks_444 = [[[0; LEN_444]; 16]; 2];
let mut masks_422 = [[[0; LEN_422]; 16]; 2];
let mut masks_420 = [[[0; LEN_420]; 16]; 2];

const_for!(n in 0..16 => {
masks_444[0][n] = copy2d(
&master[cb[n].direction as usize],
w,
h,
32 - (w * cb[n].x_offset as usize >> 3),
32 - (h * cb[n].y_offset as usize >> 3),
);
});
const_for!(n in 0..16 => {
masks_444[1][n] = invert(&masks_444[0][n], w, h);
});

masks_422[sign as usize][n] = init_chroma(luma, false, w, h, false);
masks_422[!sign as usize][n] = init_chroma(luma, true, w, h, false);
masks_420[sign as usize][n] = init_chroma(luma, false, w, h, true);
masks_420[!sign as usize][n] = init_chroma(luma, true, w, h, true);
});
const_for!(n in 0..16 => {
let sign = (signs >> n & 1) != 0;
let luma = &masks_444[sign as usize][n];

let mut masks = [[[&[] as &'static [u8]; 16]; 2]; 3];
masks_422[sign as usize][n] = init_chroma(luma, false, w, h, false);
masks_422[!sign as usize][n] = init_chroma(luma, true, w, h, false);
masks_420[sign as usize][n] = init_chroma(luma, false, w, h, true);
masks_420[!sign as usize][n] = init_chroma(luma, true, w, h, true);
});

// assign pointers in externally visible array
const_for!(n in 0..16 => {
let sign = (signs >> n & 1) != 0;
Self {
masks_444: Align64(masks_444),
masks_422: Align64(masks_422),
masks_420: Align64(masks_420),
signs,
}
}

masks[0][0][n] = &masks_444[sign as usize][n];
// not using !sign is intentional here, since 444 does not require
// any rounding since no chroma subsampling is applied.
masks[0][1][n] = &masks_444[sign as usize][n];
masks[1][0][n] = &masks_422[sign as usize][n];
masks[1][1][n] = &masks_422[!sign as usize][n];
masks[2][0][n] = &masks_420[sign as usize][n];
masks[2][1][n] = &masks_420[!sign as usize][n];
});
const fn slice(&self) -> [[[&[u8]; 16]; 2]; 3] {
let Self {
masks_444: Align64(masks_444),
masks_422: Align64(masks_422),
masks_420: Align64(masks_420),
signs,
} = self;

let mut masks = [[[&[] as &'static [u8]; 16]; 2]; 3];

// assign pointers in externally visible array
const_for!(n in 0..16 => {
let sign = (*signs >> n & 1) != 0;

masks[0][0][n] = &masks_444[sign as usize][n];
// not using !sign is intentional here, since 444 does not require
// any rounding since no chroma subsampling is applied.
masks[0][1][n] = &masks_444[sign as usize][n];
masks[1][0][n] = &masks_422[sign as usize][n];
masks[1][1][n] = &masks_422[!sign as usize][n];
masks[2][0][n] = &masks_420[sign as usize][n];
masks[2][1][n] = &masks_420[!sign as usize][n];
});

masks
masks
}
}

const fn build_master() -> [[[u8; 64]; 64]; N_WEDGE_DIRECTIONS] {
Expand Down Expand Up @@ -366,98 +356,30 @@ const fn build_master() -> [[[u8; 64]; 64]; N_WEDGE_DIRECTIONS] {
pub unsafe fn dav1d_init_wedge_masks() {
// This function is guaranteed to be called only once

let master = build_master();

dav1d_wedge_masks[BS_32x32 as usize] = fill2d_16x2(
32,
32,
&master,
&wedge_codebook_16_heqw,
&mut wedge_masks_444_32x32.0,
&mut wedge_masks_422_16x32.0,
&mut wedge_masks_420_16x16.0,
0x7bfb,
);
dav1d_wedge_masks[BS_32x16 as usize] = fill2d_16x2(
32,
16,
&master,
&wedge_codebook_16_hltw,
&mut wedge_masks_444_32x16.0,
&mut wedge_masks_422_16x16.0,
&mut wedge_masks_420_16x8.0,
0x7beb,
);
dav1d_wedge_masks[BS_32x8 as usize] = fill2d_16x2(
32,
8,
&master,
&wedge_codebook_16_hltw,
&mut wedge_masks_444_32x8.0,
&mut wedge_masks_422_16x8.0,
&mut wedge_masks_420_16x4.0,
0x6beb,
);
dav1d_wedge_masks[BS_16x32 as usize] = fill2d_16x2(
16,
32,
&master,
&wedge_codebook_16_hgtw,
&mut wedge_masks_444_16x32.0,
&mut wedge_masks_422_8x32.0,
&mut wedge_masks_420_8x16.0,
0x7beb,
);
dav1d_wedge_masks[BS_16x16 as usize] = fill2d_16x2(
16,
16,
&master,
&wedge_codebook_16_heqw,
&mut wedge_masks_444_16x16.0,
&mut wedge_masks_422_8x16.0,
&mut wedge_masks_420_8x8.0,
0x7bfb,
);
dav1d_wedge_masks[BS_16x8 as usize] = fill2d_16x2(
16,
8,
&master,
&wedge_codebook_16_hltw,
&mut wedge_masks_444_16x8.0,
&mut wedge_masks_422_8x8.0,
&mut wedge_masks_420_8x4.0,
0x7beb,
);
dav1d_wedge_masks[BS_8x32 as usize] = fill2d_16x2(
8,
32,
&master,
&wedge_codebook_16_hgtw,
&mut wedge_masks_444_8x32.0,
&mut wedge_masks_422_4x32.0,
&mut wedge_masks_420_4x16.0,
0x7aeb,
);
dav1d_wedge_masks[BS_8x16 as usize] = fill2d_16x2(
8,
16,
&master,
&wedge_codebook_16_hgtw,
&mut wedge_masks_444_8x16.0,
&mut wedge_masks_422_4x16.0,
&mut wedge_masks_420_4x8.0,
0x7beb,
);
dav1d_wedge_masks[BS_8x8 as usize] = fill2d_16x2(
8,
8,
&master,
&wedge_codebook_16_heqw,
&mut wedge_masks_444_8x8.0,
&mut wedge_masks_422_4x8.0,
&mut wedge_masks_420_4x4.0,
0x7bfb,
);
static master: [[[u8; 64]; 64]; N_WEDGE_DIRECTIONS] = build_master();

macro_rules! fill {
($w:literal x $h:literal, $cb:expr, $signs:expr) => {{
static wedge_masks: WedgeMasks<
{ $w * $h },
{ ($w / 2) * $h },
{ ($w / 2) * ($h / 2) },
> = WedgeMasks::fill2d_16x2($w, $h, &master, $cb, $signs);
paste! {
dav1d_wedge_masks[[<BS_ $w x $h>] as usize] = wedge_masks.slice();
}
}};
}

fill!(32 x 32, &wedge_codebook_16_heqw, 0x7bfb);
fill!(32 x 16, &wedge_codebook_16_hltw, 0x7beb);
fill!(32 x 8, &wedge_codebook_16_hltw, 0x6beb);
fill!(16 x 32, &wedge_codebook_16_hgtw, 0x7beb);
fill!(16 x 16, &wedge_codebook_16_heqw, 0x7bfb);
fill!(16 x 8, &wedge_codebook_16_hltw, 0x7beb);
fill!( 8 x 32, &wedge_codebook_16_hgtw, 0x7aeb);
fill!( 8 x 16, &wedge_codebook_16_hgtw, 0x7beb);
fill!( 8 x 8, &wedge_codebook_16_heqw, 0x7bfb);
}

static ii_dc_mask: Align64<[u8; 32 * 32]> = Align64([32; 32 * 32]);
Expand Down

0 comments on commit 248e347

Please sign in to comment.