290 lines
8.9 KiB
Rust
290 lines
8.9 KiB
Rust
|
use crate::prelude::*;
|
||
|
|
||
|
use anyhow::Result;
|
||
|
|
||
|
use std::sync::Arc;
|
||
|
|
||
|
struct ShaderBindingTableEntry {
|
||
|
group_index: u32,
|
||
|
inline_data: Vec<u8>,
|
||
|
}
|
||
|
pub(crate) struct ShaderBindingTableBuilder {
|
||
|
ray_gen_entries: Vec<ShaderBindingTableEntry>,
|
||
|
miss_entries: Vec<ShaderBindingTableEntry>,
|
||
|
hit_group_entries: Vec<ShaderBindingTableEntry>,
|
||
|
}
|
||
|
|
||
|
pub struct ShaderBindingTable {
|
||
|
_sbt_buffer: Arc<Buffer<u8>>,
|
||
|
|
||
|
raygen_shader_binding_table: VkStridedDeviceAddressRegionKHR,
|
||
|
miss_shader_binding_table: VkStridedDeviceAddressRegionKHR,
|
||
|
hit_shader_binding_table: VkStridedDeviceAddressRegionKHR,
|
||
|
callable_shader_binding_table: VkStridedDeviceAddressRegionKHR,
|
||
|
}
|
||
|
|
||
|
impl ShaderBindingTable {
|
||
|
pub fn raygen_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR {
|
||
|
&self.raygen_shader_binding_table
|
||
|
}
|
||
|
|
||
|
pub fn miss_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR {
|
||
|
&self.miss_shader_binding_table
|
||
|
}
|
||
|
|
||
|
pub fn hit_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR {
|
||
|
&self.hit_shader_binding_table
|
||
|
}
|
||
|
|
||
|
pub fn callable_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR {
|
||
|
&self.callable_shader_binding_table
|
||
|
}
|
||
|
|
||
|
fn create(
|
||
|
sbt_buffer: Arc<Buffer<u8>>,
|
||
|
ray_gen_entry_size: VkDeviceSize,
|
||
|
ray_gen_entry_count: VkDeviceSize,
|
||
|
miss_offset: VkDeviceSize,
|
||
|
miss_entry_size: VkDeviceSize,
|
||
|
miss_entry_count: VkDeviceSize,
|
||
|
hit_group_offset: VkDeviceSize,
|
||
|
hit_group_entry_size: VkDeviceSize,
|
||
|
hit_group_entry_count: VkDeviceSize,
|
||
|
) -> Self {
|
||
|
let device_address: VkDeviceAddress = sbt_buffer.device_address().into();
|
||
|
|
||
|
ShaderBindingTable {
|
||
|
raygen_shader_binding_table: VkStridedDeviceAddressRegionKHR {
|
||
|
deviceAddress: device_address,
|
||
|
stride: ray_gen_entry_size,
|
||
|
size: ray_gen_entry_size * ray_gen_entry_count,
|
||
|
},
|
||
|
|
||
|
miss_shader_binding_table: VkStridedDeviceAddressRegionKHR {
|
||
|
deviceAddress: device_address + miss_offset,
|
||
|
stride: miss_entry_size,
|
||
|
size: miss_entry_size * miss_entry_count,
|
||
|
},
|
||
|
|
||
|
hit_shader_binding_table: VkStridedDeviceAddressRegionKHR {
|
||
|
deviceAddress: device_address + hit_group_offset,
|
||
|
stride: hit_group_entry_size,
|
||
|
size: hit_group_entry_size * hit_group_entry_count,
|
||
|
},
|
||
|
|
||
|
callable_shader_binding_table: VkStridedDeviceAddressRegionKHR {
|
||
|
deviceAddress: 0,
|
||
|
stride: 0,
|
||
|
size: 0,
|
||
|
},
|
||
|
|
||
|
_sbt_buffer: sbt_buffer,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl ShaderBindingTableBuilder {
|
||
|
pub(crate) fn new() -> ShaderBindingTableBuilder {
|
||
|
ShaderBindingTableBuilder {
|
||
|
ray_gen_entries: Vec::new(),
|
||
|
miss_entries: Vec::new(),
|
||
|
hit_group_entries: Vec::new(),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub(crate) fn add_ray_gen_program(mut self, group_index: u32, data: Option<Vec<u8>>) -> Self {
|
||
|
self.ray_gen_entries.push(ShaderBindingTableEntry {
|
||
|
group_index,
|
||
|
inline_data: match data {
|
||
|
Some(data) => data,
|
||
|
None => Vec::new(),
|
||
|
},
|
||
|
});
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub(crate) fn add_miss_program(mut self, group_index: u32, data: Option<Vec<u8>>) -> Self {
|
||
|
self.miss_entries.push(ShaderBindingTableEntry {
|
||
|
group_index,
|
||
|
inline_data: match data {
|
||
|
Some(data) => data,
|
||
|
None => Vec::new(),
|
||
|
},
|
||
|
});
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub(crate) fn add_hit_group_program(mut self, group_index: u32, data: Option<Vec<u8>>) -> Self {
|
||
|
self.hit_group_entries.push(ShaderBindingTableEntry {
|
||
|
group_index,
|
||
|
inline_data: match data {
|
||
|
Some(data) => data,
|
||
|
None => Vec::new(),
|
||
|
},
|
||
|
});
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub(crate) fn build(
|
||
|
&mut self,
|
||
|
device: &Arc<Device>,
|
||
|
pipeline: &Arc<Pipeline>,
|
||
|
) -> Result<ShaderBindingTable> {
|
||
|
let ray_tracing_properties = device.physical_device().ray_tracing_properties();
|
||
|
|
||
|
let prog_id_size = ray_tracing_properties.shaderGroupHandleSize;
|
||
|
let base_alignment = ray_tracing_properties.shaderGroupBaseAlignment;
|
||
|
|
||
|
let ray_gen_entry_size =
|
||
|
Self::entry_size(prog_id_size, &self.ray_gen_entries, prog_id_size as u64);
|
||
|
let miss_entry_size =
|
||
|
Self::entry_size(prog_id_size, &self.miss_entries, prog_id_size as u64);
|
||
|
let hit_group_entry_size =
|
||
|
Self::entry_size(prog_id_size, &self.hit_group_entries, prog_id_size as u64);
|
||
|
|
||
|
let sbt_size = (ray_gen_entry_size * self.ray_gen_entries.len() as VkDeviceSize)
|
||
|
.max(base_alignment as VkDeviceSize)
|
||
|
+ (miss_entry_size * self.miss_entries.len() as VkDeviceSize)
|
||
|
.max(base_alignment as VkDeviceSize)
|
||
|
+ hit_group_entry_size * self.hit_group_entries.len() as VkDeviceSize;
|
||
|
|
||
|
let group_count =
|
||
|
self.ray_gen_entries.len() + self.miss_entries.len() + self.hit_group_entries.len();
|
||
|
|
||
|
let shader_handle_storage =
|
||
|
pipeline.ray_tracing_shader_group_handles(group_count as u32, prog_id_size)?;
|
||
|
|
||
|
let mut sbt_data = vec![0; sbt_size as usize];
|
||
|
let mut offset = 0;
|
||
|
|
||
|
Self::copy_shader_data(
|
||
|
&mut sbt_data,
|
||
|
prog_id_size,
|
||
|
&mut offset,
|
||
|
&self.ray_gen_entries,
|
||
|
ray_gen_entry_size,
|
||
|
base_alignment,
|
||
|
&shader_handle_storage,
|
||
|
);
|
||
|
|
||
|
let miss_offset = offset;
|
||
|
|
||
|
Self::copy_shader_data(
|
||
|
&mut sbt_data,
|
||
|
prog_id_size,
|
||
|
&mut offset,
|
||
|
&self.miss_entries,
|
||
|
miss_entry_size,
|
||
|
base_alignment,
|
||
|
&shader_handle_storage,
|
||
|
);
|
||
|
|
||
|
let hit_group_offset = offset;
|
||
|
|
||
|
Self::copy_shader_data(
|
||
|
&mut sbt_data,
|
||
|
prog_id_size,
|
||
|
&mut offset,
|
||
|
&self.hit_group_entries,
|
||
|
hit_group_entry_size,
|
||
|
base_alignment,
|
||
|
&shader_handle_storage,
|
||
|
);
|
||
|
|
||
|
let sbt_buffer = Buffer::builder()
|
||
|
.set_usage(
|
||
|
VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR
|
||
|
| VK_BUFFER_USAGE_TRANSFER_SRC_BIT
|
||
|
| VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
|
||
|
)
|
||
|
.set_memory_usage(MemoryUsage::CpuToGpu)
|
||
|
.set_data(&sbt_data)
|
||
|
.build(device.clone())?;
|
||
|
|
||
|
Ok(ShaderBindingTable::create(
|
||
|
sbt_buffer,
|
||
|
ray_gen_entry_size,
|
||
|
self.ray_gen_entries.len() as VkDeviceSize,
|
||
|
miss_offset,
|
||
|
miss_entry_size,
|
||
|
self.miss_entries.len() as VkDeviceSize,
|
||
|
hit_group_offset,
|
||
|
hit_group_entry_size,
|
||
|
self.hit_group_entries.len() as VkDeviceSize,
|
||
|
))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl ShaderBindingTableBuilder {
|
||
|
#[inline]
|
||
|
fn entry_size(
|
||
|
prog_id_size: u32,
|
||
|
entries: &[ShaderBindingTableEntry],
|
||
|
padding: u64,
|
||
|
) -> VkDeviceSize {
|
||
|
let mut max_args = 0;
|
||
|
|
||
|
for entry in entries {
|
||
|
max_args = max_args.max(entry.inline_data.len());
|
||
|
}
|
||
|
|
||
|
let mut entry_size = prog_id_size as VkDeviceSize + max_args as VkDeviceSize;
|
||
|
|
||
|
entry_size = Self::round_up(entry_size, padding);
|
||
|
|
||
|
entry_size
|
||
|
}
|
||
|
|
||
|
#[inline]
|
||
|
fn round_up(source: u64, value: u64) -> u64 {
|
||
|
((source) + (value) - 1) & !((value) - 1)
|
||
|
}
|
||
|
|
||
|
#[inline]
|
||
|
fn copy_shader_data(
|
||
|
sbt_data: &mut Vec<u8>,
|
||
|
prog_id_size: u32,
|
||
|
offset: &mut VkDeviceSize,
|
||
|
entries: &[ShaderBindingTableEntry],
|
||
|
_entry_size: VkDeviceSize,
|
||
|
base_alignment: u32,
|
||
|
shader_handle_storage: &[u8],
|
||
|
) {
|
||
|
for entry in entries {
|
||
|
// copy the shader identifier
|
||
|
{
|
||
|
let sbt_start = *offset as usize;
|
||
|
let sbt_end = sbt_start + prog_id_size as usize;
|
||
|
|
||
|
let shs_start = (entry.group_index * prog_id_size) as usize;
|
||
|
let shs_end = shs_start + prog_id_size as usize;
|
||
|
|
||
|
sbt_data[sbt_start..sbt_end]
|
||
|
.copy_from_slice(&shader_handle_storage[shs_start..shs_end]);
|
||
|
}
|
||
|
|
||
|
// copy data if present
|
||
|
if !entry.inline_data.is_empty() {
|
||
|
let tmp_offset = *offset + prog_id_size as VkDeviceSize;
|
||
|
|
||
|
let sbt_start = tmp_offset as usize;
|
||
|
let sbt_end = sbt_start + entry.inline_data.len();
|
||
|
|
||
|
sbt_data[sbt_start..sbt_end].copy_from_slice(&entry.inline_data);
|
||
|
}
|
||
|
|
||
|
*offset += prog_id_size as VkDeviceSize;
|
||
|
}
|
||
|
|
||
|
// increase offset with correct alignment
|
||
|
let modulo = *offset % base_alignment as VkDeviceSize;
|
||
|
|
||
|
if modulo != 0 {
|
||
|
*offset += base_alignment as VkDeviceSize - modulo;
|
||
|
}
|
||
|
}
|
||
|
}
|