From e73945e7a1755ad48388f7b4844336f1e1b5ae05 Mon Sep 17 00:00:00 2001 From: TilakMaddy Date: Sun, 6 Oct 2024 19:51:01 +0530 Subject: [PATCH] reentrancy detector test works! --- aderyn_core/src/context/browser/mod.rs | 1 - aderyn_core/src/context/flow/mod.rs | 31 ++- aderyn_core/src/detect/detector.rs | 7 + aderyn_core/src/detect/high/mod.rs | 1 + .../high/state_change_after_ext_call.rs | 214 ++++++++++++++++++ .../src/StateChangeAfterExternalCall.sol | 74 ++++++ 6 files changed, 319 insertions(+), 9 deletions(-) create mode 100644 aderyn_core/src/detect/high/state_change_after_ext_call.rs create mode 100644 tests/contract-playground/src/StateChangeAfterExternalCall.sol diff --git a/aderyn_core/src/context/browser/mod.rs b/aderyn_core/src/context/browser/mod.rs index 327320d7a..35cf6eb95 100644 --- a/aderyn_core/src/context/browser/mod.rs +++ b/aderyn_core/src/context/browser/mod.rs @@ -14,7 +14,6 @@ mod sort_nodes; mod storage_vars; pub use ancestral_line::*; pub use closest_ancestor::*; -pub use external_calls::*; pub use extractor::*; pub use immediate_children::*; pub use location::*; diff --git a/aderyn_core/src/context/flow/mod.rs b/aderyn_core/src/context/flow/mod.rs index 64c4abdc2..1d4d83797 100644 --- a/aderyn_core/src/context/flow/mod.rs +++ b/aderyn_core/src/context/flow/mod.rs @@ -373,6 +373,7 @@ impl Cfg { impl Cfg { /// Creates a new CFG from a given FunctionDefinition's body /// + /// * Returns - Tuple containing Cfg, Start Node, End Node /// /// We don't yet have the ability to derive a CFG for the whole function because that involves /// combining modifiers with the function body plus resolving internal functions, etc. @@ -381,7 +382,7 @@ impl Cfg { pub fn from_function_body( context: &WorkspaceContext, function_definition: &FunctionDefinition, - ) -> Option { + ) -> Option<(Cfg, CfgNodeId, CfgNodeId)> { // Verify that the function has a body let function_body_block = function_definition.body.as_ref()?; @@ -409,7 +410,21 @@ impl Cfg { cfg.callibrate_jump_statements_in_body(start, end); // Return the CFG - Some(cfg) + Some((cfg, start, end)) + } +} + +// These methods help with recursion for detectors using the library +impl CfgNodeId { + pub fn children(&self, cfg: &Cfg) -> Vec { + cfg.raw_successors(*self) + } +} + +impl CfgNode { + pub fn children<'a>(&self, cfg: &'a Cfg) -> Vec<&'a CfgNode> { + let children_ids = cfg.raw_successors(self.id); + children_ids.into_iter().map(|c| cfg.nodes.get(&c).expect("cfg invalid!")).collect() } } @@ -655,7 +670,7 @@ mod control_flow_tests { ); let contract = context.find_contract_by_name("SimpleProgram"); let function = contract.find_function_by_name("function11"); - let cfg = Cfg::from_function_body(&context, function).unwrap(); + let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap(); output_graph(&context, &cfg, "SimpleProgram_function11"); assert_eq!(cfg.nodes.len(), 26); @@ -670,7 +685,7 @@ mod control_flow_tests { ); let contract = context.find_contract_by_name("SimpleProgram"); let function = contract.find_function_by_name("function12"); - let cfg = Cfg::from_function_body(&context, function).unwrap(); + let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap(); output_graph(&context, &cfg, "SimpleProgram_function12"); assert_eq!(cfg.nodes.len(), 42); @@ -685,7 +700,7 @@ mod control_flow_tests { ); let contract = context.find_contract_by_name("SimpleProgram"); let function = contract.find_function_by_name("function13"); - let cfg = Cfg::from_function_body(&context, function).unwrap(); + let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap(); output_graph(&context, &cfg, "SimpleProgram_function13"); assert_eq!(cfg.nodes.len(), 36); @@ -700,7 +715,7 @@ mod control_flow_tests { ); let contract = context.find_contract_by_name("SimpleProgram"); let function = contract.find_function_by_name("function14"); - let cfg = Cfg::from_function_body(&context, function).unwrap(); + let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap(); output_graph(&context, &cfg, "SimpleProgram_function14"); assert_eq!(cfg.nodes.len(), 46); @@ -715,7 +730,7 @@ mod control_flow_tests { ); let contract = context.find_contract_by_name("SimpleProgram"); let function = contract.find_function_by_name("function15"); - let cfg = Cfg::from_function_body(&context, function).unwrap(); + let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap(); output_graph(&context, &cfg, "SimpleProgram_function15"); assert_eq!(cfg.nodes.len(), 70); @@ -730,7 +745,7 @@ mod control_flow_tests { ); let contract = context.find_contract_by_name("SimpleProgram"); let function = contract.find_function_by_name("function16"); - let cfg = Cfg::from_function_body(&context, function).unwrap(); + let (cfg, _, _) = Cfg::from_function_body(&context, function).unwrap(); output_graph(&context, &cfg, "SimpleProgram_function16"); assert_eq!(cfg.nodes.len(), 82); diff --git a/aderyn_core/src/detect/detector.rs b/aderyn_core/src/detect/detector.rs index 9002caecf..1c3ced3de 100644 --- a/aderyn_core/src/detect/detector.rs +++ b/aderyn_core/src/detect/detector.rs @@ -14,6 +14,8 @@ use std::{ str::FromStr, }; +use self::state_change_after_ext_call::StateChangeAfterExternalCallDetector; + pub fn get_all_issue_detectors() -> Vec> { vec![ Box::::default(), @@ -103,6 +105,7 @@ pub fn get_all_issue_detectors() -> Vec> { Box::::default(), Box::::default(), Box::::default(), + Box::::default(), ] } @@ -114,6 +117,7 @@ pub fn get_all_detectors_names() -> Vec { #[derive(Debug, PartialEq, EnumString, Display)] #[strum(serialize_all = "kebab-case")] pub(crate) enum IssueDetectorNamePool { + StateChangeAfterExternalCall, StateVariableCouldBeDeclaredImmutable, MultiplePlaceholders, StateVariableChangesWithoutEvents, @@ -211,6 +215,9 @@ pub fn request_issue_detector_by_name(detector_name: &str) -> Option { + Some(Box::::default()) + } IssueDetectorNamePool::StateVariableCouldBeDeclaredImmutable => { Some(Box::::default()) } diff --git a/aderyn_core/src/detect/high/mod.rs b/aderyn_core/src/detect/high/mod.rs index a38523098..bed78a49e 100644 --- a/aderyn_core/src/detect/high/mod.rs +++ b/aderyn_core/src/detect/high/mod.rs @@ -26,6 +26,7 @@ pub(crate) mod reused_contract_name; pub(crate) mod rtlo; pub(crate) mod selfdestruct; pub(crate) mod send_ether_no_checks; +pub(crate) mod state_change_after_ext_call; pub(crate) mod state_variable_shadowing; pub(crate) mod storage_array_edit_with_memory; pub(crate) mod storage_signed_integer_array; diff --git a/aderyn_core/src/detect/high/state_change_after_ext_call.rs b/aderyn_core/src/detect/high/state_change_after_ext_call.rs new file mode 100644 index 000000000..52c40562f --- /dev/null +++ b/aderyn_core/src/detect/high/state_change_after_ext_call.rs @@ -0,0 +1,214 @@ +use std::collections::{BTreeMap, HashSet}; +use std::error::Error; + +use crate::ast::NodeID; + +use crate::capture; +use crate::context::browser::{ApproximateStorageChangeFinder, ExtractFunctionCalls}; +use crate::context::flow::{Cfg, CfgNodeId}; +use crate::detect::detector::IssueDetectorNamePool; +use crate::detect::helpers; +use crate::{ + context::workspace_context::WorkspaceContext, + detect::detector::{IssueDetector, IssueSeverity}, +}; +use eyre::{eyre, Result}; + +#[derive(Default)] +pub struct StateChangeAfterExternalCallDetector { + // Keys are: [0] source file name, [1] line number, [2] character location of node. + // Do not add items manually, use `capture!` to add nodes to this BTreeMap. + found_instances: BTreeMap<(String, usize, String), NodeID>, + hints: BTreeMap<(String, usize, String), String>, +} + +impl IssueDetector for StateChangeAfterExternalCallDetector { + fn detect(&mut self, context: &WorkspaceContext) -> Result> { + // When you have found an instance of the issue, + // use the following macro to add it to `found_instances`: + // + // capture!(self, context, item); + // capture!(self, context, item, "hint"); + + for func in helpers::get_implemented_external_and_public_functions(context) { + let (cfg, start, _) = + Cfg::from_function_body(context, func).ok_or(eyre!("corrupted function"))?; + + // Discover external calls + let external_call_sites = find_external_call_sites(context, &cfg, start); + + // For each external call, figure out if it's followed by a state change + for external_call_site in external_call_sites { + // Discover state changes that follow the external call + let state_changes = find_state_change_sites(context, &cfg, external_call_site); + + for state_change in state_changes { + // There is no way to tell is the state change took place after the event if + // both are found at the same place + if state_change != external_call_site { + // Capture the external call + let external_call_cfg_node = + cfg.nodes.get(&external_call_site).expect("cfg is corrupted!"); + + if let Some(external_call_ast_node) = + external_call_cfg_node.reflect(context) + { + capture!(self, context, external_call_ast_node); + } + } + } + } + } + + Ok(!self.found_instances.is_empty()) + } + + fn severity(&self) -> IssueSeverity { + IssueSeverity::High + } + + fn title(&self) -> String { + String::from("External call is followed by a state variable change") + } + + fn description(&self) -> String { + String::from("In most cases it is a best practice to perform the state change before making an external call to avoid a potential re-entrancy attack.") + } + + fn instances(&self) -> BTreeMap<(String, usize, String), NodeID> { + self.found_instances.clone() + } + + fn hints(&self) -> BTreeMap<(String, usize, String), String> { + self.hints.clone() + } + + fn name(&self) -> String { + IssueDetectorNamePool::StateChangeAfterExternalCall.to_string() + } +} + +fn find_state_change_sites( + context: &WorkspaceContext, + cfg: &Cfg, + start_node: CfgNodeId, +) -> HashSet { + let mut visited = Default::default(); + let mut state_change_sites = Default::default(); + + fn _find_following_state_change_sites( + context: &WorkspaceContext, + cfg: &Cfg, + visited: &mut HashSet, + curr_node: CfgNodeId, + state_change_sites: &mut HashSet, + ) -> Option<()> { + if visited.contains(&curr_node) { + return Some(()); + } + + visited.insert(curr_node); + + // Check if `curr_node` is an external call site + let curr_cfg_node = cfg.nodes.get(&curr_node)?; + + // Grab the AST version of the Cfg Node + if let Some(curr_ast_node) = curr_cfg_node.reflect(context) { + let state_changes = ApproximateStorageChangeFinder::from(context, curr_ast_node); + + if state_changes.state_variables_have_been_manipulated() { + state_change_sites.insert(curr_node); + } + } + + // Continue the recursion + for child in curr_node.children(cfg) { + _find_following_state_change_sites(context, cfg, visited, child, state_change_sites); + } + + Some(()) + } + + _find_following_state_change_sites( + context, + cfg, + &mut visited, + start_node, + &mut state_change_sites, + ); + + state_change_sites +} + +fn find_external_call_sites( + context: &WorkspaceContext, + cfg: &Cfg, + start_node: CfgNodeId, +) -> HashSet { + let mut visited = Default::default(); + let mut external_call_sites = Default::default(); + + fn _find_external_call_sites( + context: &WorkspaceContext, + cfg: &Cfg, + visited: &mut HashSet, + curr_node: CfgNodeId, + external_call_sites: &mut HashSet, + ) -> Option<()> { + if visited.contains(&curr_node) { + return Some(()); + } + + visited.insert(curr_node); + + // Check if `curr_node` is an external call site + let curr_cfg_node = cfg.nodes.get(&curr_node)?; + + // Grab the AST version of the Cfg Node + if let Some(curr_ast_node) = curr_cfg_node.reflect(context) { + let function_calls = ExtractFunctionCalls::from(curr_ast_node).extracted; + + if function_calls.iter().any(|f| f.is_external_call()) { + external_call_sites.insert(curr_node); + } + } + + // Continue the recursion + for child in curr_node.children(cfg) { + _find_external_call_sites(context, cfg, visited, child, external_call_sites); + } + + Some(()) + } + + _find_external_call_sites(context, cfg, &mut visited, start_node, &mut external_call_sites); + + external_call_sites +} + +#[cfg(test)] +mod state_change_after_external_call_tests { + use serial_test::serial; + + use crate::detect::{ + detector::IssueDetector, + high::state_change_after_ext_call::StateChangeAfterExternalCallDetector, + }; + + #[test] + #[serial] + fn test_state_change_after_external_call() { + let context = crate::detect::test_utils::load_solidity_source_unit( + "../tests/contract-playground/src/StateChangeAfterExternalCall.sol", + ); + + let mut detector = StateChangeAfterExternalCallDetector::default(); + let found = detector.detect(&context).unwrap(); + // assert that the detector found an issue + assert!(found); + // assert that the detector found the correct number of instances + assert_eq!(detector.instances().len(), 3); + // assert the severity is high + assert_eq!(detector.severity(), crate::detect::detector::IssueSeverity::High); + } +} diff --git a/tests/contract-playground/src/StateChangeAfterExternalCall.sol b/tests/contract-playground/src/StateChangeAfterExternalCall.sol new file mode 100644 index 000000000..991950982 --- /dev/null +++ b/tests/contract-playground/src/StateChangeAfterExternalCall.sol @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.19; + +contract MaliciousActor { + // LOGIC inside hello() doesn't matter. So long as it's an external call we don't trust! + function hello() external { + (bool s, ) = msg.sender.call(""); + require(s, "attempt failed"); + } +} + +contract StateChangeAfterExternalCall { + uint256 s_useMe; + MaliciousActor s_actor; + + constructor(address actor) { + require(actor != address(0)); + s_actor = MaliciousActor(actor); + } + + // BAD + function badSituation1() external { + // Interaction + s_actor.hello(); + + // Effect + s_useMe += 1; + } + + // BAD + function badSituation2() external { + // Interaction + s_actor.hello(); + + if (msg.sender != address(this)) { + // Effect + s_useMe -= 1; + } + } + + // BAD + function badSituation3() external { + // NOTE: Although this may seem like it's following CEI, because it's inside a loop + //one can imagine that in the second iteration, the effect happens after the interaction + //in the first iteration. + + for (uint256 i = 0; i < s_useMe; ++i) { + // Effect + s_useMe += 4; + + // Interaction + s_actor.hello(); + } + } + + // GOOD + function goodSituation1() external { + // Effect + s_useMe += 1; + + // Interaction + s_actor.hello(); + } + + // GOOD + function goodSituation2() external { + if (msg.sender != address(this)) { + s_actor.hello(); + return; + } + + s_useMe += 1; + } +}