Skip to content

Commit

Permalink
Improve device scoring.
Browse files Browse the repository at this point in the history
Decrease device buffer sizes and improve searching for the right buffer.
  • Loading branch information
zlogic committed Jan 28, 2024
1 parent 9bed704 commit e2e78f7
Showing 1 changed file with 45 additions and 63 deletions.
108 changes: 45 additions & 63 deletions src/correlation/vk.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, error, ffi::CStr, fmt, slice, time::SystemTime};
use std::{cmp::Ordering, collections::HashMap, error, ffi::CStr, fmt, slice, time::SystemTime};

use ash::{prelude::VkResult, vk};
use nalgebra::Matrix3;
Expand Down Expand Up @@ -437,7 +437,7 @@ impl Device {
fn new(img1_pixels: usize, img2_pixels: usize) -> Result<Device, Box<dyn error::Error>> {
// Ensure there's enough memory for the largest buffer.
let max_pixels = img1_pixels.max(img2_pixels);
let max_buffer_size = max_pixels * 4 * std::mem::size_of::<i32>();
let max_buffer_size = max_pixels * 2 * std::mem::size_of::<i32>();

// Init adapter.
let entry = unsafe { ash::Entry::load()? };
Expand Down Expand Up @@ -906,7 +906,7 @@ impl Device {
unsafe fn find_device(
instance: &ash::Instance,
max_buffer_size: usize,
) -> Result<(vk::PhysicalDevice, &'static str, u32), Box<dyn error::Error>> {
) -> Result<(vk::PhysicalDevice, String, u32), Box<dyn error::Error>> {
let devices = instance.enumerate_physical_devices()?;
let device = devices
.iter()
Expand All @@ -923,16 +923,7 @@ impl Device {
let queue_index = Device::find_compute_queue(instance, device)?;

let device_name = CStr::from_ptr(props.device_name.as_ptr());
let device_name = device_name.to_str().unwrap();
println!(
"Device {} type {} {}-{}-{}-{}",
device_name,
props.device_type.as_raw(),
props.limits.max_push_constants_size,
props.limits.max_bound_descriptor_sets,
props.limits.max_storage_buffer_range,
max_buffer_size
);
let device_name = String::from_utf8_lossy(device_name.to_bytes()).to_string();
// TODO: allow to specify a device name filter/regex?
let score = match props.device_type {
vk::PhysicalDeviceType::DISCRETE_GPU => 3,
Expand All @@ -941,24 +932,25 @@ impl Device {
_ => 0,
};
// Prefer real devices instead of dzn emulation.
let dzn_multiplier = if device_name
let is_dzn = device_name
.to_lowercase()
.starts_with("microsoft direct3d12")
{
1
} else {
10
};
Some((device, device_name, queue_index, score * dzn_multiplier))
.starts_with("microsoft direct3d12");
let score = (score, is_dzn);
Some((device, device_name, queue_index, score))
})
.max_by_key(|(_device, _name, _queue_index, score)| *score);
let (device, name, queue_index) = if let Some((device, name, queue_index, _score)) = device
{
.max_by(|(_, _, _, a), (_, _, _, b)| {
if a.1 && !b.1 {
return Ordering::Less;
} else if !a.1 && b.1 {
return Ordering::Greater;
}
return a.0.cmp(&b.0);
});
let (device, name, queue_index) = if let Some((device, name, queue_index, score)) = device {
(device, name, queue_index)
} else {
return Err(GpuError::new("Device not found").into());
};
println!("selected device {}", name);
Ok((device, name, queue_index))
}

Expand Down Expand Up @@ -1009,6 +1001,7 @@ impl Device {
let max_pixels = img1_pixels.max(img2_pixels);
let mut buffers: Vec<Buffer> = vec![];
let cleanup_err = |buffers: &[Buffer], err| {
println!("buffers count is {}", buffers.len());
buffers.iter().for_each(|buffer| {
device.free_memory(buffer.buffer_memory, None);
device.destroy_buffer(buffer.buffer, None)
Expand Down Expand Up @@ -1045,7 +1038,7 @@ impl Device {
let buffer_internal_int = Device::create_buffer(
device,
memory_properties,
max_pixels * 4 * std::mem::size_of::<i32>(),
max_pixels * 2 * std::mem::size_of::<i32>(),
BufferType::GpuOnly,
)
.map_err(|err| cleanup_err(buffers.as_slice(), err))?;
Expand Down Expand Up @@ -1095,13 +1088,13 @@ impl Device {
buffer_type: BufferType,
) -> Result<Buffer, Box<dyn error::Error>> {
let size = size as u64;
let gpu_local = match buffer_type {
BufferType::GpuOnly | BufferType::GpuDestination | BufferType::GpuSource => true,
BufferType::HostSource | BufferType::HostDestination => false,
};
let host_visible = match buffer_type {
BufferType::HostSource | BufferType::HostDestination => true,
BufferType::GpuOnly | BufferType::GpuDestination | BufferType::GpuSource => false,
let required_memory_properties = match buffer_type {
BufferType::GpuOnly | BufferType::GpuDestination | BufferType::GpuSource => {
vk::MemoryPropertyFlags::DEVICE_LOCAL
}
BufferType::HostSource | BufferType::HostDestination => {
vk::MemoryPropertyFlags::HOST_VISIBLE
}
};
let extra_usage_flags = match buffer_type {
BufferType::HostSource => vk::BufferUsageFlags::TRANSFER_SRC,
Expand All @@ -1122,37 +1115,26 @@ impl Device {
};
let buffer = device.create_buffer(&buffer_create_info, None)?;
let memory_requirements = device.get_buffer_memory_requirements(buffer);
let memory_type_index = memory_properties.memory_types
[..memory_properties.memory_type_count as usize]
.iter()
.enumerate()
.find(|(memory_type_index, memory_type)| {
if memory_properties.memory_heaps[memory_type.heap_index as usize].size
< memory_requirements.size
{
return false;
};
if (1 << memory_type_index) & memory_requirements.memory_type_bits == 0 {
return false;
}
let memory_type_index = (0..memory_properties.memory_type_count as usize).find(|i| {
let memory_type = memory_properties.memory_types[*i];
if memory_properties.memory_heaps[memory_type.heap_index as usize].size
< memory_requirements.size
{
return false;
};
if ((1 << i) & memory_requirements.memory_type_bits) == 0 {
return false;
}

if gpu_local
&& memory_type
.property_flags
.contains(vk::MemoryPropertyFlags::DEVICE_LOCAL)
{
return true;
}
if host_visible
&& memory_type
.property_flags
.contains(vk::MemoryPropertyFlags::HOST_VISIBLE)
{
return true;
}
false
});
let memory_type_index = if let Some((index, _)) = memory_type_index {
if memory_type
.property_flags
.contains(required_memory_properties)
{
return true;
}
false
});
let memory_type_index = if let Some(index) = memory_type_index {
index as u32
} else {
return Err(GpuError::new("Cannot find suitable memory").into());
Expand Down

0 comments on commit e2e78f7

Please sign in to comment.