use crate::pipeline::PipelineType; use crate::prelude::*; use anyhow::Result; use std::sync::Arc; use super::shader_binding_table::ShaderBindingTableBuilder; pub struct Library<'a> { pipeline: &'a Arc, max_payload_size: u32, max_attribute_size: u32, } impl<'a> Library<'a> { pub fn new( pipeline: &'a Arc, max_payload_size: u32, max_attribute_size: u32, ) -> Self { Library { pipeline, max_payload_size, max_attribute_size, } } } macro_rules! impl_from_shader_type { ($struct: ident, $shader_type: ident) => { impl From>> for $struct { fn from(value: Arc>) -> Self { Self::$shader_type(value) } } }; } #[derive(Clone)] enum RaytracingShader { RayGeneration(Arc>), ClosestHit(Arc>), Miss(Arc>), AnyHit(Arc>), Intersection(Arc>), } impl_from_shader_type!(RaytracingShader, RayGeneration); impl_from_shader_type!(RaytracingShader, ClosestHit); impl_from_shader_type!(RaytracingShader, Miss); impl_from_shader_type!(RaytracingShader, AnyHit); impl_from_shader_type!(RaytracingShader, Intersection); impl From for RaytracingShader { fn from(value: HitShader) -> Self { match value { HitShader::ClosestHit(s) => Self::ClosestHit(s), HitShader::AnyHit(s) => Self::AnyHit(s), HitShader::Intersection(s) => Self::Intersection(s), } } } impl From for RaytracingShader { fn from(value: OtherShader) -> Self { match value { OtherShader::Miss(s) => Self::Miss(s), OtherShader::RayGeneration(s) => Self::RayGeneration(s), } } } #[derive(Clone)] pub enum HitShader { ClosestHit(Arc>), AnyHit(Arc>), Intersection(Arc>), } impl_from_shader_type!(HitShader, ClosestHit); impl_from_shader_type!(HitShader, AnyHit); impl_from_shader_type!(HitShader, Intersection); #[derive(Clone)] pub enum OtherShader { RayGeneration(Arc>), Miss(Arc>), } impl_from_shader_type!(OtherShader, RayGeneration); impl_from_shader_type!(OtherShader, Miss); pub struct RayTracingPipelineBuilder<'a> { shader_modules: Vec<(RaytracingShader, Option)>, shader_groups: Vec, libraries: Vec>, dynamic_states: Vec, flags: VkPipelineCreateFlagBits, max_recursion: u32, shader_binding_table_builder: ShaderBindingTableBuilder, pipeline_cache: Option<&'a Arc>, } impl<'a> RayTracingPipelineBuilder<'a> { pub fn check_max_recursion(device: &Arc, max_recursion: u32) -> u32 { max_recursion.min( device .physical_device() .ray_tracing_properties() .maxRayRecursionDepth, ) } pub fn add_dynamic_state(mut self, dynamic_state: VkDynamicState) -> Self { self.dynamic_states.push(dynamic_state); self } pub fn set_pipeline_cache(mut self, pipeline_cache: &'a Arc) -> Self { self.pipeline_cache = Some(pipeline_cache); self } pub fn set_flags(mut self, flags: impl Into) -> Self { self.flags = flags.into(); self } pub fn add_library(mut self, library: Library<'a>) -> Self { self.libraries.push(library); self } pub fn add_shader( mut self, shader_module: impl Into, data: Option>, specialization_constants: Option, ) -> Self { let shader_module = shader_module.into(); self.shader_binding_table_builder = match shader_module { OtherShader::RayGeneration(_) => self .shader_binding_table_builder .add_ray_gen_program(self.shader_groups.len() as u32, data), OtherShader::Miss(_) => self .shader_binding_table_builder .add_miss_program(self.shader_groups.len() as u32, data), }; let shader_index = self.shader_modules.len(); self.shader_modules .push((shader_module.into(), specialization_constants)); self.shader_groups .push(VkRayTracingShaderGroupCreateInfoKHR::new( VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR, shader_index as u32, VK_SHADER_UNUSED_KHR, VK_SHADER_UNUSED_KHR, VK_SHADER_UNUSED_KHR, )); self } pub fn add_hit_shaders( mut self, shader_modules: impl IntoIterator)>, data: Option>, ) -> Self { let mut group = VkRayTracingShaderGroupCreateInfoKHR::new( VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR, VK_SHADER_UNUSED_KHR, VK_SHADER_UNUSED_KHR, VK_SHADER_UNUSED_KHR, VK_SHADER_UNUSED_KHR, ); for (shader_module, specialization_constant) in shader_modules.into_iter() { let shader_index = self.shader_modules.len() as u32; match shader_module { HitShader::AnyHit(_) => { // sanity check debug_assert_eq!( group.anyHitShader, VK_SHADER_UNUSED_KHR, "any hit shader already used in current hit group" ); group.anyHitShader = shader_index; } HitShader::ClosestHit(_) => { // sanity check debug_assert_eq!( group.closestHitShader, VK_SHADER_UNUSED_KHR, "closest hit shader already used in current hit group" ); group.closestHitShader = shader_index; } HitShader::Intersection(_) => { // sanity check debug_assert_eq!( group.intersectionShader, VK_SHADER_UNUSED_KHR, "intersection shader already used in current hit group" ); group.intersectionShader = shader_index; group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR; } } self.shader_modules .push((shader_module.into(), specialization_constant)); } self.shader_binding_table_builder = self .shader_binding_table_builder .add_hit_group_program(self.shader_groups.len() as u32, data); self.shader_groups.push(group); self } pub fn max_recursion_depth(mut self, max_recursion_depth: u32) -> Self { self.max_recursion = max_recursion_depth; self } pub fn build( mut self, device: Arc, pipeline_layout: &Arc, ) -> Result<(Arc, ShaderBindingTable)> { let shader_stages: Vec = self .shader_modules .iter() .map(|(shader, specialization_constant)| { let mut stage_info = match shader { RaytracingShader::RayGeneration(s) => s.pipeline_stage_info(), RaytracingShader::ClosestHit(s) => s.pipeline_stage_info(), RaytracingShader::Miss(s) => s.pipeline_stage_info(), RaytracingShader::AnyHit(s) => s.pipeline_stage_info(), RaytracingShader::Intersection(s) => s.pipeline_stage_info(), }; if let Some(specialization_constant) = specialization_constant { stage_info.set_specialization_info(specialization_constant.vk_handle()); } stage_info }) .collect(); // check that we dont exceed the gpu's capabilities let max_recursion = Self::check_max_recursion(&device, self.max_recursion); let pipeline = { let mut libraries = Vec::with_capacity(self.libraries.len()); let mut library_interface = VkRayTracingPipelineInterfaceCreateInfoKHR::new(0, 0); for library in self.libraries.iter() { libraries.push(library.pipeline.vk_handle()); library_interface.maxPipelineRayPayloadSize = library_interface .maxPipelineRayPayloadSize .max(library.max_payload_size); library_interface.maxPipelineRayHitAttributeSize = library_interface .maxPipelineRayHitAttributeSize .max(library.max_attribute_size); } let lib_create_info = VkPipelineLibraryCreateInfoKHR::new(&libraries); let dynamic_states = VkPipelineDynamicStateCreateInfo::new(0, &self.dynamic_states); device.create_ray_tracing_pipelines( None, self.pipeline_cache.map(|cache| cache.vk_handle()), &[VkRayTracingPipelineCreateInfoKHR::new( self.flags, &shader_stages, // stages &self.shader_groups, // groups max_recursion, &lib_create_info, // libraries &library_interface, // library interfaces &dynamic_states, pipeline_layout.vk_handle(), )], None, )?[0] }; let pipeline = Arc::new(Pipeline::new( device.clone(), pipeline_layout.clone(), PipelineType::RayTracing, pipeline, )); let sbt = self .shader_binding_table_builder .build(&device, &pipeline)?; Ok((pipeline, sbt)) } } impl<'a> Default for RayTracingPipelineBuilder<'a> { fn default() -> Self { RayTracingPipelineBuilder { shader_modules: Vec::new(), shader_groups: Vec::new(), flags: 0.into(), max_recursion: 2, libraries: Vec::new(), dynamic_states: Vec::new(), shader_binding_table_builder: ShaderBindingTableBuilder::new(), pipeline_cache: None, } } }