diff --git a/src/correlation.rs b/src/correlation.rs index bd69867d..d588fc3a 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -293,12 +293,18 @@ impl PointCorrelations { for row in 0..nrows { for col in 0..ncols { let point = out_data[(row, col)]; - if point.is_none() { - continue; - } let out_row = (row as f32 / scale) as usize; let out_col = (col as f32 / scale) as usize; - correlated_points[(out_row, out_col)] = point; + let out_point = &mut correlated_points[(out_row, out_col)]; + let point_corr = if let Some(point) = point { + point.2 + } else { + continue; + }; + if out_point.map_or(false, |out_point| out_point.2 < point_corr) { + continue; + } + *out_point = point; } } @@ -891,7 +897,7 @@ mod gpu { ); let buffer_internal_img1 = init_buffer( &device, - (img1_pixels * 4) * std::mem::size_of::(), + (img1_pixels * 2) * std::mem::size_of::(), true, false, ); diff --git a/src/correlation.wgsl b/src/correlation.wgsl index aa375ab6..07d5b7cb 100644 --- a/src/correlation.wgsl +++ b/src/correlation.wgsl @@ -23,9 +23,9 @@ struct Parameters var parameters: Parameters; // Array of image data: [img1, img2] @group(0) @binding(0) var images: array; -// For searchdata: contains [min_corridor, stdev, _] for image1 -// For cross_correlate: contains [avg, stdev, corr] for image1 -@group(0) @binding(1) var internals_img1: array>; +// For searchdata: contains [min_corridor, stdev] for image1 +// For cross_correlate: contains [avg, stdev] for image1 +@group(0) @binding(1) var internals_img1: array>; // Contains [avg, stdev] for image 2 @group(0) @binding(2) var internals_img2: array>; // Contains [min, max, neighbor_count] for the corridor range @@ -56,7 +56,7 @@ fn prepare_initialdata_searchdata(@builtin(global_invocation_id) global_id: vec3 let img1_height = parameters.img1_height; if x < img1_width && y < img1_height { - internals_img1[img1_width*y+x] = vec3(0.0, 0.0, 0.0); + internals_img1[img1_width*y+x] = vec2(0.0, 0.0); internals_int[img1_width*y+x] = vec3(-1, -1, 0); } } @@ -97,9 +97,9 @@ fn prepare_initialdata_correlation(@builtin(global_invocation_id) global_id: vec } } stdev = sqrt(stdev/kernel_point_count); - internals_img1[img1_width*y+x] = vec3(avg, stdev, -1.0); + internals_img1[img1_width*y+x] = vec2(avg, stdev); } else if x < img1_width && y < img1_height { - internals_img1[img1_width*y+x] = vec3(0.0, -1.0, -1.0); + internals_img1[img1_width*y+x] = vec2(0.0, -1.0); } if x >= kernel_size && x < img2_width-kernel_size && y >= kernel_size && y < img2_height-kernel_size { @@ -272,7 +272,7 @@ fn cross_correlate(@builtin(global_invocation_id) global_id: vec3) { var data_img1 = internals_img1[img1_width*y1+x1]; let avg1 = data_img1[0]; let stdev1 = data_img1[1]; - let current_corr = data_img1[2]; + let current_corr = result_corr[img1_width*y1+x1]; if stdev1 < min_stdev { return; } @@ -346,7 +346,6 @@ fn cross_correlate(@builtin(global_invocation_id) global_id: vec3) { if (best_corr >= threshold && best_corr > current_corr) { let out_pos = out_width*u32((f32(y1)/scale)) + u32(f32(x1)/scale); - data_img1[2] = best_corr; internals_img1[img1_width*y1+x1] = data_img1; result_matches[out_pos] = best_match; result_corr[img1_width*y1+x1] = best_corr;