diff --git a/vulkan-rs/src/macros.rs b/vulkan-rs/src/macros.rs index 9eec117..6c969c5 100644 --- a/vulkan-rs/src/macros.rs +++ b/vulkan-rs/src/macros.rs @@ -24,6 +24,31 @@ macro_rules! impl_vk_handle { } } }; + ($struct_name:ident $(<$( $name:ident: $type:ident, )*>)?, $target_name:ident, $value:ident) => { + impl$(<$( $name: $type, )*>)? VkHandle<$target_name> for $struct_name$(<$($name,)?>)? { + fn vk_handle(&self) -> $target_name { + self.$value + } + } + + impl<'a $($(, $name: $type)*)?> VkHandle<$target_name> for &'a $struct_name$(<$($name,)?>)? { + fn vk_handle(&self) -> $target_name { + self.$value + } + } + + impl$(<$( $name: $type, )*>)? VkHandle<$target_name> for Arc<$struct_name$(<$($name,)?>)?> { + fn vk_handle(&self) -> $target_name { + self.$value + } + } + + impl<'a $($(, $name: $type)*)?> VkHandle<$target_name> for &'a Arc<$struct_name$(<$($name,)?>)?> { + fn vk_handle(&self) -> $target_name { + self.$value + } + } + }; } macro_rules! impl_vk_handle_t { diff --git a/vulkan-rs/src/pipelines/compute_pipeline.rs b/vulkan-rs/src/pipelines/compute_pipeline.rs index dd2dc73..bd98ecd 100644 --- a/vulkan-rs/src/pipelines/compute_pipeline.rs +++ b/vulkan-rs/src/pipelines/compute_pipeline.rs @@ -6,7 +6,7 @@ use crate::prelude::*; use std::sync::Arc; pub struct ComputePipelineBuilder<'a> { - shader_module: Option<&'a Arc>>, + shader_module: Option<&'a Arc>>, pipeline_cache: Option<&'a Arc>, flags: VkPipelineCreateFlagBits, } @@ -15,7 +15,7 @@ impl<'a> ComputePipelineBuilder<'a> { // TODO: add support for specialization constants pub fn set_shader_module( mut self, - shader_module: &'a Arc>, + shader_module: &'a Arc>, ) -> Self { if cfg!(debug_assertions) { if self.shader_module.is_some() { diff --git a/vulkan-rs/src/pipelines/graphics_pipeline.rs b/vulkan-rs/src/pipelines/graphics_pipeline.rs index 1e25341..276e940 100644 --- a/vulkan-rs/src/pipelines/graphics_pipeline.rs +++ b/vulkan-rs/src/pipelines/graphics_pipeline.rs @@ -12,21 +12,21 @@ pub struct GraphicsPipelineBuilder { amd_rasterization_order: Option, - vertex_shader: Option>>, + vertex_shader: Option>>, vertex_binding_description: Vec, vertex_attribute_description: Vec, input_assembly: Option, tesselation_shader: Option<( - Arc>, - Arc>, + Arc>, + Arc>, )>, patch_control_points: u32, - geometry_shader: Option>>, + geometry_shader: Option>>, - fragment_shader: Option>>, + fragment_shader: Option>>, viewports: Vec, scissors: Vec, @@ -45,7 +45,7 @@ impl GraphicsPipelineBuilder { // TODO: add support for specialization constants pub fn set_vertex_shader( mut self, - shader: Arc>, + shader: Arc>, ) -> Self { self.vertex_shader = Some(shader); self.vertex_binding_description = T::bindings(); @@ -57,8 +57,8 @@ impl GraphicsPipelineBuilder { // TODO: add support for specialization constants pub fn set_tesselation_shader( mut self, - tesselation_control: Arc>, - tesselation_evaluation: Arc>, + tesselation_control: Arc>, + tesselation_evaluation: Arc>, patch_control_points: u32, ) -> Self { self.tesselation_shader = Some((tesselation_control, tesselation_evaluation)); @@ -68,20 +68,14 @@ impl GraphicsPipelineBuilder { } // TODO: add support for specialization constants - pub fn set_geometry_shader( - mut self, - shader: Arc>, - ) -> Self { + pub fn set_geometry_shader(mut self, shader: Arc>) -> Self { self.geometry_shader = Some(shader); self } // TODO: add support for specialization constants - pub fn set_fragment_shader( - mut self, - shader: Arc>, - ) -> Self { + pub fn set_fragment_shader(mut self, shader: Arc>) -> Self { self.fragment_shader = Some(shader); self diff --git a/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs b/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs index b8e58e4..5fc0108 100644 --- a/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs +++ b/vulkan-rs/src/pipelines/ray_tracing_pipeline.rs @@ -31,8 +31,8 @@ impl<'a> Library<'a> { macro_rules! impl_from_shader_type { ($struct: ident, $shader_type: ident) => { - impl From>> for $struct { - fn from(value: Arc>) -> Self { + impl From>> for $struct { + fn from(value: Arc>) -> Self { Self::$shader_type(value) } } @@ -41,11 +41,11 @@ macro_rules! impl_from_shader_type { #[derive(Clone)] enum RaytracingShader { - RayGeneration(Arc>), - ClosestHit(Arc>), - Miss(Arc>), - AnyHit(Arc>), - Intersection(Arc>), + RayGeneration(Arc>), + ClosestHit(Arc>), + Miss(Arc>), + AnyHit(Arc>), + Intersection(Arc>), } impl_from_shader_type!(RaytracingShader, RayGeneration); @@ -75,9 +75,9 @@ impl From for RaytracingShader { #[derive(Clone)] pub enum HitShader { - ClosestHit(Arc>), - AnyHit(Arc>), - Intersection(Arc>), + ClosestHit(Arc>), + AnyHit(Arc>), + Intersection(Arc>), } impl_from_shader_type!(HitShader, ClosestHit); @@ -86,8 +86,8 @@ impl_from_shader_type!(HitShader, Intersection); #[derive(Clone)] pub enum OtherShader { - RayGeneration(Arc>), - Miss(Arc>), + RayGeneration(Arc>), + Miss(Arc>), } impl_from_shader_type!(OtherShader, RayGeneration); diff --git a/vulkan-rs/src/prelude.rs b/vulkan-rs/src/prelude.rs index ab0460f..d712780 100644 --- a/vulkan-rs/src/prelude.rs +++ b/vulkan-rs/src/prelude.rs @@ -24,8 +24,9 @@ pub use super::renderpass::RenderPass; pub use super::sampler_manager::{Sampler, SamplerBuilder}; pub use super::semaphore::Semaphore; pub use super::shadermodule::{ - AddSpecializationConstant, PipelineStageInfo, ShaderModule, ShaderType, - SpecializationConstants, VertexInputDescription, + shader_type::{self, ShaderType}, + AddSpecializationConstant, PipelineStageInfo, ShaderModule, SpecializationConstants, + VertexInputDescription, }; pub use super::surface::Surface; pub use super::swapchain::Swapchain; diff --git a/vulkan-rs/src/shadermodule.rs b/vulkan-rs/src/shadermodule.rs index a322202..c878050 100644 --- a/vulkan-rs/src/shadermodule.rs +++ b/vulkan-rs/src/shadermodule.rs @@ -4,51 +4,59 @@ use anyhow::{Context, Result}; use std::fs::File; use std::io::Read; +use std::marker::PhantomData; use std::sync::Arc; -#[allow(clippy::cast_ptr_alignment)] -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum ShaderType { - None, - Vertex, - Fragment, - Geometry, - TesselationControl, - TesselationEvaluation, - Compute, - RayGeneration, - ClosestHit, - Miss, - AnyHit, - Intersection, -} - -impl From for ShaderType { - fn from(value: u8) -> Self { - match value { - 0 => Self::None, - 1 => Self::Vertex, - 2 => Self::Fragment, - 3 => Self::Geometry, - 4 => Self::TesselationControl, - 5 => Self::TesselationEvaluation, - 6 => Self::Compute, - 7 => Self::RayGeneration, - 8 => Self::ClosestHit, - 9 => Self::Miss, - 10 => Self::AnyHit, - 11 => Self::Intersection, - - _ => panic!("can't convert ShaderType from {}", value), - } +pub mod shader_type { + mod sealed { + pub trait Sealed {} + impl Sealed for super::Vertex {} + impl Sealed for super::Fragment {} + impl Sealed for super::Geometry {} + impl Sealed for super::TesselationControl {} + impl Sealed for super::TesselationEvaluation {} + impl Sealed for super::Compute {} + impl Sealed for super::RayGeneration {} + impl Sealed for super::ClosestHit {} + impl Sealed for super::Miss {} + impl Sealed for super::AnyHit {} + impl Sealed for super::Intersection {} } -} -impl Default for ShaderType { - fn default() -> Self { - ShaderType::None - } + pub trait ShaderType: sealed::Sealed {} + + pub struct Vertex; + impl ShaderType for Vertex {} + + pub struct Fragment; + impl ShaderType for Fragment {} + + pub struct Geometry; + impl ShaderType for Geometry {} + + pub struct TesselationControl; + impl ShaderType for TesselationControl {} + + pub struct TesselationEvaluation; + impl ShaderType for TesselationEvaluation {} + + pub struct Compute; + impl ShaderType for Compute {} + + pub struct RayGeneration; + impl ShaderType for RayGeneration {} + + pub struct ClosestHit; + impl ShaderType for ClosestHit {} + + pub struct Miss; + impl ShaderType for Miss {} + + pub struct AnyHit; + impl ShaderType for AnyHit {} + + pub struct Intersection; + impl ShaderType for Intersection {} } pub trait VertexInputDescription: ReprC + Sized { @@ -69,7 +77,7 @@ pub trait PipelineStageInfo { macro_rules! impl_pipeline_stage_info { ($func:ident, $type:ident) => { - impl PipelineStageInfo for ShaderModule<{ ShaderType::$type as u8 }> { + impl PipelineStageInfo for ShaderModule<$type> { fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo { VkPipelineShaderStageCreateInfo::$func(self.shader_module) } @@ -78,25 +86,30 @@ macro_rules! impl_pipeline_stage_info { } #[derive(Debug)] -pub struct ShaderModule { +pub struct ShaderModule { + t: PhantomData, device: Arc, shader_module: VkShaderModule, } -impl ShaderModule { - pub fn new(device: Arc, path: &str) -> Result>> { +impl ShaderModule { + pub fn new(device: Arc, path: &str) -> Result>> { let code = Self::shader_code(path)?; Self::from_slice(device, code.as_slice()) } - pub fn from_slice(device: Arc, code: &[u8]) -> Result>> { + pub fn from_slice( + device: Arc, + code: &[u8], + ) -> Result>> { let shader_module_ci = VkShaderModuleCreateInfo::new(VK_SHADER_MODULE_CREATE_NULL_BIT, code); let shader_module = device.create_shader_module(&shader_module_ci)?; Ok(Arc::new(ShaderModule { + t: PhantomData, device, shader_module, })) @@ -116,6 +129,8 @@ impl ShaderModule { } } +use shader_type::*; + impl_pipeline_stage_info!(vertex, Vertex); impl_pipeline_stage_info!(geometry, Geometry); impl_pipeline_stage_info!(tesselation_control, TesselationControl); @@ -128,15 +143,15 @@ impl_pipeline_stage_info!(closest_hit, ClosestHit); impl_pipeline_stage_info!(ray_generation, RayGeneration); impl_pipeline_stage_info!(miss, Miss); -impl VulkanDevice for ShaderModule { +impl VulkanDevice for ShaderModule { fn device(&self) -> &Arc { &self.device } } -impl_vk_handle!(ShaderModule, VkShaderModule, shader_module); +impl_vk_handle!(ShaderModule, VkShaderModule, shader_module); -impl Drop for ShaderModule { +impl Drop for ShaderModule { fn drop(&mut self) { self.device.destroy_shader_module(self.shader_module); }