diff --git a/src/correlation/gpu/vulkan.rs b/src/correlation/gpu/vulkan.rs index a738542..3d74877 100644 --- a/src/correlation/gpu/vulkan.rs +++ b/src/correlation/gpu/vulkan.rs @@ -65,8 +65,7 @@ struct DeviceBuffers { struct DescriptorSets { descriptor_pool: vk::DescriptorPool, - regular_layout: vk::DescriptorSetLayout, - cross_check_layout: vk::DescriptorSetLayout, + layout: vk::DescriptorSetLayout, pipeline_layout: vk::PipelineLayout, descriptor_sets: Vec, } @@ -719,20 +718,13 @@ impl Device { device.destroy_descriptor_pool(descriptor_pool, None); err }; - let regular_layout = create_layout_bindings(6).map_err(cleanup_err)?; + let layout = create_layout_bindings(6).map_err(cleanup_err)?; let cleanup_err = |err| { - device.destroy_descriptor_set_layout(regular_layout, None); + device.destroy_descriptor_set_layout(layout, None); device.destroy_descriptor_pool(descriptor_pool, None); err }; - let cross_check_layout = create_layout_bindings(2).map_err(cleanup_err)?; - let cleanup_err = |err| { - device.destroy_descriptor_set_layout(cross_check_layout, None); - device.destroy_descriptor_set_layout(regular_layout, None); - device.destroy_descriptor_pool(descriptor_pool, None); - err - }; - let layouts = [regular_layout, cross_check_layout]; + let layouts = [layout]; let push_constant_ranges = vk::PushConstantRange::default() .offset(0) .size(std::mem::size_of::() as u32) @@ -747,8 +739,7 @@ impl Device { .map_err(cleanup_err)?; let cleanup_err = |err| { device.destroy_pipeline_layout(pipeline_layout, None); - device.destroy_descriptor_set_layout(cross_check_layout, None); - device.destroy_descriptor_set_layout(regular_layout, None); + device.destroy_descriptor_set_layout(layout, None); device.destroy_descriptor_pool(descriptor_pool, None); err }; @@ -761,8 +752,7 @@ impl Device { Ok(DescriptorSets { descriptor_pool, - regular_layout, - cross_check_layout, + layout, pipeline_layout, descriptor_sets, }) @@ -843,17 +833,6 @@ impl Device { let direction = self.direction; let descriptor_sets = &self.descriptor_sets; let buffers = &self.buffers()?; - let create_buffer_infos = |buffers: &[Buffer]| { - buffers - .iter() - .map(|buf| { - vk::DescriptorBufferInfo::default() - .buffer(buf.buffer) - .offset(0) - .range(vk::WHOLE_SIZE) - }) - .collect::>() - }; let (buffer_internal_img1, buffer_internal_img2, buffer_out, buffer_out_reverse) = match direction { CorrelationDirection::Forward => ( @@ -869,28 +848,29 @@ impl Device { buffers.buffer_out, ), }; - let (buffer_list, descriptor_set) = if matches!(shader, ShaderModuleType::CrossCheckFilter) - { - ( - vec![buffer_out, buffer_out_reverse], - descriptor_sets.descriptor_sets[0], - ) + let buffer_list = if matches!(shader, ShaderModuleType::CrossCheckFilter) { + vec![buffer_out, buffer_out_reverse] } else { - ( - vec![ - buffers.buffer_img, - buffer_internal_img1, - buffer_internal_img2, - buffers.buffer_internal_int, - buffer_out, - buffers.buffer_out_corr, - ], - descriptor_sets.descriptor_sets[1], - ) + vec![ + buffers.buffer_img, + buffer_internal_img1, + buffer_internal_img2, + buffers.buffer_internal_int, + buffer_out, + buffers.buffer_out_corr, + ] }; - let buffer_infos = create_buffer_infos(buffer_list.as_slice()); + let buffer_infos = buffer_list + .iter() + .map(|buf| { + vk::DescriptorBufferInfo::default() + .buffer(buf.buffer) + .offset(0) + .range(vk::WHOLE_SIZE) + }) + .collect::>(); let write_descriptor = vk::WriteDescriptorSet::default() - .dst_set(descriptor_set) + .dst_set(descriptor_sets.descriptor_sets[0]) .dst_binding(0) .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) .buffer_info(buffer_infos.as_slice()); @@ -1122,8 +1102,7 @@ impl DescriptorSets { unsafe fn destroy(&self, device: &ash::Device) { let _ = device.free_descriptor_sets(self.descriptor_pool, self.descriptor_sets.as_slice()); device.destroy_pipeline_layout(self.pipeline_layout, None); - device.destroy_descriptor_set_layout(self.cross_check_layout, None); - device.destroy_descriptor_set_layout(self.regular_layout, None); + device.destroy_descriptor_set_layout(self.layout, None); device.destroy_descriptor_pool(self.descriptor_pool, None); } }