Skip to content

Commit

Permalink
Add even more iterators
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Saveau <[email protected]>
  • Loading branch information
SUPERCILEX authored and TheDan64 committed Jan 27, 2024
1 parent f5f39bf commit 0cdaa34
Show file tree
Hide file tree
Showing 9 changed files with 281 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub use crate::types::int_type::{IntType, StringRadix};
pub use crate::types::metadata_type::MetadataType;
pub use crate::types::ptr_type::PointerType;
pub use crate::types::struct_type::StructType;
pub use crate::types::struct_type::FieldTypesIter;
pub use crate::types::traits::{AnyType, AsTypeRef, BasicType, FloatMathType, IntMathType, PointerMathType};
pub use crate::types::vec_type::VectorType;
pub use crate::types::void_type::VoidType;
Expand Down
42 changes: 41 additions & 1 deletion src/types/struct_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,16 @@ impl<'ctx> StructType<'ctx> {
return None;
}

unsafe { Some(BasicTypeEnum::new(LLVMStructGetTypeAtIndex(self.as_type_ref(), index))) }
Some(unsafe { self.get_field_type_at_index_unchecked(index) })
}

/// Gets the type of a field belonging to this `StructType`.
///
/// # Safety
///
/// The index must be less than [StructType::count_fields] and the struct must not be opaque.
pub unsafe fn get_field_type_at_index_unchecked(self, index: u32) -> BasicTypeEnum<'ctx> {
unsafe { BasicTypeEnum::new(LLVMStructGetTypeAtIndex(self.as_type_ref(), index)) }
}

/// Creates a `StructValue` based on this `StructType`'s definition.
Expand Down Expand Up @@ -328,6 +337,15 @@ impl<'ctx> StructType<'ctx> {
raw_vec.iter().map(|val| unsafe { BasicTypeEnum::new(*val) }).collect()
}

/// Get a struct field iterator.
pub fn get_field_types_iter(self) -> FieldTypesIter<'ctx> {
FieldTypesIter {
st: self,
i: 0,
count: if self.is_opaque() { 0 } else { self.count_fields() },
}
}

/// Print the definition of a `StructType` to `LLVMString`.
pub fn print_to_string(self) -> LLVMString {
self.struct_type.print_to_string()
Expand Down Expand Up @@ -445,3 +463,25 @@ impl Display for StructType<'_> {
write!(f, "{}", self.print_to_string())
}
}

/// Iterate over all `BasicTypeEnum`s in a struct.
#[derive(Debug)]
pub struct FieldTypesIter<'ctx> {
st: StructType<'ctx>,
i: u32,
count: u32,
}

impl<'ctx> Iterator for FieldTypesIter<'ctx> {
type Item = BasicTypeEnum<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
if self.i < self.count {
let result = unsafe { self.st.get_field_type_at_index_unchecked(self.i) };
self.i += 1;
Some(result)
} else {
None
}
}
}
80 changes: 80 additions & 0 deletions src/values/instruction_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,15 @@ impl<'ctx> InstructionValue<'ctx> {
return None;
}

unsafe { self.get_operand_unchecked(index) }
}

/// Get the operand of an `InstructionValue`.
///
/// # Safety
///
/// The index must be less than [InstructionValue::get_num_operands].
pub unsafe fn get_operand_unchecked(self, index: u32) -> Option<Either<BasicValueEnum<'ctx>, BasicBlock<'ctx>>> {
let operand = unsafe { LLVMGetOperand(self.as_value_ref(), index) };

if operand.is_null() {
Expand All @@ -520,6 +529,15 @@ impl<'ctx> InstructionValue<'ctx> {
}
}

/// Get an instruction value operand iterator.
pub fn get_operands(self) -> OperandIter<'ctx> {
OperandIter {
iv: self,
i: 0,
count: self.get_num_operands(),
}
}

/// Sets the operand an `InstructionValue` has at a given index if possible.
/// An operand is a `BasicValue` used in an IR instruction.
///
Expand Down Expand Up @@ -598,6 +616,15 @@ impl<'ctx> InstructionValue<'ctx> {
return None;
}

unsafe { self.get_operand_use_unchecked(index) }
}

/// Gets the use of an operand(`BasicValue`), if any.
///
/// # Safety
///
/// The index must be smaller than [InstructionValue::get_num_operands].
pub unsafe fn get_operand_use_unchecked(self, index: u32) -> Option<BasicValueUse<'ctx>> {
let use_ = unsafe { LLVMGetOperandUse(self.as_value_ref(), index) };

