Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(stack-limiter): correctly calculate max height for functions #81

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
204 changes: 140 additions & 64 deletions src/stack_limiter/max_height.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
use super::resolve_func_type;
use alloc::vec::Vec;
use parity_wasm::elements::{self, BlockType, Type};
use parity_wasm::elements::{self, BlockType, Instruction, Type};

#[cfg(feature = "sign_ext")]
use parity_wasm::elements::SignExtInstruction;

// The cost in stack items that should be charged per call of a function. This is
// is a static cost that is added to each function call. This makes sense because even
// if a function does not use any parameters or locals some stack space on the host
// machine might be consumed to hold some context.
const ACTIVATION_FRAME_COST: u32 = 2;

/// Control stack frame.
#[derive(Debug)]
struct Frame {
Expand Down Expand Up @@ -41,8 +35,8 @@ struct Stack {
}

impl Stack {
fn new() -> Stack {
Stack { height: ACTIVATION_FRAME_COST, control_stack: Vec::new() }
fn new() -> Self {
Self { height: 0, control_stack: Vec::new() }
}

/// Returns current height of the value stack.
Expand Down Expand Up @@ -121,58 +115,135 @@ impl Stack {
}
}

/// This function expects the function to be validated.
pub fn compute(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
use parity_wasm::elements::Instruction::*;

let func_section = module.function_section().ok_or("No function section")?;
let code_section = module.code_section().ok_or("No code section")?;
let type_section = module.type_section().ok_or("No type section")?;

// Get a signature and a body of the specified function.
let func_sig_idx = func_section
.entries()
.get(func_idx as usize)
.ok_or("Function is not found in func section")?
.type_ref();
let Type::Function(func_signature) = type_section
.types()
.get(func_sig_idx as usize)
.ok_or("Function is not found in func section")?;
let body = code_section
.bodies()
.get(func_idx as usize)
.ok_or("Function body for the index isn't found")?;
let instructions = body.code();

let mut stack = Stack::new();
let mut max_height: u32 = 0;
let mut pc = 0;

// Add implicit frame for the function. Breaks to this frame and execution of
// the last end should deal with this frame.
let func_arity = func_signature.results().len() as u32;
stack.push_frame(Frame {
is_polymorphic: false,
end_arity: func_arity,
branch_arity: func_arity,
start_height: 0,
});

loop {
if pc >= instructions.elements().len() {
break
/// This is a helper context that is used by [`MaxStackHeightCounter`].
#[derive(Clone, Copy)]
pub(crate) struct MaxStackHeightCounterContext<'a> {
pub module: &'a elements::Module,
pub func_imports: u32,
pub func_section: &'a elements::FunctionSection,
pub code_section: &'a elements::CodeSection,
pub type_section: &'a elements::TypeSection,
}

impl<'a> TryFrom<&'a elements::Module> for MaxStackHeightCounterContext<'a> {
type Error = &'static str;

fn try_from(module: &'a elements::Module) -> Result<Self, Self::Error> {
Ok(Self {
module,
func_imports: module
.import_count(elements::ImportCountType::Function)
.try_into()
.map_err(|_| "Can't convert func imports count to u32")?,
func_section: module.function_section().ok_or("No function section")?,
code_section: module.code_section().ok_or("No code section")?,
type_section: module.type_section().ok_or("No type section")?,
})
}
}

/// This is a counter for the maximum stack height with the ability to take into account the
/// overhead that is added by the [`instrument_call!`] macro.
pub(crate) struct MaxStackHeightCounter<'a> {
context: MaxStackHeightCounterContext<'a>,
stack: Stack,
max_height: u32,
count_instrumented_calls: bool,
}

impl<'a> MaxStackHeightCounter<'a> {
/// Creates a [`MaxStackHeightCounter`] from [`MaxStackHeightCounterContext`].
pub fn new_with_context(context: MaxStackHeightCounterContext<'a>) -> Self {
Self { context, stack: Stack::new(), max_height: 0, count_instrumented_calls: false }
}

/// Should the overhead of the [`instrument_call!`] macro be taken into account?
pub fn count_instrumented_calls(mut self, count_instrumented_calls: bool) -> Self {
self.count_instrumented_calls = count_instrumented_calls;
self
}

/// Tries to calculate the maximum stack height for the `func_idx` defined in the wasm module.
pub fn compute_for_defined_func(&mut self, func_idx: u32) -> Result<u32, &'static str> {
let MaxStackHeightCounterContext { func_section, code_section, type_section, .. } =
self.context;

// Get a signature and a body of the specified function.
let func_sig_idx = func_section
.entries()
.get(func_idx as usize)
.ok_or("Function is not found in func section")?
.type_ref();
let Type::Function(func_signature) = type_section
.types()
.get(func_sig_idx as usize)
.ok_or("Function is not found in func section")?;
let body = code_section
.bodies()
.get(func_idx as usize)
.ok_or("Function body for the index isn't found")?;
let instructions = body.code();

self.compute_for_raw_func(func_signature, instructions.elements())
}

/// Tries to calculate the maximum stack height for a raw function, which consists of
/// `func_signature` and `instructions`.
pub fn compute_for_raw_func(
&mut self,
func_signature: &elements::FunctionType,
instructions: &[Instruction],
) -> Result<u32, &'static str> {
// Add implicit frame for the function. Breaks to this frame and execution of
// the last end should deal with this frame.
let func_arity = func_signature.results().len() as u32;
self.stack.push_frame(Frame {
is_polymorphic: false,
end_arity: func_arity,
branch_arity: func_arity,
start_height: 0,
});

for instruction in instructions {
let maybe_instructions =
self.count_instrumented_calls
.then_some(instruction)
.and_then(|inst| match inst {
&Instruction::Call(idx) if idx >= self.context.func_imports =>
Some(instrument_call!(idx, 0, 0, 0)),
_ => None,
});

if let Some(instructions) = maybe_instructions.as_ref() {
for instruction in instructions {
self.process_instruction(instruction, func_arity)?;
}
} else {
self.process_instruction(instruction, func_arity)?;
}
}

Ok(self.max_height)
}

/// This function processes all incoming instructions and updates the `self.max_height` field.
fn process_instruction(
&mut self,
opcode: &Instruction,
func_arity: u32,
) -> Result<(), &'static str> {
use Instruction::*;

let Self { stack, max_height, .. } = self;
let MaxStackHeightCounterContext { module, type_section, .. } = self.context;

// If current value stack is higher than maximal height observed so far,
// save the new height.
// However, we don't increase maximal value in unreachable code.
if stack.height() > max_height && !stack.frame(0)?.is_polymorphic {
max_height = stack.height();
if stack.height() > *max_height && !stack.frame(0)?.is_polymorphic {
*max_height = stack.height();
}

let opcode = &instructions.elements()[pc];

match opcode {
Nop => {},
Block(ty) | Loop(ty) | If(ty) => {
Expand Down Expand Up @@ -403,17 +474,22 @@ pub fn compute(func_idx: u32, module: &elements::Module) -> Result<u32, &'static
stack.push_values(1)?;
},
}
pc += 1;
}

Ok(max_height)
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use parity_wasm::elements;

fn compute(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
MaxStackHeightCounter::new_with_context(module.try_into()?)
.count_instrumented_calls(true)
.compute_for_defined_func(func_idx)
}

fn parse_wat(source: &str) -> elements::Module {
elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm"))
.expect("Failed to deserialize the module")
Expand All @@ -437,7 +513,7 @@ mod tests {
);

let height = compute(0, &module).unwrap();
assert_eq!(height, 3 + ACTIVATION_FRAME_COST);
assert_eq!(height, 3);
}

#[test]
Expand All @@ -454,7 +530,7 @@ mod tests {
);

let height = compute(0, &module).unwrap();
assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
assert_eq!(height, 1);
}

