vulkan_lib/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs

271 lines
8.6 KiB
Rust
Raw Normal View History

2023-01-14 12:03:01 +00:00
use crate::pipeline::PipelineType;
use crate::prelude::*;
use anyhow::Result;
use std::sync::Arc;
use super::shader_binding_table::ShaderBindingTableBuilder;
pub struct Library<'a> {
pipeline: &'a Arc<Pipeline>,
max_payload_size: u32,
max_attribute_size: u32,
}
impl<'a> Library<'a> {
pub fn new(
pipeline: &'a Arc<Pipeline>,
max_payload_size: u32,
max_attribute_size: u32,
) -> Self {
Library {
pipeline,
max_payload_size,
max_attribute_size,
}
}
}
pub struct RayTracingPipelineBuilder<'a> {
shader_modules: Vec<(Arc<ShaderModule>, Option<SpecializationConstants>)>,
shader_groups: Vec<VkRayTracingShaderGroupCreateInfoKHR>,
libraries: Vec<Library<'a>>,
dynamic_states: Vec<VkDynamicState>,
flags: VkPipelineCreateFlagBits,
max_recursion: u32,
shader_binding_table_builder: ShaderBindingTableBuilder,
pipeline_cache: Option<&'a Arc<PipelineCache>>,
}
impl<'a> RayTracingPipelineBuilder<'a> {
pub fn check_max_recursion(device: &Arc<Device>, max_recursion: u32) -> u32 {
max_recursion.min(
device
.physical_device()
.ray_tracing_properties()
.maxRayRecursionDepth,
)
}
pub fn add_dynamic_state(mut self, dynamic_state: VkDynamicState) -> Self {
self.dynamic_states.push(dynamic_state);
self
}
pub fn set_pipeline_cache(mut self, pipeline_cache: &'a Arc<PipelineCache>) -> Self {
self.pipeline_cache = Some(pipeline_cache);
self
}
pub fn set_flags(mut self, flags: impl Into<VkPipelineCreateFlagBits>) -> Self {
self.flags = flags.into();
self
}
pub fn add_library(mut self, library: Library<'a>) -> Self {
self.libraries.push(library);
self
}
pub fn add_shader(
mut self,
shader_module: Arc<ShaderModule>,
data: Option<Vec<u8>>,
specialization_constants: Option<SpecializationConstants>,
) -> Self {
self.shader_binding_table_builder = match shader_module.shader_type() {
ShaderType::RayGeneration => self
.shader_binding_table_builder
.add_ray_gen_program(self.shader_groups.len() as u32, data),
ShaderType::Miss => self
.shader_binding_table_builder
.add_miss_program(self.shader_groups.len() as u32, data),
_ => panic!(
"unsupported shader type: {:?}, expected RayGen or Miss Shader",
shader_module.shader_type()
),
};
let shader_index = self.shader_modules.len();
self.shader_modules
.push((shader_module, specialization_constants));
self.shader_groups
.push(VkRayTracingShaderGroupCreateInfoKHR::new(
VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
shader_index as u32,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
));
self
}
pub fn add_hit_shaders(
mut self,
shader_modules: impl IntoIterator<Item = (Arc<ShaderModule>, Option<SpecializationConstants>)>,
data: Option<Vec<u8>>,
) -> Self {
let mut group = VkRayTracingShaderGroupCreateInfoKHR::new(
VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
);
for (shader_module, specialization_constant) in shader_modules.into_iter() {
let shader_index = self.shader_modules.len() as u32;
match shader_module.shader_type() {
ShaderType::AnyHit => {
// sanity check
if cfg!(debug_assertions) && group.anyHitShader != VK_SHADER_UNUSED_KHR {
panic!("any hit shader already used in current hit group");
}
group.anyHitShader = shader_index;
}
ShaderType::ClosestHit => {
// sanity check
if cfg!(debug_assertions) && group.closestHitShader != VK_SHADER_UNUSED_KHR {
panic!("closest hit shader already used in current hit group");
}
group.closestHitShader = shader_index;
}
ShaderType::Intersection => {
// sanity check
if cfg!(debug_assertions) && group.intersectionShader != VK_SHADER_UNUSED_KHR {
panic!("intersection shader already used in current hit group");
}
group.intersectionShader = shader_index;
group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
}
_ => panic!("unsupported shader type: {:?}, expected AnyHit, ClosestHit or Intersection Shader", shader_module.shader_type()),
}
self.shader_modules
.push((shader_module, specialization_constant));
}
self.shader_binding_table_builder = self
.shader_binding_table_builder
.add_hit_group_program(self.shader_groups.len() as u32, data);
self.shader_groups.push(group);
self
}
pub fn max_recursion_depth(mut self, max_recursion_depth: u32) -> Self {
self.max_recursion = max_recursion_depth;
self
}
pub fn build(
mut self,
device: Arc<Device>,
pipeline_layout: &Arc<PipelineLayout>,
) -> Result<(Arc<Pipeline>, ShaderBindingTable)> {
let shader_stages: Vec<VkPipelineShaderStageCreateInfo> = self
.shader_modules
.iter()
.map(|(shader, specialization_constant)| {
let mut stage_info = shader.pipeline_stage_info();
if let Some(specialization_constant) = specialization_constant {
stage_info.set_specialization_info(specialization_constant.vk_handle());
}
stage_info
})
.collect();
// check that we dont exceed the gpu's capabilities
let max_recursion = Self::check_max_recursion(&device, self.max_recursion);
let pipeline = {
let mut libraries = Vec::with_capacity(self.libraries.len());
let mut library_interface = VkRayTracingPipelineInterfaceCreateInfoKHR::new(0, 0);
for library in self.libraries.iter() {
libraries.push(library.pipeline.vk_handle());
library_interface.maxPipelineRayPayloadSize = library_interface
.maxPipelineRayPayloadSize
.max(library.max_payload_size);
library_interface.maxPipelineRayHitAttributeSize = library_interface
.maxPipelineRayHitAttributeSize
.max(library.max_attribute_size);
}
let lib_create_info = VkPipelineLibraryCreateInfoKHR::new(&libraries);
let dynamic_states = VkPipelineDynamicStateCreateInfo::new(0, &self.dynamic_states);
device.create_ray_tracing_pipelines(
None,
self.pipeline_cache.map(|cache| cache.vk_handle()),
&[VkRayTracingPipelineCreateInfoKHR::new(
self.flags,
&shader_stages, // stages
&self.shader_groups, // groups
max_recursion,
&lib_create_info, // libraries
&library_interface, // library interfaces
&dynamic_states,
pipeline_layout.vk_handle(),
)],
None,
)?[0]
};
let pipeline = Arc::new(Pipeline::new(
device.clone(),
pipeline_layout.clone(),
PipelineType::RayTracing,
pipeline,
));
let sbt = self
.shader_binding_table_builder
.build(&device, &pipeline)?;
Ok((pipeline, sbt))
}
}
impl<'a> Default for RayTracingPipelineBuilder<'a> {
fn default() -> Self {
RayTracingPipelineBuilder {
shader_modules: Vec::new(),
shader_groups: Vec::new(),
flags: 0.into(),
max_recursion: 2,
libraries: Vec::new(),
dynamic_states: Vec::new(),
shader_binding_table_builder: ShaderBindingTableBuilder::new(),
pipeline_cache: None,
}
}
}