use crate::prelude::*;

use anyhow::{Context, Result};

use std::fs::File;
use std::io::Read;
use std::marker::PhantomData;
use std::sync::Arc;

pub mod shader_type {
    mod sealed {
        pub trait Sealed {}
        impl Sealed for super::Vertex {}
        impl Sealed for super::Fragment {}
        impl Sealed for super::Geometry {}
        impl Sealed for super::TesselationControl {}
        impl Sealed for super::TesselationEvaluation {}
        impl Sealed for super::Compute {}
        impl Sealed for super::RayGeneration {}
        impl Sealed for super::ClosestHit {}
        impl Sealed for super::Miss {}
        impl Sealed for super::AnyHit {}
        impl Sealed for super::Intersection {}
    }

    pub trait ShaderType: sealed::Sealed {}

    pub struct Vertex;
    impl ShaderType for Vertex {}

    pub struct Fragment;
    impl ShaderType for Fragment {}

    pub struct Geometry;
    impl ShaderType for Geometry {}

    pub struct TesselationControl;
    impl ShaderType for TesselationControl {}

    pub struct TesselationEvaluation;
    impl ShaderType for TesselationEvaluation {}

    pub struct Compute;
    impl ShaderType for Compute {}

    pub struct RayGeneration;
    impl ShaderType for RayGeneration {}

    pub struct ClosestHit;
    impl ShaderType for ClosestHit {}

    pub struct Miss;
    impl ShaderType for Miss {}

    pub struct AnyHit;
    impl ShaderType for AnyHit {}

    pub struct Intersection;
    impl ShaderType for Intersection {}
}

pub trait VertexInputDescription: ReprC + Sized {
    fn bindings() -> Vec<VkVertexInputBindingDescription> {
        vec![VkVertexInputBindingDescription {
            binding: 0,
            stride: std::mem::size_of::<Self>() as u32,
            inputRate: VK_VERTEX_INPUT_RATE_VERTEX,
        }]
    }

    fn attributes() -> Vec<VkVertexInputAttributeDescription>;
}

pub trait PipelineStageInfo {
    fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo;
}

macro_rules! impl_pipeline_stage_info {
    ($func:ident, $type:ident) => {
        impl PipelineStageInfo for ShaderModule<$type> {
            fn pipeline_stage_info(&self) -> VkPipelineShaderStageCreateInfo {
                VkPipelineShaderStageCreateInfo::$func(self.shader_module)
            }
        }
    };
}

#[derive(Debug)]
pub struct ShaderModule<ShaderModuleType: ShaderType> {
    t: PhantomData<ShaderModuleType>,
    device: Arc<Device>,
    shader_module: VkShaderModule,
}

impl<ShaderModuleType: ShaderType> ShaderModule<ShaderModuleType> {
    pub fn new(device: Arc<Device>, path: &str) -> Result<Arc<ShaderModule<ShaderModuleType>>> {
        let code = Self::shader_code(path)?;

        Self::from_slice(device, code.as_slice())
    }

    pub fn from_slice(
        device: Arc<Device>,
        code: &[u8],
    ) -> Result<Arc<ShaderModule<ShaderModuleType>>> {
        let shader_module_ci =
            VkShaderModuleCreateInfo::new(VK_SHADER_MODULE_CREATE_NULL_BIT, code);

        let shader_module = device.create_shader_module(&shader_module_ci)?;

        Ok(Arc::new(ShaderModule {
            t: PhantomData,
            device,
            shader_module,
        }))
    }

    fn shader_code(path: &str) -> Result<Vec<u8>> {
        let mut file = File::open(path).with_context({
            let path = path.to_string();
            || path
        })?;

        let mut code: Vec<u8> = Vec::new();

        file.read_to_end(&mut code)?;

        Ok(code)
    }
}

use shader_type::*;

impl_pipeline_stage_info!(vertex, Vertex);
impl_pipeline_stage_info!(geometry, Geometry);
impl_pipeline_stage_info!(tesselation_control, TesselationControl);
impl_pipeline_stage_info!(tesselation_evaluation, TesselationEvaluation);
impl_pipeline_stage_info!(fragment, Fragment);
impl_pipeline_stage_info!(compute, Compute);
impl_pipeline_stage_info!(any_hit, AnyHit);
impl_pipeline_stage_info!(intersection, Intersection);
impl_pipeline_stage_info!(closest_hit, ClosestHit);
impl_pipeline_stage_info!(ray_generation, RayGeneration);
impl_pipeline_stage_info!(miss, Miss);

impl<ShaderModuleType: ShaderType> VulkanDevice for ShaderModule<ShaderModuleType> {
    fn device(&self) -> &Arc<Device> {
        &self.device
    }
}

impl_vk_handle!(ShaderModule<ShaderModuleType: ShaderType,>, VkShaderModule, shader_module);

impl<ShaderModuleType: ShaderType> Drop for ShaderModule<ShaderModuleType> {
    fn drop(&mut self) {
        self.device.destroy_shader_module(self.shader_module);
    }
}

pub trait AddSpecializationConstant<T> {
    fn add(&mut self, value: T, id: u32);
}

pub struct SpecializationConstants {
    // store data as raw bytes
    data: Vec<u8>,
    entries: Vec<VkSpecializationMapEntry>,

    info: VkSpecializationInfo,
}

impl SpecializationConstants {
    pub fn new() -> Self {
        let mut me = SpecializationConstants {
            data: Vec::new(),
            entries: Vec::new(),

            info: VkSpecializationInfo::empty(),
        };

        me.info.set_data(&me.data);
        me.info.set_map_entries(&me.entries);

        me
    }

    pub fn vk_handle(&self) -> &VkSpecializationInfo {
        &self.info
    }
}

macro_rules! impl_add_specialization_constant {
    ($($type: ty),+) => {
        $(
            impl AddSpecializationConstant<$type> for SpecializationConstants {
                fn add(&mut self, value: $type, id: u32) {
                    let bytes = value.to_ne_bytes();

                    self.entries.push(VkSpecializationMapEntry {
                        constantID: id,
                        offset: self.data.len() as u32,
                        size: bytes.len(),
                    });

                    self.data.extend(&bytes);
                }
            }
        )+
    };
}

impl_add_specialization_constant!(f32, f64, u64, i64, u32, i32, u16, i16, u8, i8, usize, isize);