From 0cdaa34bd612464f310b3fdefa00a75d1324687d Mon Sep 17 00:00:00 2001 From: Alex Saveau Date: Sat, 20 Jan 2024 16:34:01 -0800 Subject: [PATCH] Add even more iterators Signed-off-by: Alex Saveau --- src/types/mod.rs | 1 + src/types/struct_type.rs | 42 ++++++++++++++- src/values/instruction_value.rs | 80 ++++++++++++++++++++++++++++ src/values/mod.rs | 4 +- src/values/phi_value.rs | 42 +++++++++++++++ src/values/struct_value.rs | 42 ++++++++++++++- tests/all/test_instruction_values.rs | 32 +++++++++++ tests/all/test_types.rs | 36 +++++++++---- tests/all/test_values.rs | 15 ++++++ 9 files changed, 281 insertions(+), 13 deletions(-) diff --git a/src/types/mod.rs b/src/types/mod.rs index 1ed6ab5072d..8f76fd7991e 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -30,6 +30,7 @@ pub use crate::types::int_type::{IntType, StringRadix}; pub use crate::types::metadata_type::MetadataType; pub use crate::types::ptr_type::PointerType; pub use crate::types::struct_type::StructType; +pub use crate::types::struct_type::FieldTypesIter; pub use crate::types::traits::{AnyType, AsTypeRef, BasicType, FloatMathType, IntMathType, PointerMathType}; pub use crate::types::vec_type::VectorType; pub use crate::types::void_type::VoidType; diff --git a/src/types/struct_type.rs b/src/types/struct_type.rs index 11d6ddd60e9..baebef6551d 100644 --- a/src/types/struct_type.rs +++ b/src/types/struct_type.rs @@ -61,7 +61,16 @@ impl<'ctx> StructType<'ctx> { return None; } - unsafe { Some(BasicTypeEnum::new(LLVMStructGetTypeAtIndex(self.as_type_ref(), index))) } + Some(unsafe { self.get_field_type_at_index_unchecked(index) }) + } + + /// Gets the type of a field belonging to this `StructType`. + /// + /// # Safety + /// + /// The index must be less than [StructType::count_fields] and the struct must not be opaque. + pub unsafe fn get_field_type_at_index_unchecked(self, index: u32) -> BasicTypeEnum<'ctx> { + unsafe { BasicTypeEnum::new(LLVMStructGetTypeAtIndex(self.as_type_ref(), index)) } } /// Creates a `StructValue` based on this `StructType`'s definition. @@ -328,6 +337,15 @@ impl<'ctx> StructType<'ctx> { raw_vec.iter().map(|val| unsafe { BasicTypeEnum::new(*val) }).collect() } + /// Get a struct field iterator. + pub fn get_field_types_iter(self) -> FieldTypesIter<'ctx> { + FieldTypesIter { + st: self, + i: 0, + count: if self.is_opaque() { 0 } else { self.count_fields() }, + } + } + /// Print the definition of a `StructType` to `LLVMString`. pub fn print_to_string(self) -> LLVMString { self.struct_type.print_to_string() @@ -445,3 +463,25 @@ impl Display for StructType<'_> { write!(f, "{}", self.print_to_string()) } } + +/// Iterate over all `BasicTypeEnum`s in a struct. +#[derive(Debug)] +pub struct FieldTypesIter<'ctx> { + st: StructType<'ctx>, + i: u32, + count: u32, +} + +impl<'ctx> Iterator for FieldTypesIter<'ctx> { + type Item = BasicTypeEnum<'ctx>; + + fn next(&mut self) -> Option { + if self.i < self.count { + let result = unsafe { self.st.get_field_type_at_index_unchecked(self.i) }; + self.i += 1; + Some(result) + } else { + None + } + } +} diff --git a/src/values/instruction_value.rs b/src/values/instruction_value.rs index 646a43dae68..61b484003b8 100644 --- a/src/values/instruction_value.rs +++ b/src/values/instruction_value.rs @@ -503,6 +503,15 @@ impl<'ctx> InstructionValue<'ctx> { return None; } + unsafe { self.get_operand_unchecked(index) } + } + + /// Get the operand of an `InstructionValue`. + /// + /// # Safety + /// + /// The index must be less than [InstructionValue::get_num_operands]. + pub unsafe fn get_operand_unchecked(self, index: u32) -> Option, BasicBlock<'ctx>>> { let operand = unsafe { LLVMGetOperand(self.as_value_ref(), index) }; if operand.is_null() { @@ -520,6 +529,15 @@ impl<'ctx> InstructionValue<'ctx> { } } + /// Get an instruction value operand iterator. + pub fn get_operands(self) -> OperandIter<'ctx> { + OperandIter { + iv: self, + i: 0, + count: self.get_num_operands(), + } + } + /// Sets the operand an `InstructionValue` has at a given index if possible. /// An operand is a `BasicValue` used in an IR instruction. /// @@ -598,6 +616,15 @@ impl<'ctx> InstructionValue<'ctx> { return None; } + unsafe { self.get_operand_use_unchecked(index) } + } + + /// Gets the use of an operand(`BasicValue`), if any. + /// + /// # Safety + /// + /// The index must be smaller than [InstructionValue::get_num_operands]. + pub unsafe fn get_operand_use_unchecked(self, index: u32) -> Option> { let use_ = unsafe { LLVMGetOperandUse(self.as_value_ref(), index) }; if use_.is_null() { @@ -607,6 +634,15 @@ impl<'ctx> InstructionValue<'ctx> { unsafe { Some(BasicValueUse::new(use_)) } } + /// Get an instruction value operand use iterator. + pub fn get_operand_uses(self) -> OperandUseIter<'ctx> { + OperandUseIter { + iv: self, + i: 0, + count: self.get_num_operands(), + } + } + /// Gets the first use of an `InstructionValue` if any. /// /// The following example, @@ -728,3 +764,47 @@ impl Display for InstructionValue<'_> { write!(f, "{}", self.print_to_string()) } } + +/// Iterate over all the operands of an instruction value. +#[derive(Debug)] +pub struct OperandIter<'ctx> { + iv: InstructionValue<'ctx>, + i: u32, + count: u32, +} + +impl<'ctx> Iterator for OperandIter<'ctx> { + type Item = Option, BasicBlock<'ctx>>>; + + fn next(&mut self) -> Option { + if self.i < self.count { + let result = unsafe { self.iv.get_operand_unchecked(self.i) }; + self.i += 1; + Some(result) + } else { + None + } + } +} + +/// Iterate over all the operands of an instruction value. +#[derive(Debug)] +pub struct OperandUseIter<'ctx> { + iv: InstructionValue<'ctx>, + i: u32, + count: u32, +} + +impl<'ctx> Iterator for OperandUseIter<'ctx> { + type Item = Option>; + + fn next(&mut self) -> Option { + if self.i < self.count { + let result = unsafe { self.iv.get_operand_use_unchecked(self.i) }; + self.i += 1; + Some(result) + } else { + None + } + } +} diff --git a/src/values/mod.rs b/src/values/mod.rs index bbf6e90e38f..2d0cd72e0e5 100644 --- a/src/values/mod.rs +++ b/src/values/mod.rs @@ -37,11 +37,13 @@ pub use crate::values::generic_value::GenericValue; pub use crate::values::global_value::GlobalValue; #[llvm_versions(7.0..=latest)] pub use crate::values::global_value::UnnamedAddress; -pub use crate::values::instruction_value::{InstructionOpcode, InstructionValue}; +pub use crate::values::instruction_value::{InstructionOpcode, InstructionValue, OperandIter, OperandUseIter}; pub use crate::values::int_value::IntValue; pub use crate::values::metadata_value::{MetadataValue, FIRST_CUSTOM_METADATA_KIND_ID}; +pub use crate::values::phi_value::IncomingIter; pub use crate::values::phi_value::PhiValue; pub use crate::values::ptr_value::PointerValue; +pub use crate::values::struct_value::FieldValueIter; pub use crate::values::struct_value::StructValue; pub use crate::values::traits::AsValueRef; pub use crate::values::traits::{AggregateValue, AnyValue, BasicValue, FloatMathValue, IntMathValue, PointerMathValue}; diff --git a/src/values/phi_value.rs b/src/values/phi_value.rs index 4be15e40fbf..bd7c1bbc911 100644 --- a/src/values/phi_value.rs +++ b/src/values/phi_value.rs @@ -67,6 +67,26 @@ impl<'ctx> PhiValue<'ctx> { Some((value, basic_block)) } + /// # Safety + /// + /// The index must be smaller [PhiValue::count_incoming]. + pub unsafe fn get_incoming_unchecked(self, index: u32) -> (BasicValueEnum<'ctx>, BasicBlock<'ctx>) { + let basic_block = + unsafe { BasicBlock::new(LLVMGetIncomingBlock(self.as_value_ref(), index)).expect("Invalid BasicBlock") }; + let value = unsafe { BasicValueEnum::new(LLVMGetIncomingValue(self.as_value_ref(), index)) }; + + (value, basic_block) + } + + /// Get an incoming edge iterator. + pub fn get_incomings(self) -> IncomingIter<'ctx> { + IncomingIter { + pv: self, + i: 0, + count: self.count_incoming(), + } + } + /// Gets the name of a `ArrayValue`. If the value is a constant, this will /// return an empty string. pub fn get_name(&self) -> &CStr { @@ -125,3 +145,25 @@ impl<'ctx> TryFrom> for PhiValue<'ctx> { } } } + +/// Iterate over all the incoming edges of a phi value. +#[derive(Debug)] +pub struct IncomingIter<'ctx> { + pv: PhiValue<'ctx>, + i: u32, + count: u32, +} + +impl<'ctx> Iterator for IncomingIter<'ctx> { + type Item = (BasicValueEnum<'ctx>, BasicBlock<'ctx>); + + fn next(&mut self) -> Option { + if self.i < self.count { + let result = unsafe { self.pv.get_incoming_unchecked(self.i) }; + self.i += 1; + Some(result) + } else { + None + } + } +} diff --git a/src/values/struct_value.rs b/src/values/struct_value.rs index 930ba3875f5..57892b2673a 100644 --- a/src/values/struct_value.rs +++ b/src/values/struct_value.rs @@ -54,7 +54,25 @@ impl<'ctx> StructValue<'ctx> { return None; } - unsafe { Some(BasicValueEnum::new(LLVMGetOperand(self.as_value_ref(), index))) } + Some(unsafe { self.get_field_at_index_unchecked(index) }) + } + + /// Gets the value of a field belonging to this `StructValue`. + /// + /// # Safety + /// + /// The index must be smaller than [StructValue::count_fields]. + pub unsafe fn get_field_at_index_unchecked(self, index: u32) -> BasicValueEnum<'ctx> { + unsafe { BasicValueEnum::new(LLVMGetOperand(self.as_value_ref(), index)) } + } + + /// Get a field value iterator. + pub fn get_fields(self) -> FieldValueIter<'ctx> { + FieldValueIter { + sv: self, + i: 0, + count: self.count_fields(), + } } /// Sets the value of a field belonging to this `StructValue`. @@ -140,3 +158,25 @@ impl Display for StructValue<'_> { write!(f, "{}", self.print_to_string()) } } + +/// Iterate over all the field values of this struct. +#[derive(Debug)] +pub struct FieldValueIter<'ctx> { + sv: StructValue<'ctx>, + i: u32, + count: u32, +} + +impl<'ctx> Iterator for FieldValueIter<'ctx> { + type Item = BasicValueEnum<'ctx>; + + fn next(&mut self) -> Option { + if self.i < self.count { + let result = unsafe { self.sv.get_field_at_index_unchecked(self.i) }; + self.i += 1; + Some(result) + } else { + None + } + } +} diff --git a/tests/all/test_instruction_values.rs b/tests/all/test_instruction_values.rs index 2e54c70e160..ded3378459f 100644 --- a/tests/all/test_instruction_values.rs +++ b/tests/all/test_instruction_values.rs @@ -34,6 +34,8 @@ fn test_operands() { // Test operands assert_eq!(store_instruction.get_num_operands(), 2); assert_eq!(free_instruction.get_num_operands(), 2); + assert_eq!(store_instruction.get_operands().count(), 2); + assert_eq!(free_instruction.get_operands().count(), 2); let store_operand0 = store_instruction.get_operand(0).unwrap(); let store_operand1 = store_instruction.get_operand(1).unwrap(); @@ -44,6 +46,14 @@ fn test_operands() { assert!(store_instruction.get_operand(3).is_none()); assert!(store_instruction.get_operand(4).is_none()); + let mut store_operands = store_instruction.get_operands(); + let store_operand0 = store_operands.next().unwrap().unwrap(); + let store_operand1 = store_operands.next().unwrap().unwrap(); + + assert_eq!(store_operand0.left().unwrap(), f32_val); // f32 const + assert_eq!(store_operand1.left().unwrap(), arg1); // f32* arg1 + assert!(store_operands.next().is_none()); + let free_operand0 = free_instruction.get_operand(0).unwrap().left().unwrap(); let free_operand1 = free_instruction.get_operand(1).unwrap().left().unwrap(); @@ -89,6 +99,7 @@ fn test_operands() { assert!(module.verify().is_ok()); assert_eq!(return_instruction.get_num_operands(), 0); + assert_eq!(return_instruction.get_operands().count(), 0); assert!(return_instruction.get_operand(0).is_none()); assert!(return_instruction.get_operand(1).is_none()); assert!(return_instruction.get_operand(2).is_none()); @@ -140,6 +151,27 @@ fn test_operands() { assert!(store_instruction.get_operand_use(5).is_none()); assert!(store_instruction.get_operand_use(6).is_none()); + // However their operands are used + let mut store_operand_uses = store_instruction.get_operand_uses(); + let store_operand_use0 = store_operand_uses.next().unwrap().unwrap(); + let store_operand_use1 = store_operand_uses.next().unwrap().unwrap(); + + assert!(store_operand_use0.get_next_use().is_none()); + assert!(store_operand_use1.get_next_use().is_none()); + assert_eq!(store_operand_use1, arg1_second_use); + + assert_eq!( + store_operand_use0.get_user().into_instruction_value(), + store_instruction + ); + assert_eq!( + store_operand_use1.get_user().into_instruction_value(), + store_instruction + ); + assert_eq!(store_operand_use0.get_used_value().left().unwrap(), f32_val); + assert_eq!(store_operand_use1.get_used_value().left().unwrap(), arg1); + assert!(store_operand_uses.next().is_none()); + let free_operand_use0 = free_instruction.get_operand_use(0).unwrap(); let free_operand_use1 = free_instruction.get_operand_use(1).unwrap(); diff --git a/tests/all/test_types.rs b/tests/all/test_types.rs index ecbd524e535..9c47068f95d 100644 --- a/tests/all/test_types.rs +++ b/tests/all/test_types.rs @@ -18,7 +18,9 @@ fn test_struct_type() { assert!(av_struct.get_name().is_none()); assert_eq!(av_struct.get_context(), context); assert_eq!(av_struct.count_fields(), 2); - assert_eq!(av_struct.get_field_types(), &[int_vector.into(), float_array.into()]); + for types in [av_struct.get_field_types(), av_struct.get_field_types_iter().collect()] { + assert_eq!(types, &[int_vector.into(), float_array.into()]); + } let field_1 = av_struct.get_field_type_at_index(0).unwrap(); let field_2 = av_struct.get_field_type_at_index(1).unwrap(); @@ -27,7 +29,9 @@ fn test_struct_type() { assert!(field_2.is_array_type()); assert!(av_struct.get_field_type_at_index(2).is_none()); assert!(av_struct.get_field_type_at_index(200).is_none()); - assert_eq!(av_struct.get_field_types(), vec![field_1, field_2]); + for types in [av_struct.get_field_types(), av_struct.get_field_types_iter().collect()] { + assert_eq!(types, &[field_1, field_2]); + } let av_struct = context.struct_type(&[int_vector.into(), float_array.into()], true); @@ -46,7 +50,9 @@ fn test_struct_type() { assert!(field_2.is_array_type()); assert!(av_struct.get_field_type_at_index(2).is_none()); assert!(av_struct.get_field_type_at_index(200).is_none()); - assert_eq!(av_struct.get_field_types(), vec![field_1, field_2]); + for types in [av_struct.get_field_types(), av_struct.get_field_types_iter().collect()] { + assert_eq!(types, &[field_1, field_2]); + } let opaque_struct = context.opaque_struct_type("opaque_struct"); @@ -57,6 +63,7 @@ fn test_struct_type() { assert_eq!(opaque_struct.get_context(), context); assert_eq!(opaque_struct.count_fields(), 0); assert!(opaque_struct.get_field_types().is_empty()); + assert_eq!(opaque_struct.get_field_types_iter().count(), 0); assert!(opaque_struct.get_field_type_at_index(0).is_none()); assert!(opaque_struct.get_field_type_at_index(1).is_none()); assert!(opaque_struct.get_field_type_at_index(2).is_none()); @@ -75,10 +82,12 @@ fn test_struct_type() { ); assert_eq!(no_longer_opaque_struct.get_context(), context); assert_eq!(no_longer_opaque_struct.count_fields(), 2); - assert_eq!( + for types in [ no_longer_opaque_struct.get_field_types(), - &[int_vector.into(), float_array.into()] - ); + no_longer_opaque_struct.get_field_types_iter().collect(), + ] { + assert_eq!(types, &[int_vector.into(), float_array.into()]); + } let field_1 = no_longer_opaque_struct.get_field_type_at_index(0).unwrap(); let field_2 = no_longer_opaque_struct.get_field_type_at_index(1).unwrap(); @@ -87,7 +96,12 @@ fn test_struct_type() { assert!(field_2.is_array_type()); assert!(no_longer_opaque_struct.get_field_type_at_index(2).is_none()); assert!(no_longer_opaque_struct.get_field_type_at_index(200).is_none()); - assert_eq!(no_longer_opaque_struct.get_field_types(), vec![field_1, field_2]); + for types in [ + no_longer_opaque_struct.get_field_types(), + no_longer_opaque_struct.get_field_types_iter().collect(), + ] { + assert_eq!(types, &[field_1, field_2]); + } no_longer_opaque_struct.set_body(&[float_array.into(), int_vector.into(), float_array.into()], false); let fields_changed_struct = no_longer_opaque_struct; @@ -95,10 +109,12 @@ fn test_struct_type() { assert!(!fields_changed_struct.is_opaque()); assert!(fields_changed_struct.is_sized()); assert_eq!(fields_changed_struct.count_fields(), 3); - assert_eq!( + for types in [ fields_changed_struct.get_field_types(), - &[float_array.into(), int_vector.into(), float_array.into(),] - ); + fields_changed_struct.get_field_types_iter().collect(), + ] { + assert_eq!(types, &[float_array.into(), int_vector.into(), float_array.into(),]); + } assert!(fields_changed_struct.get_field_type_at_index(3).is_none()); } diff --git a/tests/all/test_values.rs b/tests/all/test_values.rs index 8d7805b816f..0f45785ac3c 100644 --- a/tests/all/test_values.rs +++ b/tests/all/test_values.rs @@ -995,6 +995,16 @@ fn test_phi_values() { assert_eq!(then_bb, then_block); assert_eq!(else_bb, else_block); assert!(phi.get_incoming(2).is_none()); + + let mut incomings = phi.get_incomings(); + let (then_val, then_bb) = incomings.next().unwrap(); + let (else_val, else_bb) = incomings.next().unwrap(); + + assert_eq!(then_val.into_int_value(), false_val); + assert_eq!(else_val.into_int_value(), true_val); + assert_eq!(then_bb, then_block); + assert_eq!(else_bb, else_block); + assert!(incomings.next().is_none()); } #[test] @@ -1194,7 +1204,12 @@ fn test_consts() { let struct_val = struct_type.const_named_struct(&[i8_val.into(), f32_val.into()]); assert_eq!(struct_val.count_fields(), 2); + assert_eq!(struct_val.get_fields().count(), 2); assert_eq!(struct_val.count_fields(), struct_type.count_fields()); + assert_eq!( + struct_val.get_fields().count(), + struct_type.get_field_types_iter().count() + ); assert!(struct_val.get_field_at_index(0).is_some()); assert!(struct_val.get_field_at_index(1).is_some()); assert!(struct_val.get_field_at_index(3).is_none());