Skip to content

Commit

Permalink
reentrancy detector test works!
Browse files Browse the repository at this point in the history
  • Loading branch information
TilakMaddy committed Oct 6, 2024
1 parent 68d3a96 commit e73945e
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 9 deletions.
1 change: 0 additions & 1 deletion aderyn_core/src/context/browser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ mod sort_nodes;
mod storage_vars;
pub use ancestral_line::*;
pub use closest_ancestor::*;
pub use external_calls::*;
pub use extractor::*;
pub use immediate_children::*;
pub use location::*;
Expand Down
31 changes: 23 additions & 8 deletions aderyn_core/src/context/flow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ impl Cfg {
impl Cfg {
/// Creates a new CFG from a given FunctionDefinition's body
///
/// * Returns - Tuple containing Cfg, Start Node, End Node
///
/// We don't yet have the ability to derive a CFG for the whole function because that involves
/// combining modifiers with the function body plus resolving internal functions, etc.
Expand All @@ -381,7 +382,7 @@ impl Cfg {
pub fn from_function_body(
context: &WorkspaceContext,
function_definition: &FunctionDefinition,
) -> Option<Cfg> {
) -> Option<(Cfg, CfgNodeId, CfgNodeId)> {
// Verify that the function has a body
let function_body_block = function_definition.body.as_ref()?;

Expand Down Expand Up @@ -409,7 +410,21 @@ impl Cfg {
cfg.callibrate_jump_statements_in_body(start, end);

// Return the CFG
Some(cfg)
Some((cfg, start, end))
}
}

// These methods help with recursion for detectors using the library
impl CfgNodeId {
pub fn children(&self, cfg: &Cfg) -> Vec<CfgNodeId> {
cfg.raw_successors(*self)
}
}

impl CfgNode {
pub fn children<'a>(&self, cfg: &'a Cfg) -> Vec<&'a CfgNode> {
let children_ids = cfg.raw_successors(self.id);
children_ids.into_iter().map(|c| cfg.nodes.get(&c).expect("cfg invalid!")).collect()
}
}

Expand Down Expand Up @@ -655,7 +670,7 @@ mod control_flow_tests {
);
let contract = context.find_contract_by_name("SimpleProgram");
let function = contract.find_function_by_name("function11");
let cfg = Cfg::from_function_body(&context, function).unwrap();
let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap();

output_graph(&context, &cfg, "SimpleProgram_function11");
assert_eq!(cfg.nodes.len(), 26);
Expand All @@ -670,7 +685,7 @@ mod control_flow_tests {
);
let contract = context.find_contract_by_name("SimpleProgram");
let function = contract.find_function_by_name("function12");
let cfg = Cfg::from_function_body(&context, function).unwrap();
let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap();

output_graph(&context, &cfg, "SimpleProgram_function12");
assert_eq!(cfg.nodes.len(), 42);
Expand All @@ -685,7 +700,7 @@ mod control_flow_tests {
);
let contract = context.find_contract_by_name("SimpleProgram");
let function = contract.find_function_by_name("function13");
let cfg = Cfg::from_function_body(&context, function).unwrap();
let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap();

output_graph(&context, &cfg, "SimpleProgram_function13");
assert_eq!(cfg.nodes.len(), 36);
Expand All @@ -700,7 +715,7 @@ mod control_flow_tests {
);
let contract = context.find_contract_by_name("SimpleProgram");
let function = contract.find_function_by_name("function14");
let cfg = Cfg::from_function_body(&context, function).unwrap();
let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap();

output_graph(&context, &cfg, "SimpleProgram_function14");
assert_eq!(cfg.nodes.len(), 46);
Expand All @@ -715,7 +730,7 @@ mod control_flow_tests {
);
let contract = context.find_contract_by_name("SimpleProgram");
let function = contract.find_function_by_name("function15");
let cfg = Cfg::from_function_body(&context, function).unwrap();
let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap();

output_graph(&context, &cfg, "SimpleProgram_function15");
assert_eq!(cfg.nodes.len(), 70);
Expand All @@ -730,7 +745,7 @@ mod control_flow_tests {
);
let contract = context.find_contract_by_name("SimpleProgram");
let function = contract.find_function_by_name("function16");
let cfg = Cfg::from_function_body(&context, function).unwrap();
let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap();

output_graph(&context, &cfg, "SimpleProgram_function16");
assert_eq!(cfg.nodes.len(), 82);
Expand Down
7 changes: 7 additions & 0 deletions aderyn_core/src/detect/detector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use std::{
str::FromStr,
};

use self::state_change_after_ext_call::StateChangeAfterExternalCallDetector;

pub fn get_all_issue_detectors() -> Vec<Box<dyn IssueDetector>> {
vec![
Box::<DelegateCallInLoopDetector>::default(),
Expand Down Expand Up @@ -103,6 +105,7 @@ pub fn get_all_issue_detectors() -> Vec<Box<dyn IssueDetector>> {
Box::<StateVariableChangesWithoutEventDetector>::default(),
Box::<StateVariableCouldBeImmutableDetector>::default(),
Box::<MultiplePlaceholdersDetector>::default(),
Box::<StateChangeAfterExternalCallDetector>::default(),
]
}

Expand All @@ -114,6 +117,7 @@ pub fn get_all_detectors_names() -> Vec<String> {
#[derive(Debug, PartialEq, EnumString, Display)]
#[strum(serialize_all = "kebab-case")]
pub(crate) enum IssueDetectorNamePool {
StateChangeAfterExternalCall,
StateVariableCouldBeDeclaredImmutable,
MultiplePlaceholders,
StateVariableChangesWithoutEvents,
Expand Down Expand Up @@ -211,6 +215,9 @@ pub fn request_issue_detector_by_name(detector_name: &str) -> Option<Box<dyn Iss
// Expects a valid detector_name
let detector_name = IssueDetectorNamePool::from_str(detector_name).ok()?;
match detector_name {
IssueDetectorNamePool::StateChangeAfterExternalCall => {
Some(Box::<StateChangeAfterExternalCallDetector>::default())
}
IssueDetectorNamePool::StateVariableCouldBeDeclaredImmutable => {
Some(Box::<StateVariableCouldBeImmutableDetector>::default())
}
Expand Down
1 change: 1 addition & 0 deletions aderyn_core/src/detect/high/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub(crate) mod reused_contract_name;
pub(crate) mod rtlo;
pub(crate) mod selfdestruct;
pub(crate) mod send_ether_no_checks;
pub(crate) mod state_change_after_ext_call;
pub(crate) mod state_variable_shadowing;
pub(crate) mod storage_array_edit_with_memory;
pub(crate) mod storage_signed_integer_array;
Expand Down
214 changes: 214 additions & 0 deletions aderyn_core/src/detect/high/state_change_after_ext_call.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
use std::collections::{BTreeMap, HashSet};
use std::error::Error;

use crate::ast::NodeID;

use crate::capture;
use crate::context::browser::{ApproximateStorageChangeFinder, ExtractFunctionCalls};
use crate::context::flow::{Cfg, CfgNodeId};
use crate::detect::detector::IssueDetectorNamePool;
use crate::detect::helpers;
use crate::{
context::workspace_context::WorkspaceContext,
detect::detector::{IssueDetector, IssueSeverity},
};
use eyre::{eyre, Result};

#[derive(Default)]
pub struct StateChangeAfterExternalCallDetector {
// Keys are: [0] source file name, [1] line number, [2] character location of node.
// Do not add items manually, use `capture!` to add nodes to this BTreeMap.
found_instances: BTreeMap<(String, usize, String), NodeID>,
hints: BTreeMap<(String, usize, String), String>,
}

impl IssueDetector for StateChangeAfterExternalCallDetector {
fn detect(&mut self, context: &WorkspaceContext) -> Result<bool, Box<dyn Error>> {
// When you have found an instance of the issue,
// use the following macro to add it to `found_instances`:
//
// capture!(self, context, item);
// capture!(self, context, item, "hint");

for func in helpers::get_implemented_external_and_public_functions(context) {
let (cfg, start, _) =
Cfg::from_function_body(context, func).ok_or(eyre!("corrupted function"))?;

// Discover external calls
let external_call_sites = find_external_call_sites(context, &cfg, start);

// For each external call, figure out if it's followed by a state change
for external_call_site in external_call_sites {
// Discover state changes that follow the external call
let state_changes = find_state_change_sites(context, &cfg, external_call_site);

for state_change in state_changes {
// There is no way to tell is the state change took place after the event if
// both are found at the same place
if state_change != external_call_site {
// Capture the external call
let external_call_cfg_node =
cfg.nodes.get(&external_call_site).expect("cfg is corrupted!");

if let Some(external_call_ast_node) =
external_call_cfg_node.reflect(context)
{
capture!(self, context, external_call_ast_node);
}
}
}
}
}

Ok(!self.found_instances.is_empty())
}

fn severity(&self) -> IssueSeverity {
IssueSeverity::High
}

fn title(&self) -> String {
String::from("External call is followed by a state variable change")
}

fn description(&self) -> String {
String::from("In most cases it is a best practice to perform the state change before making an external call to avoid a potential re-entrancy attack.")
}

fn instances(&self) -> BTreeMap<(String, usize, String), NodeID> {
self.found_instances.clone()
}

fn hints(&self) -> BTreeMap<(String, usize, String), String> {
self.hints.clone()
}

fn name(&self) -> String {
IssueDetectorNamePool::StateChangeAfterExternalCall.to_string()
}
}

fn find_state_change_sites(
context: &WorkspaceContext,
cfg: &Cfg,
start_node: CfgNodeId,
) -> HashSet<CfgNodeId> {
let mut visited = Default::default();
let mut state_change_sites = Default::default();

fn _find_following_state_change_sites(
context: &WorkspaceContext,
cfg: &Cfg,
visited: &mut HashSet<CfgNodeId>,
curr_node: CfgNodeId,
state_change_sites: &mut HashSet<CfgNodeId>,
) -> Option<()> {
if visited.contains(&curr_node) {
return Some(());
}

visited.insert(curr_node);

// Check if `curr_node` is an external call site
let curr_cfg_node = cfg.nodes.get(&curr_node)?;

// Grab the AST version of the Cfg Node
if let Some(curr_ast_node) = curr_cfg_node.reflect(context) {
let state_changes = ApproximateStorageChangeFinder::from(context, curr_ast_node);

if state_changes.state_variables_have_been_manipulated() {
state_change_sites.insert(curr_node);
}
}

// Continue the recursion
for child in curr_node.children(cfg) {
_find_following_state_change_sites(context, cfg, visited, child, state_change_sites);
}

Some(())
}

_find_following_state_change_sites(
context,
cfg,
&mut visited,
start_node,
&mut state_change_sites,
);

state_change_sites
}

fn find_external_call_sites(
context: &WorkspaceContext,
cfg: &Cfg,
start_node: CfgNodeId,
) -> HashSet<CfgNodeId> {
let mut visited = Default::default();
let mut external_call_sites = Default::default();

fn _find_external_call_sites(
context: &WorkspaceContext,
cfg: &Cfg,
visited: &mut HashSet<CfgNodeId>,
curr_node: CfgNodeId,
external_call_sites: &mut HashSet<CfgNodeId>,
) -> Option<()> {
if visited.contains(&curr_node) {
return Some(());
}

visited.insert(curr_node);

// Check if `curr_node` is an external call site
let curr_cfg_node = cfg.nodes.get(&curr_node)?;

// Grab the AST version of the Cfg Node
if let Some(curr_ast_node) = curr_cfg_node.reflect(context) {
let function_calls = ExtractFunctionCalls::from(curr_ast_node).extracted;

if function_calls.iter().any(|f| f.is_external_call()) {
external_call_sites.insert(curr_node);
}
}

// Continue the recursion
for child in curr_node.children(cfg) {
_find_external_call_sites(context, cfg, visited, child, external_call_sites);
}

Some(())
}

_find_external_call_sites(context, cfg, &mut visited, start_node, &mut external_call_sites);

external_call_sites
}

#[cfg(test)]
mod state_change_after_external_call_tests {
use serial_test::serial;

use crate::detect::{
detector::IssueDetector,
high::state_change_after_ext_call::StateChangeAfterExternalCallDetector,
};

#[test]
#[serial]
fn test_state_change_after_external_call() {
let context = crate::detect::test_utils::load_solidity_source_unit(
"../tests/contract-playground/src/StateChangeAfterExternalCall.sol",
);

let mut detector = StateChangeAfterExternalCallDetector::default();
let found = detector.detect(&context).unwrap();
// assert that the detector found an issue
assert!(found);
// assert that the detector found the correct number of instances
assert_eq!(detector.instances().len(), 3);
// assert the severity is high
assert_eq!(detector.severity(), crate::detect::detector::IssueSeverity::High);
}
}
Loading

0 comments on commit e73945e

Please sign in to comment.