From a78f9382731b249d20ead74d3a59d086261aa3db Mon Sep 17 00:00:00 2001 From: hodasemi Date: Fri, 27 Jan 2023 10:12:59 +0100 Subject: [PATCH] Use const generics for shader module types --- vulkan-rs/Cargo.toml | 1 + vulkan-rs/src/macros.rs | 10 +- vulkan-rs/src/pipelines/compute_pipeline.rs | 11 +- vulkan-rs/src/pipelines/graphics_pipeline.rs | 59 +++------ .../src/pipelines/ray_tracing_pipeline.rs | 125 ++++++++++++++---- vulkan-rs/src/prelude.rs | 3 +- vulkan-rs/src/shadermodule.rs | 109 +++++++-------- 7 files changed, 189 insertions(+), 129 deletions(-) diff --git a/vulkan-rs/Cargo.toml b/vulkan-rs/Cargo.toml index 98bb031..5359efa 100644 --- a/vulkan-rs/Cargo.toml +++ b/vulkan-rs/Cargo.toml @@ -11,3 +11,4 @@ vma-rs = { path = "../vma-rs" } anyhow = { version = "1.0.68", features = ["backtrace"] } cgmath = "0.18.0" assetpath = { path = "../assetpath" } +safer-ffi = "0.0.10" diff --git a/vulkan-rs/src/macros.rs b/vulkan-rs/src/macros.rs index fe919ac..6770bfa 100644 --- a/vulkan-rs/src/macros.rs +++ b/vulkan-rs/src/macros.rs @@ -1,24 +1,24 @@ macro_rules! impl_vk_handle { - ($struct_name:ident, $target_name:ident, $value:ident) => { - impl VkHandle<$target_name> for $struct_name { + ($struct_name:ident $(<$( $const:ident $name:ident: $type:ident, )*>)?, $target_name:ident, $value:ident) => { + impl$(<$( $const $name: $type, )*>)? VkHandle<$target_name> for $struct_name$(<$($name,)?>)? { fn vk_handle(&self) -> $target_name { self.$value } } - impl<'a> VkHandle<$target_name> for &'a $struct_name { + impl<'a $($(, $const $name: $type)*)?> VkHandle<$target_name> for &'a $struct_name$(<$($name,)?>)? { fn vk_handle(&self) -> $target_name { self.$value } } - impl VkHandle<$target_name> for Arc<$struct_name> { + impl$(<$( $const $name: $type, )*>)? VkHandle<$target_name> for Arc<$struct_name$(<$($name,)?>)?> { fn vk_handle(&self) -> $target_name { self.$value } } - impl<'a> VkHandle<$target_name> for &'a Arc<$struct_name> { + impl<'a $($(, $const $name: $type)*)?> VkHandle<$target_name> for &'a Arc<$struct_name$(<$($name,)?>)?> { fn vk_handle(&self) -> $target_name { self.$value } diff --git a/vulkan-rs/src/pipelines/compute_pipeline.rs b/vulkan-rs/src/pipelines/compute_pipeline.rs index 0f618ce..dd2dc73 100644 --- a/vulkan-rs/src/pipelines/compute_pipeline.rs +++ b/vulkan-rs/src/pipelines/compute_pipeline.rs @@ -6,22 +6,21 @@ use crate::prelude::*; use std::sync::Arc; pub struct ComputePipelineBuilder<'a> { - shader_module: Option<&'a Arc>, + shader_module: Option<&'a Arc>>, pipeline_cache: Option<&'a Arc>, flags: VkPipelineCreateFlagBits, } impl<'a> ComputePipelineBuilder<'a> { // TODO: add support for specialization constants - pub fn set_shader_module(mut self, shader_module: &'a Arc) -> Self { + pub fn set_shader_module( + mut self, + shader_module: &'a Arc>, + ) -> Self { if cfg!(debug_assertions) { if self.shader_module.is_some() { panic!("shader already set!"); } - - if shader_module.shader_type() != ShaderType::Compute { - panic!("shader has wrong type!"); - } } self.shader_module = Some(shader_module); diff --git a/vulkan-rs/src/pipelines/graphics_pipeline.rs b/vulkan-rs/src/pipelines/graphics_pipeline.rs index c959589..1e25341 100644 --- a/vulkan-rs/src/pipelines/graphics_pipeline.rs +++ b/vulkan-rs/src/pipelines/graphics_pipeline.rs @@ -12,18 +12,21 @@ pub struct GraphicsPipelineBuilder { amd_rasterization_order: Option, - vertex_shader: Option>, + vertex_shader: Option>>, vertex_binding_description: Vec, vertex_attribute_description: Vec, input_assembly: Option, - tesselation_shader: Option<(Arc, Arc)>, + tesselation_shader: Option<( + Arc>, + Arc>, + )>, patch_control_points: u32, - geometry_shader: Option>, + geometry_shader: Option>>, - fragment_shader: Option>, + fragment_shader: Option>>, viewports: Vec, scissors: Vec, @@ -40,19 +43,13 @@ pub struct GraphicsPipelineBuilder { impl GraphicsPipelineBuilder { // TODO: add support for specialization constants - pub fn set_vertex_shader( + pub fn set_vertex_shader( mut self, - shader: Arc, - vertex_binding_description: Vec, - vertex_attribute_description: Vec, + shader: Arc>, ) -> Self { - if cfg!(debug_assertions) { - assert_eq!(shader.shader_type(), ShaderType::Vertex); - } - self.vertex_shader = Some(shader); - self.vertex_binding_description = vertex_binding_description; - self.vertex_attribute_description = vertex_attribute_description; + self.vertex_binding_description = T::bindings(); + self.vertex_attribute_description = T::attributes(); self } @@ -60,22 +57,10 @@ impl GraphicsPipelineBuilder { // TODO: add support for specialization constants pub fn set_tesselation_shader( mut self, - tesselation_control: Arc, - tesselation_evaluation: Arc, + tesselation_control: Arc>, + tesselation_evaluation: Arc>, patch_control_points: u32, ) -> Self { - if cfg!(debug_assertions) { - assert_eq!( - tesselation_control.shader_type(), - ShaderType::TesselationControl - ); - - assert_eq!( - tesselation_evaluation.shader_type(), - ShaderType::TesselationEvaluation - ); - } - self.tesselation_shader = Some((tesselation_control, tesselation_evaluation)); self.patch_control_points = patch_control_points; @@ -83,22 +68,20 @@ impl GraphicsPipelineBuilder { } // TODO: add support for specialization constants - pub fn set_geometry_shader(mut self, shader: Arc) -> Self { - if cfg!(debug_assertions) { - assert_eq!(shader.shader_type(), ShaderType::Geometry); - } - + pub fn set_geometry_shader( + mut self, + shader: Arc>, + ) -> Self { self.geometry_shader = Some(shader); self } // TODO: add support for specialization constants - pub fn set_fragment_shader(mut self, shader: Arc) -> Self { - if cfg!(debug_assertions) { - assert_eq!(shader.shader_type(), ShaderType::Fragment); - } - + pub fn set_fragment_shader( + mut self, + shader: Arc>, + ) -> Self { self.fragment_shader = Some(shader); self diff --git a/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs b/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs index c166af6..9f90063 100644 --- a/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs +++ b/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs @@ -29,8 +29,72 @@ impl<'a> Library<'a> { } } +macro_rules! impl_from_shader_type { + ($struct: ident, $shader_type: ident) => { + impl From>> for $struct { + fn from(value: Arc>) -> Self { + Self::$shader_type(value) + } + } + }; +} + +#[derive(Clone)] +enum RaytracingShader { + RayGeneration(Arc>), + ClosestHit(Arc>), + Miss(Arc>), + AnyHit(Arc>), + Intersection(Arc>), +} + +impl_from_shader_type!(RaytracingShader, RayGeneration); +impl_from_shader_type!(RaytracingShader, ClosestHit); +impl_from_shader_type!(RaytracingShader, Miss); +impl_from_shader_type!(RaytracingShader, AnyHit); +impl_from_shader_type!(RaytracingShader, Intersection); + +impl From for RaytracingShader { + fn from(value: HitShader) -> Self { + match value { + HitShader::ClosestHit(s) => Self::ClosestHit(s), + HitShader::AnyHit(s) => Self::AnyHit(s), + HitShader::Intersection(s) => Self::Intersection(s), + } + } +} + +impl From for RaytracingShader { + fn from(value: OtherShader) -> Self { + match value { + OtherShader::Miss(s) => Self::Miss(s), + OtherShader::RayGeneration(s) => Self::RayGeneration(s), + } + } +} + +#[derive(Clone)] +pub enum HitShader { + ClosestHit(Arc>), + AnyHit(Arc>), + Intersection(Arc>), +} + +impl_from_shader_type!(HitShader, ClosestHit); +impl_from_shader_type!(HitShader, AnyHit); +impl_from_shader_type!(HitShader, Intersection); + +#[derive(Clone)] +pub enum OtherShader { + RayGeneration(Arc>), + Miss(Arc>), +} + +impl_from_shader_type!(OtherShader, RayGeneration); +impl_from_shader_type!(OtherShader, Miss); + pub struct RayTracingPipelineBuilder<'a> { - shader_modules: Vec<(Arc, Option)>, + shader_modules: Vec<(RaytracingShader, Option)>, shader_groups: Vec, @@ -82,26 +146,24 @@ impl<'a> RayTracingPipelineBuilder<'a> { pub fn add_shader( mut self, - shader_module: Arc, + shader_module: impl Into, data: Option>, specialization_constants: Option, ) -> Self { - self.shader_binding_table_builder = match shader_module.shader_type() { - ShaderType::RayGeneration => self + let shader_module = shader_module.into(); + + self.shader_binding_table_builder = match shader_module { + OtherShader::RayGeneration(_) => self .shader_binding_table_builder .add_ray_gen_program(self.shader_groups.len() as u32, data), - ShaderType::Miss => self + OtherShader::Miss(_) => self .shader_binding_table_builder .add_miss_program(self.shader_groups.len() as u32, data), - _ => panic!( - "unsupported shader type: {:?}, expected RayGen or Miss Shader", - shader_module.shader_type() - ), }; let shader_index = self.shader_modules.len(); self.shader_modules - .push((shader_module, specialization_constants)); + .push((shader_module.into(), specialization_constants)); self.shader_groups .push(VkRayTracingShaderGroupCreateInfoKHR::new( @@ -117,7 +179,7 @@ impl<'a> RayTracingPipelineBuilder<'a> { pub fn add_hit_shaders( mut self, - shader_modules: impl IntoIterator, Option)>, + shader_modules: impl IntoIterator)>, data: Option>, ) -> Self { let mut group = VkRayTracingShaderGroupCreateInfoKHR::new( @@ -131,37 +193,39 @@ impl<'a> RayTracingPipelineBuilder<'a> { for (shader_module, specialization_constant) in shader_modules.into_iter() { let shader_index = self.shader_modules.len() as u32; - match shader_module.shader_type() { - ShaderType::AnyHit => { + match shader_module { + HitShader::AnyHit(_) => { // sanity check - if cfg!(debug_assertions) && group.anyHitShader != VK_SHADER_UNUSED_KHR { - panic!("any hit shader already used in current hit group"); - } + debug_assert_ne!( + group.anyHitShader, VK_SHADER_UNUSED_KHR, + "any hit shader already used in current hit group" + ); group.anyHitShader = shader_index; } - ShaderType::ClosestHit => { + HitShader::ClosestHit(_) => { // sanity check - if cfg!(debug_assertions) && group.closestHitShader != VK_SHADER_UNUSED_KHR { - panic!("closest hit shader already used in current hit group"); - } + debug_assert_ne!( + group.closestHitShader, VK_SHADER_UNUSED_KHR, + "closest hit shader already used in current hit group" + ); group.closestHitShader = shader_index; } - ShaderType::Intersection => { + HitShader::Intersection(_) => { // sanity check - if cfg!(debug_assertions) && group.intersectionShader != VK_SHADER_UNUSED_KHR { - panic!("intersection shader already used in current hit group"); - } + debug_assert_ne!( + group.intersectionShader, VK_SHADER_UNUSED_KHR, + "intersection shader already used in current hit group" + ); group.intersectionShader = shader_index; group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR; } - _ => panic!("unsupported shader type: {:?}, expected AnyHit, ClosestHit or Intersection Shader", shader_module.shader_type()), } self.shader_modules - .push((shader_module, specialization_constant)); + .push((shader_module.into(), specialization_constant)); } self.shader_binding_table_builder = self .shader_binding_table_builder @@ -186,7 +250,14 @@ impl<'a> RayTracingPipelineBuilder<'a> { .shader_modules .iter() .map(|(shader, specialization_constant)| { - let mut stage_info = shader.pipeline_stage_info(); + let mut stage_info = match shader { + RaytracingShader::RayGeneration(s) => s.pipeline_stage_info(), + RaytracingShader::ClosestHit(s) => s.pipeline_stage_info(), + RaytracingShader::Miss(s) => s.pipeline_stage_info(), + RaytracingShader::AnyHit(s) => s.pipeline_stage_info(), + RaytracingShader::Intersection(s) => s.pipeline_stage_info(), + }; + if let Some(specialization_constant) = specialization_constant { stage_info.set_specialization_info(specialization_constant.vk_handle()); } diff --git a/vulkan-rs/src/prelude.rs b/vulkan-rs/src/prelude.rs index 8df00f3..efdf95a 100644 --- a/vulkan-rs/src/prelude.rs +++ b/vulkan-rs/src/prelude.rs @@ -24,7 +24,8 @@ pub use super::renderpass::RenderPass; pub use super::sampler_manager::{Sampler, SamplerBuilder}; pub use super::semaphore::Semaphore; pub use super::shadermodule::{ - AddSpecializationConstant, ShaderModule, ShaderType, SpecializationConstants, + AddSpecializationConstant, PipelineStageInfo, ShaderModule, ShaderType, + SpecializationConstants, VertexInputDescription, }; pub use super::surface::Surface; pub use super::swapchain::Swapchain; diff --git a/vulkan-rs/src/shadermodule.rs b/vulkan-rs/src/shadermodule.rs index 7cce834..f933cd5 100644 --- a/vulkan-rs/src/shadermodule.rs +++ b/vulkan-rs/src/shadermodule.rs @@ -7,6 +7,7 @@ use std::io::Read; use std::sync::Arc; #[allow(clippy::cast_ptr_alignment)] +#[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq)] pub enum ShaderType { None, @@ -23,35 +24,59 @@ pub enum ShaderType { Intersection, } +impl ShaderType { + pub const VERTEX: u8 = Self::Vertex as u8; + pub const FRAGMENT: u8 = Self::Fragment as u8; + pub const GEOMETRY: u8 = Self::Geometry as u8; + pub const TESSELATION_CONTROL: u8 = Self::TesselationControl as u8; + pub const TESSELATION_EVALUATION: u8 = Self::TesselationEvaluation as u8; + pub const COMPUTE: u8 = Self::Compute as u8; + pub const RAY_GENERATION: u8 = Self::RayGeneration as u8; + pub const CLOSEST_HIT: u8 = Self::ClosestHit as u8; + pub const MISS: u8 = Self::Miss as u8; + pub const ANY_HIT: u8 = Self::AnyHit as u8; + pub const INTERSECTION: u8 = Self::Intersection as u8; +} + impl Default for ShaderType { fn default() -> Self { ShaderType::None } } -#[derive(Debug)] -pub struct ShaderModule { - device: Arc, - shader_module: VkShaderModule, - shader_type: ShaderType, +pub trait VertexInputDescription { + fn bindings() -> Vec; + fn attributes() -> Vec; } -impl ShaderModule { - pub fn new( - device: Arc, - path: &str, - shader_type: ShaderType, - ) -> Result> { +pub trait PipelineStageInfo { + fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo; +} + +macro_rules! impl_pipeline_stage_info { + ($func:ident, $type:ident) => { + impl PipelineStageInfo for ShaderModule<{ ShaderType::$type as u8 }> { + fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo { + VkPipelineShaderStageCreateInfo::$func(self.shader_module) + } + } + }; +} + +#[derive(Debug)] +pub struct ShaderModule { + device: Arc, + shader_module: VkShaderModule, +} + +impl ShaderModule { + pub fn new(device: Arc, path: &str) -> Result>> { let code = Self::shader_code(path)?; - Self::from_slice(device, code.as_slice(), shader_type) + Self::from_slice(device, code.as_slice()) } - pub fn from_slice( - device: Arc, - code: &[u8], - shader_type: ShaderType, - ) -> Result> { + pub fn from_slice(device: Arc, code: &[u8]) -> Result>> { let shader_module_ci = VkShaderModuleCreateInfo::new(VK_SHADER_MODULE_CREATE_NULL_BIT, code); @@ -60,7 +85,6 @@ impl ShaderModule { Ok(Arc::new(ShaderModule { device, shader_module, - shader_type, })) } @@ -76,48 +100,29 @@ impl ShaderModule { Ok(code) } - - pub fn shader_type(&self) -> ShaderType { - self.shader_type - } - - pub fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo { - match self.shader_type { - ShaderType::None => unimplemented!(), - ShaderType::Vertex => VkPipelineShaderStageCreateInfo::vertex(self.shader_module), - ShaderType::Geometry => VkPipelineShaderStageCreateInfo::geometry(self.shader_module), - ShaderType::TesselationControl => { - VkPipelineShaderStageCreateInfo::tesselation_control(self.shader_module) - } - ShaderType::TesselationEvaluation => { - VkPipelineShaderStageCreateInfo::tesselation_evaluation(self.shader_module) - } - ShaderType::Fragment => VkPipelineShaderStageCreateInfo::fragment(self.shader_module), - ShaderType::Compute => VkPipelineShaderStageCreateInfo::compute(self.shader_module), - ShaderType::AnyHit => VkPipelineShaderStageCreateInfo::any_hit(self.shader_module), - ShaderType::Intersection => { - VkPipelineShaderStageCreateInfo::intersection(self.shader_module) - } - ShaderType::ClosestHit => { - VkPipelineShaderStageCreateInfo::closest_hit(self.shader_module) - } - ShaderType::RayGeneration => { - VkPipelineShaderStageCreateInfo::ray_generation(self.shader_module) - } - ShaderType::Miss => VkPipelineShaderStageCreateInfo::miss(self.shader_module), - } - } } -impl VulkanDevice for ShaderModule { +impl_pipeline_stage_info!(vertex, Vertex); +impl_pipeline_stage_info!(geometry, Geometry); +impl_pipeline_stage_info!(tesselation_control, TesselationControl); +impl_pipeline_stage_info!(tesselation_evaluation, TesselationEvaluation); +impl_pipeline_stage_info!(fragment, Fragment); +impl_pipeline_stage_info!(compute, Compute); +impl_pipeline_stage_info!(any_hit, AnyHit); +impl_pipeline_stage_info!(intersection, Intersection); +impl_pipeline_stage_info!(closest_hit, ClosestHit); +impl_pipeline_stage_info!(ray_generation, RayGeneration); +impl_pipeline_stage_info!(miss, Miss); + +impl VulkanDevice for ShaderModule { fn device(&self) -> &Arc { &self.device } } -impl_vk_handle!(ShaderModule, VkShaderModule, shader_module); +impl_vk_handle!(ShaderModule, VkShaderModule, shader_module); -impl Drop for ShaderModule { +impl Drop for ShaderModule { fn drop(&mut self) { self.device.destroy_shader_module(self.shader_module); }