Use const generics for shader module types

This commit is contained in:
hodasemi 2023-01-27 10:12:59 +01:00
parent a36ba524ac
commit a78f938273
7 changed files with 189 additions and 129 deletions

View file

@ -11,3 +11,4 @@ vma-rs = { path = "../vma-rs" }
anyhow = { version = "1.0.68", features = ["backtrace"] }
cgmath = "0.18.0"
assetpath = { path = "../assetpath" }
safer-ffi = "0.0.10"

View file

@ -1,24 +1,24 @@
macro_rules! impl_vk_handle {
($struct_name:ident, $target_name:ident, $value:ident) => {
impl VkHandle<$target_name> for $struct_name {
($struct_name:ident $(<$( $const:ident $name:ident: $type:ident, )*>)?, $target_name:ident, $value:ident) => {
impl$(<$( $const $name: $type, )*>)? VkHandle<$target_name> for $struct_name$(<$($name,)?>)? {
fn vk_handle(&self) -> $target_name {
self.$value
}
}
impl<'a> VkHandle<$target_name> for &'a $struct_name {
impl<'a $($(, $const $name: $type)*)?> VkHandle<$target_name> for &'a $struct_name$(<$($name,)?>)? {
fn vk_handle(&self) -> $target_name {
self.$value
}
}
impl VkHandle<$target_name> for Arc<$struct_name> {
impl$(<$( $const $name: $type, )*>)? VkHandle<$target_name> for Arc<$struct_name$(<$($name,)?>)?> {
fn vk_handle(&self) -> $target_name {
self.$value
}
}
impl<'a> VkHandle<$target_name> for &'a Arc<$struct_name> {
impl<'a $($(, $const $name: $type)*)?> VkHandle<$target_name> for &'a Arc<$struct_name$(<$($name,)?>)?> {
fn vk_handle(&self) -> $target_name {
self.$value
}

View file

@ -6,22 +6,21 @@ use crate::prelude::*;
use std::sync::Arc;
pub struct ComputePipelineBuilder<'a> {
shader_module: Option<&'a Arc<ShaderModule>>,
shader_module: Option<&'a Arc<ShaderModule<{ ShaderType::Compute as u8 }>>>,
pipeline_cache: Option<&'a Arc<PipelineCache>>,
flags: VkPipelineCreateFlagBits,
}
impl<'a> ComputePipelineBuilder<'a> {
// TODO: add support for specialization constants
pub fn set_shader_module(mut self, shader_module: &'a Arc<ShaderModule>) -> Self {
pub fn set_shader_module(
mut self,
shader_module: &'a Arc<ShaderModule<{ ShaderType::Compute as u8 }>>,
) -> Self {
if cfg!(debug_assertions) {
if self.shader_module.is_some() {
panic!("shader already set!");
}
if shader_module.shader_type() != ShaderType::Compute {
panic!("shader has wrong type!");
}
}
self.shader_module = Some(shader_module);

View file

@ -12,18 +12,21 @@ pub struct GraphicsPipelineBuilder {
amd_rasterization_order: Option<VkPipelineRasterizationStateRasterizationOrderAMD>,
vertex_shader: Option<Arc<ShaderModule>>,
vertex_shader: Option<Arc<ShaderModule<{ ShaderType::Vertex as u8 }>>>,
vertex_binding_description: Vec<VkVertexInputBindingDescription>,
vertex_attribute_description: Vec<VkVertexInputAttributeDescription>,
input_assembly: Option<VkPipelineInputAssemblyStateCreateInfo>,
tesselation_shader: Option<(Arc<ShaderModule>, Arc<ShaderModule>)>,
tesselation_shader: Option<(
Arc<ShaderModule<{ ShaderType::TesselationControl as u8 }>>,
Arc<ShaderModule<{ ShaderType::TesselationEvaluation as u8 }>>,
)>,
patch_control_points: u32,
geometry_shader: Option<Arc<ShaderModule>>,
geometry_shader: Option<Arc<ShaderModule<{ ShaderType::Geometry as u8 }>>>,
fragment_shader: Option<Arc<ShaderModule>>,
fragment_shader: Option<Arc<ShaderModule<{ ShaderType::Fragment as u8 }>>>,
viewports: Vec<VkViewport>,
scissors: Vec<VkRect2D>,
@ -40,19 +43,13 @@ pub struct GraphicsPipelineBuilder {
impl GraphicsPipelineBuilder {
// TODO: add support for specialization constants
pub fn set_vertex_shader(
pub fn set_vertex_shader<T: VertexInputDescription>(
mut self,
shader: Arc<ShaderModule>,
vertex_binding_description: Vec<VkVertexInputBindingDescription>,
vertex_attribute_description: Vec<VkVertexInputAttributeDescription>,
shader: Arc<ShaderModule<{ ShaderType::Vertex as u8 }>>,
) -> Self {
if cfg!(debug_assertions) {
assert_eq!(shader.shader_type(), ShaderType::Vertex);
}
self.vertex_shader = Some(shader);
self.vertex_binding_description = vertex_binding_description;
self.vertex_attribute_description = vertex_attribute_description;
self.vertex_binding_description = T::bindings();
self.vertex_attribute_description = T::attributes();
self
}
@ -60,22 +57,10 @@ impl GraphicsPipelineBuilder {
// TODO: add support for specialization constants
pub fn set_tesselation_shader(
mut self,
tesselation_control: Arc<ShaderModule>,
tesselation_evaluation: Arc<ShaderModule>,
tesselation_control: Arc<ShaderModule<{ ShaderType::TesselationControl as u8 }>>,
tesselation_evaluation: Arc<ShaderModule<{ ShaderType::TesselationEvaluation as u8 }>>,
patch_control_points: u32,
) -> Self {
if cfg!(debug_assertions) {
assert_eq!(
tesselation_control.shader_type(),
ShaderType::TesselationControl
);
assert_eq!(
tesselation_evaluation.shader_type(),
ShaderType::TesselationEvaluation
);
}
self.tesselation_shader = Some((tesselation_control, tesselation_evaluation));
self.patch_control_points = patch_control_points;
@ -83,22 +68,20 @@ impl GraphicsPipelineBuilder {
}
// TODO: add support for specialization constants
pub fn set_geometry_shader(mut self, shader: Arc<ShaderModule>) -> Self {
if cfg!(debug_assertions) {
assert_eq!(shader.shader_type(), ShaderType::Geometry);
}
pub fn set_geometry_shader(
mut self,
shader: Arc<ShaderModule<{ ShaderType::Geometry as u8 }>>,
) -> Self {
self.geometry_shader = Some(shader);
self
}
// TODO: add support for specialization constants
pub fn set_fragment_shader(mut self, shader: Arc<ShaderModule>) -> Self {
if cfg!(debug_assertions) {
assert_eq!(shader.shader_type(), ShaderType::Fragment);
}
pub fn set_fragment_shader(
mut self,
shader: Arc<ShaderModule<{ ShaderType::Fragment as u8 }>>,
) -> Self {
self.fragment_shader = Some(shader);
self

View file

@ -29,8 +29,72 @@ impl<'a> Library<'a> {
}
}
macro_rules! impl_from_shader_type {
($struct: ident, $shader_type: ident) => {
impl From<Arc<ShaderModule<{ ShaderType::$shader_type as u8 }>>> for $struct {
fn from(value: Arc<ShaderModule<{ ShaderType::$shader_type as u8 }>>) -> Self {
Self::$shader_type(value)
}
}
};
}
#[derive(Clone)]
enum RaytracingShader {
RayGeneration(Arc<ShaderModule<{ ShaderType::RayGeneration as u8 }>>),
ClosestHit(Arc<ShaderModule<{ ShaderType::ClosestHit as u8 }>>),
Miss(Arc<ShaderModule<{ ShaderType::Miss as u8 }>>),
AnyHit(Arc<ShaderModule<{ ShaderType::AnyHit as u8 }>>),
Intersection(Arc<ShaderModule<{ ShaderType::Intersection as u8 }>>),
}
impl_from_shader_type!(RaytracingShader, RayGeneration);
impl_from_shader_type!(RaytracingShader, ClosestHit);
impl_from_shader_type!(RaytracingShader, Miss);
impl_from_shader_type!(RaytracingShader, AnyHit);
impl_from_shader_type!(RaytracingShader, Intersection);
impl From<HitShader> for RaytracingShader {
fn from(value: HitShader) -> Self {
match value {
HitShader::ClosestHit(s) => Self::ClosestHit(s),
HitShader::AnyHit(s) => Self::AnyHit(s),
HitShader::Intersection(s) => Self::Intersection(s),
}
}
}
impl From<OtherShader> for RaytracingShader {
fn from(value: OtherShader) -> Self {
match value {
OtherShader::Miss(s) => Self::Miss(s),
OtherShader::RayGeneration(s) => Self::RayGeneration(s),
}
}
}
#[derive(Clone)]
pub enum HitShader {
ClosestHit(Arc<ShaderModule<{ ShaderType::ClosestHit as u8 }>>),
AnyHit(Arc<ShaderModule<{ ShaderType::AnyHit as u8 }>>),
Intersection(Arc<ShaderModule<{ ShaderType::Intersection as u8 }>>),
}
impl_from_shader_type!(HitShader, ClosestHit);
impl_from_shader_type!(HitShader, AnyHit);
impl_from_shader_type!(HitShader, Intersection);
#[derive(Clone)]
pub enum OtherShader {
RayGeneration(Arc<ShaderModule<{ ShaderType::RayGeneration as u8 }>>),
Miss(Arc<ShaderModule<{ ShaderType::Miss as u8 }>>),
}
impl_from_shader_type!(OtherShader, RayGeneration);
impl_from_shader_type!(OtherShader, Miss);
pub struct RayTracingPipelineBuilder<'a> {
shader_modules: Vec<(Arc<ShaderModule>, Option<SpecializationConstants>)>,
shader_modules: Vec<(RaytracingShader, Option<SpecializationConstants>)>,
shader_groups: Vec<VkRayTracingShaderGroupCreateInfoKHR>,
@ -82,26 +146,24 @@ impl<'a> RayTracingPipelineBuilder<'a> {
pub fn add_shader(
mut self,
shader_module: Arc<ShaderModule>,
shader_module: impl Into<OtherShader>,
data: Option<Vec<u8>>,
specialization_constants: Option<SpecializationConstants>,
) -> Self {
self.shader_binding_table_builder = match shader_module.shader_type() {
ShaderType::RayGeneration => self
let shader_module = shader_module.into();
self.shader_binding_table_builder = match shader_module {
OtherShader::RayGeneration(_) => self
.shader_binding_table_builder
.add_ray_gen_program(self.shader_groups.len() as u32, data),
ShaderType::Miss => self
OtherShader::Miss(_) => self
.shader_binding_table_builder
.add_miss_program(self.shader_groups.len() as u32, data),
_ => panic!(
"unsupported shader type: {:?}, expected RayGen or Miss Shader",
shader_module.shader_type()
),
};
let shader_index = self.shader_modules.len();
self.shader_modules
.push((shader_module, specialization_constants));
.push((shader_module.into(), specialization_constants));
self.shader_groups
.push(VkRayTracingShaderGroupCreateInfoKHR::new(
@ -117,7 +179,7 @@ impl<'a> RayTracingPipelineBuilder<'a> {
pub fn add_hit_shaders(
mut self,
shader_modules: impl IntoIterator<Item = (Arc<ShaderModule>, Option<SpecializationConstants>)>,
shader_modules: impl IntoIterator<Item = (HitShader, Option<SpecializationConstants>)>,
data: Option<Vec<u8>>,
) -> Self {
let mut group = VkRayTracingShaderGroupCreateInfoKHR::new(
@ -131,37 +193,39 @@ impl<'a> RayTracingPipelineBuilder<'a> {
for (shader_module, specialization_constant) in shader_modules.into_iter() {
let shader_index = self.shader_modules.len() as u32;
match shader_module.shader_type() {
ShaderType::AnyHit => {
match shader_module {
HitShader::AnyHit(_) => {
// sanity check
if cfg!(debug_assertions) && group.anyHitShader != VK_SHADER_UNUSED_KHR {
panic!("any hit shader already used in current hit group");
}
debug_assert_ne!(
group.anyHitShader, VK_SHADER_UNUSED_KHR,
"any hit shader already used in current hit group"
);
group.anyHitShader = shader_index;
}
ShaderType::ClosestHit => {
HitShader::ClosestHit(_) => {
// sanity check
if cfg!(debug_assertions) && group.closestHitShader != VK_SHADER_UNUSED_KHR {
panic!("closest hit shader already used in current hit group");
}
debug_assert_ne!(
group.closestHitShader, VK_SHADER_UNUSED_KHR,
"closest hit shader already used in current hit group"
);
group.closestHitShader = shader_index;
}
ShaderType::Intersection => {
HitShader::Intersection(_) => {
// sanity check
if cfg!(debug_assertions) && group.intersectionShader != VK_SHADER_UNUSED_KHR {
panic!("intersection shader already used in current hit group");
}
debug_assert_ne!(
group.intersectionShader, VK_SHADER_UNUSED_KHR,
"intersection shader already used in current hit group"
);
group.intersectionShader = shader_index;
group.r#type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
}
_ => panic!("unsupported shader type: {:?}, expected AnyHit, ClosestHit or Intersection Shader", shader_module.shader_type()),
}
self.shader_modules
.push((shader_module, specialization_constant));
.push((shader_module.into(), specialization_constant));
}
self.shader_binding_table_builder = self
.shader_binding_table_builder
@ -186,7 +250,14 @@ impl<'a> RayTracingPipelineBuilder<'a> {
.shader_modules
.iter()
.map(|(shader, specialization_constant)| {
let mut stage_info = shader.pipeline_stage_info();
let mut stage_info = match shader {
RaytracingShader::RayGeneration(s) => s.pipeline_stage_info(),
RaytracingShader::ClosestHit(s) => s.pipeline_stage_info(),
RaytracingShader::Miss(s) => s.pipeline_stage_info(),
RaytracingShader::AnyHit(s) => s.pipeline_stage_info(),
RaytracingShader::Intersection(s) => s.pipeline_stage_info(),
};
if let Some(specialization_constant) = specialization_constant {
stage_info.set_specialization_info(specialization_constant.vk_handle());
}

View file

@ -24,7 +24,8 @@ pub use super::renderpass::RenderPass;
pub use super::sampler_manager::{Sampler, SamplerBuilder};
pub use super::semaphore::Semaphore;
pub use super::shadermodule::{
AddSpecializationConstant, ShaderModule, ShaderType, SpecializationConstants,
AddSpecializationConstant, PipelineStageInfo, ShaderModule, ShaderType,
SpecializationConstants, VertexInputDescription,
};
pub use super::surface::Surface;
pub use super::swapchain::Swapchain;

View file

@ -7,6 +7,7 @@ use std::io::Read;
use std::sync::Arc;
#[allow(clippy::cast_ptr_alignment)]
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ShaderType {
None,
@ -23,35 +24,59 @@ pub enum ShaderType {
Intersection,
}
impl ShaderType {
pub const VERTEX: u8 = Self::Vertex as u8;
pub const FRAGMENT: u8 = Self::Fragment as u8;
pub const GEOMETRY: u8 = Self::Geometry as u8;
pub const TESSELATION_CONTROL: u8 = Self::TesselationControl as u8;
pub const TESSELATION_EVALUATION: u8 = Self::TesselationEvaluation as u8;
pub const COMPUTE: u8 = Self::Compute as u8;
pub const RAY_GENERATION: u8 = Self::RayGeneration as u8;
pub const CLOSEST_HIT: u8 = Self::ClosestHit as u8;
pub const MISS: u8 = Self::Miss as u8;
pub const ANY_HIT: u8 = Self::AnyHit as u8;
pub const INTERSECTION: u8 = Self::Intersection as u8;
}
impl Default for ShaderType {
fn default() -> Self {
ShaderType::None
}
}
#[derive(Debug)]
pub struct ShaderModule {
device: Arc<Device>,
shader_module: VkShaderModule,
shader_type: ShaderType,
pub trait VertexInputDescription {
fn bindings() -> Vec<VkVertexInputBindingDescription>;
fn attributes() -> Vec<VkVertexInputAttributeDescription>;
}
impl ShaderModule {
pub fn new(
pub trait PipelineStageInfo {
fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo;
}
macro_rules! impl_pipeline_stage_info {
($func:ident, $type:ident) => {
impl PipelineStageInfo for ShaderModule<{ ShaderType::$type as u8 }> {
fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo {
VkPipelineShaderStageCreateInfo::$func(self.shader_module)
}
}
};
}
#[derive(Debug)]
pub struct ShaderModule<const TYPE: u8> {
device: Arc<Device>,
path: &str,
shader_type: ShaderType,
) -> Result<Arc<ShaderModule>> {
shader_module: VkShaderModule,
}
impl<const TYPE: u8> ShaderModule<TYPE> {
pub fn new(device: Arc<Device>, path: &str) -> Result<Arc<ShaderModule<TYPE>>> {
let code = Self::shader_code(path)?;
Self::from_slice(device, code.as_slice(), shader_type)
Self::from_slice(device, code.as_slice())
}
pub fn from_slice(
device: Arc<Device>,
code: &[u8],
shader_type: ShaderType,
) -> Result<Arc<ShaderModule>> {
pub fn from_slice(device: Arc<Device>, code: &[u8]) -> Result<Arc<ShaderModule<TYPE>>> {
let shader_module_ci =
VkShaderModuleCreateInfo::new(VK_SHADER_MODULE_CREATE_NULL_BIT, code);
@ -60,7 +85,6 @@ impl ShaderModule {
Ok(Arc::new(ShaderModule {
device,
shader_module,
shader_type,
}))
}
@ -76,48 +100,29 @@ impl ShaderModule {
Ok(code)
}
pub fn shader_type(&self) -> ShaderType {
self.shader_type
}
pub fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo {
match self.shader_type {
ShaderType::None => unimplemented!(),
ShaderType::Vertex => VkPipelineShaderStageCreateInfo::vertex(self.shader_module),
ShaderType::Geometry => VkPipelineShaderStageCreateInfo::geometry(self.shader_module),
ShaderType::TesselationControl => {
VkPipelineShaderStageCreateInfo::tesselation_control(self.shader_module)
}
ShaderType::TesselationEvaluation => {
VkPipelineShaderStageCreateInfo::tesselation_evaluation(self.shader_module)
}
ShaderType::Fragment => VkPipelineShaderStageCreateInfo::fragment(self.shader_module),
ShaderType::Compute => VkPipelineShaderStageCreateInfo::compute(self.shader_module),
ShaderType::AnyHit => VkPipelineShaderStageCreateInfo::any_hit(self.shader_module),
ShaderType::Intersection => {
VkPipelineShaderStageCreateInfo::intersection(self.shader_module)
}
ShaderType::ClosestHit => {
VkPipelineShaderStageCreateInfo::closest_hit(self.shader_module)
}
ShaderType::RayGeneration => {
VkPipelineShaderStageCreateInfo::ray_generation(self.shader_module)
}
ShaderType::Miss => VkPipelineShaderStageCreateInfo::miss(self.shader_module),
}
}
}
impl VulkanDevice for ShaderModule {
impl_pipeline_stage_info!(vertex, Vertex);
impl_pipeline_stage_info!(geometry, Geometry);
impl_pipeline_stage_info!(tesselation_control, TesselationControl);
impl_pipeline_stage_info!(tesselation_evaluation, TesselationEvaluation);
impl_pipeline_stage_info!(fragment, Fragment);
impl_pipeline_stage_info!(compute, Compute);
impl_pipeline_stage_info!(any_hit, AnyHit);
impl_pipeline_stage_info!(intersection, Intersection);
impl_pipeline_stage_info!(closest_hit, ClosestHit);
impl_pipeline_stage_info!(ray_generation, RayGeneration);
impl_pipeline_stage_info!(miss, Miss);
impl<const TYPE: u8> VulkanDevice for ShaderModule<TYPE> {
fn device(&self) -> &Arc<Device> {
&self.device
}
}
impl_vk_handle!(ShaderModule, VkShaderModule, shader_module);
impl_vk_handle!(ShaderModule<const TYPE: u8,>, VkShaderModule, shader_module);
impl Drop for ShaderModule {
impl<const TYPE: u8> Drop for ShaderModule<TYPE> {
fn drop(&mut self) {
self.device.destroy_shader_module(self.shader_module);
}