vulkan_lib/vulkan-rs/src/pipelines/compute_pipeline.rs
2024-03-24 21:29:47 +01:00

83 lines
2.1 KiB
Rust

use anyhow::Result;
use crate::pipeline::PipelineType;
use crate::prelude::*;
use std::sync::Arc;
pub struct ComputePipelineBuilder<'a> {
shader_module: Option<&'a Arc<ShaderModule<shader_type::Compute>>>,
pipeline_cache: Option<&'a Arc<PipelineCache>>,
flags: VkPipelineCreateFlagBits,
}
impl<'a> ComputePipelineBuilder<'a> {
// TODO: add support for specialization constants
pub fn set_shader_module(
mut self,
shader_module: &'a Arc<ShaderModule<shader_type::Compute>>,
) -> Self {
if cfg!(debug_assertions) {
if self.shader_module.is_some() {
panic!("shader already set!");
}
}
self.shader_module = Some(shader_module);
self
}
pub fn set_pipeline_cache(mut self, pipeline_cache: &'a Arc<PipelineCache>) -> Self {
self.pipeline_cache = Some(pipeline_cache);
self
}
pub fn set_flags(mut self, flags: impl Into<VkPipelineCreateFlagBits>) -> Self {
self.flags = flags.into();
self
}
pub fn build(
self,
device: &Arc<Device>,
pipeline_layout: &Arc<PipelineLayout>,
) -> Result<Arc<Pipeline>> {
let pipeline_ci = match self.shader_module {
Some(module) => VkComputePipelineCreateInfo::new(
self.flags,
module.pipeline_stage_info(),
pipeline_layout.vk_handle(),
),
None => {
return Err(anyhow::Error::msg(
"Required shader module could not be found",
))
}
};
let pipeline = device.create_compute_pipelines(
self.pipeline_cache.map(|cache| cache.vk_handle()),
&[pipeline_ci],
)?[0];
Ok(Arc::new(Pipeline::new(
device.clone(),
pipeline_layout.clone(),
PipelineType::Compute,
pipeline,
)))
}
}
impl<'a> Default for ComputePipelineBuilder<'a> {
fn default() -> Self {
ComputePipelineBuilder {
shader_module: None,
pipeline_cache: None,
flags: 0.into(),
}
}
}