From 4909eaa78824f669c3c2b4ac357f50d4aa61d793 Mon Sep 17 00:00:00 2001 From: Alex Saveau Date: Sun, 14 Jan 2024 11:55:24 -0800 Subject: [PATCH] Add missing iterators Signed-off-by: Alex Saveau --- src/basic_block.rs | 22 ++++++++++ src/module.rs | 78 +++++++---------------------------- src/values/fn_value.rs | 21 ++++++++++ tests/all/test_basic_block.rs | 50 +++++++++++----------- 4 files changed, 83 insertions(+), 88 deletions(-) diff --git a/src/basic_block.rs b/src/basic_block.rs index 231169a8dbb..7dbacb5f3b7 100644 --- a/src/basic_block.rs +++ b/src/basic_block.rs @@ -346,6 +346,11 @@ impl<'ctx> BasicBlock<'ctx> { unsafe { Some(InstructionValue::new(value)) } } + /// Get an instruction iterator + pub fn get_instructions(self) -> InstructionIter<'ctx> { + InstructionIter(self.get_first_instruction()) + } + /// Removes this `BasicBlock` from its parent `FunctionValue`. /// It returns `Err(())` when it has no parent to remove from. /// @@ -597,3 +602,20 @@ impl fmt::Debug for BasicBlock<'_> { .finish() } } + +/// Iterate over all `InstructionValue`s in a basic block. +#[derive(Debug)] +pub struct InstructionIter<'ctx>(Option>); + +impl<'ctx> Iterator for InstructionIter<'ctx> { + type Item = InstructionValue<'ctx>; + + fn next(&mut self) -> Option { + if let Some(instr) = self.0 { + self.0 = instr.get_next_instruction(); + Some(instr) + } else { + None + } + } +} diff --git a/src/module.rs b/src/module.rs index 08a3bf91a7b..770f01b43c3 100644 --- a/src/module.rs +++ b/src/module.rs @@ -1573,24 +1573,11 @@ pub enum FlagBehavior { /// Iterate over all `FunctionValue`s in an llvm module #[derive(Debug)] -pub struct FunctionIterator<'ctx>(FunctionIteratorInner<'ctx>); - -/// Inner type so the variants are not publicly visible -#[derive(Debug)] -enum FunctionIteratorInner<'ctx> { - Empty, - Start(FunctionValue<'ctx>), - Previous(FunctionValue<'ctx>), -} +pub struct FunctionIterator<'ctx>(Option>); impl<'ctx> FunctionIterator<'ctx> { fn from_module(module: &Module<'ctx>) -> Self { - use FunctionIteratorInner::*; - - match module.get_first_function() { - None => Self(Empty), - Some(first) => Self(Start(first)), - } + Self(module.get_first_function()) } } @@ -1598,47 +1585,22 @@ impl<'ctx> Iterator for FunctionIterator<'ctx> { type Item = FunctionValue<'ctx>; fn next(&mut self) -> Option { - use FunctionIteratorInner::*; - - match self.0 { - Empty => None, - Start(first) => { - self.0 = Previous(first); - - Some(first) - }, - Previous(prev) => match prev.get_next_function() { - Some(current) => { - self.0 = Previous(current); - - Some(current) - }, - None => None, - }, + if let Some(func) = self.0 { + self.0 = func.get_next_function(); + Some(func) + } else { + None } } } /// Iterate over all `GlobalValue`s in an llvm module #[derive(Debug)] -pub struct GlobalIterator<'ctx>(GlobalIteratorInner<'ctx>); - -/// Inner type so the variants are not publicly visible -#[derive(Debug)] -enum GlobalIteratorInner<'ctx> { - Empty, - Start(GlobalValue<'ctx>), - Previous(GlobalValue<'ctx>), -} +pub struct GlobalIterator<'ctx>(Option>); impl<'ctx> GlobalIterator<'ctx> { fn from_module(module: &Module<'ctx>) -> Self { - use GlobalIteratorInner::*; - - match module.get_first_global() { - None => Self(Empty), - Some(first) => Self(Start(first)), - } + Self(module.get_first_global()) } } @@ -1646,23 +1608,11 @@ impl<'ctx> Iterator for GlobalIterator<'ctx> { type Item = GlobalValue<'ctx>; fn next(&mut self) -> Option { - use GlobalIteratorInner::*; - - match self.0 { - Empty => None, - Start(first) => { - self.0 = Previous(first); - - Some(first) - }, - Previous(prev) => match prev.get_next_global() { - Some(current) => { - self.0 = Previous(current); - - Some(current) - }, - None => None, - }, + if let Some(global) = self.0 { + self.0 = global.get_next_global(); + Some(global) + } else { + None } } } diff --git a/src/values/fn_value.rs b/src/values/fn_value.rs index 131725dfe4a..6818b513666 100644 --- a/src/values/fn_value.rs +++ b/src/values/fn_value.rs @@ -131,6 +131,10 @@ impl<'ctx> FunctionValue<'ctx> { unsafe { LLVMCountBasicBlocks(self.as_value_ref()) } } + pub fn get_basic_block_iter(self) -> BasicBlockIter<'ctx> { + BasicBlockIter(self.get_first_basic_block()) + } + pub fn get_basic_blocks(self) -> Vec> { let count = self.count_basic_blocks(); let mut raw_vec: Vec = Vec::with_capacity(count as usize); @@ -552,6 +556,23 @@ impl fmt::Debug for FunctionValue<'_> { } } +/// Iterate over all `BasicBlock`s in a function. +#[derive(Debug)] +pub struct BasicBlockIter<'ctx>(Option>); + +impl<'ctx> Iterator for BasicBlockIter<'ctx> { + type Item = BasicBlock<'ctx>; + + fn next(&mut self) -> Option { + if let Some(bb) = self.0 { + self.0 = bb.get_next_basic_block(); + Some(bb) + } else { + None + } + } +} + #[derive(Debug)] pub struct ParamValueIter<'ctx> { param_iter_value: LLVMValueRef, diff --git a/tests/all/test_basic_block.rs b/tests/all/test_basic_block.rs index f4c52da0727..39ad35fa859 100644 --- a/tests/all/test_basic_block.rs +++ b/tests/all/test_basic_block.rs @@ -21,32 +21,33 @@ fn test_basic_block_ordering() { let basic_block2 = context.insert_basic_block_after(basic_block, "block2"); let basic_block3 = context.prepend_basic_block(basic_block4, "block3"); - let basic_blocks = function.get_basic_blocks(); - - assert_eq!(basic_blocks.len(), 4); - assert_eq!(basic_blocks[0], basic_block); - assert_eq!(basic_blocks[1], basic_block2); - assert_eq!(basic_blocks[2], basic_block3); - assert_eq!(basic_blocks[3], basic_block4); + for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] { + assert_eq!(basic_blocks.len(), 4); + assert_eq!(basic_blocks[0], basic_block); + assert_eq!(basic_blocks[1], basic_block2); + assert_eq!(basic_blocks[2], basic_block3); + assert_eq!(basic_blocks[3], basic_block4); + } assert!(basic_block3.move_before(basic_block2).is_ok()); assert!(basic_block.move_after(basic_block4).is_ok()); let basic_block5 = context.prepend_basic_block(basic_block, "block5"); - let basic_blocks = function.get_basic_blocks(); - - assert_eq!(basic_blocks.len(), 5); - assert_eq!(basic_blocks[0], basic_block3); - assert_eq!(basic_blocks[1], basic_block2); - assert_eq!(basic_blocks[2], basic_block4); - assert_eq!(basic_blocks[3], basic_block5); - assert_eq!(basic_blocks[4], basic_block); - assert_ne!(basic_blocks[0], basic_block); - assert_ne!(basic_blocks[1], basic_block3); - assert_ne!(basic_blocks[2], basic_block2); - assert_ne!(basic_blocks[3], basic_block4); - assert_ne!(basic_blocks[4], basic_block5); + for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] { + assert_eq!(basic_blocks.len(), 5); + assert_eq!(basic_blocks[0], basic_block3); + assert_eq!(basic_blocks[1], basic_block2); + assert_eq!(basic_blocks[2], basic_block4); + assert_eq!(basic_blocks[3], basic_block5); + assert_eq!(basic_blocks[4], basic_block); + + assert_ne!(basic_blocks[0], basic_block); + assert_ne!(basic_blocks[1], basic_block3); + assert_ne!(basic_blocks[2], basic_block2); + assert_ne!(basic_blocks[3], basic_block4); + assert_ne!(basic_blocks[4], basic_block5); + } context.append_basic_block(function, "block6"); @@ -89,6 +90,7 @@ fn test_get_basic_blocks() { assert!(function.get_last_basic_block().is_none()); assert_eq!(function.get_basic_blocks().len(), 0); + assert_eq!(function.get_basic_block_iter().count(), 0); let basic_block = context.append_basic_block(function, "entry"); @@ -98,10 +100,10 @@ fn test_get_basic_blocks() { assert_eq!(last_basic_block, basic_block); - let basic_blocks = function.get_basic_blocks(); - - assert_eq!(basic_blocks.len(), 1); - assert_eq!(basic_blocks[0], basic_block); + for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] { + assert_eq!(basic_blocks.len(), 1); + assert_eq!(basic_blocks[0], basic_block); + } } #[test]