vulkan_lib/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs

342 lines
11 KiB
Rust
Raw Normal View History

2023-01-14 12:03:01 +00:00
use crate::pipeline::PipelineType;
use crate::prelude::*;
use anyhow::Result;
use std::sync::Arc;
use super::shader_binding_table::ShaderBindingTableBuilder;
pub struct Library<'a> {
pipeline: &'a Arc<Pipeline>,
max_payload_size: u32,
max_attribute_size: u32,
}
impl<'a> Library<'a> {
pub fn new(
pipeline: &'a Arc<Pipeline>,
max_payload_size: u32,
max_attribute_size: u32,
) -> Self {
Library {
pipeline,
max_payload_size,
max_attribute_size,
}
}
}
macro_rules! impl_from_shader_type {
($struct: ident, $shader_type: ident) => {
impl From<Arc<ShaderModule<{ ShaderType::$shader_type as u8 }>>> for $struct {
fn from(value: Arc<ShaderModule<{ ShaderType::$shader_type as u8 }>>) -> Self {
Self::$shader_type(value)
}
}
};
}
#[derive(Clone)]
enum RaytracingShader {
RayGeneration(Arc<ShaderModule<{ ShaderType::RayGeneration as u8 }>>),
ClosestHit(Arc<ShaderModule<{ ShaderType::ClosestHit as u8 }>>),
Miss(Arc<ShaderModule<{ ShaderType::Miss as u8 }>>),
AnyHit(Arc<ShaderModule<{ ShaderType::AnyHit as u8 }>>),
Intersection(Arc<ShaderModule<{ ShaderType::Intersection as u8 }>>),
}
impl_from_shader_type!(RaytracingShader, RayGeneration);
impl_from_shader_type!(RaytracingShader, ClosestHit);
impl_from_shader_type!(RaytracingShader, Miss);
impl_from_shader_type!(RaytracingShader, AnyHit);
impl_from_shader_type!(RaytracingShader, Intersection);
impl From<HitShader> for RaytracingShader {
fn from(value: HitShader) -> Self {
match value {
HitShader::ClosestHit(s) => Self::ClosestHit(s),
HitShader::AnyHit(s) => Self::AnyHit(s),
HitShader::Intersection(s) => Self::Intersection(s),
}
}
}
impl From<OtherShader> for RaytracingShader {
fn from(value: OtherShader) -> Self {
match value {
OtherShader::Miss(s) => Self::Miss(s),
OtherShader::RayGeneration(s) => Self::RayGeneration(s),
}
}
}
#[derive(Clone)]
pub enum HitShader {
ClosestHit(Arc<ShaderModule<{ ShaderType::ClosestHit as u8 }>>),
AnyHit(Arc<ShaderModule<{ ShaderType::AnyHit as u8 }>>),
Intersection(Arc<ShaderModule<{ ShaderType::Intersection as u8 }>>),
}
impl_from_shader_type!(HitShader, ClosestHit);
impl_from_shader_type!(HitShader, AnyHit);
impl_from_shader_type!(HitShader, Intersection);
#[derive(Clone)]
pub enum OtherShader {
RayGeneration(Arc<ShaderModule<{ ShaderType::RayGeneration as u8 }>>),
Miss(Arc<ShaderModule<{ ShaderType::Miss as u8 }>>),
}
impl_from_shader_type!(OtherShader, RayGeneration);
impl_from_shader_type!(OtherShader, Miss);
2023-01-14 12:03:01 +00:00
pub struct RayTracingPipelineBuilder<'a> {
shader_modules: Vec<(RaytracingShader, Option<SpecializationConstants>)>,
2023-01-14 12:03:01 +00:00
shader_groups: Vec<VkRayTracingShaderGroupCreateInfoKHR>,
libraries: Vec<Library<'a>>,
dynamic_states: Vec<VkDynamicState>,
flags: VkPipelineCreateFlagBits,
max_recursion: u32,
shader_binding_table_builder: ShaderBindingTableBuilder,
pipeline_cache: Option<&'a Arc<PipelineCache>>,
}
impl<'a> RayTracingPipelineBuilder<'a> {
pub fn check_max_recursion(device: &Arc<Device>, max_recursion: u32) -> u32 {
max_recursion.min(
device
.physical_device()
.ray_tracing_properties()
.maxRayRecursionDepth,
)
}
pub fn add_dynamic_state(mut self, dynamic_state: VkDynamicState) -> Self {
self.dynamic_states.push(dynamic_state);
self
}
pub fn set_pipeline_cache(mut self, pipeline_cache: &'a Arc<PipelineCache>) -> Self {
self.pipeline_cache = Some(pipeline_cache);
self
}
pub fn set_flags(mut self, flags: impl Into<VkPipelineCreateFlagBits>) -> Self {
self.flags = flags.into();
self
}
pub fn add_library(mut self, library: Library<'a>) -> Self {
self.libraries.push(library);
self
}
pub fn add_shader(
mut self,
shader_module: impl Into<OtherShader>,
2023-01-14 12:03:01 +00:00
data: Option<Vec<u8>>,
specialization_constants: Option<SpecializationConstants>,
) -> Self {
let shader_module = shader_module.into();
self.shader_binding_table_builder = match shader_module {
OtherShader::RayGeneration(_) => self
2023-01-14 12:03:01 +00:00
.shader_binding_table_builder
.add_ray_gen_program(self.shader_groups.len() as u32, data),
OtherShader::Miss(_) => self
2023-01-14 12:03:01 +00:00
.shader_binding_table_builder
.add_miss_program(self.shader_groups.len() as u32, data),
};
let shader_index = self.shader_modules.len();
self.shader_modules
.push((shader_module.into(), specialization_constants));
2023-01-14 12:03:01 +00:00
self.shader_groups
.push(VkRayTracingShaderGroupCreateInfoKHR::new(
VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR,
shader_index as u32,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
));
self
}
pub fn add_hit_shaders(
mut self,
shader_modules: impl IntoIterator<Item = (HitShader, Option<SpecializationConstants>)>,
2023-01-14 12:03:01 +00:00
data: Option<Vec<u8>>,
) -> Self {
let mut group = VkRayTracingShaderGroupCreateInfoKHR::new(
VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
VK_SHADER_UNUSED_KHR,
);
for (shader_module, specialization_constant) in shader_modules.into_iter() {
let shader_index = self.shader_modules.len() as u32;
match shader_module {
HitShader::AnyHit(_) => {
2023-01-14 12:03:01 +00:00
// sanity check
2023-01-28 14:55:40 +00:00
debug_assert_eq!(
group.anyHitShader, VK_SHADER_UNUSED_KHR,
"any hit shader already used in current hit group"
);
2023-01-14 12:03:01 +00:00
group.anyHitShader = shader_index;
}
HitShader::ClosestHit(_) => {
2023-01-14 12:03:01 +00:00
// sanity check
2023-01-28 14:55:40 +00:00
debug_assert_eq!(
group.closestHitShader, VK_SHADER_UNUSED_KHR,
"closest hit shader already used in current hit group"
);
2023-01-14 12:03:01 +00:00
group.closestHitShader = shader_index;
}
HitShader::Intersection(_) => {
2023-01-14 12:03:01 +00:00
// sanity check
2023-01-28 14:55:40 +00:00
debug_assert_eq!(
group.intersectionShader, VK_SHADER_UNUSED_KHR,
"intersection shader already used in current hit group"
);
2023-01-14 12:03:01 +00:00
group.intersectionShader = shader_index;
group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
}
}
self.shader_modules
.push((shader_module.into(), specialization_constant));
2023-01-14 12:03:01 +00:00
}
self.shader_binding_table_builder = self
.shader_binding_table_builder
.add_hit_group_program(self.shader_groups.len() as u32, data);
self.shader_groups.push(group);
self
}
pub fn max_recursion_depth(mut self, max_recursion_depth: u32) -> Self {
self.max_recursion = max_recursion_depth;
self
}
pub fn build(
mut self,
device: Arc<Device>,
pipeline_layout: &Arc<PipelineLayout>,
) -> Result<(Arc<Pipeline>, ShaderBindingTable)> {
let shader_stages: Vec<VkPipelineShaderStageCreateInfo> = self
.shader_modules
.iter()
.map(|(shader, specialization_constant)| {
let mut stage_info = match shader {
RaytracingShader::RayGeneration(s) => s.pipeline_stage_info(),
RaytracingShader::ClosestHit(s) => s.pipeline_stage_info(),
RaytracingShader::Miss(s) => s.pipeline_stage_info(),
RaytracingShader::AnyHit(s) => s.pipeline_stage_info(),
RaytracingShader::Intersection(s) => s.pipeline_stage_info(),
};
2023-01-14 12:03:01 +00:00
if let Some(specialization_constant) = specialization_constant {
stage_info.set_specialization_info(specialization_constant.vk_handle());
}
stage_info
})
.collect();
// check that we dont exceed the gpu's capabilities
let max_recursion = Self::check_max_recursion(&device, self.max_recursion);
let pipeline = {
let mut libraries = Vec::with_capacity(self.libraries.len());
let mut library_interface = VkRayTracingPipelineInterfaceCreateInfoKHR::new(0, 0);
for library in self.libraries.iter() {
libraries.push(library.pipeline.vk_handle());
library_interface.maxPipelineRayPayloadSize = library_interface
.maxPipelineRayPayloadSize
.max(library.max_payload_size);
library_interface.maxPipelineRayHitAttributeSize = library_interface
.maxPipelineRayHitAttributeSize
.max(library.max_attribute_size);
}
let lib_create_info = VkPipelineLibraryCreateInfoKHR::new(&libraries);
let dynamic_states = VkPipelineDynamicStateCreateInfo::new(0, &self.dynamic_states);
device.create_ray_tracing_pipelines(
None,
self.pipeline_cache.map(|cache| cache.vk_handle()),
&[VkRayTracingPipelineCreateInfoKHR::new(
self.flags,
&shader_stages, // stages
&self.shader_groups, // groups
max_recursion,
&lib_create_info, // libraries
&library_interface, // library interfaces
&dynamic_states,
pipeline_layout.vk_handle(),
)],
None,
)?[0]
};
let pipeline = Arc::new(Pipeline::new(
device.clone(),
pipeline_layout.clone(),
PipelineType::RayTracing,
pipeline,
));
let sbt = self
.shader_binding_table_builder
.build(&device, &pipeline)?;
Ok((pipeline, sbt))
}
}
impl<'a> Default for RayTracingPipelineBuilder<'a> {
fn default() -> Self {
RayTracingPipelineBuilder {
shader_modules: Vec::new(),
shader_groups: Vec::new(),
flags: 0.into(),
max_recursion: 2,
libraries: Vec::new(),
dynamic_states: Vec::new(),
shader_binding_table_builder: ShaderBindingTableBuilder::new(),
pipeline_cache: None,
}
}
}