if use_.is_null() {
Expand All @@ -607,6 +634,15 @@ impl<'ctx> InstructionValue<'ctx> {
unsafe { Some(BasicValueUse::new(use_)) }
}

/// Get an instruction value operand use iterator.
pub fn get_operand_uses(self) -> OperandUseIter<'ctx> {
OperandUseIter {
iv: self,
i: 0,
count: self.get_num_operands(),
}
}

/// Gets the first use of an `InstructionValue` if any.
///
/// The following example,
Expand Down Expand Up @@ -728,3 +764,47 @@ impl Display for InstructionValue<'_> {
write!(f, "{}", self.print_to_string())
}
}

/// Iterate over all the operands of an instruction value.
#[derive(Debug)]
pub struct OperandIter<'ctx> {
iv: InstructionValue<'ctx>,
i: u32,
count: u32,
}

impl<'ctx> Iterator for OperandIter<'ctx> {
type Item = Option<Either<BasicValueEnum<'ctx>, BasicBlock<'ctx>>>;

fn next(&mut self) -> Option<Self::Item> {
if self.i < self.count {
let result = unsafe { self.iv.get_operand_unchecked(self.i) };
self.i += 1;
Some(result)
} else {
None
}
}
}

/// Iterate over all the operands of an instruction value.
#[derive(Debug)]
pub struct OperandUseIter<'ctx> {
iv: InstructionValue<'ctx>,
i: u32,
count: u32,
}

impl<'ctx> Iterator for OperandUseIter<'ctx> {
type Item = Option<BasicValueUse<'ctx>>;

fn next(&mut self) -> Option<Self::Item> {
if self.i < self.count {
let result = unsafe { self.iv.get_operand_use_unchecked(self.i) };
self.i += 1;
Some(result)
} else {
None
}
}
}
4 changes: 3 additions & 1 deletion src/values/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ pub use crate::values::generic_value::GenericValue;
pub use crate::values::global_value::GlobalValue;
#[llvm_versions(7.0..=latest)]
pub use crate::values::global_value::UnnamedAddress;
pub use crate::values::instruction_value::{InstructionOpcode, InstructionValue};
pub use crate::values::instruction_value::{InstructionOpcode, InstructionValue, OperandIter, OperandUseIter};
pub use crate::values::int_value::IntValue;
pub use crate::values::metadata_value::{MetadataValue, FIRST_CUSTOM_METADATA_KIND_ID};
pub use crate::values::phi_value::IncomingIter;
pub use crate::values::phi_value::PhiValue;
pub use crate::values::ptr_value::PointerValue;
pub use crate::values::struct_value::FieldValueIter;
pub use crate::values::struct_value::StructValue;
pub use crate::values::traits::AsValueRef;
pub use crate::values::traits::{AggregateValue, AnyValue, BasicValue, FloatMathValue, IntMathValue, PointerMathValue};
Expand Down
42 changes: 42 additions & 0 deletions src/values/phi_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ impl<'ctx> PhiValue<'ctx> {
Some((value, basic_block))
}

/// # Safety
///
/// The index must be smaller [PhiValue::count_incoming].
pub unsafe fn get_incoming_unchecked(self, index: u32) -> (BasicValueEnum<'ctx>, BasicBlock<'ctx>) {
let basic_block =
unsafe { BasicBlock::new(LLVMGetIncomingBlock(self.as_value_ref(), index)).expect("Invalid BasicBlock") };
let value = unsafe { BasicValueEnum::new(LLVMGetIncomingValue(self.as_value_ref(), index)) };

(value, basic_block)
}

/// Get an incoming edge iterator.
pub fn get_incomings(self) -> IncomingIter<'ctx> {
IncomingIter {
pv: self,
i: 0,
count: self.count_incoming(),
}
}

/// Gets the name of a `ArrayValue`. If the value is a constant, this will
/// return an empty string.
pub fn get_name(&self) -> &CStr {
Expand Down Expand Up @@ -125,3 +145,25 @@ impl<'ctx> TryFrom<InstructionValue<'ctx>> for PhiValue<'ctx> {
}
}
}

/// Iterate over all the incoming edges of a phi value.
#[derive(Debug)]
pub struct IncomingIter<'ctx> {
pv: PhiValue<'ctx>,
i: u32,
count: u32,
}

impl<'ctx> Iterator for IncomingIter<'ctx> {
type Item = (BasicValueEnum<'ctx>, BasicBlock<'ctx>);

fn next(&mut self) -> Option<Self::Item> {
if self.i < self.count {
let result = unsafe { self.pv.get_incoming_unchecked(self.i) };
self.i += 1;
Some(result)
} else {
None
}
}
}
42 changes: 41 additions & 1 deletion src/values/struct_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,25 @@ impl<'ctx> StructValue<'ctx> {
return None;
}

