Skip to content

Commit

Permalink
static wedge_masks_*: Move the array length multiplication to outer…
Browse files Browse the repository at this point in the history
… arrays (#464)

This simplifies the code a lot (as the multiplications are now done by
the compiler) and also more easily tells the compiler that the inner
arrays are exclusive (so we can have multiple `&mut`s). It also makes
the length of the slices clear, and also doesn't actually do any slicing
(which can't be done in `const fn`s).
  • Loading branch information
kkysen authored Sep 19, 2023
2 parents 8459deb + a6f5dcb commit f4f26c9
Showing 1 changed file with 53 additions and 58 deletions.
111 changes: 53 additions & 58 deletions src/wedge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,35 +109,42 @@ static wedge_codebook_16_heqw: [wedge_code_type; 16] = [
wedge_code_type::new(6, 4, WEDGE_OBLIQUE117),
];

static mut wedge_masks_444_32x32: Align64<[u8; 2 * 16 * 32 * 32]> = Align64([0; 2 * 16 * 32 * 32]);
static mut wedge_masks_444_32x16: Align64<[u8; 2 * 16 * 32 * 16]> = Align64([0; 2 * 16 * 32 * 16]);
static mut wedge_masks_444_32x8: Align64<[u8; 2 * 16 * 32 * 8]> = Align64([0; 2 * 16 * 32 * 8]);
static mut wedge_masks_444_16x32: Align64<[u8; 2 * 16 * 16 * 32]> = Align64([0; 2 * 16 * 16 * 32]);
static mut wedge_masks_444_16x16: Align64<[u8; 2 * 16 * 16 * 16]> = Align64([0; 2 * 16 * 16 * 16]);
static mut wedge_masks_444_16x8: Align64<[u8; 2 * 16 * 16 * 8]> = Align64([0; 2 * 16 * 16 * 8]);
static mut wedge_masks_444_8x32: Align64<[u8; 2 * 16 * 8 * 32]> = Align64([0; 2 * 16 * 8 * 32]);
static mut wedge_masks_444_8x16: Align64<[u8; 2 * 16 * 8 * 16]> = Align64([0; 2 * 16 * 8 * 16]);
static mut wedge_masks_444_8x8: Align64<[u8; 2 * 16 * 8 * 8]> = Align64([0; 2 * 16 * 8 * 8]);

static mut wedge_masks_422_16x32: Align64<[u8; 2 * 16 * 16 * 32]> = Align64([0; 2 * 16 * 16 * 32]);
static mut wedge_masks_422_16x16: Align64<[u8; 2 * 16 * 16 * 16]> = Align64([0; 2 * 16 * 16 * 16]);
static mut wedge_masks_422_16x8: Align64<[u8; 2 * 16 * 16 * 8]> = Align64([0; 2 * 16 * 16 * 8]);
static mut wedge_masks_422_8x32: Align64<[u8; 2 * 16 * 8 * 32]> = Align64([0; 2 * 16 * 8 * 32]);
static mut wedge_masks_422_8x16: Align64<[u8; 2 * 16 * 8 * 16]> = Align64([0; 2 * 16 * 8 * 16]);
static mut wedge_masks_422_8x8: Align64<[u8; 2 * 16 * 8 * 8]> = Align64([0; 2 * 16 * 8 * 8]);
static mut wedge_masks_422_4x32: Align64<[u8; 2 * 16 * 4 * 32]> = Align64([0; 2 * 16 * 4 * 32]);
static mut wedge_masks_422_4x16: Align64<[u8; 2 * 16 * 4 * 16]> = Align64([0; 2 * 16 * 4 * 16]);
static mut wedge_masks_422_4x8: Align64<[u8; 2 * 16 * 4 * 8]> = Align64([0; 2 * 16 * 4 * 8]);

static mut wedge_masks_420_16x16: Align64<[u8; 2 * 16 * 16 * 16]> = Align64([0; 2 * 16 * 16 * 16]);
static mut wedge_masks_420_16x8: Align64<[u8; 2 * 16 * 16 * 8]> = Align64([0; 2 * 16 * 16 * 8]);
static mut wedge_masks_420_16x4: Align64<[u8; 2 * 16 * 16 * 4]> = Align64([0; 2 * 16 * 16 * 4]);
static mut wedge_masks_420_8x16: Align64<[u8; 2 * 16 * 8 * 16]> = Align64([0; 2 * 16 * 8 * 16]);
static mut wedge_masks_420_8x8: Align64<[u8; 2 * 16 * 8 * 8]> = Align64([0; 2 * 16 * 8 * 8]);
static mut wedge_masks_420_8x4: Align64<[u8; 2 * 16 * 8 * 4]> = Align64([0; 2 * 16 * 8 * 4]);
static mut wedge_masks_420_4x16: Align64<[u8; 2 * 16 * 4 * 16]> = Align64([0; 2 * 16 * 4 * 16]);
static mut wedge_masks_420_4x8: Align64<[u8; 2 * 16 * 4 * 8]> = Align64([0; 2 * 16 * 4 * 8]);
static mut wedge_masks_420_4x4: Align64<[u8; 2 * 16 * 4 * 4]> = Align64([0; 2 * 16 * 4 * 4]);
static mut wedge_masks_444_32x32: Align64<[[[u8; 32 * 32]; 16]; 2]> =
Align64([[[0; 32 * 32]; 16]; 2]);
static mut wedge_masks_444_32x16: Align64<[[[u8; 32 * 16]; 16]; 2]> =
Align64([[[0; 32 * 16]; 16]; 2]);
static mut wedge_masks_444_32x8: Align64<[[[u8; 32 * 8]; 16]; 2]> = Align64([[[0; 32 * 8]; 16]; 2]);
static mut wedge_masks_444_16x32: Align64<[[[u8; 16 * 32]; 16]; 2]> =
Align64([[[0; 16 * 32]; 16]; 2]);
static mut wedge_masks_444_16x16: Align64<[[[u8; 16 * 16]; 16]; 2]> =
Align64([[[0; 16 * 16]; 16]; 2]);
static mut wedge_masks_444_16x8: Align64<[[[u8; 16 * 8]; 16]; 2]> = Align64([[[0; 16 * 8]; 16]; 2]);
static mut wedge_masks_444_8x32: Align64<[[[u8; 8 * 32]; 16]; 2]> = Align64([[[0; 8 * 32]; 16]; 2]);
static mut wedge_masks_444_8x16: Align64<[[[u8; 8 * 16]; 16]; 2]> = Align64([[[0; 8 * 16]; 16]; 2]);
static mut wedge_masks_444_8x8: Align64<[[[u8; 8 * 8]; 16]; 2]> = Align64([[[0; 8 * 8]; 16]; 2]);

static mut wedge_masks_422_16x32: Align64<[[[u8; 16 * 32]; 16]; 2]> =
Align64([[[0; 16 * 32]; 16]; 2]);
static mut wedge_masks_422_16x16: Align64<[[[u8; 16 * 16]; 16]; 2]> =
Align64([[[0; 16 * 16]; 16]; 2]);
static mut wedge_masks_422_16x8: Align64<[[[u8; 16 * 8]; 16]; 2]> = Align64([[[0; 16 * 8]; 16]; 2]);
static mut wedge_masks_422_8x32: Align64<[[[u8; 8 * 32]; 16]; 2]> = Align64([[[0; 8 * 32]; 16]; 2]);
static mut wedge_masks_422_8x16: Align64<[[[u8; 8 * 16]; 16]; 2]> = Align64([[[0; 8 * 16]; 16]; 2]);
static mut wedge_masks_422_8x8: Align64<[[[u8; 8 * 8]; 16]; 2]> = Align64([[[0; 8 * 8]; 16]; 2]);
static mut wedge_masks_422_4x32: Align64<[[[u8; 4 * 32]; 16]; 2]> = Align64([[[0; 4 * 32]; 16]; 2]);
static mut wedge_masks_422_4x16: Align64<[[[u8; 4 * 16]; 16]; 2]> = Align64([[[0; 4 * 16]; 16]; 2]);
static mut wedge_masks_422_4x8: Align64<[[[u8; 4 * 8]; 16]; 2]> = Align64([[[0; 4 * 8]; 16]; 2]);

static mut wedge_masks_420_16x16: Align64<[[[u8; 16 * 16]; 16]; 2]> =
Align64([[[0; 16 * 16]; 16]; 2]);
static mut wedge_masks_420_16x8: Align64<[[[u8; 16 * 8]; 16]; 2]> = Align64([[[0; 16 * 8]; 16]; 2]);
static mut wedge_masks_420_16x4: Align64<[[[u8; 16 * 4]; 16]; 2]> = Align64([[[0; 16 * 4]; 16]; 2]);
static mut wedge_masks_420_8x16: Align64<[[[u8; 8 * 16]; 16]; 2]> = Align64([[[0; 8 * 16]; 16]; 2]);
static mut wedge_masks_420_8x8: Align64<[[[u8; 8 * 8]; 16]; 2]> = Align64([[[0; 8 * 8]; 16]; 2]);
static mut wedge_masks_420_8x4: Align64<[[[u8; 8 * 4]; 16]; 2]> = Align64([[[0; 8 * 4]; 16]; 2]);
static mut wedge_masks_420_4x16: Align64<[[[u8; 4 * 16]; 16]; 2]> = Align64([[[0; 4 * 16]; 16]; 2]);
static mut wedge_masks_420_4x8: Align64<[[[u8; 4 * 8]; 16]; 2]> = Align64([[[0; 4 * 8]; 16]; 2]);
static mut wedge_masks_420_4x4: Align64<[[[u8; 4 * 4]; 16]; 2]> = Align64([[[0; 4 * 4]; 16]; 2]);

pub static mut dav1d_wedge_masks: [[[[*const u8; 16]; 2]; 3]; N_BS_SIZES] =
[[[[0 as *const u8; 16]; 2]; 3]; N_BS_SIZES];
Expand Down Expand Up @@ -255,60 +262,48 @@ unsafe fn init_chroma(
}

#[cold]
unsafe fn fill2d_16x2(
dst: &mut [u8],
unsafe fn fill2d_16x2<const LEN_444: usize, const LEN_422: usize, const LEN_420: usize>(
dst: &mut [[[u8; LEN_444]; 16]; 2],
w: usize,
h: usize,
master: &[[u8; 64 * 64]; 6],
cb: &[wedge_code_type; 16],
mut masks_444: &mut [u8],
mut masks_422: &mut [u8],
mut masks_420: &mut [u8],
masks_444: &mut [[[u8; LEN_444]; 16]; 2],
masks_422: &mut [[[u8; LEN_422]; 16]; 2],
masks_420: &mut [[[u8; LEN_420]; 16]; 2],
signs: libc::c_uint,
) -> [[[*const u8; 16]; 2]; 3] {
assert!(dst.len() == 2 * 16 * w * h);
assert!(LEN_444 == (w * h) >> 0);
assert!(LEN_422 == (w * h) >> 1);
assert!(LEN_420 == (w * h) >> 2);

let mut ptr = &mut dst[..];
for n in 0..16 {
copy2d(
ptr.as_mut_ptr(),
dst[0][n].as_mut_ptr(),
master[cb[n].direction as usize].as_ptr(),
w,
h,
32 - (w * cb[n].x_offset as usize >> 3),
32 - (h * cb[n].y_offset as usize >> 3),
);
ptr = &mut ptr[w * h..];
}
let (dst, ptr) = dst.split_at_mut(16 * w * h);
let mut off = 0;
for _ in 0..16 {
invert(ptr[off..].as_mut_ptr(), dst[off..].as_ptr(), w, h);
off += w * h;
for n in 0..16 {
invert(dst[1][n].as_mut_ptr(), dst[0][n].as_ptr(), w, h);
}

let mut masks = [[[0 as *const u8; 16]; 2]; 3];

let n_stride_444 = w * h;
let n_stride_422 = n_stride_444 >> 1;
let n_stride_420 = n_stride_444 >> 2;
let sign_stride_444 = 16 * n_stride_444;
let sign_stride_422 = 16 * n_stride_422;
let sign_stride_420 = 16 * n_stride_420;
// assign pointers in externally visible array
for n in 0..16 {
let sign = (signs >> n & 1) != 0;
masks[0][0][n] = masks_444[sign as usize * sign_stride_444..].as_ptr();
masks[0][0][n] = masks_444[sign as usize][n].as_ptr();
// not using !sign is intentional here, since 444 does not require
// any rounding since no chroma subsampling is applied.
masks[0][1][n] = masks_444[sign as usize * sign_stride_444..].as_ptr();
masks[1][0][n] = masks_422[sign as usize * sign_stride_422..].as_ptr();
masks[1][1][n] = masks_422[!sign as usize * sign_stride_422..].as_ptr();
masks[2][0][n] = masks_420[sign as usize * sign_stride_420..].as_ptr();
masks[2][1][n] = masks_420[!sign as usize * sign_stride_420..].as_ptr();
masks_444 = &mut masks_444[n_stride_444..];
masks_422 = &mut masks_422[n_stride_422..];
masks_420 = &mut masks_420[n_stride_420..];
masks[0][1][n] = masks_444[sign as usize][n].as_ptr();
masks[1][0][n] = masks_422[sign as usize][n].as_ptr();
masks[1][1][n] = masks_422[!sign as usize][n].as_ptr();
masks[2][0][n] = masks_420[sign as usize][n].as_ptr();
masks[2][1][n] = masks_420[!sign as usize][n].as_ptr();

// since the pointers come from inside, we know that
// violation of the const is OK here. Any other approach
Expand Down

0 comments on commit f4f26c9

Please sign in to comment.