Skip to content

Commit

Permalink
Use correlation value for pose estimation.
Browse files Browse the repository at this point in the history
Only points with the best correlation should be used to estimate the
pose.
  • Loading branch information
zlogic committed Aug 18, 2023
1 parent 582c4a0 commit d9a7ef4
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 54 deletions.
104 changes: 82 additions & 22 deletions src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const KERNEL_WIDTH: usize = KERNEL_SIZE * 2 + 1;
const KERNEL_POINT_COUNT: usize = KERNEL_WIDTH * KERNEL_WIDTH;

const THRESHOLD_AFFINE: f32 = 0.6;
const THRESHOLD_PERSPECTIVE: f32 = 0.7;
const THRESHOLD_PERSPECTIVE: f32 = 0.6;
const MIN_STDEV_AFFINE: f32 = 1.0;
const MIN_STDEV_PERSPECTIVE: f32 = 3.0;
const CORRIDOR_SIZE_AFFINE: usize = 2;
Expand All @@ -24,7 +24,7 @@ const CORRIDOR_EXTEND_RANGE_PERSPECTIVE: f64 = 1.0;
const CORRIDOR_MIN_RANGE: f64 = 2.5;
const CROSS_CHECK_SEARCH_AREA: usize = 2;

type Match = (u32, u32);
type Match = (u32, u32, f32);

#[derive(Debug)]
pub struct PointData<const KPC: usize> {
Expand Down Expand Up @@ -77,7 +77,7 @@ struct EpipolarLine {
}

struct BestMatch {
pos: Option<Match>,
pos: Option<(u32, u32)>,
corr: Option<f32>,
}

Expand Down Expand Up @@ -188,26 +188,29 @@ impl PointCorrelations {
img2: DMatrix<u8>,
scale: f32,
progress_listener: Option<&PL>,
) {
) -> Result<(), Box<dyn error::Error>> {
// Start with reverse direction - so that the last correlation values will be from the forward direction.
self.correlate_images_step(
&img1,
&img2,
&img1,
scale,
progress_listener,
CorrelationDirection::Forward,
);
CorrelationDirection::Reverse,
)?;
self.correlate_images_step(
&img2,
&img1,
&img2,
scale,
progress_listener,
CorrelationDirection::Reverse,
);
CorrelationDirection::Forward,
)?;

self.cross_check_filter(scale, CorrelationDirection::Forward);
self.cross_check_filter(scale, CorrelationDirection::Reverse);

self.first_pass = false;

Ok(())
}

fn correlate_images_step<PL: ProgressListener>(
Expand All @@ -217,21 +220,20 @@ impl PointCorrelations {
scale: f32,
progress_listener: Option<&PL>,
dir: CorrelationDirection,
) {
) -> Result<(), Box<dyn error::Error>> {
if let Some(gpu_context) = &mut self.gpu_context {
let dir = match dir {
CorrelationDirection::Forward => gpu::CorrelationDirection::Forward,
CorrelationDirection::Reverse => gpu::CorrelationDirection::Reverse,
};
gpu_context.correlate_images(
return gpu_context.correlate_images(
img1,
img2,
scale,
self.first_pass,
progress_listener,
dir,
);
return;
};
let img2_data = compute_image_point_data(img2);
let mut out_data: DMatrix<Option<Match>> =
Expand Down Expand Up @@ -299,6 +301,8 @@ impl PointCorrelations {
correlated_points[(out_row, out_col)] = point;
}
}

Ok(())
}

fn correlate_point(
Expand Down Expand Up @@ -364,7 +368,9 @@ impl PointCorrelations {
corridor_range.clone(),
);
}
*out_point = best_match.pos
*out_point = best_match
.pos
.and_then(|m| best_match.corr.map(|corr| (m.0, m.1, corr)))
}