unsafe { Some(BasicValueEnum::new(LLVMGetOperand(self.as_value_ref(), index))) }
Some(unsafe { self.get_field_at_index_unchecked(index) })
}

/// Gets the value of a field belonging to this `StructValue`.
///
/// # Safety
///
/// The index must be smaller than [StructValue::count_fields].
pub unsafe fn get_field_at_index_unchecked(self, index: u32) -> BasicValueEnum<'ctx> {
unsafe { BasicValueEnum::new(LLVMGetOperand(self.as_value_ref(), index)) }
}

/// Get a field value iterator.
pub fn get_fields(self) -> FieldValueIter<'ctx> {
FieldValueIter {
sv: self,
i: 0,
count: self.count_fields(),
}
}

/// Sets the value of a field belonging to this `StructValue`.
Expand Down Expand Up @@ -140,3 +158,25 @@ impl Display for StructValue<'_> {
write!(f, "{}", self.print_to_string())
}
}

/// Iterate over all the field values of this struct.
#[derive(Debug)]
pub struct FieldValueIter<'ctx> {
sv: StructValue<'ctx>,
i: u32,
count: u32,
}

impl<'ctx> Iterator for FieldValueIter<'ctx> {
type Item = BasicValueEnum<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
if self.i < self.count {
let result = unsafe { self.sv.get_field_at_index_unchecked(self.i) };
self.i += 1;
Some(result)
} else {
None
}
}
}
32 changes: 32 additions & 0 deletions tests/all/test_instruction_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ fn test_operands() {
// Test operands
assert_eq!(store_instruction.get_num_operands(), 2);
assert_eq!(free_instruction.get_num_operands(), 2);
assert_eq!(store_instruction.get_operands().count(), 2);
assert_eq!(free_instruction.get_operands().count(), 2);

let store_operand0 = store_instruction.get_operand(0).unwrap();
let store_operand1 = store_instruction.get_operand(1).unwrap();
Expand All @@ -44,6 +46,14 @@ fn test_operands() {
assert!(store_instruction.get_operand(3).is_none());
assert!(store_instruction.get_operand(4).is_none());

let mut store_operands = store_instruction.get_operands();
let store_operand0 = store_operands.next().unwrap().unwrap();
let store_operand1 = store_operands.next().unwrap().unwrap();

assert_eq!(store_operand0.left().unwrap(), f32_val); // f32 const
assert_eq!(store_operand1.left().unwrap(), arg1); // f32* arg1
assert!(store_operands.next().is_none());

let free_operand0 = free_instruction.get_operand(0).unwrap().left().unwrap();
let free_operand1 = free_instruction.get_operand(1).unwrap().left().unwrap();

Expand Down Expand Up @@ -89,6 +99,7 @@ fn test_operands() {
assert!(module.verify().is_ok());

assert_eq!(return_instruction.get_num_operands(), 0);
assert_eq!(return_instruction.get_operands().count(), 0);
assert!(return_instruction.get_operand(0).is_none());
assert!(return_instruction.get_operand(1).is_none());
assert!(return_instruction.get_operand(2).is_none());
Expand Down Expand Up @@ -140,6 +151,27 @@ fn test_operands() {
assert!(store_instruction.get_operand_use(5).is_none());
assert!(store_instruction.get_operand_use(6).is_none());

// However their operands are used
let mut store_operand_uses = store_instruction.get_operand_uses();
let store_operand_use0 = store_operand_uses.next().unwrap().unwrap();
let store_operand_use1 = store_operand_uses.next().unwrap().unwrap();

assert!(store_operand_use0.get_next_use().is_none());
assert!(store_operand_use1.get_next_use().is_none());
assert_eq!(store_operand_use1, arg1_second_use);

assert_eq!(
store_operand_use0.get_user().into_instruction_value(),
store_instruction
);
assert_eq!(
store_operand_use1.get_user().into_instruction_value(),
store_instruction
);
assert_eq!(store_operand_use0.get_used_value().left().unwrap(), f32_val);
assert_eq!(store_operand_use1.get_used_value().left().unwrap(), arg1);
assert!(store_operand_uses.next().is_none());

let free_operand_use0 = free_instruction.get_operand_use(0).unwrap();
let free_operand_use1 = free_instruction.get_operand_use(1).unwrap();

Expand Down
Loading

0 comments on commit 0cdaa34

Please sign in to comment.