#[test]
Expand All @@ -472,7 +548,7 @@ mod tests {
);

let height = compute(0, &module).unwrap();
assert_eq!(height, ACTIVATION_FRAME_COST);
assert_eq!(height, 0);
}

#[test]
Expand Down Expand Up @@ -501,7 +577,7 @@ mod tests {
);

let height = compute(0, &module).unwrap();
assert_eq!(height, 2 + ACTIVATION_FRAME_COST);
assert_eq!(height, 2);
}

#[test]
Expand All @@ -525,7 +601,7 @@ mod tests {
);

let height = compute(0, &module).unwrap();
assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
assert_eq!(height, 1);
}

#[test]
Expand All @@ -547,7 +623,7 @@ mod tests {
);

let height = compute(0, &module).unwrap();
assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
assert_eq!(height, 1);
}

#[test]
Expand All @@ -573,6 +649,6 @@ mod tests {
);

let height = compute(0, &module).unwrap();
assert_eq!(height, 3 + ACTIVATION_FRAME_COST);
assert_eq!(height, 3);
}
}
39 changes: 27 additions & 12 deletions src/stack_limiter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use alloc::{vec, vec::Vec};
use core::mem;
use max_height::{MaxStackHeightCounter, MaxStackHeightCounterContext};
use parity_wasm::{
builder,
elements::{self, Instruction, Instructions, Type},
Expand Down Expand Up @@ -155,16 +156,27 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
///
/// Returns a vector with a stack cost for each function, including imports.
fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, &'static str> {
let func_imports = module.import_count(elements::ImportCountType::Function);
let functions_space = module
.functions_space()
.try_into()
.map_err(|_| "Can't convert functions space to u32")?;

// Don't create context when there are no functions (this will fail).
if functions_space == 0 {
return Ok(Vec::new());
}

// TODO: optimize!
(0..module.functions_space())
// This context already contains the module, number of imports and section references.
// So we can use it to optimize access to these objects.
let context: MaxStackHeightCounterContext = module.try_into()?;

(0..functions_space)
.map(|func_idx| {
if func_idx < func_imports {
if func_idx < context.func_imports {
// We can't calculate stack_cost of the import functions.
Ok(0)
} else {
compute_stack_cost(func_idx as u32, module)
compute_stack_cost(func_idx, context)
}
})
.collect()
Expand All @@ -173,17 +185,18 @@ fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, &'static s
/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
/// number of arguments plus number of local variables) and the maximal stack
/// height.
fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
fn compute_stack_cost(
func_idx: u32,
context: MaxStackHeightCounterContext,
) -> Result<u32, &'static str> {
// To calculate the cost of a function we need to convert index from
// function index space to defined function spaces.
let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
let defined_func_idx = func_idx
.checked_sub(func_imports)
.checked_sub(context.func_imports)
.ok_or("This should be a index of a defined function")?;

let code_section =
module.code_section().ok_or("Due to validation code section should exists")?;
let body = &code_section
let body = context
.code_section
.bodies()
.get(defined_func_idx as usize)
.ok_or("Function body is out of bounds")?;
Expand All @@ -194,7 +207,9 @@ fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, &
locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
}

let max_stack_height = max_height::compute(defined_func_idx, module)?;
let max_stack_height = MaxStackHeightCounter::new_with_context(context)
.count_instrumented_calls(true)
.compute_for_defined_func(defined_func_idx)?;

locals_count
.checked_add(max_stack_height)
Expand Down
Loading
Loading