fn get_epipolar_line(
Expand Down Expand Up @@ -575,7 +581,7 @@ impl PointCorrelations {
search_area: usize,
row: usize,
col: usize,
m: (u32, u32),
m: Match,
) -> bool {
let min_row = (m.0 as usize)
.saturating_sub(search_area)
Expand Down Expand Up @@ -726,7 +732,7 @@ pub fn compute_point_data<const KS: usize, const KPC: usize>(
}

mod gpu {
const MAX_BINDINGS: u32 = 5;
const MAX_BINDINGS: u32 = 6;

use std::{borrow::Cow, collections::HashMap, error, fmt};

Expand Down Expand Up @@ -785,6 +791,7 @@ mod gpu {
buffer_internal_img2: wgpu::Buffer,
buffer_internal_int: wgpu::Buffer,
buffer_out: wgpu::Buffer,
buffer_out_corr: wgpu::Buffer,
buffer_out_reverse: wgpu::Buffer,

pipeline_configs: HashMap<String, ComputePipelineConfig>,
Expand Down Expand Up @@ -884,7 +891,7 @@ mod gpu {
);
let buffer_internal_img1 = init_buffer(
&device,
(img1_pixels * 4) * std::mem::size_of::<f32>(),
(img1_pixels * 2) * std::mem::size_of::<f32>(),
true,
false,
);
Expand All @@ -906,6 +913,12 @@ mod gpu {
false,
false,
);
let buffer_out_corr = init_buffer(
&device,
img1_pixels * 1 * std::mem::size_of::<f32>(),
false,
false,
);
let buffer_out_reverse = init_buffer(
&device,
img2_pixels * 2 * std::mem::size_of::<i32>(),
Expand All @@ -932,6 +945,7 @@ mod gpu {
buffer_internal_img2,
buffer_internal_int,
buffer_out,
buffer_out_corr,
buffer_out_reverse,
pipeline_configs: HashMap::new(),
};
Expand All @@ -950,7 +964,7 @@ mod gpu {
first_pass: bool,
progress_listener: Option<&PL>,
dir: CorrelationDirection,
) {
) -> Result<(), Box<dyn error::Error>> {
let max_width = img1.ncols().max(img2.ncols());
let max_height = img1.nrows().max(img2.nrows());
let max_shape = (max_height, max_width);
Expand Down Expand Up @@ -1067,6 +1081,8 @@ mod gpu {
send_progress(percent_complete);
}
}

Ok(())
}

pub fn cross_check_filter(&mut self, scale: f32, dir: CorrelationDirection) {
Expand Down Expand Up @@ -1180,6 +1196,7 @@ mod gpu {
self.buffer_internal_img2.destroy();
self.buffer_internal_int.destroy();
self.buffer_out_reverse.destroy();
self.buffer_out_corr.destroy();

let mut out_image = DMatrix::from_element(self.img1_shape.0, self.img1_shape.1, None);

Expand All @@ -1189,7 +1206,14 @@ mod gpu {
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let out_corr_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: self.buffer_out_corr.size(),
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let out_buffer_slice = out_buffer.slice(..);
let out_corr_buffer_slice = out_corr_buffer.slice(..);
{
let mut encoder = self
.device
Expand All @@ -1201,31 +1225,51 @@ mod gpu {
0,
self.buffer_out.size(),
);
encoder.copy_buffer_to_buffer(
&self.buffer_out_corr,
0,
&out_corr_buffer,
0,
self.buffer_out_corr.size(),
);
self.queue.submit(Some(encoder.finish()));
let (sender, receiver) = mpsc::channel();
out_buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
let (sender_out, receiver_out) = mpsc::channel();
out_buffer_slice
.map_async(wgpu::MapMode::Read, move |v| sender_out.send(v).unwrap());
let (sender_out_corr, receiver_out_corr) = mpsc::channel();
out_corr_buffer_slice.map_async(wgpu::MapMode::Read, move |v| {
sender_out_corr.send(v).unwrap()
});
self.device.poll(wgpu::Maintain::Wait);

if let Err(err) = receiver.recv() {
if let Err(err) = receiver_out.recv() {
return Err(err.into());
}
if let Err(err) = receiver_out_corr.recv() {
return Err(err.into());
}
}

let out_buffer_slice_mapped = out_buffer_slice.get_mapped_range();
let out_corr_buffer_slice_mapped = out_corr_buffer_slice.get_mapped_range();
let out_data: &[i32] = bytemuck::cast_slice(&out_buffer_slice_mapped);
let out_corr_data: &[f32] = bytemuck::cast_slice(&out_buffer_slice_mapped);
for col in 0..out_image.ncols() {
for row in 0..out_image.nrows() {
let pos = 2 * (row * out_image.ncols() + col);
let point_match = (out_data[pos], out_data[pos + 1]);
let corr = out_corr_data[row * out_image.ncols() + col];
out_image[(row, col)] = if point_match.0 > 0 && point_match.1 > 0 {
Some((point_match.1 as u32, point_match.0 as u32))
Some((point_match.1 as u32, point_match.0 as u32, corr))
} else {
None
};
}
}
drop(out_buffer_slice_mapped);
drop(out_corr_buffer_slice_mapped);
out_buffer.unmap();
out_corr_buffer.unmap();
Ok(out_image)
}

Expand Down Expand Up @@ -1300,6 +1344,18 @@ mod gpu {
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 5,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: wgpu::BufferSize::new(
self.buffer_out_corr.size(),
),
},
count: None,
},
],
});

Expand Down Expand Up @@ -1377,6 +1433,10 @@ mod gpu {
binding: 4,
resource: buffer_out.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: self.buffer_out_corr.as_entire_binding(),
},
],
});

