vulkan_lib/vulkan-rs/src/descriptorsetlayout.rs

183 lines
5.2 KiB
Rust

use crate::prelude::*;
use anyhow::Result;
use std::sync::Arc;
pub struct DescriptorSetLayoutBuilder {
layout_bindings: Vec<VkDescriptorSetLayoutBinding>,
indexing_flags: Vec<VkDescriptorBindingFlagBitsEXT>,
flags: VkDescriptorSetLayoutCreateFlagBits,
}
impl DescriptorSetLayoutBuilder {
pub fn add_layout_binding(
mut self,
binding: u32,
descriptor_type: VkDescriptorType,
stage_flags: impl Into<VkShaderStageFlagBits>,
indexing_flags: impl Into<VkDescriptorBindingFlagBitsEXT>,
) -> Self {
self.layout_bindings.push(VkDescriptorSetLayoutBinding::new(
binding,
descriptor_type,
stage_flags,
));
let flags = indexing_flags.into();
self.indexing_flags.push(flags);
if (flags & VK_DESCRIPTOR_BINDING_UPDATE_AFTER_BIND_BIT_EXT) != 0 {
self.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_UPDATE_AFTER_BIND_POOL_BIT_EXT;
}
self
}
pub fn change_descriptor_count(mut self, count: u32) -> Self {
if let Some(binding) = self.layout_bindings.last_mut() {
binding.descriptorCount = count;
}
self
}
pub fn set_flags(mut self, flags: impl Into<VkDescriptorSetLayoutCreateFlagBits>) -> Self {
self.flags = flags.into();
self
}
pub fn build(self, device: Arc<Device>) -> Result<Arc<DescriptorSetLayout>> {
let mut descriptor_set_ci =
VkDescriptorSetLayoutCreateInfo::new(self.flags, &self.layout_bindings);
let binding_flags_ci =
VkDescriptorSetLayoutBindingFlagsCreateInfoEXT::new(&self.indexing_flags);
if device.enabled_extensions().descriptor_indexing {
descriptor_set_ci.chain(&binding_flags_ci);
/*
if device.enabled_extensions().maintenance3 {
let mut layout_support = VkDescriptorSetLayoutSupport::default();
let variable_support =
VkDescriptorSetVariableDescriptorCountLayoutSupportEXT::default();
layout_support.chain(&variable_support);
device.descriptor_set_layout_support(&descriptor_set_ci, &mut layout_support);
}
*/
}
let bindings = self
.layout_bindings
.iter()
.map(|b| DescriptorLayoutBinding::from(b.clone()))
.collect();
let descriptor_set_layout = device.create_descriptor_set_layout(&descriptor_set_ci)?;
let pool_sizes = self
.layout_bindings
.into_iter()
.map(|layout_binding| VkDescriptorPoolSize {
ty: layout_binding.descriptorType,
descriptorCount: layout_binding.descriptorCount,
})
.collect();
Ok(Arc::new(DescriptorSetLayout {
device,
descriptor_set_layout,
pool_sizes,
bindings,
}))
}
}
#[derive(Debug)]
pub struct DescriptorSetLayout {
device: Arc<Device>,
descriptor_set_layout: VkDescriptorSetLayout,
pool_sizes: Vec<VkDescriptorPoolSize>,
bindings: Vec<DescriptorLayoutBinding>,
}
impl DescriptorSetLayout {
pub fn builder() -> DescriptorSetLayoutBuilder {
DescriptorSetLayoutBuilder {
layout_bindings: Vec::new(),
indexing_flags: Vec::new(),
flags: 0u32.into(),
}
}
pub(crate) fn pool_sizes(&self) -> &[VkDescriptorPoolSize] {
self.pool_sizes.as_slice()
}
pub fn bindings(&self) -> &[DescriptorLayoutBinding] {
&self.bindings
}
}
impl VulkanDevice for DescriptorSetLayout {
fn device(&self) -> &Arc<Device> {
&self.device
}
}
impl_vk_handle!(
DescriptorSetLayout,
VkDescriptorSetLayout,
descriptor_set_layout
);
impl Drop for DescriptorSetLayout {
fn drop(&mut self) {
self.device
.destroy_descriptor_set_layout(self.descriptor_set_layout);
}
}
#[derive(Debug)]
pub struct DescriptorLayoutBinding {
pub binding: u32,
pub desc_type: VkDescriptorType,
pub stage_flags: Vec<VkShaderStageFlags>,
}
impl From<VkDescriptorSetLayoutBinding> for DescriptorLayoutBinding {
fn from(value: VkDescriptorSetLayoutBinding) -> Self {
let flag_enum_list = [
VK_SHADER_STAGE_VERTEX_BIT,
VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
VK_SHADER_STAGE_GEOMETRY_BIT,
VK_SHADER_STAGE_FRAGMENT_BIT,
VK_SHADER_STAGE_COMPUTE_BIT,
VK_SHADER_STAGE_ALL_GRAPHICS,
VK_SHADER_STAGE_ALL,
VK_SHADER_STAGE_RAYGEN_BIT_KHR,
VK_SHADER_STAGE_ANY_HIT_BIT_KHR,
VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR,
VK_SHADER_STAGE_MISS_BIT_KHR,
VK_SHADER_STAGE_INTERSECTION_BIT_KHR,
VK_SHADER_STAGE_CALLABLE_BIT_KHR,
VK_SHADER_STAGE_TASK_BIT_NV,
VK_SHADER_STAGE_MESH_BIT_NV,
];
Self {
binding: value.binding,
desc_type: value.descriptorType,
stage_flags: flag_enum_list
.into_iter()
.filter(|&flag| (flag & value.stageFlagBits) != 0)
.collect(),
}
}
}