Skip to content

Commit

Permalink
Use the latest correlation value.
Browse files Browse the repository at this point in the history
Using the highest increases peaks in SEM images.
Parallelize extraction of GPU results.
  • Loading branch information
zlogic committed Aug 20, 2023
1 parent 25b64a1 commit 3560e46
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 79 deletions.
180 changes: 102 additions & 78 deletions src/correlation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use nalgebra::{DMatrix, Matrix3, Vector3};
use rayon::prelude::*;
use std::{cell::RefCell, error, ops::Range, sync::atomic::AtomicUsize, sync::atomic::Ordering};

use rayon::prelude::*;

const SCALE_MIN_SIZE: usize = 64;
const KERNEL_SIZE: usize = 5;
const KERNEL_WIDTH: usize = KERNEL_SIZE * 2 + 1;
Expand Down Expand Up @@ -295,14 +296,6 @@ impl PointCorrelations {
let out_row = (row as f32 / scale) as usize;
let out_col = (col as f32 / scale) as usize;
let out_point = &mut correlated_points[(out_row, out_col)];
let point_corr = if let Some(point) = point {
point.2
} else {
continue;
};
if out_point.map_or(false, |out_point| out_point.2 < point_corr) {
continue;
}
*out_point = point;
}
}
Expand Down Expand Up @@ -746,6 +739,8 @@ mod gpu {
use pollster::FutureExt;
use std::sync::mpsc;

use rayon::prelude::*;

use super::{
CORRIDOR_MIN_RANGE, CORRIDOR_SEGMENT_LENGTH_HIGHPERFORMANCE,
CORRIDOR_SEGMENT_LENGTH_LOWPOWER, CROSS_CHECK_SEARCH_AREA, KERNEL_SIZE, NEIGHBOR_DISTANCE,
Expand Down Expand Up @@ -782,6 +777,8 @@ mod gpu {
img1_shape: (usize, usize),
img2_shape: (usize, usize),

correlation_values: DMatrix<Option<f32>>,

corridor_segment_length: usize,
search_area_segment_length: usize,
corridor_size: usize,
Expand All @@ -798,7 +795,6 @@ mod gpu {
buffer_out: wgpu::Buffer,
buffer_out_reverse: wgpu::Buffer,
buffer_out_corr: wgpu::Buffer,
buffer_reverse_corr: wgpu::Buffer,

pipeline_configs: HashMap<String, ComputePipelineConfig>,
}
Expand Down Expand Up @@ -830,6 +826,7 @@ mod gpu {

let img1_pixels = img1_dimensions.0 * img1_dimensions.1;
let img2_pixels = img2_dimensions.0 * img2_dimensions.1;
let max_pixels = img1_pixels.max(img2_pixels);

// Init adapter.
let instance = wgpu::Instance::default();
Expand Down Expand Up @@ -865,7 +862,7 @@ mod gpu {
limits.max_bindings_per_bind_group = MAX_BINDINGS;
limits.max_storage_buffers_per_shader_stage = MAX_BINDINGS;
// Ensure there's enough memory for the largest buffer.
let max_buffer_size = (img1_pixels * 3 + img2_pixels * 2) * std::mem::size_of::<f32>();
let max_buffer_size = max_pixels * 4 * std::mem::size_of::<i32>();
limits.max_storage_buffer_binding_size = max_buffer_size as u32;
limits.max_buffer_size = max_buffer_size as u64;
limits.max_push_constant_size = std::mem::size_of::<ShaderParams>() as u32;
Expand Down Expand Up @@ -909,13 +906,13 @@ mod gpu {
);
let buffer_internal_int = init_buffer(
&device,
img1_pixels * 4 * std::mem::size_of::<i32>(),
max_pixels * 4 * std::mem::size_of::<i32>(),
true,
false,
);
let buffer_out = init_buffer(
&device,
img1_pixels * 2 * std::mem::size_of::<i32>(),
max_pixels * 2 * std::mem::size_of::<i32>(),
false,
false,
);
Expand All @@ -927,17 +924,13 @@ mod gpu {
);
let buffer_out_corr = init_buffer(
&device,
img1_pixels * std::mem::size_of::<f32>(),
false,
false,
);
let buffer_reverse_corr = init_buffer(
&device,
img2_pixels * std::mem::size_of::<f32>(),
max_pixels * std::mem::size_of::<f32>(),
false,
false,
);

let correlation_values = DMatrix::from_element(img1_shape.0, img1_shape.1, None);

let result = GpuContext {
min_stdev,
correlation_threshold,
Expand All @@ -946,6 +939,7 @@ mod gpu {
fundamental_matrix,
img1_shape,
img2_shape,
correlation_values,
corridor_segment_length,
search_area_segment_length,
device_name,
Expand All @@ -959,7 +953,6 @@ mod gpu {
buffer_out,
buffer_out_reverse,
buffer_out_corr,
buffer_reverse_corr,
pipeline_configs: HashMap::new(),
};
Ok(result)
Expand Down Expand Up @@ -991,7 +984,7 @@ mod gpu {
let send_progress = |value| {
let value = match dir {
CorrelationDirection::Forward => value * 0.98 / 2.0,
CorrelationDirection::Reverse => 0.5 + value * 0.98 / 2.0,
CorrelationDirection::Reverse => 0.51 + value * 0.98 / 2.0,
};
if let Some(pl) = progress_listener {
pl.report_status(value);
Expand Down Expand Up @@ -1095,7 +1088,7 @@ mod gpu {
}
}

Ok(())
self.save_corr(&dir)
}

pub fn cross_check_filter(&mut self, scale: f32, dir: CorrelationDirection) {
Expand Down Expand Up @@ -1201,6 +1194,61 @@ mod gpu {
);
}

fn save_corr(&mut self, dir: &CorrelationDirection) -> Result<(), Box<dyn error::Error>> {
if !matches!(dir, CorrelationDirection::Forward) {
return Ok(());
}
let out_corr_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: self.buffer_out_corr.size(),
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let out_corr_buffer_slice = out_corr_buffer.slice(..);
{
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
encoder.copy_buffer_to_buffer(
&self.buffer_out_corr,
0,
&out_corr_buffer,
0,
self.buffer_out_corr.size(),
);
self.queue.submit(Some(encoder.finish()));
let (sender, receiver) = mpsc::channel();
out_corr_buffer_slice
.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
self.device.poll(wgpu::Maintain::Wait);

if let Err(err) = receiver.recv() {
return Err(err.into());
}
}

let out_corr_buffer_slice_mapped = out_corr_buffer_slice.get_mapped_range();
let out_corr_data: &[f32] = bytemuck::cast_slice(&out_corr_buffer_slice_mapped);

let ncols = self.correlation_values.ncols();
self.correlation_values
.column_iter_mut()
.enumerate()
.par_bridge()
.for_each(|(col, mut out_col)| {
out_col.iter_mut().enumerate().for_each(|(row, out_point)| {
let corr = out_corr_data[row * ncols + col];
if corr > self.correlation_threshold {
*out_point = Some(corr);
}
})
});
drop(out_corr_buffer_slice_mapped);
out_corr_buffer.unmap();

Ok(())
}

pub fn complete_process(
&mut self,
) -> Result<DMatrix<Option<super::Match>>, Box<dyn error::Error>> {
Expand All @@ -1209,7 +1257,7 @@ mod gpu {
self.buffer_internal_img2.destroy();
self.buffer_internal_int.destroy();
self.buffer_out_reverse.destroy();
self.buffer_reverse_corr.destroy();
self.buffer_out_corr.destroy();

let mut out_image = DMatrix::from_element(self.img1_shape.0, self.img1_shape.1, None);

Expand All @@ -1219,14 +1267,7 @@ mod gpu {
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let out_corr_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: self.buffer_out_corr.size(),
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let out_buffer_slice = out_buffer.slice(..);
let out_corr_buffer_slice = out_corr_buffer.slice(..);
{
let mut encoder = self
.device
Expand All @@ -1238,51 +1279,40 @@ mod gpu {
0,
self.buffer_out.size(),
);
encoder.copy_buffer_to_buffer(
&self.buffer_out_corr,
0,
&out_corr_buffer,
0,
self.buffer_out_corr.size(),
);
self.queue.submit(Some(encoder.finish()));
let (sender_out, receiver_out) = mpsc::channel();
out_buffer_slice
.map_async(wgpu::MapMode::Read, move |v| sender_out.send(v).unwrap());
let (sender_out_corr, receiver_out_corr) = mpsc::channel();
out_corr_buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
sender_out_corr.send(v).unwrap()
});
let (sender, receiver) = mpsc::channel();
out_buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
self.device.poll(wgpu::Maintain::Wait);

if let Err(err) = receiver_out.recv() {
return Err(err.into());
}
if let Err(err) = receiver_out_corr.recv() {
if let Err(err) = receiver.recv() {
return Err(err.into());
}
}

let out_buffer_slice_mapped = out_buffer_slice.get_mapped_range();
let out_corr_buffer_slice_mapped = out_corr_buffer_slice.get_mapped_range();
let out_data: &[i32] = bytemuck::cast_slice(&out_buffer_slice_mapped);
let out_corr_data: &[f32] = bytemuck::cast_slice(&out_corr_buffer_slice_mapped);
for col in 0..out_image.ncols() {
for row in 0..out_image.nrows() {
let pos = 2 * (row * out_image.ncols() + col);
let point_match = (out_data[pos], out_data[pos + 1]);
let corr = out_corr_data[row * out_image.ncols() + col];
out_image[(row, col)] = if point_match.0 > 0 && point_match.1 > 0 {
Some((point_match.1 as u32, point_match.0 as u32, corr))
} else {
None
};
}
}
let ncols = out_image.ncols();
out_image
.column_iter_mut()
.enumerate()
.par_bridge()
.for_each(|(col, mut out_col)| {
out_col.iter_mut().enumerate().for_each(|(row, out_point)| {
let pos = 2 * (row * ncols + col);
let point_match = (out_data[pos], out_data[pos + 1]);
if let Some(corr) = self.correlation_values[(row, col)] {
*out_point = if point_match.0 > 0 && point_match.1 > 0 {
Some((point_match.1 as u32, point_match.0 as u32, corr))
} else {
None
};
} else {
*out_point = None;
};
})
});
drop(out_buffer_slice_mapped);
drop(out_corr_buffer_slice_mapped);
out_buffer.unmap();
out_corr_buffer.unmap();
Ok(out_image)
}

Expand All @@ -1291,17 +1321,9 @@ mod gpu {
entry_point: &str,
dir: &CorrelationDirection,
) -> ComputePipelineConfig {
let (buffer_out, buffer_out_reverse, buffer_out_corr) = match dir {
CorrelationDirection::Forward => (
&self.buffer_out,
&self.buffer_out_reverse,
&self.buffer_out_corr,
),
CorrelationDirection::Reverse => (
&self.buffer_out_reverse,
&self.buffer_out,
&self.buffer_reverse_corr,
),
let (buffer_out, buffer_out_reverse) = match dir {
CorrelationDirection::Forward => (&self.buffer_out, &self.buffer_out_reverse),
CorrelationDirection::Reverse => (&self.buffer_out_reverse, &self.buffer_out),
};

let correlation_layout =
Expand Down Expand Up @@ -1371,7 +1393,9 @@ mod gpu {
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: wgpu::BufferSize::new(buffer_out_corr.size()),
min_binding_size: wgpu::BufferSize::new(
self.buffer_out_corr.size(),
),
},
count: None,
},
Expand Down Expand Up @@ -1454,7 +1478,7 @@ mod gpu {
},
wgpu::BindGroupEntry {
binding: 5,
resource: buffer_out_corr.as_entire_binding(),
resource: self.buffer_out_corr.as_entire_binding(),
},
],
});
Expand Down
2 changes: 1 addition & 1 deletion src/correlation.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ fn init_out_data(@builtin(global_invocation_id) global_id: vec3<u32>) {

if x < out_width && y < out_height {
result_matches[out_width*y+x] = vec2(-1, -1);
result_corr[out_width*y+x] = -1.0;
}
}

Expand All @@ -58,6 +57,7 @@ fn prepare_initialdata_searchdata(@builtin(global_invocation_id) global_id: vec3
if x < img1_width && y < img1_height {
internals_img1[img1_width*y+x] = vec2(0.0, 0.0);
internals_int[img1_width*y+x] = vec3(-1, -1, 0);
result_corr[out_width*y+x] = -1.0;
}
}

Expand Down

0 comments on commit 3560e46

Please sign in to comment.