Skip to content

Commit

Permalink
Use best correlation value from all levels.
Browse files Browse the repository at this point in the history
This also reduces GPU memory usage.
  • Loading branch information
zlogic committed Aug 18, 2023
1 parent 5917a5d commit 101e935
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
16 changes: 11 additions & 5 deletions src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,18 @@ impl PointCorrelations {
for row in 0..nrows {
for col in 0..ncols {
let point = out_data[(row, col)];
if point.is_none() {
continue;
}
let out_row = (row as f32 / scale) as usize;
let out_col = (col as f32 / scale) as usize;
correlated_points[(out_row, out_col)] = point;
let out_point = &mut correlated_points[(out_row, out_col)];
let point_corr = if let Some(point) = point {
point.2
} else {
continue;
};
if out_point.map_or(false, |out_point| out_point.2 < point_corr) {
continue;
}
*out_point = point;
}
}

Expand Down Expand Up @@ -891,7 +897,7 @@ mod gpu {
);
let buffer_internal_img1 = init_buffer(
&device,
(img1_pixels * 4) * std::mem::size_of::<f32>(),
(img1_pixels * 2) * std::mem::size_of::<f32>(),
true,
false,
);
Expand Down
15 changes: 7 additions & 8 deletions src/correlation.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ struct Parameters
var<push_constant> parameters: Parameters;
// Array of image data: [img1, img2]
@group(0) @binding(0) var<storage> images: array<f32>;
// For searchdata: contains [min_corridor, stdev, _] for image1
// For cross_correlate: contains [avg, stdev, corr] for image1
@group(0) @binding(1) var<storage, read_write> internals_img1: array<vec3<f32>>;
// For searchdata: contains [min_corridor, stdev] for image1
// For cross_correlate: contains [avg, stdev] for image1
@group(0) @binding(1) var<storage, read_write> internals_img1: array<vec2<f32>>;
// Contains [avg, stdev] for image 2
@group(0) @binding(2) var<storage, read_write> internals_img2: array<vec2<f32>>;
// Contains [min, max, neighbor_count] for the corridor range
Expand Down Expand Up @@ -56,7 +56,7 @@ fn prepare_initialdata_searchdata(@builtin(global_invocation_id) global_id: vec3
let img1_height = parameters.img1_height;

if x < img1_width && y < img1_height {
internals_img1[img1_width*y+x] = vec3(0.0, 0.0, 0.0);
internals_img1[img1_width*y+x] = vec2(0.0, 0.0);
internals_int[img1_width*y+x] = vec3(-1, -1, 0);
}
}
Expand Down Expand Up @@ -97,9 +97,9 @@ fn prepare_initialdata_correlation(@builtin(global_invocation_id) global_id: vec
}
}
stdev = sqrt(stdev/kernel_point_count);
internals_img1[img1_width*y+x] = vec3(avg, stdev, -1.0);
internals_img1[img1_width*y+x] = vec2(avg, stdev);
} else if x < img1_width && y < img1_height {
internals_img1[img1_width*y+x] = vec3(0.0, -1.0, -1.0);
internals_img1[img1_width*y+x] = vec2(0.0, -1.0);
}

if x >= kernel_size && x < img2_width-kernel_size && y >= kernel_size && y < img2_height-kernel_size {
Expand Down Expand Up @@ -272,7 +272,7 @@ fn cross_correlate(@builtin(global_invocation_id) global_id: vec3<u32>) {
var data_img1 = internals_img1[img1_width*y1+x1];
let avg1 = data_img1[0];
let stdev1 = data_img1[1];
let current_corr = data_img1[2];
let current_corr = result_corr[img1_width*y1+x1];
if stdev1 < min_stdev {
return;
}
Expand Down Expand Up @@ -346,7 +346,6 @@ fn cross_correlate(@builtin(global_invocation_id) global_id: vec3<u32>) {
if (best_corr >= threshold && best_corr > current_corr)
{
let out_pos = out_width*u32((f32(y1)/scale)) + u32(f32(x1)/scale);
data_img1[2] = best_corr;
internals_img1[img1_width*y1+x1] = data_img1;
result_matches[out_pos] = best_match;
result_corr[img1_width*y1+x1] = best_corr;
Expand Down

0 comments on commit 101e935

Please sign in to comment.