vulkan_lib/vulkan-rs/src/pipelines/shader_binding_table.rs
2023-01-14 13:03:01 +01:00

289 lines
8.9 KiB
Rust

use crate::prelude::*;
use anyhow::Result;
use std::sync::Arc;
struct ShaderBindingTableEntry {
group_index: u32,
inline_data: Vec<u8>,
}
pub(crate) struct ShaderBindingTableBuilder {
ray_gen_entries: Vec<ShaderBindingTableEntry>,
miss_entries: Vec<ShaderBindingTableEntry>,
hit_group_entries: Vec<ShaderBindingTableEntry>,
}
pub struct ShaderBindingTable {
_sbt_buffer: Arc<Buffer<u8>>,
raygen_shader_binding_table: VkStridedDeviceAddressRegionKHR,
miss_shader_binding_table: VkStridedDeviceAddressRegionKHR,
hit_shader_binding_table: VkStridedDeviceAddressRegionKHR,
callable_shader_binding_table: VkStridedDeviceAddressRegionKHR,
}
impl ShaderBindingTable {
pub fn raygen_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR {
&self.raygen_shader_binding_table
}
pub fn miss_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR {
&self.miss_shader_binding_table
}
pub fn hit_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR {
&self.hit_shader_binding_table
}
pub fn callable_shader_binding_table(&self) -> &VkStridedDeviceAddressRegionKHR {
&self.callable_shader_binding_table
}
fn create(
sbt_buffer: Arc<Buffer<u8>>,
ray_gen_entry_size: VkDeviceSize,
ray_gen_entry_count: VkDeviceSize,
miss_offset: VkDeviceSize,
miss_entry_size: VkDeviceSize,
miss_entry_count: VkDeviceSize,
hit_group_offset: VkDeviceSize,
hit_group_entry_size: VkDeviceSize,
hit_group_entry_count: VkDeviceSize,
) -> Self {
let device_address: VkDeviceAddress = sbt_buffer.device_address().into();
ShaderBindingTable {
raygen_shader_binding_table: VkStridedDeviceAddressRegionKHR {
deviceAddress: device_address,
stride: ray_gen_entry_size,
size: ray_gen_entry_size * ray_gen_entry_count,
},
miss_shader_binding_table: VkStridedDeviceAddressRegionKHR {
deviceAddress: device_address + miss_offset,
stride: miss_entry_size,
size: miss_entry_size * miss_entry_count,
},
hit_shader_binding_table: VkStridedDeviceAddressRegionKHR {
deviceAddress: device_address + hit_group_offset,
stride: hit_group_entry_size,
size: hit_group_entry_size * hit_group_entry_count,
},
callable_shader_binding_table: VkStridedDeviceAddressRegionKHR {
deviceAddress: 0,
stride: 0,
size: 0,
},
_sbt_buffer: sbt_buffer,
}
}
}
impl ShaderBindingTableBuilder {
pub(crate) fn new() -> ShaderBindingTableBuilder {
ShaderBindingTableBuilder {
ray_gen_entries: Vec::new(),
miss_entries: Vec::new(),
hit_group_entries: Vec::new(),
}
}
pub(crate) fn add_ray_gen_program(mut self, group_index: u32, data: Option<Vec<u8>>) -> Self {
self.ray_gen_entries.push(ShaderBindingTableEntry {
group_index,
inline_data: match data {
Some(data) => data,
None => Vec::new(),
},
});
self
}
pub(crate) fn add_miss_program(mut self, group_index: u32, data: Option<Vec<u8>>) -> Self {
self.miss_entries.push(ShaderBindingTableEntry {
group_index,
inline_data: match data {
Some(data) => data,
None => Vec::new(),
},
});
self
}
pub(crate) fn add_hit_group_program(mut self, group_index: u32, data: Option<Vec<u8>>) -> Self {
self.hit_group_entries.push(ShaderBindingTableEntry {
group_index,
inline_data: match data {
Some(data) => data,
None => Vec::new(),
},
});
self
}
pub(crate) fn build(
&mut self,
device: &Arc<Device>,
pipeline: &Arc<Pipeline>,
) -> Result<ShaderBindingTable> {
let ray_tracing_properties = device.physical_device().ray_tracing_properties();
let prog_id_size = ray_tracing_properties.shaderGroupHandleSize;
let base_alignment = ray_tracing_properties.shaderGroupBaseAlignment;
let ray_gen_entry_size =
Self::entry_size(prog_id_size, &self.ray_gen_entries, prog_id_size as u64);
let miss_entry_size =
Self::entry_size(prog_id_size, &self.miss_entries, prog_id_size as u64);
let hit_group_entry_size =
Self::entry_size(prog_id_size, &self.hit_group_entries, prog_id_size as u64);
let sbt_size = (ray_gen_entry_size * self.ray_gen_entries.len() as VkDeviceSize)
.max(base_alignment as VkDeviceSize)
+ (miss_entry_size * self.miss_entries.len() as VkDeviceSize)
.max(base_alignment as VkDeviceSize)
+ hit_group_entry_size * self.hit_group_entries.len() as VkDeviceSize;
let group_count =
self.ray_gen_entries.len() + self.miss_entries.len() + self.hit_group_entries.len();
let shader_handle_storage =
pipeline.ray_tracing_shader_group_handles(group_count as u32, prog_id_size)?;
let mut sbt_data = vec![0; sbt_size as usize];
let mut offset = 0;
Self::copy_shader_data(
&mut sbt_data,
prog_id_size,
&mut offset,
&self.ray_gen_entries,
ray_gen_entry_size,
base_alignment,
&shader_handle_storage,
);
let miss_offset = offset;
Self::copy_shader_data(
&mut sbt_data,
prog_id_size,
&mut offset,
&self.miss_entries,
miss_entry_size,
base_alignment,
&shader_handle_storage,
);
let hit_group_offset = offset;
Self::copy_shader_data(
&mut sbt_data,
prog_id_size,
&mut offset,
&self.hit_group_entries,
hit_group_entry_size,
base_alignment,
&shader_handle_storage,
);
let sbt_buffer = Buffer::builder()
.set_usage(
VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR
| VK_BUFFER_USAGE_TRANSFER_SRC_BIT
| VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
)
.set_memory_usage(MemoryUsage::CpuToGpu)
.set_data(&sbt_data)
.build(device.clone())?;
Ok(ShaderBindingTable::create(
sbt_buffer,
ray_gen_entry_size,
self.ray_gen_entries.len() as VkDeviceSize,
miss_offset,
miss_entry_size,
self.miss_entries.len() as VkDeviceSize,
hit_group_offset,
hit_group_entry_size,
self.hit_group_entries.len() as VkDeviceSize,
))
}
}
impl ShaderBindingTableBuilder {
#[inline]
fn entry_size(
prog_id_size: u32,
entries: &[ShaderBindingTableEntry],
padding: u64,
) -> VkDeviceSize {
let mut max_args = 0;
for entry in entries {
max_args = max_args.max(entry.inline_data.len());
}
let mut entry_size = prog_id_size as VkDeviceSize + max_args as VkDeviceSize;
entry_size = Self::round_up(entry_size, padding);
entry_size
}
#[inline]
fn round_up(source: u64, value: u64) -> u64 {
((source) + (value) - 1) & !((value) - 1)
}
#[inline]
fn copy_shader_data(
sbt_data: &mut Vec<u8>,
prog_id_size: u32,
offset: &mut VkDeviceSize,
entries: &[ShaderBindingTableEntry],
_entry_size: VkDeviceSize,
base_alignment: u32,
shader_handle_storage: &[u8],
) {
for entry in entries {
// copy the shader identifier
{
let sbt_start = *offset as usize;
let sbt_end = sbt_start + prog_id_size as usize;
let shs_start = (entry.group_index * prog_id_size) as usize;
let shs_end = shs_start + prog_id_size as usize;
sbt_data[sbt_start..sbt_end]
.copy_from_slice(&shader_handle_storage[shs_start..shs_end]);
}
// copy data if present
if !entry.inline_data.is_empty() {
let tmp_offset = *offset + prog_id_size as VkDeviceSize;
let sbt_start = tmp_offset as usize;
let sbt_end = sbt_start + entry.inline_data.len();
sbt_data[sbt_start..sbt_end].copy_from_slice(&entry.inline_data);
}
*offset += prog_id_size as VkDeviceSize;
}
// increase offset with correct alignment
let modulo = *offset % base_alignment as VkDeviceSize;
if modulo != 0 {
*offset += base_alignment as VkDeviceSize - modulo;
}
}
}