diff --git a/aderyn_core/src/context/investigator/standard.rs b/aderyn_core/src/context/graph/callgraph.rs similarity index 60% rename from aderyn_core/src/context/investigator/standard.rs rename to aderyn_core/src/context/graph/callgraph.rs index 70329a92..7e8ae518 100644 --- a/aderyn_core/src/context/investigator/standard.rs +++ b/aderyn_core/src/context/graph/callgraph.rs @@ -1,6 +1,6 @@ //! This module helps with strategies on performing different types of investigations. //! -//! Our first kind of investigator is [`StandardInvestigator`] it comes bundled with actions to help +//! Our first kind of investigator is [`CallGraph`] it comes bundled with actions to help //! application modules "hook in" and consume the graphs. //! //! @@ -16,53 +16,53 @@ use crate::{ }, }; -use super::StandardInvestigatorVisitor; +use super::traits::CallGraphVisitor; #[derive(PartialEq)] -pub enum StandardInvestigationStyle { - /// Picks the regular call graph (forward) - Downstream, +pub enum CallGraphDirection { + /// Depper into the callgraph + Inward, - /// Picks the reverse call graph - Upstream, + /// Opposite of Inward + Outward, - /// Picks both the call graphs (choose this if upstream side effects also need to be tracked) + /// Both inward and outward (If outward side effects also need to be tracked) BothWays, } -pub struct StandardInvestigator { - /// Ad-hoc Nodes that we would like to explore downstream from. +pub struct CallGraph { + /// Ad-hoc Nodes that we would like to explore inward from. pub entry_points: Vec, /// Surface points are calculated based on the entry points (input) /// and only consists of [`crate::ast::FunctionDefinition`] and [`crate::ast::ModifierDefinition`] /// These are nodes that are the *actual* starting points for traversal in the graph - pub forward_surface_points: Vec, + pub inward_surface_points: Vec, - /// Same as the forward one, but acts on reverse graph. - pub backward_surface_points: Vec, + /// Same as the inward one, but acts on reverse graph. + pub outward_surface_points: Vec, /// Decides what graph type to chose from [`WorkspaceContext`] - pub investigation_style: StandardInvestigationStyle, + pub direction: CallGraphDirection, } #[derive(PartialEq, Clone, Copy)] enum CurrentDFSVector { - Forward, // Going downstream - Backward, // Going upstream - UpstreamSideEffect, // Going downstream from upstream nodes + Inward, + Outward, + OutwardSideEffect, } -impl StandardInvestigator { - /// Creates a [`StandardInvestigator`] by exploring paths from given nodes. This is the starting point. - pub fn for_specific_nodes( +impl CallGraph { + /// Creates a [`CallGraph`] by exploring paths from given nodes. This is the starting point. + pub fn from_nodes( context: &WorkspaceContext, nodes: &[&ASTNode], - investigation_style: StandardInvestigationStyle, - ) -> super::Result { + direction: CallGraphDirection, + ) -> super::Result { let mut entry_points = vec![]; - let mut forward_surface_points = vec![]; - let mut backward_surface_points = vec![]; + let mut inward_surface_points = vec![]; + let mut outward_surface_points = vec![]; // Construct entry points for &node in nodes { @@ -72,30 +72,30 @@ impl StandardInvestigator { entry_points.push(node_id); } - // Construct forward surface points + // Construct inward surface points for &node in nodes { let referenced_declarations = ExtractReferencedDeclarations::from(node).extracted; for declared_id in referenced_declarations { if let Some(node) = context.nodes.get(&declared_id) { if node.node_type() == NodeType::ModifierDefinition { - forward_surface_points.push(declared_id); + inward_surface_points.push(declared_id); } else if let ASTNode::FunctionDefinition(function_definition) = node { if function_definition.implemented { - forward_surface_points.push(declared_id); + inward_surface_points.push(declared_id); } } } } } - // Construct backward surface points + // Construct outward surface points for &node in nodes { if node.node_type() == NodeType::FunctionDefinition || node.node_type() == NodeType::ModifierDefinition { if let Some(id) = node.id() { - backward_surface_points.push(id); + outward_surface_points.push(id); } } else { let parent_surface_point = node @@ -105,44 +105,44 @@ impl StandardInvestigator { }); if let Some(parent_surface_point) = parent_surface_point { if let Some(parent_surface_point_id) = parent_surface_point.id() { - backward_surface_points.push(parent_surface_point_id); + outward_surface_points.push(parent_surface_point_id); } } } } - Ok(StandardInvestigator { + Ok(CallGraph { entry_points, - forward_surface_points, - backward_surface_points, - investigation_style, + inward_surface_points, + outward_surface_points, + direction, }) } pub fn new( context: &WorkspaceContext, nodes: &[&ASTNode], - investigation_style: StandardInvestigationStyle, - ) -> super::Result { - Self::for_specific_nodes(context, nodes, investigation_style) + direction: CallGraphDirection, + ) -> super::Result { + Self::from_nodes(context, nodes, direction) } /// Visit the entry points and all the plausible function definitions and modifier definitions that /// EVM may encounter during execution. - pub fn investigate(&self, context: &WorkspaceContext, visitor: &mut T) -> super::Result<()> + pub fn accept(&self, context: &WorkspaceContext, visitor: &mut T) -> super::Result<()> where - T: StandardInvestigatorVisitor, + T: CallGraphVisitor, { - self._investigate( + self._accept( context, context - .forward_callgraph + .inward_callgraph .as_ref() - .ok_or(super::Error::ForwardCallgraphNotAvailable)?, + .ok_or(super::Error::InwardCallgraphNotAvailable)?, context - .reverse_callgraph + .outward_callgraph .as_ref() - .ok_or(super::Error::BackwardCallgraphNotAvailable)?, + .ok_or(super::Error::OutwardCallgraphNotAvailable)?, visitor, ) } @@ -151,15 +151,15 @@ impl StandardInvestigator { /// First, we visit the entry points. Then, we derive the subgraph from the [`WorkspaceCallGraph`] /// which consists of all the nodes that can be reached by traversing the edges starting /// from the surface points. - fn _investigate( + fn _accept( &self, context: &WorkspaceContext, - forward_callgraph: &WorkspaceCallGraph, - reverse_callgraph: &WorkspaceCallGraph, + inward_callgraph: &WorkspaceCallGraph, + outward_callgraph: &WorkspaceCallGraph, visitor: &mut T, ) -> super::Result<()> where - T: StandardInvestigatorVisitor, + T: CallGraphVisitor, { // Visit entry point nodes (so that trackers can track the state across all code regions in 1 place) for entry_point_id in &self.entry_points { @@ -167,40 +167,40 @@ impl StandardInvestigator { } // Keep track of visited node IDs during DFS from surface nodes - let mut visited_downstream = HashSet::new(); - let mut visited_upstream = HashSet::new(); - let mut visited_upstream_side_effects = HashSet::new(); + let mut visited_inward = HashSet::new(); + let mut visited_outward = HashSet::new(); + let mut visited_outward_side_effects = HashSet::new(); - // Now decide, which points to visit upstream or downstream - if self.investigation_style == StandardInvestigationStyle::BothWays - || self.investigation_style == StandardInvestigationStyle::Downstream + // Now decide, which points to visit outward or inward + if self.direction == CallGraphDirection::BothWays + || self.direction == CallGraphDirection::Inward { // Visit the subgraph starting from surface points - for surface_point_id in &self.forward_surface_points { + for surface_point_id in &self.inward_surface_points { self.dfs_and_visit_subgraph( *surface_point_id, - &mut visited_downstream, + &mut visited_inward, context, - forward_callgraph, + inward_callgraph, visitor, - CurrentDFSVector::Forward, + CurrentDFSVector::Inward, None, )?; } } - if self.investigation_style == StandardInvestigationStyle::BothWays - || self.investigation_style == StandardInvestigationStyle::Upstream + if self.direction == CallGraphDirection::BothWays + || self.direction == CallGraphDirection::Outward { // Visit the subgraph starting from surface points - for surface_point_id in &self.backward_surface_points { + for surface_point_id in &self.outward_surface_points { self.dfs_and_visit_subgraph( *surface_point_id, - &mut visited_upstream, + &mut visited_outward, context, - reverse_callgraph, + outward_callgraph, visitor, - CurrentDFSVector::Backward, + CurrentDFSVector::Outward, None, )?; } @@ -209,22 +209,22 @@ impl StandardInvestigator { // Collect already visited nodes so that we don't repeat visit calls on them // while traversing through side effect nodes. let mut blacklisted = HashSet::new(); - blacklisted.extend(visited_downstream.iter()); - blacklisted.extend(visited_upstream.iter()); + blacklisted.extend(visited_inward.iter()); + blacklisted.extend(visited_outward.iter()); blacklisted.extend(self.entry_points.iter()); - if self.investigation_style == StandardInvestigationStyle::BothWays { - // Visit the subgraph from the upstream points (go downstream in forward graph) - // but do not re-visit the upstream nodes or the downstream nodes again + if self.direction == CallGraphDirection::BothWays { + // Visit the subgraph from the outward points (go inward in inward graph) + // but do not re-visit the outward nodes or the inward nodes again - for surface_point_id in &visited_upstream { + for surface_point_id in &visited_outward { self.dfs_and_visit_subgraph( *surface_point_id, - &mut visited_upstream_side_effects, + &mut visited_outward_side_effects, context, - forward_callgraph, + inward_callgraph, visitor, - CurrentDFSVector::UpstreamSideEffect, + CurrentDFSVector::OutwardSideEffect, Some(&blacklisted), )?; } @@ -245,7 +245,7 @@ impl StandardInvestigator { blacklist: Option<&HashSet>, ) -> super::Result<()> where - T: StandardInvestigatorVisitor, + T: CallGraphVisitor, { if visited.contains(&node_id) { return Ok(()); @@ -271,7 +271,7 @@ impl StandardInvestigator { )?; } - if let Some(pointing_to) = callgraph.graph.get(&node_id) { + if let Some(pointing_to) = callgraph.raw_callgraph.get(&node_id) { for destination in pointing_to { self.dfs_and_visit_subgraph( *destination, @@ -295,7 +295,7 @@ impl StandardInvestigator { current_investigation_direction: CurrentDFSVector, ) -> super::Result<()> where - T: StandardInvestigatorVisitor, + T: CallGraphVisitor, { if let Some(node) = context.nodes.get(&node_id) { if node.node_type() != NodeType::FunctionDefinition @@ -305,43 +305,43 @@ impl StandardInvestigator { } match current_investigation_direction { - CurrentDFSVector::Forward => { + CurrentDFSVector::Inward => { if let ASTNode::FunctionDefinition(function) = node { visitor - .visit_downstream_function_definition(function) - .map_err(|_| super::Error::DownstreamFunctionDefinitionVisitError)?; + .visit_inward_function_definition(function) + .map_err(|_| super::Error::InwardFunctionDefinitionVisitError)?; } if let ASTNode::ModifierDefinition(modifier) = node { visitor - .visit_downstream_modifier_definition(modifier) - .map_err(|_| super::Error::DownstreamModifierDefinitionVisitError)?; + .visit_inward_modifier_definition(modifier) + .map_err(|_| super::Error::InwardModifierDefinitionVisitError)?; } } - CurrentDFSVector::Backward => { + CurrentDFSVector::Outward => { if let ASTNode::FunctionDefinition(function) = node { visitor - .visit_upstream_function_definition(function) - .map_err(|_| super::Error::UpstreamFunctionDefinitionVisitError)?; + .visit_outward_function_definition(function) + .map_err(|_| super::Error::OutwardFunctionDefinitionVisitError)?; } if let ASTNode::ModifierDefinition(modifier) = node { visitor - .visit_upstream_modifier_definition(modifier) - .map_err(|_| super::Error::UpstreamModifierDefinitionVisitError)?; + .visit_outward_modifier_definition(modifier) + .map_err(|_| super::Error::OutwardModifierDefinitionVisitError)?; } } - CurrentDFSVector::UpstreamSideEffect => { + CurrentDFSVector::OutwardSideEffect => { if let ASTNode::FunctionDefinition(function) = node { visitor - .visit_upstream_side_effect_function_definition(function) + .visit_outward_side_effect_function_definition(function) .map_err(|_| { - super::Error::UpstreamSideEffectFunctionDefinitionVisitError + super::Error::OutwardSideEffectFunctionDefinitionVisitError })?; } if let ASTNode::ModifierDefinition(modifier) = node { visitor - .visit_upstream_side_effect_modifier_definition(modifier) + .visit_outward_side_effect_modifier_definition(modifier) .map_err(|_| { - super::Error::UpstreamSideEffectModifierDefinitionVisitError + super::Error::OutwardSideEffectModifierDefinitionVisitError })?; } } @@ -358,7 +358,7 @@ impl StandardInvestigator { visitor: &mut T, ) -> super::Result<()> where - T: StandardInvestigatorVisitor, + T: CallGraphVisitor, { let node = context .nodes diff --git a/aderyn_core/src/context/graph/callgraph_tests.rs b/aderyn_core/src/context/graph/callgraph_tests.rs new file mode 100644 index 00000000..49be47ea --- /dev/null +++ b/aderyn_core/src/context/graph/callgraph_tests.rs @@ -0,0 +1,272 @@ +#![allow(clippy::collapsible_match)] + +#[cfg(test)] +mod callgraph_tests { + use crate::{ + ast::{FunctionDefinition, ModifierDefinition}, + context::{ + graph::{callgraph::CallGraph, traits::CallGraphVisitor}, + workspace_context::{ASTNode, WorkspaceContext}, + }, + }; + + use crate::context::graph::callgraph::CallGraphDirection::{BothWays, Inward, Outward}; + use serial_test::serial; + + fn get_function_by_name(context: &WorkspaceContext, name: &str) -> ASTNode { + ASTNode::from( + context + .function_definitions() + .into_iter() + .find(|&x| x.name == *name) + .unwrap(), + ) + } + + fn get_modifier_definition_by_name(context: &WorkspaceContext, name: &str) -> ASTNode { + ASTNode::from( + context + .modifier_definitions() + .into_iter() + .find(|&x| x.name == *name) + .unwrap(), + ) + } + + #[test] + #[serial] + fn test_callgraph_is_not_none() { + let context = crate::detect::test_utils::load_solidity_source_unit( + "../tests/contract-playground/src/CallGraphTests.sol", + ); + assert!(context.inward_callgraph.is_some()); + assert!(context.outward_callgraph.is_some()); + } + + #[test] + #[serial] + fn test_tower1_modifier_has_no_inward() { + let context = crate::detect::test_utils::load_solidity_source_unit( + "../tests/contract-playground/src/CallGraphTests.sol", + ); + + let visit_eighth_floor1 = get_function_by_name(&context, "visitEighthFloor1"); + + let callgraph = CallGraph::new(&context, &[&visit_eighth_floor1], Inward).unwrap(); + + let mut tracker = Tracker::new(&context); + callgraph.accept(&context, &mut tracker).unwrap(); + + assert!(tracker.inward_func_definitions_names.is_empty()); + assert!(tracker.inward_modifier_definitions_names.is_empty()); + } + + #[test] + #[serial] + fn test_tower1_modifier_has_outward() { + let context = crate::detect::test_utils::load_solidity_source_unit( + "../tests/contract-playground/src/CallGraphTests.sol", + ); + + let visit_eighth_floor1 = get_function_by_name(&context, "visitEighthFloor1"); + + let callgraph = CallGraph::new(&context, &[&visit_eighth_floor1], Outward).unwrap(); + + let mut tracker = Tracker::new(&context); + callgraph.accept(&context, &mut tracker).unwrap(); + + assert!(tracker.has_found_outward_modifiers_with_names(&["passThroughNinthFloor1"])); + assert!(tracker.has_found_outward_functions_with_names(&["enterTenthFloor1"])); + } + + #[test] + #[serial] + fn test_tower2_modifier_has_both_outward_and_inward() { + let context = crate::detect::test_utils::load_solidity_source_unit( + "../tests/contract-playground/src/CallGraphTests.sol", + ); + + let pass_through_ninth_floor2 = + get_modifier_definition_by_name(&context, "passThroughNinthFloor2"); + + let callgraph = CallGraph::new(&context, &[&pass_through_ninth_floor2], BothWays).unwrap(); + + let mut tracker = Tracker::new(&context); + callgraph.accept(&context, &mut tracker).unwrap(); + + assert!(tracker.has_found_inward_functions_with_names(&["visitEighthFloor2"])); + assert!(tracker.has_found_outward_functions_with_names(&["enterTenthFloor2"])); + } + + #[test] + #[serial] + fn test_tower3_modifier_has_both_outward_and_inward() { + let context = crate::detect::test_utils::load_solidity_source_unit( + "../tests/contract-playground/src/CallGraphTests.sol", + ); + + let pass_through_ninth_floor3 = + get_modifier_definition_by_name(&context, "passThroughNinthFloor3"); + + let callgraph = CallGraph::new(&context, &[&pass_through_ninth_floor3], BothWays).unwrap(); + + let mut tracker = Tracker::new(&context); + callgraph.accept(&context, &mut tracker).unwrap(); + + assert!(tracker.has_found_outward_functions_with_names(&["enterTenthFloor3"])); + assert!(tracker.has_found_inward_functions_with_names(&["visitEighthFloor3"])); + assert!(tracker.has_not_found_any_outward_functions_with_name("visitSeventhFloor3")); + assert!(tracker.has_found_outward_side_effect_functions_with_name(&["visitSeventhFloor3"])); + } + + #[test] + #[serial] + fn test_tower3_functions_has_outward() { + let context = crate::detect::test_utils::load_solidity_source_unit( + "../tests/contract-playground/src/CallGraphTests.sol", + ); + + let visit_eighth_floor3 = get_function_by_name(&context, "visitSeventhFloor3"); + + let callgraph = CallGraph::new(&context, &[&visit_eighth_floor3], Outward).unwrap(); + + let mut tracker = Tracker::new(&context); + callgraph.accept(&context, &mut tracker).unwrap(); + + assert!(tracker.has_found_outward_functions_with_names(&["enterTenthFloor3"])); + } + + #[test] + #[serial] + fn test_tower4_functions_has_outward_and_inward() { + let context = crate::detect::test_utils::load_solidity_source_unit( + "../tests/contract-playground/src/CallGraphTests.sol", + ); + + let recurse = get_function_by_name(&context, "recurse"); + + let callgraph = CallGraph::new(&context, &[&recurse], BothWays).unwrap(); + + let mut tracker = Tracker::new(&context); + callgraph.accept(&context, &mut tracker).unwrap(); + + assert!(tracker.has_found_outward_functions_with_names(&["recurse"])); + assert!(tracker.has_found_inward_functions_with_names(&["recurse"])); + } + + struct Tracker<'a> { + context: &'a WorkspaceContext, + entry_points: Vec<(String, usize, String)>, + inward_func_definitions_names: Vec, + outward_func_definitions_names: Vec, + inward_modifier_definitions_names: Vec, + outward_modifier_definitions_names: Vec, + outward_side_effects_func_definitions_names: Vec, + outward_side_effects_modifier_definitions_names: Vec, + } + + impl<'a> Tracker<'a> { + fn new(context: &WorkspaceContext) -> Tracker { + Tracker { + context, + entry_points: vec![], + inward_func_definitions_names: vec![], + inward_modifier_definitions_names: vec![], + outward_func_definitions_names: vec![], + outward_modifier_definitions_names: vec![], + outward_side_effects_func_definitions_names: vec![], + outward_side_effects_modifier_definitions_names: vec![], + } + } + + // inward functions + fn has_found_inward_functions_with_names(&self, name: &[&str]) -> bool { + name.iter() + .all(|&n| self.inward_func_definitions_names.contains(&n.to_string())) + } + + // outward functions + fn has_found_outward_functions_with_names(&self, name: &[&str]) -> bool { + name.iter() + .all(|&n| self.outward_func_definitions_names.contains(&n.to_string())) + } + + fn has_not_found_any_outward_functions_with_name(&self, name: &str) -> bool { + !self + .outward_func_definitions_names + .contains(&name.to_string()) + } + + // outward modifiers + fn has_found_outward_modifiers_with_names(&self, name: &[&str]) -> bool { + name.iter().all(|&n| { + self.outward_modifier_definitions_names + .contains(&n.to_string()) + }) + } + + // outward side effects + fn has_found_outward_side_effect_functions_with_name(&self, name: &[&str]) -> bool { + name.iter().all(|&n| { + self.outward_side_effects_func_definitions_names + .contains(&n.to_string()) + }) + } + } + + impl CallGraphVisitor for Tracker<'_> { + fn visit_entry_point(&mut self, node: &ASTNode) -> eyre::Result<()> { + self.entry_points + .push(self.context.get_node_sort_key_pure(node)); + Ok(()) + } + fn visit_inward_function_definition( + &mut self, + node: &crate::ast::FunctionDefinition, + ) -> eyre::Result<()> { + self.inward_func_definitions_names + .push(node.name.to_string()); + Ok(()) + } + fn visit_inward_modifier_definition( + &mut self, + node: &crate::ast::ModifierDefinition, + ) -> eyre::Result<()> { + self.inward_modifier_definitions_names + .push(node.name.to_string()); + Ok(()) + } + fn visit_outward_function_definition( + &mut self, + node: &crate::ast::FunctionDefinition, + ) -> eyre::Result<()> { + self.outward_func_definitions_names + .push(node.name.to_string()); + Ok(()) + } + fn visit_outward_modifier_definition( + &mut self, + node: &crate::ast::ModifierDefinition, + ) -> eyre::Result<()> { + self.outward_modifier_definitions_names + .push(node.name.to_string()); + Ok(()) + } + fn visit_outward_side_effect_function_definition( + &mut self, + node: &FunctionDefinition, + ) -> eyre::Result<()> { + self.outward_side_effects_func_definitions_names + .push(node.name.to_string()); + Ok(()) + } + fn visit_outward_side_effect_modifier_definition( + &mut self, + node: &ModifierDefinition, + ) -> eyre::Result<()> { + self.outward_side_effects_modifier_definitions_names + .push(node.name.to_string()); + Ok(()) + } + } +} diff --git a/aderyn_core/src/context/graph/mod.rs b/aderyn_core/src/context/graph/mod.rs index 3d688c44..e6ca14e6 100644 --- a/aderyn_core/src/context/graph/mod.rs +++ b/aderyn_core/src/context/graph/mod.rs @@ -1,10 +1,16 @@ -pub mod traits; +mod callgraph; +mod callgraph_tests; +mod traits; mod workspace_callgraph; +pub use callgraph::*; +pub use traits::*; pub use workspace_callgraph::*; use derive_more::From; +use crate::ast::{ASTNode, NodeID}; + pub type Result = core::result::Result; #[derive(Debug, From)] @@ -14,6 +20,17 @@ pub enum Error { // region: -- standard::* errors WorkspaceCallGraphDFSError, + InwardCallgraphNotAvailable, + OutwardCallgraphNotAvailable, + UnidentifiedEntryPointNode(ASTNode), + InvalidEntryPointId(NodeID), + EntryPointVisitError, + OutwardFunctionDefinitionVisitError, + OutwardModifierDefinitionVisitError, + InwardFunctionDefinitionVisitError, + InwardModifierDefinitionVisitError, + OutwardSideEffectFunctionDefinitionVisitError, + OutwardSideEffectModifierDefinitionVisitError, // endregion } diff --git a/aderyn_core/src/context/graph/traits.rs b/aderyn_core/src/context/graph/traits.rs index 661bbf57..f6476f49 100644 --- a/aderyn_core/src/context/graph/traits.rs +++ b/aderyn_core/src/context/graph/traits.rs @@ -1,4 +1,62 @@ +use crate::ast::{ASTNode, FunctionDefinition, ModifierDefinition}; + /// Trait to support reversing of callgraph. (Because, direct impl is not allowed on Foreign Types) pub trait Transpose { fn reverse(&self) -> Self; } + +/// Use with [`super::CallGraph`] +pub trait CallGraphVisitor { + /// Shift all logic to tracker otherwise, you would track state at 2 different places + /// One at the tracker level, and other at the application level. Instead, we must + /// contain all of the tracking logic in the tracker. Therefore, visit entry point + /// is essential because the tracker can get to take a look at not just the + /// inward functions and modifiers, but also the entry points that have invoked it. + fn visit_entry_point(&mut self, node: &ASTNode) -> eyre::Result<()> { + self.visit_any(node) + } + + /// Meant to be invoked while traversing [`crate::context::workspace_context::WorkspaceContext::inward_callgraph`] + fn visit_inward_function_definition(&mut self, node: &FunctionDefinition) -> eyre::Result<()> { + self.visit_any(&(node.into())) + } + + /// Meant to be invoked while traversing [`crate::context::workspace_context::WorkspaceContext::outward_callgraph`] + fn visit_outward_function_definition(&mut self, node: &FunctionDefinition) -> eyre::Result<()> { + self.visit_any(&(node.into())) + } + + /// Meant to be invoked while traversing [`crate::context::workspace_context::WorkspaceContext::inward_callgraph`] + fn visit_inward_modifier_definition(&mut self, node: &ModifierDefinition) -> eyre::Result<()> { + self.visit_any(&(node.into())) + } + + /// Meant to be invoked while traversing [`crate::context::workspace_context::WorkspaceContext::outward_callgraph`] + fn visit_outward_modifier_definition(&mut self, node: &ModifierDefinition) -> eyre::Result<()> { + self.visit_any(&(node.into())) + } + + /// Read as "outward's inward-side-effect" function definition + /// These are function definitions that are inward from the outward nodes + /// but are themselves neither outward nor inward to the entry points + fn visit_outward_side_effect_function_definition( + &mut self, + node: &FunctionDefinition, + ) -> eyre::Result<()> { + self.visit_any(&(node.into())) + } + + /// Read as "outward's inward-side-effect" modifier definition + /// These are modifier definitions that are inward from the outward nodes + /// but are themselves neither outward nor inward to the entry points + fn visit_outward_side_effect_modifier_definition( + &mut self, + node: &ModifierDefinition, + ) -> eyre::Result<()> { + self.visit_any(&(node.into())) + } + + fn visit_any(&mut self, _node: &ASTNode) -> eyre::Result<()> { + Ok(()) + } +} diff --git a/aderyn_core/src/context/graph/workspace_callgraph.rs b/aderyn_core/src/context/graph/workspace_callgraph.rs index 74517547..95f12fe5 100644 --- a/aderyn_core/src/context/graph/workspace_callgraph.rs +++ b/aderyn_core/src/context/graph/workspace_callgraph.rs @@ -12,18 +12,18 @@ use super::traits::Transpose; #[derive(Debug)] pub struct WorkspaceCallGraph { - pub graph: CallGraph, + pub raw_callgraph: RawCallGraph, } /** -* Every NodeID in CallGraph should corresponds to [`crate::ast::FunctionDefinition`] or [`crate::ast::ModifierDefinition`] +* Every NodeID in RawCallGraph should corresponds to [`crate::ast::FunctionDefinition`] or [`crate::ast::ModifierDefinition`] */ -pub type CallGraph = HashMap>; +pub type RawCallGraph = HashMap>; impl WorkspaceCallGraph { /// Formula to create [`WorkspaceCallGraph`] for global preprocessing . pub fn from_context(context: &WorkspaceContext) -> super::Result { - let mut graph: CallGraph = HashMap::new(); + let mut raw_callgraph: RawCallGraph = HashMap::new(); let mut visited: HashSet = HashSet::new(); let funcs = context @@ -35,16 +35,16 @@ impl WorkspaceCallGraph { let modifier_definitions = context.modifier_definitions(); for func in funcs { - dfs_to_create_graph(func.id, &mut graph, &mut visited, context) + dfs_to_create_graph(func.id, &mut raw_callgraph, &mut visited, context) .map_err(|_| super::Error::WorkspaceCallGraphDFSError)?; } for modifier in modifier_definitions { - dfs_to_create_graph(modifier.id, &mut graph, &mut visited, context) + dfs_to_create_graph(modifier.id, &mut raw_callgraph, &mut visited, context) .map_err(|_| super::Error::WorkspaceCallGraphDFSError)?; } - Ok(WorkspaceCallGraph { graph }) + Ok(WorkspaceCallGraph { raw_callgraph }) } } @@ -52,7 +52,7 @@ impl WorkspaceCallGraph { /// with their connected counterparts. fn dfs_to_create_graph( id: NodeID, - graph: &mut CallGraph, + raw_callgraph: &mut RawCallGraph, visited: &mut HashSet, context: &WorkspaceContext, ) -> super::Result<()> { @@ -76,8 +76,8 @@ fn dfs_to_create_graph( for function_call in function_calls { if let Expression::Identifier(identifier) = function_call.expression.as_ref() { if let Some(referenced_function_id) = identifier.referenced_declaration { - create_connection_if_not_exsits(id, referenced_function_id, graph); - dfs_to_create_graph(referenced_function_id, graph, visited, context)?; + create_connection_if_not_exsits(id, referenced_function_id, raw_callgraph); + dfs_to_create_graph(referenced_function_id, raw_callgraph, visited, context)?; } } } @@ -88,14 +88,28 @@ fn dfs_to_create_graph( match &modifier_invocation.modifier_name { IdentifierOrIdentifierPath::Identifier(identifier) => { if let Some(reference_modifier_id) = identifier.referenced_declaration { - create_connection_if_not_exsits(id, reference_modifier_id, graph); - dfs_to_create_graph(reference_modifier_id, graph, visited, context)?; + create_connection_if_not_exsits(id, reference_modifier_id, raw_callgraph); + dfs_to_create_graph( + reference_modifier_id, + raw_callgraph, + visited, + context, + )?; } } IdentifierOrIdentifierPath::IdentifierPath(identifier_path) => { let referenced_modifier_id = identifier_path.referenced_declaration; - create_connection_if_not_exsits(id, referenced_modifier_id as i64, graph); - dfs_to_create_graph(referenced_modifier_id as i64, graph, visited, context)?; + create_connection_if_not_exsits( + id, + referenced_modifier_id as i64, + raw_callgraph, + ); + dfs_to_create_graph( + referenced_modifier_id as i64, + raw_callgraph, + visited, + context, + )?; } } } @@ -107,8 +121,12 @@ fn dfs_to_create_graph( Ok(()) } -fn create_connection_if_not_exsits(from_id: NodeID, to_id: NodeID, graph: &mut CallGraph) { - match graph.entry(from_id) { +fn create_connection_if_not_exsits( + from_id: NodeID, + to_id: NodeID, + raw_callgraph: &mut RawCallGraph, +) { + match raw_callgraph.entry(from_id) { hash_map::Entry::Occupied(mut o) => { // Performance Tip: Maybe later use binary search (it requires keeping ascending order while inserting tho) if !o.get().contains(&to_id) { @@ -121,9 +139,9 @@ fn create_connection_if_not_exsits(from_id: NodeID, to_id: NodeID, graph: &mut C } } -impl Transpose for CallGraph { +impl Transpose for RawCallGraph { fn reverse(&self) -> Self { - let mut reversed_callgraph = CallGraph::default(); + let mut reversed_callgraph = RawCallGraph::default(); for (from_id, tos) in self { for to_id in tos { create_connection_if_not_exsits(*to_id, *from_id, &mut reversed_callgraph); diff --git a/aderyn_core/src/context/investigator/callgraph_tests.rs b/aderyn_core/src/context/investigator/callgraph_tests.rs deleted file mode 100644 index 3f5c1b26..00000000 --- a/aderyn_core/src/context/investigator/callgraph_tests.rs +++ /dev/null @@ -1,283 +0,0 @@ -#![allow(clippy::collapsible_match)] - -#[cfg(test)] -mod callgraph_tests { - use crate::{ - ast::{FunctionDefinition, ModifierDefinition}, - context::{ - investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, - }, - workspace_context::{ASTNode, WorkspaceContext}, - }, - }; - - use serial_test::serial; - use StandardInvestigationStyle::*; - - fn get_function_by_name(context: &WorkspaceContext, name: &str) -> ASTNode { - ASTNode::from( - context - .function_definitions() - .into_iter() - .find(|&x| x.name == *name) - .unwrap(), - ) - } - - fn get_modifier_definition_by_name(context: &WorkspaceContext, name: &str) -> ASTNode { - ASTNode::from( - context - .modifier_definitions() - .into_iter() - .find(|&x| x.name == *name) - .unwrap(), - ) - } - - #[test] - #[serial] - fn test_callgraph_is_not_none() { - let context = crate::detect::test_utils::load_solidity_source_unit( - "../tests/contract-playground/src/CallGraphTests.sol", - ); - assert!(context.forward_callgraph.is_some()); - assert!(context.reverse_callgraph.is_some()); - } - - #[test] - #[serial] - fn test_tower1_modifier_has_no_downstream() { - let context = crate::detect::test_utils::load_solidity_source_unit( - "../tests/contract-playground/src/CallGraphTests.sol", - ); - - let visit_eighth_floor1 = get_function_by_name(&context, "visitEighthFloor1"); - - let investigator = - StandardInvestigator::new(&context, &[&visit_eighth_floor1], Downstream).unwrap(); - - let mut tracker = Tracker::new(&context); - investigator.investigate(&context, &mut tracker).unwrap(); - - assert!(tracker.downstream_func_definitions_names.is_empty()); - assert!(tracker.downstream_modifier_definitions_names.is_empty()); - } - - #[test] - #[serial] - fn test_tower1_modifier_has_upstream() { - let context = crate::detect::test_utils::load_solidity_source_unit( - "../tests/contract-playground/src/CallGraphTests.sol", - ); - - let visit_eighth_floor1 = get_function_by_name(&context, "visitEighthFloor1"); - - let investigator = - StandardInvestigator::new(&context, &[&visit_eighth_floor1], Upstream).unwrap(); - - let mut tracker = Tracker::new(&context); - investigator.investigate(&context, &mut tracker).unwrap(); - - assert!(tracker.has_found_upstream_modifiers_with_names(&["passThroughNinthFloor1"])); - assert!(tracker.has_found_upstream_functions_with_names(&["enterTenthFloor1"])); - } - - #[test] - #[serial] - fn test_tower2_modifier_has_both_upstream_and_downstream() { - let context = crate::detect::test_utils::load_solidity_source_unit( - "../tests/contract-playground/src/CallGraphTests.sol", - ); - - let pass_through_ninth_floor2 = - get_modifier_definition_by_name(&context, "passThroughNinthFloor2"); - - let investigator = - StandardInvestigator::new(&context, &[&pass_through_ninth_floor2], BothWays).unwrap(); - - let mut tracker = Tracker::new(&context); - investigator.investigate(&context, &mut tracker).unwrap(); - - assert!(tracker.has_found_downstream_functions_with_names(&["visitEighthFloor2"])); - assert!(tracker.has_found_upstream_functions_with_names(&["enterTenthFloor2"])); - } - - #[test] - #[serial] - fn test_tower3_modifier_has_both_upstream_and_downstream() { - let context = crate::detect::test_utils::load_solidity_source_unit( - "../tests/contract-playground/src/CallGraphTests.sol", - ); - - let pass_through_ninth_floor3 = - get_modifier_definition_by_name(&context, "passThroughNinthFloor3"); - - let investigator = - StandardInvestigator::new(&context, &[&pass_through_ninth_floor3], BothWays).unwrap(); - - let mut tracker = Tracker::new(&context); - investigator.investigate(&context, &mut tracker).unwrap(); - - assert!(tracker.has_found_upstream_functions_with_names(&["enterTenthFloor3"])); - assert!(tracker.has_found_downstream_functions_with_names(&["visitEighthFloor3"])); - assert!(tracker.has_not_found_any_upstream_functions_with_name("visitSeventhFloor3")); - assert!(tracker.has_found_upstream_side_effect_functions_with_name(&["visitSeventhFloor3"])); - } - - #[test] - #[serial] - fn test_tower3_functions_has_upstream() { - let context = crate::detect::test_utils::load_solidity_source_unit( - "../tests/contract-playground/src/CallGraphTests.sol", - ); - - let visit_eighth_floor3 = get_function_by_name(&context, "visitSeventhFloor3"); - - let investigator = - StandardInvestigator::new(&context, &[&visit_eighth_floor3], Upstream).unwrap(); - - let mut tracker = Tracker::new(&context); - investigator.investigate(&context, &mut tracker).unwrap(); - - assert!(tracker.has_found_upstream_functions_with_names(&["enterTenthFloor3"])); - } - - #[test] - #[serial] - fn test_tower4_functions_has_upstream_and_downstream() { - let context = crate::detect::test_utils::load_solidity_source_unit( - "../tests/contract-playground/src/CallGraphTests.sol", - ); - - let recurse = get_function_by_name(&context, "recurse"); - - let investigator = StandardInvestigator::new(&context, &[&recurse], BothWays).unwrap(); - - let mut tracker = Tracker::new(&context); - investigator.investigate(&context, &mut tracker).unwrap(); - - assert!(tracker.has_found_upstream_functions_with_names(&["recurse"])); - assert!(tracker.has_found_downstream_functions_with_names(&["recurse"])); - } - - struct Tracker<'a> { - context: &'a WorkspaceContext, - entry_points: Vec<(String, usize, String)>, - downstream_func_definitions_names: Vec, - upstream_func_definitions_names: Vec, - downstream_modifier_definitions_names: Vec, - upstream_modifier_definitions_names: Vec, - upstream_side_effects_func_definitions_names: Vec, - upstream_side_effects_modifier_definitions_names: Vec, - } - - impl<'a> Tracker<'a> { - fn new(context: &WorkspaceContext) -> Tracker { - Tracker { - context, - entry_points: vec![], - downstream_func_definitions_names: vec![], - downstream_modifier_definitions_names: vec![], - upstream_func_definitions_names: vec![], - upstream_modifier_definitions_names: vec![], - upstream_side_effects_func_definitions_names: vec![], - upstream_side_effects_modifier_definitions_names: vec![], - } - } - - // downstream functions - fn has_found_downstream_functions_with_names(&self, name: &[&str]) -> bool { - name.iter().all(|&n| { - self.downstream_func_definitions_names - .contains(&n.to_string()) - }) - } - - // upstream functions - fn has_found_upstream_functions_with_names(&self, name: &[&str]) -> bool { - name.iter().all(|&n| { - self.upstream_func_definitions_names - .contains(&n.to_string()) - }) - } - - fn has_not_found_any_upstream_functions_with_name(&self, name: &str) -> bool { - !self - .upstream_func_definitions_names - .contains(&name.to_string()) - } - - // upstream modifiers - fn has_found_upstream_modifiers_with_names(&self, name: &[&str]) -> bool { - name.iter().all(|&n| { - self.upstream_modifier_definitions_names - .contains(&n.to_string()) - }) - } - - // upstream side effects - fn has_found_upstream_side_effect_functions_with_name(&self, name: &[&str]) -> bool { - name.iter().all(|&n| { - self.upstream_side_effects_func_definitions_names - .contains(&n.to_string()) - }) - } - } - - impl StandardInvestigatorVisitor for Tracker<'_> { - fn visit_entry_point(&mut self, node: &ASTNode) -> eyre::Result<()> { - self.entry_points - .push(self.context.get_node_sort_key_pure(node)); - Ok(()) - } - fn visit_downstream_function_definition( - &mut self, - node: &crate::ast::FunctionDefinition, - ) -> eyre::Result<()> { - self.downstream_func_definitions_names - .push(node.name.to_string()); - Ok(()) - } - fn visit_downstream_modifier_definition( - &mut self, - node: &crate::ast::ModifierDefinition, - ) -> eyre::Result<()> { - self.downstream_modifier_definitions_names - .push(node.name.to_string()); - Ok(()) - } - fn visit_upstream_function_definition( - &mut self, - node: &crate::ast::FunctionDefinition, - ) -> eyre::Result<()> { - self.upstream_func_definitions_names - .push(node.name.to_string()); - Ok(()) - } - fn visit_upstream_modifier_definition( - &mut self, - node: &crate::ast::ModifierDefinition, - ) -> eyre::Result<()> { - self.upstream_modifier_definitions_names - .push(node.name.to_string()); - Ok(()) - } - fn visit_upstream_side_effect_function_definition( - &mut self, - node: &FunctionDefinition, - ) -> eyre::Result<()> { - self.upstream_side_effects_func_definitions_names - .push(node.name.to_string()); - Ok(()) - } - fn visit_upstream_side_effect_modifier_definition( - &mut self, - node: &ModifierDefinition, - ) -> eyre::Result<()> { - self.upstream_side_effects_modifier_definitions_names - .push(node.name.to_string()); - Ok(()) - } - } -} diff --git a/aderyn_core/src/context/investigator/mod.rs b/aderyn_core/src/context/investigator/mod.rs deleted file mode 100644 index 79f66015..00000000 --- a/aderyn_core/src/context/investigator/mod.rs +++ /dev/null @@ -1,46 +0,0 @@ -mod callgraph_tests; -mod standard; -mod traits; - -pub use standard::*; -pub use traits::*; - -use derive_more::From; - -use crate::ast::{ASTNode, NodeID}; - -pub type Result = core::result::Result; - -#[derive(Debug, From)] -pub enum Error { - #[from] - Custom(String), - - // region: -- standard::* errors - ForwardCallgraphNotAvailable, - BackwardCallgraphNotAvailable, - UnidentifiedEntryPointNode(ASTNode), - InvalidEntryPointId(NodeID), - EntryPointVisitError, - UpstreamFunctionDefinitionVisitError, - UpstreamModifierDefinitionVisitError, - DownstreamFunctionDefinitionVisitError, - DownstreamModifierDefinitionVisitError, - UpstreamSideEffectFunctionDefinitionVisitError, - UpstreamSideEffectModifierDefinitionVisitError, - // endregion -} - -impl core::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl From<&str> for Error { - fn from(value: &str) -> Self { - Error::Custom(value.to_string()) - } -} - -impl std::error::Error for Error {} diff --git a/aderyn_core/src/context/investigator/traits.rs b/aderyn_core/src/context/investigator/traits.rs deleted file mode 100644 index c0c9a9b6..00000000 --- a/aderyn_core/src/context/investigator/traits.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Trackers can implement the following traits to interact with investigators -//! -//! NOTE -//! Upstream and downstream here is relative to [`super::StandardInvestigator::entry_points`] -//! which is initialized with [`super::StandardInvestigator::new`] function. - -use crate::{ - ast::{FunctionDefinition, ModifierDefinition}, - context::workspace_context::ASTNode, -}; - -/// Use with [`super::StandardInvestigator`] -pub trait StandardInvestigatorVisitor { - /// Shift all logic to tracker otherwise, you would track state at 2 different places - /// One at the tracker level, and other at the application level. Instead, we must - /// contain all of the tracking logic in the tracker. Therefore, visit entry point - /// is essential because the tracker can get to take a look at not just the - /// downstream functions and modifiers, but also the entry points that have invoked it. - fn visit_entry_point(&mut self, node: &ASTNode) -> eyre::Result<()> { - self.visit_any(node) - } - - /// Meant to be invoked while traversing [`crate::context::workspace_context::WorkspaceContext::forward_callgraph`] - fn visit_downstream_function_definition( - &mut self, - node: &FunctionDefinition, - ) -> eyre::Result<()> { - self.visit_any(&(node.into())) - } - - /// Meant to be invoked while traversing [`crate::context::workspace_context::WorkspaceContext::reverse_callgraph`] - fn visit_upstream_function_definition( - &mut self, - node: &FunctionDefinition, - ) -> eyre::Result<()> { - self.visit_any(&(node.into())) - } - - /// Meant to be invoked while traversing [`crate::context::workspace_context::WorkspaceContext::forward_callgraph`] - fn visit_downstream_modifier_definition( - &mut self, - node: &ModifierDefinition, - ) -> eyre::Result<()> { - self.visit_any(&(node.into())) - } - - /// Meant to be invoked while traversing [`crate::context::workspace_context::WorkspaceContext::reverse_callgraph`] - fn visit_upstream_modifier_definition( - &mut self, - node: &ModifierDefinition, - ) -> eyre::Result<()> { - self.visit_any(&(node.into())) - } - - /// Read as "upstream's downstream-side-effect" function definition - /// These are function definitions that are downstream from the upstream nodes - /// but are themselves neither upstream nor downstream to the entry points - fn visit_upstream_side_effect_function_definition( - &mut self, - node: &FunctionDefinition, - ) -> eyre::Result<()> { - self.visit_any(&(node.into())) - } - - /// Read as "upstream's downstream-side-effect" modifier definition - /// These are modifier definitions that are downstream from the upstream nodes - /// but are themselves neither upstream nor downstream to the entry points - fn visit_upstream_side_effect_modifier_definition( - &mut self, - node: &ModifierDefinition, - ) -> eyre::Result<()> { - self.visit_any(&(node.into())) - } - - fn visit_any(&mut self, _node: &ASTNode) -> eyre::Result<()> { - Ok(()) - } -} diff --git a/aderyn_core/src/context/mod.rs b/aderyn_core/src/context/mod.rs index ca73f66c..6ff7b956 100644 --- a/aderyn_core/src/context/mod.rs +++ b/aderyn_core/src/context/mod.rs @@ -1,7 +1,6 @@ pub mod browser; pub mod capturable; pub mod graph; -pub mod investigator; pub mod macros; pub mod meta_workspace; pub mod workspace_context; diff --git a/aderyn_core/src/context/workspace_context.rs b/aderyn_core/src/context/workspace_context.rs index 4d475993..0b463746 100644 --- a/aderyn_core/src/context/workspace_context.rs +++ b/aderyn_core/src/context/workspace_context.rs @@ -28,8 +28,8 @@ pub struct WorkspaceContext { pub src_filepaths: Vec, pub sloc_stats: HashMap, pub ignore_lines_stats: HashMap>>, - pub forward_callgraph: Option, - pub reverse_callgraph: Option, + pub inward_callgraph: Option, + pub outward_callgraph: Option, pub nodes: HashMap, // Hashmaps of all nodes => source_unit_id diff --git a/aderyn_core/src/detect/high/contract_locks_ether.rs b/aderyn_core/src/detect/high/contract_locks_ether.rs index 808e0df0..dbf2eef0 100644 --- a/aderyn_core/src/detect/high/contract_locks_ether.rs +++ b/aderyn_core/src/detect/high/contract_locks_ether.rs @@ -66,7 +66,8 @@ mod contract_eth_helper { use crate::{ ast::{ASTNode, ContractDefinition, StateMutability, Visibility}, context::{ - browser::ExtractFunctionDefinitions, investigator::*, + browser::ExtractFunctionDefinitions, + graph::{CallGraph, CallGraphDirection, CallGraphVisitor}, workspace_context::WorkspaceContext, }, detect::helpers, @@ -111,14 +112,14 @@ mod contract_eth_helper { let mut tracker = EthWithdrawalAllowerTracker::default(); - let investigator = StandardInvestigator::new( + let callgraph = CallGraph::new( context, funcs.iter().collect::>().as_slice(), - StandardInvestigationStyle::Downstream, + CallGraphDirection::Inward, ) .ok()?; - investigator.investigate(context, &mut tracker).ok()?; + callgraph.accept(context, &mut tracker).ok()?; if tracker.has_calls_that_sends_native_eth { return Some(true); @@ -137,7 +138,7 @@ mod contract_eth_helper { has_calls_that_sends_native_eth: bool, } - impl StandardInvestigatorVisitor for EthWithdrawalAllowerTracker { + impl CallGraphVisitor for EthWithdrawalAllowerTracker { fn visit_any(&mut self, ast_node: &ASTNode) -> eyre::Result<()> { if !self.has_calls_that_sends_native_eth && helpers::has_calls_that_sends_native_eth(ast_node) diff --git a/aderyn_core/src/detect/high/delegate_call_no_address_check.rs b/aderyn_core/src/detect/high/delegate_call_no_address_check.rs index 19a698cc..77d9ef8f 100644 --- a/aderyn_core/src/detect/high/delegate_call_no_address_check.rs +++ b/aderyn_core/src/detect/high/delegate_call_no_address_check.rs @@ -4,9 +4,8 @@ use std::error::Error; use crate::ast::NodeID; use crate::capture; -use crate::context::investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, -}; + +use crate::context::graph::{CallGraph, CallGraphDirection, CallGraphVisitor}; use crate::detect::detector::IssueDetectorNamePool; use crate::detect::helpers; use crate::{ @@ -30,12 +29,8 @@ impl IssueDetector for DelegateCallOnUncheckedAddressDetector { has_delegate_call_on_non_state_variable_address: false, context, }; - let investigator = StandardInvestigator::new( - context, - &[&(func.into())], - StandardInvestigationStyle::Downstream, - )?; - investigator.investigate(context, &mut tracker)?; + let callgraph = CallGraph::new(context, &[&(func.into())], CallGraphDirection::Inward)?; + callgraph.accept(context, &mut tracker)?; if tracker.has_delegate_call_on_non_state_variable_address && !tracker.has_address_checks @@ -74,7 +69,7 @@ struct DelegateCallNoAddressChecksTracker<'a> { context: &'a WorkspaceContext, } -impl<'a> StandardInvestigatorVisitor for DelegateCallNoAddressChecksTracker<'a> { +impl<'a> CallGraphVisitor for DelegateCallNoAddressChecksTracker<'a> { fn visit_any(&mut self, node: &crate::context::workspace_context::ASTNode) -> eyre::Result<()> { if !self.has_address_checks && helpers::has_binary_checks_on_some_address(node) { self.has_address_checks = true; diff --git a/aderyn_core/src/detect/high/msg_value_in_loops.rs b/aderyn_core/src/detect/high/msg_value_in_loops.rs index b95ed64d..d8644961 100644 --- a/aderyn_core/src/detect/high/msg_value_in_loops.rs +++ b/aderyn_core/src/detect/high/msg_value_in_loops.rs @@ -6,9 +6,7 @@ use crate::ast::{ASTNode, Expression, NodeID}; use crate::capture; use crate::context::browser::ExtractMemberAccesses; -use crate::context::investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, -}; +use crate::context::graph::{CallGraph, CallGraphDirection, CallGraphVisitor}; use crate::detect::detector::IssueDetectorNamePool; use crate::{ context::workspace_context::WorkspaceContext, @@ -72,11 +70,9 @@ impl IssueDetector for MsgValueUsedInLoopDetector { fn uses_msg_value(context: &WorkspaceContext, ast_node: &ASTNode) -> Option { let mut tracker = MsgValueTracker::default(); - let investigator = - StandardInvestigator::new(context, &[ast_node], StandardInvestigationStyle::Downstream) - .ok()?; + let callgraph = CallGraph::new(context, &[ast_node], CallGraphDirection::Inward).ok()?; - investigator.investigate(context, &mut tracker).ok()?; + callgraph.accept(context, &mut tracker).ok()?; Some(tracker.has_msg_value) } @@ -85,7 +81,7 @@ struct MsgValueTracker { has_msg_value: bool, } -impl StandardInvestigatorVisitor for MsgValueTracker { +impl CallGraphVisitor for MsgValueTracker { fn visit_any(&mut self, node: &crate::ast::ASTNode) -> eyre::Result<()> { if !self.has_msg_value && ExtractMemberAccesses::from(node) diff --git a/aderyn_core/src/detect/high/out_of_order_retryable.rs b/aderyn_core/src/detect/high/out_of_order_retryable.rs index a3910226..433a70cd 100644 --- a/aderyn_core/src/detect/high/out_of_order_retryable.rs +++ b/aderyn_core/src/detect/high/out_of_order_retryable.rs @@ -5,9 +5,7 @@ use crate::ast::{Expression, MemberAccess, NodeID}; use crate::capture; use crate::context::browser::ExtractFunctionCalls; -use crate::context::investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, -}; +use crate::context::graph::{CallGraph, CallGraphDirection, CallGraphVisitor}; use crate::detect::detector::IssueDetectorNamePool; use crate::detect::helpers; use crate::{ @@ -29,12 +27,8 @@ impl IssueDetector for OutOfOrderRetryableDetector { let mut tracker = OutOfOrderRetryableTracker { number_of_retry_calls: 0, }; - let investigator = StandardInvestigator::new( - context, - &[&(func.into())], - StandardInvestigationStyle::Downstream, - )?; - investigator.investigate(context, &mut tracker)?; + let callgraph = CallGraph::new(context, &[&(func.into())], CallGraphDirection::Inward)?; + callgraph.accept(context, &mut tracker)?; if tracker.number_of_retry_calls >= 2 { capture!(self, context, func); } @@ -77,7 +71,7 @@ const SEQUENCER_FUNCTIONS: [&str; 3] = [ "unsafeCreateRetryableTicket", ]; -impl StandardInvestigatorVisitor for OutOfOrderRetryableTracker { +impl CallGraphVisitor for OutOfOrderRetryableTracker { fn visit_any(&mut self, node: &crate::ast::ASTNode) -> eyre::Result<()> { if self.number_of_retry_calls >= 2 { return Ok(()); diff --git a/aderyn_core/src/detect/high/send_ether_no_checks.rs b/aderyn_core/src/detect/high/send_ether_no_checks.rs index 964c694a..5358e7b1 100644 --- a/aderyn_core/src/detect/high/send_ether_no_checks.rs +++ b/aderyn_core/src/detect/high/send_ether_no_checks.rs @@ -4,9 +4,7 @@ use std::error::Error; use crate::ast::NodeID; use crate::capture; -use crate::context::investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, -}; +use crate::context::graph::{CallGraph, CallGraphDirection, CallGraphVisitor}; use crate::context::workspace_context::ASTNode; use crate::detect::detector::IssueDetectorNamePool; use crate::detect::helpers; @@ -27,12 +25,9 @@ impl IssueDetector for SendEtherNoChecksDetector { fn detect(&mut self, context: &WorkspaceContext) -> Result> { for func in helpers::get_implemented_external_and_public_functions(context) { let mut tracker = MsgSenderAndCallWithValueTracker::default(); - let investigator = StandardInvestigator::new( - context, - &[&(func.into())], - StandardInvestigationStyle::Downstream, - )?; - investigator.investigate(context, &mut tracker)?; + let investigator = + CallGraph::new(context, &[&(func.into())], CallGraphDirection::Inward)?; + investigator.accept(context, &mut tracker)?; if tracker.sends_native_eth && !tracker.has_msg_sender_checks { capture!(self, context, func); @@ -108,7 +103,7 @@ pub struct MsgSenderAndCallWithValueTracker { pub sends_native_eth: bool, } -impl StandardInvestigatorVisitor for MsgSenderAndCallWithValueTracker { +impl CallGraphVisitor for MsgSenderAndCallWithValueTracker { fn visit_any(&mut self, node: &ASTNode) -> eyre::Result<()> { if !self.has_msg_sender_checks && helpers::has_msg_sender_binary_operation(node) { self.has_msg_sender_checks = true; diff --git a/aderyn_core/src/detect/high/tx_origin_used_for_auth.rs b/aderyn_core/src/detect/high/tx_origin_used_for_auth.rs index f733de66..11fd5ab9 100644 --- a/aderyn_core/src/detect/high/tx_origin_used_for_auth.rs +++ b/aderyn_core/src/detect/high/tx_origin_used_for_auth.rs @@ -5,9 +5,7 @@ use crate::ast::{ASTNode, Expression, Identifier, NodeID}; use crate::capture; use crate::context::browser::ExtractMemberAccesses; -use crate::context::investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, -}; +use crate::context::graph::{CallGraph, CallGraphDirection, CallGraphVisitor}; use crate::detect::detector::IssueDetectorNamePool; use crate::{ context::workspace_context::WorkspaceContext, @@ -85,12 +83,8 @@ impl TxOriginUsedForAuthDetector { ) -> Result<(), Box> { // Boilerplate let mut tracker = MsgSenderAndTxOriginTracker::default(); - let investigator = StandardInvestigator::new( - context, - check_nodes, - StandardInvestigationStyle::Downstream, - )?; - investigator.investigate(context, &mut tracker)?; + let callgraph = CallGraph::new(context, check_nodes, CallGraphDirection::Inward)?; + callgraph.accept(context, &mut tracker)?; if tracker.satisifed() { capture!(self, context, capture_node); @@ -113,7 +107,7 @@ impl MsgSenderAndTxOriginTracker { } } -impl StandardInvestigatorVisitor for MsgSenderAndTxOriginTracker { +impl CallGraphVisitor for MsgSenderAndTxOriginTracker { fn visit_any(&mut self, node: &crate::ast::ASTNode) -> eyre::Result<()> { let member_accesses = ExtractMemberAccesses::from(node).extracted; diff --git a/aderyn_core/src/detect/low/constant_funcs_assembly.rs b/aderyn_core/src/detect/low/constant_funcs_assembly.rs index 3608ba2b..a2a9aefa 100644 --- a/aderyn_core/src/detect/low/constant_funcs_assembly.rs +++ b/aderyn_core/src/detect/low/constant_funcs_assembly.rs @@ -8,9 +8,8 @@ use crate::capture; use crate::context::browser::{ ExtractInlineAssemblys, ExtractPragmaDirectives, GetClosestAncestorOfTypeX, }; -use crate::context::investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, -}; + +use crate::context::graph::{CallGraph, CallGraphDirection, CallGraphVisitor}; use crate::detect::detector::IssueDetectorNamePool; use crate::detect::helpers::{self, pragma_directive_to_semver}; use crate::{ @@ -50,12 +49,12 @@ impl IssueDetector for ConstantFunctionContainsAssemblyDetector { let mut tracker = AssemblyTracker { has_assembly: false, }; - let investigator = StandardInvestigator::new( + let callgraph = CallGraph::new( context, &[&(function.into())], - StandardInvestigationStyle::Downstream, + CallGraphDirection::Inward, )?; - investigator.investigate(context, &mut tracker)?; + callgraph.accept(context, &mut tracker)?; if tracker.has_assembly { capture!(self, context, function); @@ -110,7 +109,7 @@ struct AssemblyTracker { has_assembly: bool, } -impl StandardInvestigatorVisitor for AssemblyTracker { +impl CallGraphVisitor for AssemblyTracker { fn visit_any(&mut self, node: &crate::ast::ASTNode) -> eyre::Result<()> { // If we are already satisifed, do not bother checking if self.has_assembly { diff --git a/aderyn_core/src/detect/low/function_init_state_vars.rs b/aderyn_core/src/detect/low/function_init_state_vars.rs index 67fc8718..f40dffa4 100644 --- a/aderyn_core/src/detect/low/function_init_state_vars.rs +++ b/aderyn_core/src/detect/low/function_init_state_vars.rs @@ -5,9 +5,7 @@ use crate::ast::{ASTNode, Expression, FunctionCall, Identifier, NodeID}; use crate::capture; use crate::context::browser::ExtractReferencedDeclarations; -use crate::context::investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, -}; +use crate::context::graph::{CallGraph, CallGraphDirection, CallGraphVisitor}; use crate::detect::detector::IssueDetectorNamePool; use crate::{ context::workspace_context::WorkspaceContext, @@ -46,13 +44,10 @@ impl IssueDetector for FunctionInitializingStateDetector { let mut tracker = NonConstantStateVariableReferenceDeclarationTracker::new(context); - let investigator = StandardInvestigator::new( - context, - &[&(func.into())], - StandardInvestigationStyle::Downstream, - )?; + let investigator = + CallGraph::new(context, &[&(func.into())], CallGraphDirection::Inward)?; - investigator.investigate(context, &mut tracker)?; + investigator.accept(context, &mut tracker)?; if tracker.makes_a_reference { capture!(self, context, variable_declaration); @@ -102,7 +97,7 @@ impl<'a> NonConstantStateVariableReferenceDeclarationTracker<'a> { } } -impl<'a> StandardInvestigatorVisitor for NonConstantStateVariableReferenceDeclarationTracker<'a> { +impl<'a> CallGraphVisitor for NonConstantStateVariableReferenceDeclarationTracker<'a> { fn visit_any(&mut self, node: &ASTNode) -> eyre::Result<()> { // We already know the condition is satisifed if self.makes_a_reference { diff --git a/aderyn_core/src/detect/low/return_bomb.rs b/aderyn_core/src/detect/low/return_bomb.rs index aa32ca6e..c49ef4c9 100644 --- a/aderyn_core/src/detect/low/return_bomb.rs +++ b/aderyn_core/src/detect/low/return_bomb.rs @@ -6,9 +6,7 @@ use crate::ast::{ASTNode, MemberAccess, NodeID}; use crate::ast::NodeType; use crate::capture; use crate::context::browser::GetClosestAncestorOfTypeX; -use crate::context::investigator::{ - StandardInvestigationStyle, StandardInvestigator, StandardInvestigatorVisitor, -}; +use crate::context::graph::{CallGraph, CallGraphDirection, CallGraphVisitor}; use crate::detect::detector::IssueDetectorNamePool; use crate::detect::helpers; use crate::{ @@ -39,12 +37,8 @@ impl IssueDetector for ReturnBombDetector { calls_on_non_state_variable_addresses: vec![], // collection of all `address.call` Member Accesses where address is not a state variable context, }; - let investigator = StandardInvestigator::new( - context, - &[&(func.into())], - StandardInvestigationStyle::Downstream, - )?; - investigator.investigate(context, &mut tracker)?; + let callgraph = CallGraph::new(context, &[&(func.into())], CallGraphDirection::Inward)?; + callgraph.accept(context, &mut tracker)?; if !tracker.has_address_checks { // Now we assume that in this region all addresses are unprotected (because they are not involved in any binary ops/checks) @@ -125,7 +119,7 @@ struct CallNoAddressChecksTracker<'a> { context: &'a WorkspaceContext, } -impl<'a> StandardInvestigatorVisitor for CallNoAddressChecksTracker<'a> { +impl<'a> CallGraphVisitor for CallNoAddressChecksTracker<'a> { fn visit_any(&mut self, node: &crate::context::workspace_context::ASTNode) -> eyre::Result<()> { if !self.has_address_checks && helpers::has_binary_checks_on_some_address(node) { self.has_address_checks = true; diff --git a/aderyn_core/src/detect/test_utils/load_source_unit.rs b/aderyn_core/src/detect/test_utils/load_source_unit.rs index c52a15f2..62daf78d 100644 --- a/aderyn_core/src/detect/test_utils/load_source_unit.rs +++ b/aderyn_core/src/detect/test_utils/load_source_unit.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ ast::SourceUnit, - context::{graph::traits::Transpose, workspace_context::WorkspaceContext}, + context::{graph::Transpose, workspace_context::WorkspaceContext}, }; use crate::{context::graph::WorkspaceCallGraph, visitor::ast_visitor::Node}; @@ -92,12 +92,12 @@ pub fn load_solidity_source_unit(filepath: &str) -> WorkspaceContext { } fn load_callgraphs(context: &mut WorkspaceContext) { - let forward_callgraph = WorkspaceCallGraph::from_context(context).unwrap(); - let reverse_callgraph = WorkspaceCallGraph { - graph: forward_callgraph.graph.reverse(), + let inward_callgraph = WorkspaceCallGraph::from_context(context).unwrap(); + let outward_callgraph = WorkspaceCallGraph { + raw_callgraph: inward_callgraph.raw_callgraph.reverse(), }; - context.forward_callgraph = Some(forward_callgraph); - context.reverse_callgraph = Some(reverse_callgraph); + context.inward_callgraph = Some(inward_callgraph); + context.outward_callgraph = Some(outward_callgraph); } fn absorb_ast_content_into_context( diff --git a/aderyn_driver/src/driver.rs b/aderyn_driver/src/driver.rs index 9ff0bb9e..748eef99 100644 --- a/aderyn_driver/src/driver.rs +++ b/aderyn_driver/src/driver.rs @@ -2,9 +2,11 @@ use crate::{ config_helpers::{append_from_foundry_toml, derive_from_aderyn_toml}, ensure_valid_root_path, process_auto, }; -use aderyn_core::context::graph::traits::Transpose; use aderyn_core::{ - context::{graph::WorkspaceCallGraph, workspace_context::WorkspaceContext}, + context::{ + graph::{Transpose, WorkspaceCallGraph}, + workspace_context::WorkspaceContext, + }, detect::detector::{get_all_issue_detectors, IssueDetector, IssueSeverity}, fscloc, report::{ @@ -155,12 +157,12 @@ fn make_context(args: &Args) -> WorkspaceContextWrapper { context.set_sloc_stats(sloc_stats); context.set_ignore_lines_stats(ignore_line_stats); - let forward_callgraph = WorkspaceCallGraph::from_context(context).unwrap(); - let reverse_callgraph = WorkspaceCallGraph { - graph: forward_callgraph.graph.reverse(), + let inward_callgraph = WorkspaceCallGraph::from_context(context).unwrap(); + let outward_callgraph = WorkspaceCallGraph { + raw_callgraph: inward_callgraph.raw_callgraph.reverse(), }; - context.forward_callgraph = Some(forward_callgraph); - context.reverse_callgraph = Some(reverse_callgraph); + context.inward_callgraph = Some(inward_callgraph); + context.outward_callgraph = Some(outward_callgraph); } // Using the source path, calculate the sloc