use crate::prelude::*; use anyhow::Result; use std::sync::Arc; struct ShaderBindingTableEntry { group_index: u32, inline_data: Vec, } pub(crate) struct ShaderBindingTableBuilder { ray_gen_entries: Vec, miss_entries: Vec, hit_group_entries: Vec, } pub struct ShaderBindingTable { _sbt_buffer: Arc>, raygen_shader_binding_table: VkStridedDeviceAddressRegionKHR, miss_shader_binding_table: VkStridedDeviceAddressRegionKHR, hit_shader_binding_table: VkStridedDeviceAddressRegionKHR, callable_shader_binding_table: VkStridedDeviceAddressRegionKHR, } impl ShaderBindingTable { pub fn raygen_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR { &self.raygen_shader_binding_table } pub fn miss_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR { &self.miss_shader_binding_table } pub fn hit_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR { &self.hit_shader_binding_table } pub fn callable_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR { &self.callable_shader_binding_table } fn create( sbt_buffer: Arc>, ray_gen_entry_size: VkDeviceSize, ray_gen_entry_count: VkDeviceSize, miss_offset: VkDeviceSize, miss_entry_size: VkDeviceSize, miss_entry_count: VkDeviceSize, hit_group_offset: VkDeviceSize, hit_group_entry_size: VkDeviceSize, hit_group_entry_count: VkDeviceSize, ) -> Self { let device_address: VkDeviceAddress = sbt_buffer.device_address().into(); ShaderBindingTable { raygen_shader_binding_table: VkStridedDeviceAddressRegionKHR { deviceAddress: device_address, stride: ray_gen_entry_size, size: ray_gen_entry_size * ray_gen_entry_count, }, miss_shader_binding_table: VkStridedDeviceAddressRegionKHR { deviceAddress: device_address + miss_offset, stride: miss_entry_size, size: miss_entry_size * miss_entry_count, }, hit_shader_binding_table: VkStridedDeviceAddressRegionKHR { deviceAddress: device_address + hit_group_offset, stride: hit_group_entry_size, size: hit_group_entry_size * hit_group_entry_count, }, callable_shader_binding_table: VkStridedDeviceAddressRegionKHR { deviceAddress: 0, stride: 0, size: 0, }, _sbt_buffer: sbt_buffer, } } } impl ShaderBindingTableBuilder { pub(crate) fn new() -> ShaderBindingTableBuilder { ShaderBindingTableBuilder { ray_gen_entries: Vec::new(), miss_entries: Vec::new(), hit_group_entries: Vec::new(), } } pub(crate) fn add_ray_gen_program(mut self, group_index: u32, data: Option>) -> Self { self.ray_gen_entries.push(ShaderBindingTableEntry { group_index, inline_data: match data { Some(data) => data, None => Vec::new(), }, }); self } pub(crate) fn add_miss_program(mut self, group_index: u32, data: Option>) -> Self { self.miss_entries.push(ShaderBindingTableEntry { group_index, inline_data: match data { Some(data) => data, None => Vec::new(), }, }); self } pub(crate) fn add_hit_group_program(mut self, group_index: u32, data: Option>) -> Self { self.hit_group_entries.push(ShaderBindingTableEntry { group_index, inline_data: match data { Some(data) => data, None => Vec::new(), }, }); self } pub(crate) fn build( &mut self, device: &Arc, pipeline: &Arc, ) -> Result { let ray_tracing_properties = device.physical_device().ray_tracing_properties(); let prog_id_size = ray_tracing_properties.shaderGroupHandleSize; let base_alignment = ray_tracing_properties.shaderGroupBaseAlignment; let ray_gen_entry_size = Self::entry_size(prog_id_size, &self.ray_gen_entries, prog_id_size as u64); let miss_entry_size = Self::entry_size(prog_id_size, &self.miss_entries, prog_id_size as u64); let hit_group_entry_size = Self::entry_size(prog_id_size, &self.hit_group_entries, prog_id_size as u64); let sbt_size = (ray_gen_entry_size * self.ray_gen_entries.len() as VkDeviceSize) .max(base_alignment as VkDeviceSize) + (miss_entry_size * self.miss_entries.len() as VkDeviceSize) .max(base_alignment as VkDeviceSize) + hit_group_entry_size * self.hit_group_entries.len() as VkDeviceSize; let group_count = self.ray_gen_entries.len() + self.miss_entries.len() + self.hit_group_entries.len(); let shader_handle_storage = pipeline.ray_tracing_shader_group_handles(group_count as u32, prog_id_size)?; let mut sbt_data = vec![0; sbt_size as usize]; let mut offset = 0; Self::copy_shader_data( &mut sbt_data, prog_id_size, &mut offset, &self.ray_gen_entries, ray_gen_entry_size, base_alignment, &shader_handle_storage, ); let miss_offset = offset; Self::copy_shader_data( &mut sbt_data, prog_id_size, &mut offset, &self.miss_entries, miss_entry_size, base_alignment, &shader_handle_storage, ); let hit_group_offset = offset; Self::copy_shader_data( &mut sbt_data, prog_id_size, &mut offset, &self.hit_group_entries, hit_group_entry_size, base_alignment, &shader_handle_storage, ); let sbt_buffer = Buffer::builder() .set_usage( VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR | VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, ) .set_memory_usage(MemoryUsage::CpuToGpu) .set_data(&sbt_data) .build(device.clone())?; Ok(ShaderBindingTable::create( sbt_buffer, ray_gen_entry_size, self.ray_gen_entries.len() as VkDeviceSize, miss_offset, miss_entry_size, self.miss_entries.len() as VkDeviceSize, hit_group_offset, hit_group_entry_size, self.hit_group_entries.len() as VkDeviceSize, )) } } impl ShaderBindingTableBuilder { #[inline] fn entry_size( prog_id_size: u32, entries: &[ShaderBindingTableEntry], padding: u64, ) -> VkDeviceSize { let mut max_args = 0; for entry in entries { max_args = max_args.max(entry.inline_data.len()); } let mut entry_size = prog_id_size as VkDeviceSize + max_args as VkDeviceSize; entry_size = Self::round_up(entry_size, padding); entry_size } #[inline] fn round_up(source: u64, value: u64) -> u64 { ((source) + (value) - 1) & !((value) - 1) } #[inline] fn copy_shader_data( sbt_data: &mut Vec, prog_id_size: u32, offset: &mut VkDeviceSize, entries: &[ShaderBindingTableEntry], _entry_size: VkDeviceSize, base_alignment: u32, shader_handle_storage: &[u8], ) { for entry in entries { // copy the shader identifier { let sbt_start = *offset as usize; let sbt_end = sbt_start + prog_id_size as usize; let shs_start = (entry.group_index * prog_id_size) as usize; let shs_end = shs_start + prog_id_size as usize; sbt_data[sbt_start..sbt_end] .copy_from_slice(&shader_handle_storage[shs_start..shs_end]); } // copy data if present if !entry.inline_data.is_empty() { let tmp_offset = *offset + prog_id_size as VkDeviceSize; let sbt_start = tmp_offset as usize; let sbt_end = sbt_start + entry.inline_data.len(); sbt_data[sbt_start..sbt_end].copy_from_slice(&entry.inline_data); } *offset += prog_id_size as VkDeviceSize; } // increase offset with correct alignment let modulo = *offset % base_alignment as VkDeviceSize; if modulo != 0 { *offset += base_alignment as VkDeviceSize - modulo; } } }