2023-01-14 12:03:01 +00:00
|
|
|
use crate::pipeline::PipelineType;
|
|
|
|
use crate::prelude::*;
|
|
|
|
|
|
|
|
use anyhow::Result;
|
|
|
|
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
use super::shader_binding_table::ShaderBindingTableBuilder;
|
|
|
|
|
|
|
|
pub struct Library<'a> {
|
|
|
|
pipeline: &'a Arc<Pipeline>,
|
|
|
|
|
|
|
|
max_payload_size: u32,
|
|
|
|
max_attribute_size: u32,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a> Library<'a> {
|
|
|
|
pub fn new(
|
|
|
|
pipeline: &'a Arc<Pipeline>,
|
|
|
|
max_payload_size: u32,
|
|
|
|
max_attribute_size: u32,
|
|
|
|
) -> Self {
|
|
|
|
Library {
|
|
|
|
pipeline,
|
|
|
|
|
|
|
|
max_payload_size,
|
|
|
|
max_attribute_size,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-01-27 09:12:59 +00:00
|
|
|
macro_rules! impl_from_shader_type {
|
|
|
|
($struct: ident, $shader_type: ident) => {
|
|
|
|
impl From<Arc<ShaderModule<{ ShaderType::$shader_type as u8 }>>> for $struct {
|
|
|
|
fn from(value: Arc<ShaderModule<{ ShaderType::$shader_type as u8 }>>) -> Self {
|
|
|
|
Self::$shader_type(value)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
enum RaytracingShader {
|
|
|
|
RayGeneration(Arc<ShaderModule<{ ShaderType::RayGeneration as u8 }>>),
|
|
|
|
ClosestHit(Arc<ShaderModule<{ ShaderType::ClosestHit as u8 }>>),
|
|
|
|
Miss(Arc<ShaderModule<{ ShaderType::Miss as u8 }>>),
|
|
|
|
AnyHit(Arc<ShaderModule<{ ShaderType::AnyHit as u8 }>>),
|
|
|
|
Intersection(Arc<ShaderModule<{ ShaderType::Intersection as u8 }>>),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl_from_shader_type!(RaytracingShader, RayGeneration);
|
|
|
|
impl_from_shader_type!(RaytracingShader, ClosestHit);
|
|
|
|
impl_from_shader_type!(RaytracingShader, Miss);
|
|
|
|
impl_from_shader_type!(RaytracingShader, AnyHit);
|
|
|
|
impl_from_shader_type!(RaytracingShader, Intersection);
|
|
|
|
|
|
|
|
impl From<HitShader> for RaytracingShader {
|
|
|
|
fn from(value: HitShader) -> Self {
|
|
|
|
match value {
|
|
|
|
HitShader::ClosestHit(s) => Self::ClosestHit(s),
|
|
|
|
HitShader::AnyHit(s) => Self::AnyHit(s),
|
|
|
|
HitShader::Intersection(s) => Self::Intersection(s),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl From<OtherShader> for RaytracingShader {
|
|
|
|
fn from(value: OtherShader) -> Self {
|
|
|
|
match value {
|
|
|
|
OtherShader::Miss(s) => Self::Miss(s),
|
|
|
|
OtherShader::RayGeneration(s) => Self::RayGeneration(s),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
pub enum HitShader {
|
|
|
|
ClosestHit(Arc<ShaderModule<{ ShaderType::ClosestHit as u8 }>>),
|
|
|
|
AnyHit(Arc<ShaderModule<{ ShaderType::AnyHit as u8 }>>),
|
|
|
|
Intersection(Arc<ShaderModule<{ ShaderType::Intersection as u8 }>>),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl_from_shader_type!(HitShader, ClosestHit);
|
|
|
|
impl_from_shader_type!(HitShader, AnyHit);
|
|
|
|
impl_from_shader_type!(HitShader, Intersection);
|
|
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
pub enum OtherShader {
|
|
|
|
RayGeneration(Arc<ShaderModule<{ ShaderType::RayGeneration as u8 }>>),
|
|
|
|
Miss(Arc<ShaderModule<{ ShaderType::Miss as u8 }>>),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl_from_shader_type!(OtherShader, RayGeneration);
|
|
|
|
impl_from_shader_type!(OtherShader, Miss);
|
|
|
|
|
2023-01-14 12:03:01 +00:00
|
|
|
pub struct RayTracingPipelineBuilder<'a> {
|
2023-01-27 09:12:59 +00:00
|
|
|
shader_modules: Vec<(RaytracingShader, Option<SpecializationConstants>)>,
|
2023-01-14 12:03:01 +00:00
|
|
|
|
|
|
|
shader_groups: Vec<VkRayTracingShaderGroupCreateInfoKHR>,
|
|
|
|
|
|
|
|
libraries: Vec<Library<'a>>,
|
|
|
|
|
|
|
|
dynamic_states: Vec<VkDynamicState>,
|
|
|
|
|
|
|
|
flags: VkPipelineCreateFlagBits,
|
|
|
|
max_recursion: u32,
|
|
|
|
|
|
|
|
shader_binding_table_builder: ShaderBindingTableBuilder,
|
|
|
|
|
|
|
|
pipeline_cache: Option<&'a Arc<PipelineCache>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a> RayTracingPipelineBuilder<'a> {
|
|
|
|
pub fn check_max_recursion(device: &Arc<Device>, max_recursion: u32) -> u32 {
|
|
|
|
max_recursion.min(
|
|
|
|
device
|
|
|
|
.physical_device()
|
|
|
|
.ray_tracing_properties()
|
|
|
|
.maxRayRecursionDepth,
|
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn add_dynamic_state(mut self, dynamic_state: VkDynamicState) -> Self {
|
|
|
|
self.dynamic_states.push(dynamic_state);
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn set_pipeline_cache(mut self, pipeline_cache: &'a Arc<PipelineCache>) -> Self {
|
|
|
|
self.pipeline_cache = Some(pipeline_cache);
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn set_flags(mut self, flags: impl Into<VkPipelineCreateFlagBits>) -> Self {
|
|
|
|
self.flags = flags.into();
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn add_library(mut self, library: Library<'a>) -> Self {
|
|
|
|
self.libraries.push(library);
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn add_shader(
|
|
|
|
mut self,
|
2023-01-27 09:12:59 +00:00
|
|
|
shader_module: impl Into<OtherShader>,
|
2023-01-14 12:03:01 +00:00
|
|
|
data: Option<Vec<u8>>,
|
|
|
|
specialization_constants: Option<SpecializationConstants>,
|
|
|
|
) -> Self {
|
2023-01-27 09:12:59 +00:00
|
|
|
let shader_module = shader_module.into();
|
|
|
|
|
|
|
|
self.shader_binding_table_builder = match shader_module {
|
|
|
|
OtherShader::RayGeneration(_) => self
|
2023-01-14 12:03:01 +00:00
|
|
|
.shader_binding_table_builder
|
|
|
|
.add_ray_gen_program(self.shader_groups.len() as u32, data),
|
2023-01-27 09:12:59 +00:00
|
|
|
OtherShader::Miss(_) => self
|
2023-01-14 12:03:01 +00:00
|
|
|
.shader_binding_table_builder
|
|
|
|
.add_miss_program(self.shader_groups.len() as u32, data),
|
|
|
|
};
|
|
|
|
|
|
|
|
let shader_index = self.shader_modules.len();
|
|
|
|
self.shader_modules
|
2023-01-27 09:12:59 +00:00
|
|
|
.push((shader_module.into(), specialization_constants));
|
2023-01-14 12:03:01 +00:00
|
|
|
|
|
|
|
self.shader_groups
|
|
|
|
.push(VkRayTracingShaderGroupCreateInfoKHR::new(
|
|
|
|
VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
|
|
|
|
shader_index as u32,
|
|
|
|
VK_SHADER_UNUSED_KHR,
|
|
|
|
VK_SHADER_UNUSED_KHR,
|
|
|
|
VK_SHADER_UNUSED_KHR,
|
|
|
|
));
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn add_hit_shaders(
|
|
|
|
mut self,
|
2023-01-27 09:12:59 +00:00
|
|
|
shader_modules: impl IntoIterator<Item = (HitShader, Option<SpecializationConstants>)>,
|
2023-01-14 12:03:01 +00:00
|
|
|
data: Option<Vec<u8>>,
|
|
|
|
) -> Self {
|
|
|
|
let mut group = VkRayTracingShaderGroupCreateInfoKHR::new(
|
|
|
|
VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
|
|
|
|
VK_SHADER_UNUSED_KHR,
|
|
|
|
VK_SHADER_UNUSED_KHR,
|
|
|
|
VK_SHADER_UNUSED_KHR,
|
|
|
|
VK_SHADER_UNUSED_KHR,
|
|
|
|
);
|
|
|
|
|
|
|
|
for (shader_module, specialization_constant) in shader_modules.into_iter() {
|
|
|
|
let shader_index = self.shader_modules.len() as u32;
|
|
|
|
|
2023-01-27 09:12:59 +00:00
|
|
|
match shader_module {
|
|
|
|
HitShader::AnyHit(_) => {
|
2023-01-14 12:03:01 +00:00
|
|
|
// sanity check
|
2023-01-28 14:55:40 +00:00
|
|
|
debug_assert_eq!(
|
2023-01-27 09:12:59 +00:00
|
|
|
group.anyHitShader, VK_SHADER_UNUSED_KHR,
|
|
|
|
"any hit shader already used in current hit group"
|
|
|
|
);
|
2023-01-14 12:03:01 +00:00
|
|
|
|
|
|
|
group.anyHitShader = shader_index;
|
|
|
|
}
|
2023-01-27 09:12:59 +00:00
|
|
|
HitShader::ClosestHit(_) => {
|
2023-01-14 12:03:01 +00:00
|
|
|
// sanity check
|
2023-01-28 14:55:40 +00:00
|
|
|
debug_assert_eq!(
|
2023-01-27 09:12:59 +00:00
|
|
|
group.closestHitShader, VK_SHADER_UNUSED_KHR,
|
|
|
|
"closest hit shader already used in current hit group"
|
|
|
|
);
|
2023-01-14 12:03:01 +00:00
|
|
|
|
|
|
|
group.closestHitShader = shader_index;
|
|
|
|
}
|
2023-01-27 09:12:59 +00:00
|
|
|
HitShader::Intersection(_) => {
|
2023-01-14 12:03:01 +00:00
|
|
|
// sanity check
|
2023-01-28 14:55:40 +00:00
|
|
|
debug_assert_eq!(
|
2023-01-27 09:12:59 +00:00
|
|
|
group.intersectionShader, VK_SHADER_UNUSED_KHR,
|
|
|
|
"intersection shader already used in current hit group"
|
|
|
|
);
|
2023-01-14 12:03:01 +00:00
|
|
|
|
|
|
|
group.intersectionShader = shader_index;
|
|
|
|
group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
self.shader_modules
|
2023-01-27 09:12:59 +00:00
|
|
|
.push((shader_module.into(), specialization_constant));
|
2023-01-14 12:03:01 +00:00
|
|
|
}
|
|
|
|
self.shader_binding_table_builder = self
|
|
|
|
.shader_binding_table_builder
|
|
|
|
.add_hit_group_program(self.shader_groups.len() as u32, data);
|
|
|
|
self.shader_groups.push(group);
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn max_recursion_depth(mut self, max_recursion_depth: u32) -> Self {
|
|
|
|
self.max_recursion = max_recursion_depth;
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn build(
|
|
|
|
mut self,
|
|
|
|
device: Arc<Device>,
|
|
|
|
pipeline_layout: &Arc<PipelineLayout>,
|
|
|
|
) -> Result<(Arc<Pipeline>, ShaderBindingTable)> {
|
|
|
|
let shader_stages: Vec<VkPipelineShaderStageCreateInfo> = self
|
|
|
|
.shader_modules
|
|
|
|
.iter()
|
|
|
|
.map(|(shader, specialization_constant)| {
|
2023-01-27 09:12:59 +00:00
|
|
|
let mut stage_info = match shader {
|
|
|
|
RaytracingShader::RayGeneration(s) => s.pipeline_stage_info(),
|
|
|
|
RaytracingShader::ClosestHit(s) => s.pipeline_stage_info(),
|
|
|
|
RaytracingShader::Miss(s) => s.pipeline_stage_info(),
|
|
|
|
RaytracingShader::AnyHit(s) => s.pipeline_stage_info(),
|
|
|
|
RaytracingShader::Intersection(s) => s.pipeline_stage_info(),
|
|
|
|
};
|
|
|
|
|
2023-01-14 12:03:01 +00:00
|
|
|
if let Some(specialization_constant) = specialization_constant {
|
|
|
|
stage_info.set_specialization_info(specialization_constant.vk_handle());
|
|
|
|
}
|
|
|
|
|
|
|
|
stage_info
|
|
|
|
})
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
// check that we dont exceed the gpu's capabilities
|
|
|
|
let max_recursion = Self::check_max_recursion(&device, self.max_recursion);
|
|
|
|
|
|
|
|
let pipeline = {
|
|
|
|
let mut libraries = Vec::with_capacity(self.libraries.len());
|
|
|
|
let mut library_interface = VkRayTracingPipelineInterfaceCreateInfoKHR::new(0, 0);
|
|
|
|
|
|
|
|
for library in self.libraries.iter() {
|
|
|
|
libraries.push(library.pipeline.vk_handle());
|
|
|
|
|
|
|
|
library_interface.maxPipelineRayPayloadSize = library_interface
|
|
|
|
.maxPipelineRayPayloadSize
|
|
|
|
.max(library.max_payload_size);
|
|
|
|
|
|
|
|
library_interface.maxPipelineRayHitAttributeSize = library_interface
|
|
|
|
.maxPipelineRayHitAttributeSize
|
|
|
|
.max(library.max_attribute_size);
|
|
|
|
}
|
|
|
|
|
|
|
|
let lib_create_info = VkPipelineLibraryCreateInfoKHR::new(&libraries);
|
|
|
|
|
|
|
|
let dynamic_states = VkPipelineDynamicStateCreateInfo::new(0, &self.dynamic_states);
|
|
|
|
|
|
|
|
device.create_ray_tracing_pipelines(
|
|
|
|
None,
|
|
|
|
self.pipeline_cache.map(|cache| cache.vk_handle()),
|
|
|
|
&[VkRayTracingPipelineCreateInfoKHR::new(
|
|
|
|
self.flags,
|
|
|
|
&shader_stages, // stages
|
|
|
|
&self.shader_groups, // groups
|
|
|
|
max_recursion,
|
|
|
|
&lib_create_info, // libraries
|
|
|
|
&library_interface, // library interfaces
|
|
|
|
&dynamic_states,
|
|
|
|
pipeline_layout.vk_handle(),
|
|
|
|
)],
|
|
|
|
None,
|
|
|
|
)?[0]
|
|
|
|
};
|
|
|
|
|
|
|
|
let pipeline = Arc::new(Pipeline::new(
|
|
|
|
device.clone(),
|
|
|
|
pipeline_layout.clone(),
|
|
|
|
PipelineType::RayTracing,
|
|
|
|
pipeline,
|
|
|
|
));
|
|
|
|
|
|
|
|
let sbt = self
|
|
|
|
.shader_binding_table_builder
|
|
|
|
.build(&device, &pipeline)?;
|
|
|
|
|
|
|
|
Ok((pipeline, sbt))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a> Default for RayTracingPipelineBuilder<'a> {
|
|
|
|
fn default() -> Self {
|
|
|
|
RayTracingPipelineBuilder {
|
|
|
|
shader_modules: Vec::new(),
|
|
|
|
shader_groups: Vec::new(),
|
|
|
|
|
|
|
|
flags: 0.into(),
|
|
|
|
max_recursion: 2,
|
|
|
|
|
|
|
|
libraries: Vec::new(),
|
|
|
|
dynamic_states: Vec::new(),
|
|
|
|
|
|
|
|
shader_binding_table_builder: ShaderBindingTableBuilder::new(),
|
|
|
|
|
|
|
|
pipeline_cache: None,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|