Expand Down
27 changes: 15 additions & 12 deletions src/correlation.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ struct Parameters
};

var<push_constant> parameters: Parameters;
// Array of image data: [img1, img2]
@group(0) @binding(0) var<storage> images: array<f32>;
// For searchdata: contains [min_corridor, stdev, _] for image1
// For cross_correlate: contains [avg, stdev, corr] for image1
@group(0) @binding(1) var<storage, read_write> internals_img1: array<vec3<f32>>;
// For searchdata: contains [min_corridor, stdev] for image1
// For cross_correlate: contains [avg, stdev] for image1
@group(0) @binding(1) var<storage, read_write> internals_img1: array<vec2<f32>>;
// Contains [avg, stdev] for image 2
@group(0) @binding(2) var<storage, read_write> internals_img2: array<vec2<f32>>;
// Contains [min, max, neighbor_count] for the corridor range
@group(0) @binding(3) var<storage, read_write> internals_int: array<vec3<i32>>;
@group(0) @binding(4) var<storage, read_write> result: array<vec2<i32>>;
@group(0) @binding(4) var<storage, read_write> result_matches: array<vec2<i32>>;
@group(0) @binding(5) var<storage, read_write> result_corr: array<f32>;

@compute @workgroup_size(16, 16, 1)
fn init_out_data(@builtin(global_invocation_id) global_id: vec3<u32>) {
Expand All @@ -40,7 +42,8 @@ fn init_out_data(@builtin(global_invocation_id) global_id: vec3<u32>) {
let out_height = parameters.out_height;

if x < out_width && y < out_height {
result[out_width*y+x] = vec2(-1, -1);
result_matches[out_width*y+x] = vec2(-1, -1);
result_corr[out_width*y+x] = -1.0;
}
}

Expand All @@ -53,7 +56,7 @@ fn prepare_initialdata_searchdata(@builtin(global_invocation_id) global_id: vec3
let img1_height = parameters.img1_height;

if x < img1_width && y < img1_height {
internals_img1[img1_width*y+x] = vec3(0.0, 0.0, 0.0);
internals_img1[img1_width*y+x] = vec2(0.0, 0.0);
internals_int[img1_width*y+x] = vec3(-1, -1, 0);
}
}
Expand Down Expand Up @@ -94,9 +97,9 @@ fn prepare_initialdata_correlation(@builtin(global_invocation_id) global_id: vec
}
}
stdev = sqrt(stdev/kernel_point_count);
internals_img1[img1_width*y+x] = vec3(avg, stdev, -1.0);
internals_img1[img1_width*y+x] = vec2(avg, stdev);
} else if x < img1_width && y < img1_height {
internals_img1[img1_width*y+x] = vec3(0.0, -1.0, -1.0);
internals_img1[img1_width*y+x] = vec2(0.0, -1.0);
}

if x >= kernel_size && x < img2_width-kernel_size && y >= kernel_size && y < img2_height-kernel_size {
Expand Down Expand Up @@ -194,7 +197,7 @@ fn prepare_searchdata(@builtin(global_invocation_id) global_id: vec3<u32>) {
continue;
}

let coord2 = vec2<f32>(result[u32(y_out)*out_width + u32(x_out)]) * scale;
let coord2 = vec2<f32>(result_corr[u32(y_out)*out_width + u32(x_out)]) * scale;
if coord2.x<0.0 || coord2.y<0.0 {
continue;
}
Expand Down Expand Up @@ -269,7 +272,7 @@ fn cross_correlate(@builtin(global_invocation_id) global_id: vec3<u32>) {
var data_img1 = internals_img1[img1_width*y1+x1];
let avg1 = data_img1[0];
let stdev1 = data_img1[1];
let current_corr = data_img1[2];
let current_corr = result_corr[img1_width*y1+x1];
if stdev1 < min_stdev {
return;
}
Expand Down Expand Up @@ -343,9 +346,9 @@ fn cross_correlate(@builtin(global_invocation_id) global_id: vec3<u32>) {
if (best_corr >= threshold && best_corr > current_corr)
{
let out_pos = out_width*u32((f32(y1)/scale)) + u32(f32(x1)/scale);
data_img1[2] = best_corr;
internals_img1[img1_width*y1+x1] = data_img1;
result[out_pos] = best_match;
result_matches[out_pos] = best_match;
result_corr[img1_width*y1+x1] = best_corr;
}
}

Expand Down
Loading

0 comments on commit d9a7ef4

Please sign in to comment.