vulkan_lib/vulkan-rs/src/descriptorset.rs

384 lines
11 KiB
Rust
Raw Normal View History

2023-01-14 12:03:01 +00:00
use crate::prelude::*;
use anyhow::Result;
use std::any::Any;
use std::collections::HashMap;
use std::slice;
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub struct DescriptorWrite {
binding: u32,
descriptor_type: VkDescriptorType,
inner: InnerWrite,
handles: Vec<Arc<dyn Any + Send + Sync>>,
}
#[derive(Debug)]
enum InnerWrite {
Buffers(Vec<VkDescriptorBufferInfo>),
Images(Vec<VkDescriptorImageInfo>),
AS(
(
VkWriteDescriptorSetAccelerationStructureKHR,
Vec<VkAccelerationStructureKHR>,
),
),
}
impl DescriptorWrite {
pub fn uniform_buffers<T: ReprC + Send + Sync + 'static>(
2023-01-14 12:03:01 +00:00
binding: u32,
buffers: &[&Arc<Buffer<T>>],
) -> Self {
DescriptorWrite {
binding,
descriptor_type: VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
inner: InnerWrite::Buffers(
buffers
.iter()
.map(|buffer| VkDescriptorBufferInfo {
buffer: buffer.vk_handle(),
offset: 0,
range: buffer.byte_size(),
})
.collect(),
),
handles: buffers
.iter()
.map(|b| (*b).clone() as Arc<dyn Any + Send + Sync>)
.collect(),
}
}
pub fn storage_buffers<T: ReprC + Send + Sync + 'static>(
2023-01-14 12:03:01 +00:00
binding: u32,
buffers: &[&Arc<Buffer<T>>],
) -> Self {
DescriptorWrite {
binding,
descriptor_type: VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
inner: InnerWrite::Buffers(
buffers
.iter()
.map(|buffer| VkDescriptorBufferInfo {
buffer: buffer.vk_handle(),
offset: 0,
range: buffer.byte_size(),
})
.collect(),
),
handles: buffers
.iter()
.map(|b| (*b).clone() as Arc<dyn Any + Send + Sync>)
.collect(),
}
}
pub fn combined_samplers(binding: u32, images: &[&Arc<Image>]) -> Self {
DescriptorWrite {
binding,
descriptor_type: VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
inner: InnerWrite::Images(
images
.iter()
.map(|image| VkDescriptorImageInfo {
sampler: image
.sampler()
.as_ref()
.expect("image has no sampler attached")
.vk_handle(),
imageView: image.vk_handle(),
imageLayout: image.image_layout(),
})
.collect(),
),
handles: images
.iter()
.map(|i| (*i).clone() as Arc<dyn Any + Send + Sync>)
.collect(),
}
}
pub fn storage_images(binding: u32, images: &[&Arc<Image>]) -> Self {
DescriptorWrite {
binding,
descriptor_type: VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
inner: InnerWrite::Images(
images
.iter()
.map(|image| VkDescriptorImageInfo {
sampler: VkSampler::NULL_HANDLE,
imageView: image.vk_handle(),
imageLayout: image.image_layout(),
})
.collect(),
),
handles: images
.iter()
.map(|i| (*i).clone() as Arc<dyn Any + Send + Sync>)
.collect(),
}
}
pub fn input_attachments(binding: u32, images: &[&Arc<Image>]) -> Self {
DescriptorWrite {
binding,
descriptor_type: VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT,
inner: InnerWrite::Images(
2023-01-14 12:03:01 +00:00
images
.iter()
.map(|image| VkDescriptorImageInfo {
sampler: VkSampler::NULL_HANDLE,
imageView: image.vk_handle(),
imageLayout: image.image_layout(),
})
.collect(),
),
handles: images
.iter()
.map(|i| (*i).clone() as Arc<dyn Any + Send + Sync>)
.collect(),
}
}
pub fn acceleration_structures(
binding: u32,
acceleration_structures: &[&Arc<AccelerationStructure>],
) -> Self {
let vk_as: Vec<VkAccelerationStructureKHR> = acceleration_structures
.iter()
.map(|a| a.vk_handle())
.collect();
let mut write = DescriptorWrite {
binding,
descriptor_type: VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR,
inner: InnerWrite::AS((
VkWriteDescriptorSetAccelerationStructureKHR::default(),
vk_as,
)),
handles: acceleration_structures
.iter()
.map(|a| (*a).clone() as Arc<dyn Any + Send + Sync>)
.collect(),
};
if let InnerWrite::AS((vk_write_as, vk_as)) = &mut write.inner {
vk_write_as.set_acceleration_structures(vk_as);
}
write
}
pub fn change_image_layout(mut self, image_layout: VkImageLayout) -> Self {
if let InnerWrite::Images(ref mut infos) = self.inner {
for info in infos {
info.imageLayout = image_layout;
}
}
self
}
fn vk_write(&self, write: &mut VkWriteDescriptorSet) {
match &self.inner {
InnerWrite::Buffers(buffer_infos) => {
write.set_buffer_infos(buffer_infos);
}
InnerWrite::Images(image_infos) => {
write.set_image_infos(image_infos);
}
InnerWrite::AS((as_write, _)) => {
write.descriptorCount = as_write.accelerationStructureCount;
write.chain(as_write);
}
}
}
}
pub struct DescriptorSetBuilder {
device: Arc<Device>,
descriptor_pool: Arc<DescriptorPool>,
variable_desc_counts: Vec<u32>,
variable_descriptor_count: VkDescriptorSetVariableDescriptorCountAllocateInfoEXT,
}
impl DescriptorSetBuilder {
pub fn set_variable_descriptor_counts(mut self, descriptor_counts: &[u32]) -> Self {
self.variable_desc_counts = descriptor_counts.to_vec();
self
}
pub fn allocate(mut self) -> Result<Arc<DescriptorSet>> {
let layout = self.descriptor_pool.vk_handle();
let mut descriptor_set_ci = VkDescriptorSetAllocateInfo::new(
self.descriptor_pool.vk_handle(),
slice::from_ref(&layout),
);
if !self.variable_desc_counts.is_empty() {
self.variable_descriptor_count
.set_descriptor_counts(&self.variable_desc_counts);
descriptor_set_ci.chain(&self.variable_descriptor_count);
}
let descriptor_set = self.device.allocate_descriptor_sets(&descriptor_set_ci)?[0];
Ok(Arc::new(DescriptorSet {
device: self.device,
pool: self.descriptor_pool,
descriptor_set,
handles: Mutex::new(HashMap::new()),
}))
}
}
#[derive(Debug)]
pub struct DescriptorSet {
device: Arc<Device>,
pool: Arc<DescriptorPool>,
descriptor_set: VkDescriptorSet,
handles: Mutex<HashMap<u32, Vec<Arc<dyn Any + Send + Sync>>>>,
}
impl DescriptorSet {
pub(crate) fn builder(
device: Arc<Device>,
descriptor_pool: Arc<DescriptorPool>,
) -> DescriptorSetBuilder {
DescriptorSetBuilder {
device,
descriptor_pool,
variable_desc_counts: Vec::new(),
variable_descriptor_count: VkDescriptorSetVariableDescriptorCountAllocateInfoEXT::new(
&[],
),
}
}
// TODO: add update function for VkCopyDescriptorSet
pub fn update(&self, writes: &[DescriptorWrite]) -> Result<()> {
debug_assert!(!writes.is_empty());
let mut vk_writes = Vec::new();
let mut handles_lock = self.handles.lock().unwrap();
for write in writes {
let mut write_desc = VkWriteDescriptorSet::new(
self.descriptor_set,
write.binding,
0,
write.descriptor_type,
);
write.vk_write(&mut write_desc);
vk_writes.push(write_desc);
match handles_lock.get_mut(&write.binding) {
Some(val) => *val = write.handles.clone(),
None => {
handles_lock.insert(write.binding, write.handles.clone());
}
}
}
self.device
.update_descriptor_sets(vk_writes.as_slice(), &[]);
Ok(())
}
}
impl VulkanDevice for DescriptorSet {
fn device(&self) -> &Arc<Device> {
&self.device
}
}
impl_vk_handle!(DescriptorSet, VkDescriptorSet, descriptor_set);
impl VkHandle<VkDescriptorSetLayout> for DescriptorSet {
fn vk_handle(&self) -> VkDescriptorSetLayout {
self.pool.vk_handle()
}
}
impl<'a> VkHandle<VkDescriptorSetLayout> for &'a DescriptorSet {
fn vk_handle(&self) -> VkDescriptorSetLayout {
self.pool.vk_handle()
}
}
impl VkHandle<VkDescriptorSetLayout> for Arc<DescriptorSet> {
fn vk_handle(&self) -> VkDescriptorSetLayout {
self.pool.vk_handle()
}
}
impl<'a> VkHandle<VkDescriptorSetLayout> for &'a Arc<DescriptorSet> {
fn vk_handle(&self) -> VkDescriptorSetLayout {
self.pool.vk_handle()
}
}
impl Drop for DescriptorSet {
fn drop(&mut self) {
if let Err(error) = self
.device
.free_descriptor_sets(self.pool.vk_handle(), &[self.descriptor_set])
{
println!("{}", error);
}
}
}
#[cfg(test)]
mod test {
use crate::prelude::*;
use anyhow::Result;
2023-04-15 07:21:08 +00:00
use std::sync::Arc;
#[test]
fn create_multiple_sets_from_one_pool() -> Result<()> {
const DESCRIPTOR_COUNT: u32 = 2;
2023-04-15 07:21:08 +00:00
let (device, _queue) = crate::create_vk_handles()?;
let descriptor_layout = DescriptorSetLayout::builder()
.add_layout_binding(
0,
VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
VK_SHADER_STAGE_FRAGMENT_BIT,
0,
)
.change_descriptor_count(DESCRIPTOR_COUNT)
.build(device.clone())?;
let descriptor_pool = DescriptorPool::builder()
.set_layout(descriptor_layout.clone())
.build(device.clone())?;
2023-04-16 04:57:39 +00:00
let descriptors: Vec<Arc<DescriptorSet>> = (0..DESCRIPTOR_COUNT)
.map(|_| {
let set = descriptor_pool.prepare_set().allocate();
assert!(set.is_ok(), "{}", set.err().unwrap());
set.unwrap()
})
.collect();
2023-04-16 04:57:39 +00:00
assert_eq!(descriptors.len(), DESCRIPTOR_COUNT as usize);
Ok(())
}
}