271 lines
8.6 KiB
Rust
271 lines
8.6 KiB
Rust
|
use crate::pipeline::PipelineType;
|
||
|
use crate::prelude::*;
|
||
|
|
||
|
use anyhow::Result;
|
||
|
|
||
|
use std::sync::Arc;
|
||
|
|
||
|
use super::shader_binding_table::ShaderBindingTableBuilder;
|
||
|
|
||
|
pub struct Library<'a> {
|
||
|
pipeline: &'a Arc<Pipeline>,
|
||
|
|
||
|
max_payload_size: u32,
|
||
|
max_attribute_size: u32,
|
||
|
}
|
||
|
|
||
|
impl<'a> Library<'a> {
|
||
|
pub fn new(
|
||
|
pipeline: &'a Arc<Pipeline>,
|
||
|
max_payload_size: u32,
|
||
|
max_attribute_size: u32,
|
||
|
) -> Self {
|
||
|
Library {
|
||
|
pipeline,
|
||
|
|
||
|
max_payload_size,
|
||
|
max_attribute_size,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub struct RayTracingPipelineBuilder<'a> {
|
||
|
shader_modules: Vec<(Arc<ShaderModule>, Option<SpecializationConstants>)>,
|
||
|
|
||
|
shader_groups: Vec<VkRayTracingShaderGroupCreateInfoKHR>,
|
||
|
|
||
|
libraries: Vec<Library<'a>>,
|
||
|
|
||
|
dynamic_states: Vec<VkDynamicState>,
|
||
|
|
||
|
flags: VkPipelineCreateFlagBits,
|
||
|
max_recursion: u32,
|
||
|
|
||
|
shader_binding_table_builder: ShaderBindingTableBuilder,
|
||
|
|
||
|
pipeline_cache: Option<&'a Arc<PipelineCache>>,
|
||
|
}
|
||
|
|
||
|
impl<'a> RayTracingPipelineBuilder<'a> {
|
||
|
pub fn check_max_recursion(device: &Arc<Device>, max_recursion: u32) -> u32 {
|
||
|
max_recursion.min(
|
||
|
device
|
||
|
.physical_device()
|
||
|
.ray_tracing_properties()
|
||
|
.maxRayRecursionDepth,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
pub fn add_dynamic_state(mut self, dynamic_state: VkDynamicState) -> Self {
|
||
|
self.dynamic_states.push(dynamic_state);
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub fn set_pipeline_cache(mut self, pipeline_cache: &'a Arc<PipelineCache>) -> Self {
|
||
|
self.pipeline_cache = Some(pipeline_cache);
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub fn set_flags(mut self, flags: impl Into<VkPipelineCreateFlagBits>) -> Self {
|
||
|
self.flags = flags.into();
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub fn add_library(mut self, library: Library<'a>) -> Self {
|
||
|
self.libraries.push(library);
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub fn add_shader(
|
||
|
mut self,
|
||
|
shader_module: Arc<ShaderModule>,
|
||
|
data: Option<Vec<u8>>,
|
||
|
specialization_constants: Option<SpecializationConstants>,
|
||
|
) -> Self {
|
||
|
self.shader_binding_table_builder = match shader_module.shader_type() {
|
||
|
ShaderType::RayGeneration => self
|
||
|
.shader_binding_table_builder
|
||
|
.add_ray_gen_program(self.shader_groups.len() as u32, data),
|
||
|
ShaderType::Miss => self
|
||
|
.shader_binding_table_builder
|
||
|
.add_miss_program(self.shader_groups.len() as u32, data),
|
||
|
_ => panic!(
|
||
|
"unsupported shader type: {:?}, expected RayGen or Miss Shader",
|
||
|
shader_module.shader_type()
|
||
|
),
|
||
|
};
|
||
|
|
||
|
let shader_index = self.shader_modules.len();
|
||
|
self.shader_modules
|
||
|
.push((shader_module, specialization_constants));
|
||
|
|
||
|
self.shader_groups
|
||
|
.push(VkRayTracingShaderGroupCreateInfoKHR::new(
|
||
|
VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
|
||
|
shader_index as u32,
|
||
|
VK_SHADER_UNUSED_KHR,
|
||
|
VK_SHADER_UNUSED_KHR,
|
||
|
VK_SHADER_UNUSED_KHR,
|
||
|
));
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub fn add_hit_shaders(
|
||
|
mut self,
|
||
|
shader_modules: impl IntoIterator<Item = (Arc<ShaderModule>, Option<SpecializationConstants>)>,
|
||
|
data: Option<Vec<u8>>,
|
||
|
) -> Self {
|
||
|
let mut group = VkRayTracingShaderGroupCreateInfoKHR::new(
|
||
|
VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
|
||
|
VK_SHADER_UNUSED_KHR,
|
||
|
VK_SHADER_UNUSED_KHR,
|
||
|
VK_SHADER_UNUSED_KHR,
|
||
|
VK_SHADER_UNUSED_KHR,
|
||
|
);
|
||
|
|
||
|
for (shader_module, specialization_constant) in shader_modules.into_iter() {
|
||
|
let shader_index = self.shader_modules.len() as u32;
|
||
|
|
||
|
match shader_module.shader_type() {
|
||
|
ShaderType::AnyHit => {
|
||
|
// sanity check
|
||
|
if cfg!(debug_assertions) && group.anyHitShader != VK_SHADER_UNUSED_KHR {
|
||
|
panic!("any hit shader already used in current hit group");
|
||
|
}
|
||
|
|
||
|
group.anyHitShader = shader_index;
|
||
|
}
|
||
|
ShaderType::ClosestHit => {
|
||
|
// sanity check
|
||
|
if cfg!(debug_assertions) && group.closestHitShader != VK_SHADER_UNUSED_KHR {
|
||
|
panic!("closest hit shader already used in current hit group");
|
||
|
}
|
||
|
|
||
|
group.closestHitShader = shader_index;
|
||
|
}
|
||
|
ShaderType::Intersection => {
|
||
|
// sanity check
|
||
|
if cfg!(debug_assertions) && group.intersectionShader != VK_SHADER_UNUSED_KHR {
|
||
|
panic!("intersection shader already used in current hit group");
|
||
|
}
|
||
|
|
||
|
group.intersectionShader = shader_index;
|
||
|
group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
|
||
|
}
|
||
|
_ => panic!("unsupported shader type: {:?}, expected AnyHit, ClosestHit or Intersection Shader", shader_module.shader_type()),
|
||
|
}
|
||
|
|
||
|
self.shader_modules
|
||
|
.push((shader_module, specialization_constant));
|
||
|
}
|
||
|
self.shader_binding_table_builder = self
|
||
|
.shader_binding_table_builder
|
||
|
.add_hit_group_program(self.shader_groups.len() as u32, data);
|
||
|
self.shader_groups.push(group);
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub fn max_recursion_depth(mut self, max_recursion_depth: u32) -> Self {
|
||
|
self.max_recursion = max_recursion_depth;
|
||
|
|
||
|
self
|
||
|
}
|
||
|
|
||
|
pub fn build(
|
||
|
mut self,
|
||
|
device: Arc<Device>,
|
||
|
pipeline_layout: &Arc<PipelineLayout>,
|
||
|
) -> Result<(Arc<Pipeline>, ShaderBindingTable)> {
|
||
|
let shader_stages: Vec<VkPipelineShaderStageCreateInfo> = self
|
||
|
.shader_modules
|
||
|
.iter()
|
||
|
.map(|(shader, specialization_constant)| {
|
||
|
let mut stage_info = shader.pipeline_stage_info();
|
||
|
if let Some(specialization_constant) = specialization_constant {
|
||
|
stage_info.set_specialization_info(specialization_constant.vk_handle());
|
||
|
}
|
||
|
|
||
|
stage_info
|
||
|
})
|
||
|
.collect();
|
||
|
|
||
|
// check that we dont exceed the gpu's capabilities
|
||
|
let max_recursion = Self::check_max_recursion(&device, self.max_recursion);
|
||
|
|
||
|
let pipeline = {
|
||
|
let mut libraries = Vec::with_capacity(self.libraries.len());
|
||
|
let mut library_interface = VkRayTracingPipelineInterfaceCreateInfoKHR::new(0, 0);
|
||
|
|
||
|
for library in self.libraries.iter() {
|
||
|
libraries.push(library.pipeline.vk_handle());
|
||
|
|
||
|
library_interface.maxPipelineRayPayloadSize = library_interface
|
||
|
.maxPipelineRayPayloadSize
|
||
|
.max(library.max_payload_size);
|
||
|
|
||
|
library_interface.maxPipelineRayHitAttributeSize = library_interface
|
||
|
.maxPipelineRayHitAttributeSize
|
||
|
.max(library.max_attribute_size);
|
||
|
}
|
||
|
|
||
|
let lib_create_info = VkPipelineLibraryCreateInfoKHR::new(&libraries);
|
||
|
|
||
|
let dynamic_states = VkPipelineDynamicStateCreateInfo::new(0, &self.dynamic_states);
|
||
|
|
||
|
device.create_ray_tracing_pipelines(
|
||
|
None,
|
||
|
self.pipeline_cache.map(|cache| cache.vk_handle()),
|
||
|
&[VkRayTracingPipelineCreateInfoKHR::new(
|
||
|
self.flags,
|
||
|
&shader_stages, // stages
|
||
|
&self.shader_groups, // groups
|
||
|
max_recursion,
|
||
|
&lib_create_info, // libraries
|
||
|
&library_interface, // library interfaces
|
||
|
&dynamic_states,
|
||
|
pipeline_layout.vk_handle(),
|
||
|
)],
|
||
|
None,
|
||
|
)?[0]
|
||
|
};
|
||
|
|
||
|
let pipeline = Arc::new(Pipeline::new(
|
||
|
device.clone(),
|
||
|
pipeline_layout.clone(),
|
||
|
PipelineType::RayTracing,
|
||
|
pipeline,
|
||
|
));
|
||
|
|
||
|
let sbt = self
|
||
|
.shader_binding_table_builder
|
||
|
.build(&device, &pipeline)?;
|
||
|
|
||
|
Ok((pipeline, sbt))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl<'a> Default for RayTracingPipelineBuilder<'a> {
|
||
|
fn default() -> Self {
|
||
|
RayTracingPipelineBuilder {
|
||
|
shader_modules: Vec::new(),
|
||
|
shader_groups: Vec::new(),
|
||
|
|
||
|
flags: 0.into(),
|
||
|
max_recursion: 2,
|
||
|
|
||
|
libraries: Vec::new(),
|
||
|
dynamic_states: Vec::new(),
|
||
|
|
||
|
shader_binding_table_builder: ShaderBindingTableBuilder::new(),
|
||
|
|
||
|
pipeline_cache: None,
|
||
|
}
|
||
|
}
|
||
|
}
|