Skip to content

Commit

Permalink
struct GetBits: Make safe (#632)
Browse files Browse the repository at this point in the history
`GetBits` is made safe by storing an index and a slice instead of a
current, start, and end ptr.
  • Loading branch information
kkysen authored Jan 4, 2024
2 parents f07d467 + 44816ae commit 634cd97
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 52 deletions.
52 changes: 25 additions & 27 deletions src/getbits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,37 @@ use std::ffi::c_int;
use std::ffi::c_uint;

#[repr(C)]
pub struct GetBits {
pub struct GetBits<'a> {
state: u64,
bits_left: c_int,
error: c_int,
ptr: *const u8,
ptr_start: *const u8,
ptr_end: *const u8,
index: usize,
data: &'a [u8],
}

impl GetBits {
pub const unsafe fn new(data: *const u8, sz: usize) -> Self {
assert!(sz != 0);
impl<'a> GetBits<'a> {
pub const fn new(data: &'a [u8]) -> Self {
assert!(!data.is_empty());
Self {
ptr_start: data,
ptr: data,
ptr_end: data.add(sz),
state: 0,
bits_left: 0,
error: 0,
index: 0,
data,
}
}

pub const fn has_error(&self) -> c_int {
self.error
}

pub unsafe fn get_bit(&mut self) -> c_uint {
pub fn get_bit(&mut self) -> c_uint {
if self.bits_left == 0 {
if self.ptr >= self.ptr_end {
if self.index >= self.data.len() {
self.error = 1;
} else {
let state = *self.ptr as c_uint;
self.ptr = self.ptr.add(1);
let state = self.data[self.index] as c_uint;
self.index += 1;
self.bits_left = 7;
self.state = (state as u64) << 57;
return state >> 7;
Expand All @@ -49,19 +47,19 @@ impl GetBits {
}

#[inline]
unsafe fn refill(&mut self, n: c_int) {
fn refill(&mut self, n: c_int) {
assert!(self.bits_left >= 0 && self.bits_left < 32);
let mut state = 0;
loop {
if self.ptr >= self.ptr_end {
if self.index >= self.data.len() {
self.error = 1;
if state != 0 {
break;
}
return;
} else {
state = state << 8 | *self.ptr as c_uint;
self.ptr = self.ptr.add(1);
state = (state << 8) | self.data[self.index] as c_uint;
self.index += 1;
self.bits_left += 8;
if !(n > self.bits_left) {
break;
Expand All @@ -71,7 +69,7 @@ impl GetBits {
self.state |= (state as u64) << 64 - self.bits_left;
}

pub unsafe fn get_bits(&mut self, n: c_int) -> c_uint {
pub fn get_bits(&mut self, n: c_int) -> c_uint {
assert!(n > 0 && n <= 32);
// Unsigned cast avoids refill after eob.
if n as c_uint > self.bits_left as c_uint {
Expand All @@ -83,7 +81,7 @@ impl GetBits {
(state as u64 >> 64 - n) as c_uint
}

pub unsafe fn get_sbits(&mut self, n: c_int) -> c_int {
pub fn get_sbits(&mut self, n: c_int) -> c_int {
assert!(n > 0 && n <= 32);
// Unsigned cast avoids refill after eob.
if n as c_uint > self.bits_left as c_uint {
Expand All @@ -95,7 +93,7 @@ impl GetBits {
(state as i64 >> 64 - n) as c_int
}

pub unsafe fn get_uleb128(&mut self) -> c_uint {
pub fn get_uleb128(&mut self) -> c_uint {
let mut val = 0;
let mut i = 0 as c_uint;
let mut more;
Expand All @@ -115,7 +113,7 @@ impl GetBits {
val as c_uint
}

pub unsafe fn get_uniform(&mut self, max: c_uint) -> c_uint {
pub fn get_uniform(&mut self, max: c_uint) -> c_uint {
assert!(max > 1);
let l = ulog2(max) + 1;
assert!(l > 1);
Expand All @@ -128,7 +126,7 @@ impl GetBits {
}
}

pub unsafe fn get_vlc(&mut self) -> c_uint {
pub fn get_vlc(&mut self) -> c_uint {
if self.get_bit() != 0 {
return 0;
}
Expand All @@ -145,7 +143,7 @@ impl GetBits {
(1 << n_bits) - 1 + self.get_bits(n_bits)
}

unsafe fn get_bits_subexp_u(&mut self, r#ref: c_uint, n: c_uint) -> c_uint {
fn get_bits_subexp_u(&mut self, r#ref: c_uint, n: c_uint) -> c_uint {
let mut v = 0 as c_uint;
let mut i = 0;
loop {
Expand All @@ -168,7 +166,7 @@ impl GetBits {
}
}

pub unsafe fn get_bits_subexp(&mut self, r#ref: c_int, n: c_uint) -> c_int {
pub fn get_bits_subexp(&mut self, r#ref: c_int, n: c_uint) -> c_int {
self.get_bits_subexp_u((r#ref + (1 << n)) as c_uint, 2 << n) as c_int - (1 << n)
}

Expand All @@ -179,7 +177,7 @@ impl GetBits {
}

#[inline]
pub const unsafe fn pos(&self) -> c_uint {
self.ptr.offset_from(self.ptr_start) as c_uint * 8 - self.bits_left as c_uint
pub const fn pos(&self) -> c_uint {
self.index as c_uint * u8::BITS - self.bits_left as c_uint
}
}
40 changes: 15 additions & 25 deletions src/obu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ use std::ffi::c_int;
use std::ffi::c_uint;
use std::fmt;
use std::mem::MaybeUninit;
use std::slice;

struct Debug {
enabled: bool,
Expand All @@ -126,7 +127,7 @@ struct Debug {
}

impl Debug {
pub const unsafe fn new(enabled: bool, name: &'static str, gb: &GetBits) -> Self {
pub const fn new(enabled: bool, name: &'static str, gb: &GetBits) -> Self {
Self {
enabled,
name,
Expand All @@ -147,7 +148,7 @@ impl Debug {
}
}

pub unsafe fn log(&self, gb: &GetBits, msg: fmt::Arguments) {
pub fn log(&self, gb: &GetBits, msg: fmt::Arguments) {
let &Self {
enabled,
name,
Expand All @@ -160,15 +161,12 @@ impl Debug {
println!("{name}: {msg} [off={offset}]");
}

pub unsafe fn post(&self, gb: &GetBits, post: &str) {
pub fn post(&self, gb: &GetBits, post: &str) {
self.log(gb, format_args!("post-{post}"));
}
}

unsafe fn parse_seq_hdr(
c: &mut Rav1dContext,
gb: &mut GetBits,
) -> Rav1dResult<Rav1dSequenceHeader> {
fn parse_seq_hdr(c: &mut Rav1dContext, gb: &mut GetBits) -> Rav1dResult<Rav1dSequenceHeader> {
let debug = Debug::new(false, "SEQHDR", gb);

let profile = gb.get_bits(3) as c_int;
Expand Down Expand Up @@ -818,7 +816,7 @@ unsafe fn parse_refidx(
Ok(refidx)
}

unsafe fn parse_tiling(
fn parse_tiling(
seqhdr: &Rav1dSequenceHeader,
size: &Rav1dFrameSize,
debug: &Debug,
Expand Down Expand Up @@ -939,7 +937,7 @@ unsafe fn parse_tiling(
})
}

unsafe fn parse_quant(
fn parse_quant(
seqhdr: &Rav1dSequenceHeader,
debug: &Debug,
gb: &mut GetBits,
Expand Down Expand Up @@ -1029,7 +1027,7 @@ unsafe fn parse_quant(
}
}

unsafe fn parse_seg_data(gb: &mut GetBits) -> Rav1dSegmentationDataSet {
fn parse_seg_data(gb: &mut GetBits) -> Rav1dSegmentationDataSet {
let mut preskip = 0;
let mut last_active_segid = -1;
let d = array::from_fn(|i| {
Expand Down Expand Up @@ -1186,7 +1184,7 @@ unsafe fn parse_segmentation(
})
}

unsafe fn parse_delta(
fn parse_delta(
quant: &Rav1dFrameHeader_quant,
allow_intrabc: c_int,
debug: &Debug,
Expand Down Expand Up @@ -1308,7 +1306,7 @@ unsafe fn parse_loopfilter(
})
}

unsafe fn parse_cdef(
fn parse_cdef(
seqhdr: &Rav1dSequenceHeader,
all_lossless: c_int,
allow_intrabc: c_int,
Expand Down Expand Up @@ -1345,7 +1343,7 @@ unsafe fn parse_cdef(
}
}

unsafe fn parse_restoration(
fn parse_restoration(
seqhdr: &Rav1dSequenceHeader,
all_lossless: c_int,
super_res_enabled: c_int,
Expand Down Expand Up @@ -1567,7 +1565,7 @@ unsafe fn parse_gmv(
Ok(gmv)
}

unsafe fn parse_film_grain_data(
fn parse_film_grain_data(
seqhdr: &Rav1dSequenceHeader,
seed: c_uint,
gb: &mut GetBits,
Expand Down Expand Up @@ -2076,10 +2074,7 @@ unsafe fn parse_frame_hdr(
})
}

unsafe fn parse_tile_hdr(
tiling: &Rav1dFrameHeader_tiling,
gb: &mut GetBits,
) -> Rav1dTileGroupHeader {
fn parse_tile_hdr(tiling: &Rav1dFrameHeader_tiling, gb: &mut GetBits) -> Rav1dTileGroupHeader {
let n_tiles = tiling.cols * tiling.rows;
let have_tile_pos = if n_tiles > 1 {
gb.get_bit() as c_int
Expand All @@ -2102,7 +2097,7 @@ unsafe fn parse_tile_hdr(

/// Check that we haven't read more than `obu_len`` bytes
/// from the buffer since `init_bit_pos`.
unsafe fn check_for_overrun(
fn check_for_overrun(
c: &mut Rav1dContext,
gb: &mut GetBits,
init_bit_pos: c_uint,
Expand Down Expand Up @@ -2154,7 +2149,7 @@ unsafe fn parse_obus(
len + init_byte_pos
}

let mut gb = GetBits::new(r#in.data, r#in.sz);
let mut gb = GetBits::new(slice::from_raw_parts(r#in.data, r#in.sz));

// obu header
gb.get_bit(); // obu_forbidden_bit
Expand Down Expand Up @@ -2189,11 +2184,6 @@ unsafe fn parse_obus(
// when reading the leb128 length field).
assert!(init_bit_pos & 7 == 0);

// We also know that we haven't tried to read more than `r#in.sz`
// bytes yet (otherwise the error flag would have been set
// by the code in [`crate::src::getbits`]).
assert!(r#in.sz >= init_byte_pos as usize);

// Make sure that there are enough bits left in the buffer
// for the rest of the OBU.
if len as usize > r#in.sz - init_byte_pos as usize {
Expand Down

0 comments on commit 634cd97

Please sign in to comment.