Use const generics for shader module types
This commit is contained in:
parent
a36ba524ac
commit
a78f938273
7 changed files with 189 additions and 129 deletions
|
@ -11,3 +11,4 @@ vma-rs = { path = "../vma-rs" }
|
||||||
anyhow = { version = "1.0.68", features = ["backtrace"] }
|
anyhow = { version = "1.0.68", features = ["backtrace"] }
|
||||||
cgmath = "0.18.0"
|
cgmath = "0.18.0"
|
||||||
assetpath = { path = "../assetpath" }
|
assetpath = { path = "../assetpath" }
|
||||||
|
safer-ffi = "0.0.10"
|
||||||
|
|
|
@ -1,24 +1,24 @@
|
||||||
macro_rules! impl_vk_handle {
|
macro_rules! impl_vk_handle {
|
||||||
($struct_name:ident, $target_name:ident, $value:ident) => {
|
($struct_name:ident $(<$( $const:ident $name:ident: $type:ident, )*>)?, $target_name:ident, $value:ident) => {
|
||||||
impl VkHandle<$target_name> for $struct_name {
|
impl$(<$( $const $name: $type, )*>)? VkHandle<$target_name> for $struct_name$(<$($name,)?>)? {
|
||||||
fn vk_handle(&self) -> $target_name {
|
fn vk_handle(&self) -> $target_name {
|
||||||
self.$value
|
self.$value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> VkHandle<$target_name> for &'a $struct_name {
|
impl<'a $($(, $const $name: $type)*)?> VkHandle<$target_name> for &'a $struct_name$(<$($name,)?>)? {
|
||||||
fn vk_handle(&self) -> $target_name {
|
fn vk_handle(&self) -> $target_name {
|
||||||
self.$value
|
self.$value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VkHandle<$target_name> for Arc<$struct_name> {
|
impl$(<$( $const $name: $type, )*>)? VkHandle<$target_name> for Arc<$struct_name$(<$($name,)?>)?> {
|
||||||
fn vk_handle(&self) -> $target_name {
|
fn vk_handle(&self) -> $target_name {
|
||||||
self.$value
|
self.$value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> VkHandle<$target_name> for &'a Arc<$struct_name> {
|
impl<'a $($(, $const $name: $type)*)?> VkHandle<$target_name> for &'a Arc<$struct_name$(<$($name,)?>)?> {
|
||||||
fn vk_handle(&self) -> $target_name {
|
fn vk_handle(&self) -> $target_name {
|
||||||
self.$value
|
self.$value
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,22 +6,21 @@ use crate::prelude::*;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub struct ComputePipelineBuilder<'a> {
|
pub struct ComputePipelineBuilder<'a> {
|
||||||
shader_module: Option<&'a Arc<ShaderModule>>,
|
shader_module: Option<&'a Arc<ShaderModule<{ ShaderType::Compute as u8 }>>>,
|
||||||
pipeline_cache: Option<&'a Arc<PipelineCache>>,
|
pipeline_cache: Option<&'a Arc<PipelineCache>>,
|
||||||
flags: VkPipelineCreateFlagBits,
|
flags: VkPipelineCreateFlagBits,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> ComputePipelineBuilder<'a> {
|
impl<'a> ComputePipelineBuilder<'a> {
|
||||||
// TODO: add support for specialization constants
|
// TODO: add support for specialization constants
|
||||||
pub fn set_shader_module(mut self, shader_module: &'a Arc<ShaderModule>) -> Self {
|
pub fn set_shader_module(
|
||||||
|
mut self,
|
||||||
|
shader_module: &'a Arc<ShaderModule<{ ShaderType::Compute as u8 }>>,
|
||||||
|
) -> Self {
|
||||||
if cfg!(debug_assertions) {
|
if cfg!(debug_assertions) {
|
||||||
if self.shader_module.is_some() {
|
if self.shader_module.is_some() {
|
||||||
panic!("shader already set!");
|
panic!("shader already set!");
|
||||||
}
|
}
|
||||||
|
|
||||||
if shader_module.shader_type() != ShaderType::Compute {
|
|
||||||
panic!("shader has wrong type!");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.shader_module = Some(shader_module);
|
self.shader_module = Some(shader_module);
|
||||||
|
|
|
@ -12,18 +12,21 @@ pub struct GraphicsPipelineBuilder {
|
||||||
|
|
||||||
amd_rasterization_order: Option<VkPipelineRasterizationStateRasterizationOrderAMD>,
|
amd_rasterization_order: Option<VkPipelineRasterizationStateRasterizationOrderAMD>,
|
||||||
|
|
||||||
vertex_shader: Option<Arc<ShaderModule>>,
|
vertex_shader: Option<Arc<ShaderModule<{ ShaderType::Vertex as u8 }>>>,
|
||||||
vertex_binding_description: Vec<VkVertexInputBindingDescription>,
|
vertex_binding_description: Vec<VkVertexInputBindingDescription>,
|
||||||
vertex_attribute_description: Vec<VkVertexInputAttributeDescription>,
|
vertex_attribute_description: Vec<VkVertexInputAttributeDescription>,
|
||||||
|
|
||||||
input_assembly: Option<VkPipelineInputAssemblyStateCreateInfo>,
|
input_assembly: Option<VkPipelineInputAssemblyStateCreateInfo>,
|
||||||
|
|
||||||
tesselation_shader: Option<(Arc<ShaderModule>, Arc<ShaderModule>)>,
|
tesselation_shader: Option<(
|
||||||
|
Arc<ShaderModule<{ ShaderType::TesselationControl as u8 }>>,
|
||||||
|
Arc<ShaderModule<{ ShaderType::TesselationEvaluation as u8 }>>,
|
||||||
|
)>,
|
||||||
patch_control_points: u32,
|
patch_control_points: u32,
|
||||||
|
|
||||||
geometry_shader: Option<Arc<ShaderModule>>,
|
geometry_shader: Option<Arc<ShaderModule<{ ShaderType::Geometry as u8 }>>>,
|
||||||
|
|
||||||
fragment_shader: Option<Arc<ShaderModule>>,
|
fragment_shader: Option<Arc<ShaderModule<{ ShaderType::Fragment as u8 }>>>,
|
||||||
|
|
||||||
viewports: Vec<VkViewport>,
|
viewports: Vec<VkViewport>,
|
||||||
scissors: Vec<VkRect2D>,
|
scissors: Vec<VkRect2D>,
|
||||||
|
@ -40,19 +43,13 @@ pub struct GraphicsPipelineBuilder {
|
||||||
|
|
||||||
impl GraphicsPipelineBuilder {
|
impl GraphicsPipelineBuilder {
|
||||||
// TODO: add support for specialization constants
|
// TODO: add support for specialization constants
|
||||||
pub fn set_vertex_shader(
|
pub fn set_vertex_shader<T: VertexInputDescription>(
|
||||||
mut self,
|
mut self,
|
||||||
shader: Arc<ShaderModule>,
|
shader: Arc<ShaderModule<{ ShaderType::Vertex as u8 }>>,
|
||||||
vertex_binding_description: Vec<VkVertexInputBindingDescription>,
|
|
||||||
vertex_attribute_description: Vec<VkVertexInputAttributeDescription>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
assert_eq!(shader.shader_type(), ShaderType::Vertex);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.vertex_shader = Some(shader);
|
self.vertex_shader = Some(shader);
|
||||||
self.vertex_binding_description = vertex_binding_description;
|
self.vertex_binding_description = T::bindings();
|
||||||
self.vertex_attribute_description = vertex_attribute_description;
|
self.vertex_attribute_description = T::attributes();
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -60,22 +57,10 @@ impl GraphicsPipelineBuilder {
|
||||||
// TODO: add support for specialization constants
|
// TODO: add support for specialization constants
|
||||||
pub fn set_tesselation_shader(
|
pub fn set_tesselation_shader(
|
||||||
mut self,
|
mut self,
|
||||||
tesselation_control: Arc<ShaderModule>,
|
tesselation_control: Arc<ShaderModule<{ ShaderType::TesselationControl as u8 }>>,
|
||||||
tesselation_evaluation: Arc<ShaderModule>,
|
tesselation_evaluation: Arc<ShaderModule<{ ShaderType::TesselationEvaluation as u8 }>>,
|
||||||
patch_control_points: u32,
|
patch_control_points: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if cfg!(debug_assertions) {
|
|
||||||
assert_eq!(
|
|
||||||
tesselation_control.shader_type(),
|
|
||||||
ShaderType::TesselationControl
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
tesselation_evaluation.shader_type(),
|
|
||||||
ShaderType::TesselationEvaluation
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.tesselation_shader = Some((tesselation_control, tesselation_evaluation));
|
self.tesselation_shader = Some((tesselation_control, tesselation_evaluation));
|
||||||
self.patch_control_points = patch_control_points;
|
self.patch_control_points = patch_control_points;
|
||||||
|
|
||||||
|
@ -83,22 +68,20 @@ impl GraphicsPipelineBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add support for specialization constants
|
// TODO: add support for specialization constants
|
||||||
pub fn set_geometry_shader(mut self, shader: Arc<ShaderModule>) -> Self {
|
pub fn set_geometry_shader(
|
||||||
if cfg!(debug_assertions) {
|
mut self,
|
||||||
assert_eq!(shader.shader_type(), ShaderType::Geometry);
|
shader: Arc<ShaderModule<{ ShaderType::Geometry as u8 }>>,
|
||||||
}
|
) -> Self {
|
||||||
|
|
||||||
self.geometry_shader = Some(shader);
|
self.geometry_shader = Some(shader);
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add support for specialization constants
|
// TODO: add support for specialization constants
|
||||||
pub fn set_fragment_shader(mut self, shader: Arc<ShaderModule>) -> Self {
|
pub fn set_fragment_shader(
|
||||||
if cfg!(debug_assertions) {
|
mut self,
|
||||||
assert_eq!(shader.shader_type(), ShaderType::Fragment);
|
shader: Arc<ShaderModule<{ ShaderType::Fragment as u8 }>>,
|
||||||
}
|
) -> Self {
|
||||||
|
|
||||||
self.fragment_shader = Some(shader);
|
self.fragment_shader = Some(shader);
|
||||||
|
|
||||||
self
|
self
|
||||||
|
|
|
@ -29,8 +29,72 @@ impl<'a> Library<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! impl_from_shader_type {
|
||||||
|
($struct: ident, $shader_type: ident) => {
|
||||||
|
impl From<Arc<ShaderModule<{ ShaderType::$shader_type as u8 }>>> for $struct {
|
||||||
|
fn from(value: Arc<ShaderModule<{ ShaderType::$shader_type as u8 }>>) -> Self {
|
||||||
|
Self::$shader_type(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum RaytracingShader {
|
||||||
|
RayGeneration(Arc<ShaderModule<{ ShaderType::RayGeneration as u8 }>>),
|
||||||
|
ClosestHit(Arc<ShaderModule<{ ShaderType::ClosestHit as u8 }>>),
|
||||||
|
Miss(Arc<ShaderModule<{ ShaderType::Miss as u8 }>>),
|
||||||
|
AnyHit(Arc<ShaderModule<{ ShaderType::AnyHit as u8 }>>),
|
||||||
|
Intersection(Arc<ShaderModule<{ ShaderType::Intersection as u8 }>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_from_shader_type!(RaytracingShader, RayGeneration);
|
||||||
|
impl_from_shader_type!(RaytracingShader, ClosestHit);
|
||||||
|
impl_from_shader_type!(RaytracingShader, Miss);
|
||||||
|
impl_from_shader_type!(RaytracingShader, AnyHit);
|
||||||
|
impl_from_shader_type!(RaytracingShader, Intersection);
|
||||||
|
|
||||||
|
impl From<HitShader> for RaytracingShader {
|
||||||
|
fn from(value: HitShader) -> Self {
|
||||||
|
match value {
|
||||||
|
HitShader::ClosestHit(s) => Self::ClosestHit(s),
|
||||||
|
HitShader::AnyHit(s) => Self::AnyHit(s),
|
||||||
|
HitShader::Intersection(s) => Self::Intersection(s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OtherShader> for RaytracingShader {
|
||||||
|
fn from(value: OtherShader) -> Self {
|
||||||
|
match value {
|
||||||
|
OtherShader::Miss(s) => Self::Miss(s),
|
||||||
|
OtherShader::RayGeneration(s) => Self::RayGeneration(s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum HitShader {
|
||||||
|
ClosestHit(Arc<ShaderModule<{ ShaderType::ClosestHit as u8 }>>),
|
||||||
|
AnyHit(Arc<ShaderModule<{ ShaderType::AnyHit as u8 }>>),
|
||||||
|
Intersection(Arc<ShaderModule<{ ShaderType::Intersection as u8 }>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_from_shader_type!(HitShader, ClosestHit);
|
||||||
|
impl_from_shader_type!(HitShader, AnyHit);
|
||||||
|
impl_from_shader_type!(HitShader, Intersection);
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum OtherShader {
|
||||||
|
RayGeneration(Arc<ShaderModule<{ ShaderType::RayGeneration as u8 }>>),
|
||||||
|
Miss(Arc<ShaderModule<{ ShaderType::Miss as u8 }>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_from_shader_type!(OtherShader, RayGeneration);
|
||||||
|
impl_from_shader_type!(OtherShader, Miss);
|
||||||
|
|
||||||
pub struct RayTracingPipelineBuilder<'a> {
|
pub struct RayTracingPipelineBuilder<'a> {
|
||||||
shader_modules: Vec<(Arc<ShaderModule>, Option<SpecializationConstants>)>,
|
shader_modules: Vec<(RaytracingShader, Option<SpecializationConstants>)>,
|
||||||
|
|
||||||
shader_groups: Vec<VkRayTracingShaderGroupCreateInfoKHR>,
|
shader_groups: Vec<VkRayTracingShaderGroupCreateInfoKHR>,
|
||||||
|
|
||||||
|
@ -82,26 +146,24 @@ impl<'a> RayTracingPipelineBuilder<'a> {
|
||||||
|
|
||||||
pub fn add_shader(
|
pub fn add_shader(
|
||||||
mut self,
|
mut self,
|
||||||
shader_module: Arc<ShaderModule>,
|
shader_module: impl Into<OtherShader>,
|
||||||
data: Option<Vec<u8>>,
|
data: Option<Vec<u8>>,
|
||||||
specialization_constants: Option<SpecializationConstants>,
|
specialization_constants: Option<SpecializationConstants>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
self.shader_binding_table_builder = match shader_module.shader_type() {
|
let shader_module = shader_module.into();
|
||||||
ShaderType::RayGeneration => self
|
|
||||||
|
self.shader_binding_table_builder = match shader_module {
|
||||||
|
OtherShader::RayGeneration(_) => self
|
||||||
.shader_binding_table_builder
|
.shader_binding_table_builder
|
||||||
.add_ray_gen_program(self.shader_groups.len() as u32, data),
|
.add_ray_gen_program(self.shader_groups.len() as u32, data),
|
||||||
ShaderType::Miss => self
|
OtherShader::Miss(_) => self
|
||||||
.shader_binding_table_builder
|
.shader_binding_table_builder
|
||||||
.add_miss_program(self.shader_groups.len() as u32, data),
|
.add_miss_program(self.shader_groups.len() as u32, data),
|
||||||
_ => panic!(
|
|
||||||
"unsupported shader type: {:?}, expected RayGen or Miss Shader",
|
|
||||||
shader_module.shader_type()
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let shader_index = self.shader_modules.len();
|
let shader_index = self.shader_modules.len();
|
||||||
self.shader_modules
|
self.shader_modules
|
||||||
.push((shader_module, specialization_constants));
|
.push((shader_module.into(), specialization_constants));
|
||||||
|
|
||||||
self.shader_groups
|
self.shader_groups
|
||||||
.push(VkRayTracingShaderGroupCreateInfoKHR::new(
|
.push(VkRayTracingShaderGroupCreateInfoKHR::new(
|
||||||
|
@ -117,7 +179,7 @@ impl<'a> RayTracingPipelineBuilder<'a> {
|
||||||
|
|
||||||
pub fn add_hit_shaders(
|
pub fn add_hit_shaders(
|
||||||
mut self,
|
mut self,
|
||||||
shader_modules: impl IntoIterator<Item = (Arc<ShaderModule>, Option<SpecializationConstants>)>,
|
shader_modules: impl IntoIterator<Item = (HitShader, Option<SpecializationConstants>)>,
|
||||||
data: Option<Vec<u8>>,
|
data: Option<Vec<u8>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let mut group = VkRayTracingShaderGroupCreateInfoKHR::new(
|
let mut group = VkRayTracingShaderGroupCreateInfoKHR::new(
|
||||||
|
@ -131,37 +193,39 @@ impl<'a> RayTracingPipelineBuilder<'a> {
|
||||||
for (shader_module, specialization_constant) in shader_modules.into_iter() {
|
for (shader_module, specialization_constant) in shader_modules.into_iter() {
|
||||||
let shader_index = self.shader_modules.len() as u32;
|
let shader_index = self.shader_modules.len() as u32;
|
||||||
|
|
||||||
match shader_module.shader_type() {
|
match shader_module {
|
||||||
ShaderType::AnyHit => {
|
HitShader::AnyHit(_) => {
|
||||||
// sanity check
|
// sanity check
|
||||||
if cfg!(debug_assertions) && group.anyHitShader != VK_SHADER_UNUSED_KHR {
|
debug_assert_ne!(
|
||||||
panic!("any hit shader already used in current hit group");
|
group.anyHitShader, VK_SHADER_UNUSED_KHR,
|
||||||
}
|
"any hit shader already used in current hit group"
|
||||||
|
);
|
||||||
|
|
||||||
group.anyHitShader = shader_index;
|
group.anyHitShader = shader_index;
|
||||||
}
|
}
|
||||||
ShaderType::ClosestHit => {
|
HitShader::ClosestHit(_) => {
|
||||||
// sanity check
|
// sanity check
|
||||||
if cfg!(debug_assertions) && group.closestHitShader != VK_SHADER_UNUSED_KHR {
|
debug_assert_ne!(
|
||||||
panic!("closest hit shader already used in current hit group");
|
group.closestHitShader, VK_SHADER_UNUSED_KHR,
|
||||||
}
|
"closest hit shader already used in current hit group"
|
||||||
|
);
|
||||||
|
|
||||||
group.closestHitShader = shader_index;
|
group.closestHitShader = shader_index;
|
||||||
}
|
}
|
||||||
ShaderType::Intersection => {
|
HitShader::Intersection(_) => {
|
||||||
// sanity check
|
// sanity check
|
||||||
if cfg!(debug_assertions) && group.intersectionShader != VK_SHADER_UNUSED_KHR {
|
debug_assert_ne!(
|
||||||
panic!("intersection shader already used in current hit group");
|
group.intersectionShader, VK_SHADER_UNUSED_KHR,
|
||||||
}
|
"intersection shader already used in current hit group"
|
||||||
|
);
|
||||||
|
|
||||||
group.intersectionShader = shader_index;
|
group.intersectionShader = shader_index;
|
||||||
group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
|
group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
|
||||||
}
|
}
|
||||||
_ => panic!("unsupported shader type: {:?}, expected AnyHit, ClosestHit or Intersection Shader", shader_module.shader_type()),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.shader_modules
|
self.shader_modules
|
||||||
.push((shader_module, specialization_constant));
|
.push((shader_module.into(), specialization_constant));
|
||||||
}
|
}
|
||||||
self.shader_binding_table_builder = self
|
self.shader_binding_table_builder = self
|
||||||
.shader_binding_table_builder
|
.shader_binding_table_builder
|
||||||
|
@ -186,7 +250,14 @@ impl<'a> RayTracingPipelineBuilder<'a> {
|
||||||
.shader_modules
|
.shader_modules
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(shader, specialization_constant)| {
|
.map(|(shader, specialization_constant)| {
|
||||||
let mut stage_info = shader.pipeline_stage_info();
|
let mut stage_info = match shader {
|
||||||
|
RaytracingShader::RayGeneration(s) => s.pipeline_stage_info(),
|
||||||
|
RaytracingShader::ClosestHit(s) => s.pipeline_stage_info(),
|
||||||
|
RaytracingShader::Miss(s) => s.pipeline_stage_info(),
|
||||||
|
RaytracingShader::AnyHit(s) => s.pipeline_stage_info(),
|
||||||
|
RaytracingShader::Intersection(s) => s.pipeline_stage_info(),
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(specialization_constant) = specialization_constant {
|
if let Some(specialization_constant) = specialization_constant {
|
||||||
stage_info.set_specialization_info(specialization_constant.vk_handle());
|
stage_info.set_specialization_info(specialization_constant.vk_handle());
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,8 @@ pub use super::renderpass::RenderPass;
|
||||||
pub use super::sampler_manager::{Sampler, SamplerBuilder};
|
pub use super::sampler_manager::{Sampler, SamplerBuilder};
|
||||||
pub use super::semaphore::Semaphore;
|
pub use super::semaphore::Semaphore;
|
||||||
pub use super::shadermodule::{
|
pub use super::shadermodule::{
|
||||||
AddSpecializationConstant, ShaderModule, ShaderType, SpecializationConstants,
|
AddSpecializationConstant, PipelineStageInfo, ShaderModule, ShaderType,
|
||||||
|
SpecializationConstants, VertexInputDescription,
|
||||||
};
|
};
|
||||||
pub use super::surface::Surface;
|
pub use super::surface::Surface;
|
||||||
pub use super::swapchain::Swapchain;
|
pub use super::swapchain::Swapchain;
|
||||||
|
|
|
@ -7,6 +7,7 @@ use std::io::Read;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(clippy::cast_ptr_alignment)]
|
#[allow(clippy::cast_ptr_alignment)]
|
||||||
|
#[repr(u8)]
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub enum ShaderType {
|
pub enum ShaderType {
|
||||||
None,
|
None,
|
||||||
|
@ -23,35 +24,59 @@ pub enum ShaderType {
|
||||||
Intersection,
|
Intersection,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ShaderType {
|
||||||
|
pub const VERTEX: u8 = Self::Vertex as u8;
|
||||||
|
pub const FRAGMENT: u8 = Self::Fragment as u8;
|
||||||
|
pub const GEOMETRY: u8 = Self::Geometry as u8;
|
||||||
|
pub const TESSELATION_CONTROL: u8 = Self::TesselationControl as u8;
|
||||||
|
pub const TESSELATION_EVALUATION: u8 = Self::TesselationEvaluation as u8;
|
||||||
|
pub const COMPUTE: u8 = Self::Compute as u8;
|
||||||
|
pub const RAY_GENERATION: u8 = Self::RayGeneration as u8;
|
||||||
|
pub const CLOSEST_HIT: u8 = Self::ClosestHit as u8;
|
||||||
|
pub const MISS: u8 = Self::Miss as u8;
|
||||||
|
pub const ANY_HIT: u8 = Self::AnyHit as u8;
|
||||||
|
pub const INTERSECTION: u8 = Self::Intersection as u8;
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for ShaderType {
|
impl Default for ShaderType {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
ShaderType::None
|
ShaderType::None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
pub trait VertexInputDescription {
|
||||||
pub struct ShaderModule {
|
fn bindings() -> Vec<VkVertexInputBindingDescription>;
|
||||||
device: Arc<Device>,
|
fn attributes() -> Vec<VkVertexInputAttributeDescription>;
|
||||||
shader_module: VkShaderModule,
|
|
||||||
shader_type: ShaderType,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShaderModule {
|
pub trait PipelineStageInfo {
|
||||||
pub fn new(
|
fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo;
|
||||||
device: Arc<Device>,
|
}
|
||||||
path: &str,
|
|
||||||
shader_type: ShaderType,
|
macro_rules! impl_pipeline_stage_info {
|
||||||
) -> Result<Arc<ShaderModule>> {
|
($func:ident, $type:ident) => {
|
||||||
|
impl PipelineStageInfo for ShaderModule<{ ShaderType::$type as u8 }> {
|
||||||
|
fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo {
|
||||||
|
VkPipelineShaderStageCreateInfo::$func(self.shader_module)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ShaderModule<const TYPE: u8> {
|
||||||
|
device: Arc<Device>,
|
||||||
|
shader_module: VkShaderModule,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const TYPE: u8> ShaderModule<TYPE> {
|
||||||
|
pub fn new(device: Arc<Device>, path: &str) -> Result<Arc<ShaderModule<TYPE>>> {
|
||||||
let code = Self::shader_code(path)?;
|
let code = Self::shader_code(path)?;
|
||||||
|
|
||||||
Self::from_slice(device, code.as_slice(), shader_type)
|
Self::from_slice(device, code.as_slice())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_slice(
|
pub fn from_slice(device: Arc<Device>, code: &[u8]) -> Result<Arc<ShaderModule<TYPE>>> {
|
||||||
device: Arc<Device>,
|
|
||||||
code: &[u8],
|
|
||||||
shader_type: ShaderType,
|
|
||||||
) -> Result<Arc<ShaderModule>> {
|
|
||||||
let shader_module_ci =
|
let shader_module_ci =
|
||||||
VkShaderModuleCreateInfo::new(VK_SHADER_MODULE_CREATE_NULL_BIT, code);
|
VkShaderModuleCreateInfo::new(VK_SHADER_MODULE_CREATE_NULL_BIT, code);
|
||||||
|
|
||||||
|
@ -60,7 +85,6 @@ impl ShaderModule {
|
||||||
Ok(Arc::new(ShaderModule {
|
Ok(Arc::new(ShaderModule {
|
||||||
device,
|
device,
|
||||||
shader_module,
|
shader_module,
|
||||||
shader_type,
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,48 +100,29 @@ impl ShaderModule {
|
||||||
|
|
||||||
Ok(code)
|
Ok(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shader_type(&self) -> ShaderType {
|
|
||||||
self.shader_type
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo {
|
|
||||||
match self.shader_type {
|
|
||||||
ShaderType::None => unimplemented!(),
|
|
||||||
ShaderType::Vertex => VkPipelineShaderStageCreateInfo::vertex(self.shader_module),
|
|
||||||
ShaderType::Geometry => VkPipelineShaderStageCreateInfo::geometry(self.shader_module),
|
|
||||||
ShaderType::TesselationControl => {
|
|
||||||
VkPipelineShaderStageCreateInfo::tesselation_control(self.shader_module)
|
|
||||||
}
|
|
||||||
ShaderType::TesselationEvaluation => {
|
|
||||||
VkPipelineShaderStageCreateInfo::tesselation_evaluation(self.shader_module)
|
|
||||||
}
|
|
||||||
ShaderType::Fragment => VkPipelineShaderStageCreateInfo::fragment(self.shader_module),
|
|
||||||
ShaderType::Compute => VkPipelineShaderStageCreateInfo::compute(self.shader_module),
|
|
||||||
ShaderType::AnyHit => VkPipelineShaderStageCreateInfo::any_hit(self.shader_module),
|
|
||||||
ShaderType::Intersection => {
|
|
||||||
VkPipelineShaderStageCreateInfo::intersection(self.shader_module)
|
|
||||||
}
|
|
||||||
ShaderType::ClosestHit => {
|
|
||||||
VkPipelineShaderStageCreateInfo::closest_hit(self.shader_module)
|
|
||||||
}
|
|
||||||
ShaderType::RayGeneration => {
|
|
||||||
VkPipelineShaderStageCreateInfo::ray_generation(self.shader_module)
|
|
||||||
}
|
|
||||||
ShaderType::Miss => VkPipelineShaderStageCreateInfo::miss(self.shader_module),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VulkanDevice for ShaderModule {
|
impl_pipeline_stage_info!(vertex, Vertex);
|
||||||
|
impl_pipeline_stage_info!(geometry, Geometry);
|
||||||
|
impl_pipeline_stage_info!(tesselation_control, TesselationControl);
|
||||||
|
impl_pipeline_stage_info!(tesselation_evaluation, TesselationEvaluation);
|
||||||
|
impl_pipeline_stage_info!(fragment, Fragment);
|
||||||
|
impl_pipeline_stage_info!(compute, Compute);
|
||||||
|
impl_pipeline_stage_info!(any_hit, AnyHit);
|
||||||
|
impl_pipeline_stage_info!(intersection, Intersection);
|
||||||
|
impl_pipeline_stage_info!(closest_hit, ClosestHit);
|
||||||
|
impl_pipeline_stage_info!(ray_generation, RayGeneration);
|
||||||
|
impl_pipeline_stage_info!(miss, Miss);
|
||||||
|
|
||||||
|
impl<const TYPE: u8> VulkanDevice for ShaderModule<TYPE> {
|
||||||
fn device(&self) -> &Arc<Device> {
|
fn device(&self) -> &Arc<Device> {
|
||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_vk_handle!(ShaderModule, VkShaderModule, shader_module);
|
impl_vk_handle!(ShaderModule<const TYPE: u8,>, VkShaderModule, shader_module);
|
||||||
|
|
||||||
impl Drop for ShaderModule {
|
impl<const TYPE: u8> Drop for ShaderModule<TYPE> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.device.destroy_shader_module(self.shader_module);
|
self.device.destroy_shader_module(self.shader_module);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue