2023-01-14 12:03:01 +00:00
|
|
|
use anyhow::Result;
|
|
|
|
|
|
|
|
use crate::pipeline::PipelineType;
|
|
|
|
use crate::prelude::*;
|
|
|
|
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
pub struct ComputePipelineBuilder<'a> {
|
2023-01-27 09:12:59 +00:00
|
|
|
shader_module: Option<&'a Arc<ShaderModule<{ ShaderType::Compute as u8 }>>>,
|
2023-01-14 12:03:01 +00:00
|
|
|
pipeline_cache: Option<&'a Arc<PipelineCache>>,
|
|
|
|
flags: VkPipelineCreateFlagBits,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a> ComputePipelineBuilder<'a> {
|
|
|
|
// TODO: add support for specialization constants
|
2023-01-27 09:12:59 +00:00
|
|
|
pub fn set_shader_module(
|
|
|
|
mut self,
|
|
|
|
shader_module: &'a Arc<ShaderModule<{ ShaderType::Compute as u8 }>>,
|
|
|
|
) -> Self {
|
2023-01-14 12:03:01 +00:00
|
|
|
if cfg!(debug_assertions) {
|
|
|
|
if self.shader_module.is_some() {
|
|
|
|
panic!("shader already set!");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
self.shader_module = Some(shader_module);
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn set_pipeline_cache(mut self, pipeline_cache: &'a Arc<PipelineCache>) -> Self {
|
|
|
|
self.pipeline_cache = Some(pipeline_cache);
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn set_flags(mut self, flags: impl Into<VkPipelineCreateFlagBits>) -> Self {
|
|
|
|
self.flags = flags.into();
|
|
|
|
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn build(
|
|
|
|
self,
|
|
|
|
device: &Arc<Device>,
|
|
|
|
pipeline_layout: &Arc<PipelineLayout>,
|
|
|
|
) -> Result<Arc<Pipeline>> {
|
|
|
|
let pipeline_ci = match self.shader_module {
|
|
|
|
Some(module) => VkComputePipelineCreateInfo::new(
|
|
|
|
self.flags,
|
|
|
|
module.pipeline_stage_info(),
|
|
|
|
pipeline_layout.vk_handle(),
|
|
|
|
),
|
|
|
|
None => {
|
|
|
|
return Err(anyhow::Error::msg(
|
|
|
|
"Required shader module could not be found",
|
|
|
|
))
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
let pipeline = device.create_compute_pipelines(
|
|
|
|
self.pipeline_cache.map(|cache| cache.vk_handle()),
|
|
|
|
&[pipeline_ci],
|
|
|
|
)?[0];
|
|
|
|
|
|
|
|
Ok(Arc::new(Pipeline::new(
|
|
|
|
device.clone(),
|
|
|
|
pipeline_layout.clone(),
|
|
|
|
PipelineType::Compute,
|
|
|
|
pipeline,
|
|
|
|
)))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'a> Default for ComputePipelineBuilder<'a> {
|
|
|
|
fn default() -> Self {
|
|
|
|
ComputePipelineBuilder {
|
|
|
|
shader_module: None,
|
|
|
|
pipeline_cache: None,
|
|
|
|
flags: 0.into(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|