use std::{
    collections::BTreeMap,
    hash::{Hash, Hasher},
    mem,
    sync::Arc,
};
use crate::{
    image::Image,
    layout::{Affine, Point, Rect, Vector},
    text::Paragraph,
    view::ViewId,
};
use super::{Color, Curve, Stroke};
#[derive(Clone, Debug, PartialEq)]
pub struct Pattern {
    pub image: Image,
    pub transform: Affine,
    pub color: Color,
}
impl Hash for Pattern {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.image.hash(state);
        self.transform.hash(state);
        self.color.hash(state);
    }
}
impl From<Image> for Pattern {
    fn from(value: Image) -> Self {
        Self {
            image: value,
            transform: Affine::IDENTITY,
            color: Color::WHITE,
        }
    }
}
#[derive(Clone, Debug, PartialEq, Hash)]
pub enum Shader {
    Solid(Color),
    Pattern(Pattern),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum BlendMode {
    Clear,
    Source,
    Destination,
    SourceOver,
    DestinationOver,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum AntiAlias {
    None,
    Fast,
    Full,
}
#[derive(Clone, Debug, PartialEq, Hash)]
pub struct Paint {
    pub shader: Shader,
    pub blend: BlendMode,
    pub anti_alias: AntiAlias,
}
impl Default for Paint {
    fn default() -> Self {
        Self {
            shader: Shader::Solid(Color::BLACK),
            blend: BlendMode::SourceOver,
            anti_alias: AntiAlias::Fast,
        }
    }
}
impl From<Color> for Paint {
    fn from(value: Color) -> Self {
        Self {
            shader: Shader::Solid(value),
            ..Default::default()
        }
    }
}
impl From<Image> for Paint {
    fn from(value: Image) -> Self {
        Self {
            shader: Shader::Pattern(Pattern::from(value)),
            ..Default::default()
        }
    }
}
impl From<Pattern> for Paint {
    fn from(value: Pattern) -> Self {
        Self {
            shader: Shader::Pattern(value),
            ..Default::default()
        }
    }
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FillRule {
    NonZero,
    EvenOdd,
}
#[derive(Clone, Debug, PartialEq)]
pub struct Mask {
    pub curve: Arc<Curve>,
    pub fill: FillRule,
}
impl Mask {
    pub fn new(curve: impl Into<Arc<Curve>>, fill: FillRule) -> Self {
        Self {
            curve: curve.into(),
            fill,
        }
    }
}
impl From<Rect> for Mask {
    fn from(value: Rect) -> Self {
        Self::new(Curve::rect(value), FillRule::NonZero)
    }
}
#[derive(Clone, Debug, PartialEq)]
pub enum Primitive {
    Fill {
        curve: Arc<Curve>,
        fill: FillRule,
        paint: Paint,
    },
    Stroke {
        curve: Arc<Curve>,
        stroke: Stroke,
        paint: Paint,
    },
    Paragraph {
        paragraph: Paragraph,
        bounds: Rect,
        rect: Rect,
    },
    Layer {
        primitives: Arc<Vec<Primitive>>,
        transform: Affine,
        mask: Option<Mask>,
        view: Option<ViewId>,
    },
}
impl Primitive {
    pub fn count(&self) -> usize {
        match self {
            Primitive::Fill { .. } | Primitive::Stroke { .. } | Primitive::Paragraph { .. } => 1,
            Primitive::Layer { primitives, .. } => primitives.iter().map(Self::count).sum(),
        }
    }
}
#[derive(Clone, Debug, PartialEq)]
pub struct Canvas {
    overlays: BTreeMap<i32, Arc<Vec<Primitive>>>,
    primitives: Arc<Vec<Primitive>>,
}
impl Default for Canvas {
    fn default() -> Self {
        Self::new()
    }
}
impl Canvas {
    pub fn new() -> Self {
        Self {
            overlays: BTreeMap::new(),
            primitives: Arc::new(Vec::new()),
        }
    }
    pub fn primitives(&self) -> impl Iterator<Item = &Primitive> + '_ {
        let overlays = self.overlays.values().flat_map(|p| p.iter());
        self.primitives.iter().chain(overlays)
    }
    pub fn clear(&mut self) {
        self.overlays.clear();
        Arc::make_mut(&mut self.primitives).clear();
    }
    pub fn rect(&mut self, rect: Rect, paint: impl Into<Paint>) {
        let curve = Curve::rect(rect);
        self.fill(curve.clone(), FillRule::NonZero, paint);
    }
    pub fn trigger(&mut self, rect: Rect, view: ViewId) {
        self.hoverable(view, |canvas| {
            let curve = Curve::rect(rect);
            canvas.fill(
                curve,
                FillRule::NonZero,
                Paint {
                    shader: Shader::Solid(Color::TRANSPARENT),
                    blend: BlendMode::Destination,
                    anti_alias: AntiAlias::None,
                },
            );
        });
    }
    pub fn fill(&mut self, curve: impl Into<Arc<Curve>>, fill: FillRule, paint: impl Into<Paint>) {
        let primitives = Arc::make_mut(&mut self.primitives);
        primitives.push(Primitive::Fill {
            curve: curve.into(),
            fill,
            paint: paint.into(),
        });
    }
    pub fn stroke(
        &mut self,
        curve: impl Into<Arc<Curve>>,
        stroke: impl Into<Stroke>,
        paint: impl Into<Paint>,
    ) {
        let primitives = Arc::make_mut(&mut self.primitives);
        primitives.push(Primitive::Stroke {
            curve: curve.into(),
            stroke: stroke.into(),
            paint: paint.into(),
        });
    }
    pub fn paragraph(&mut self, paragraph: Paragraph, rect: Rect, bounds: Rect) {
        let primitives = Arc::make_mut(&mut self.primitives);
        primitives.push(Primitive::Paragraph {
            paragraph,
            bounds,
            rect,
        });
    }
    pub fn draw_canvas(&mut self, canvas: Canvas) {
        self.layer(Affine::IDENTITY, None, None, |ca| *ca = canvas);
    }
    pub fn overlay<T>(&mut self, index: i32, f: impl FnOnce(&mut Self) -> T) -> T {
        let mut overlay = Canvas::new();
        let result = f(&mut overlay);
        for (i, mut others) in overlay.overlays {
            let others = mem::take(Arc::make_mut(&mut others));
            let primitives = Arc::make_mut(self.overlays.entry(i).or_default());
            primitives.extend(others);
        }
        let other = mem::take(Arc::make_mut(&mut overlay.primitives));
        let primitives = Arc::make_mut(self.overlays.entry(index).or_default());
        primitives.extend(other);
        result
    }
    pub fn layer<T>(
        &mut self,
        transform: Affine,
        mask: Option<Mask>,
        view: Option<ViewId>,
        f: impl FnOnce(&mut Self) -> T,
    ) -> T {
        let mut layer = Canvas::new();
        let result = f(&mut layer);
        for (i, mut other) in layer.overlays {
            let other = mem::take(Arc::make_mut(&mut other));
            let primitives = Arc::make_mut(self.overlays.entry(i).or_default());
            primitives.extend(other);
        }
        let primitives = Arc::make_mut(&mut self.primitives);
        primitives.push(Primitive::Layer {
            primitives: layer.primitives,
            transform,
            mask,
            view,
        });
        result
    }
    pub fn transformed<T>(&mut self, transform: Affine, f: impl FnOnce(&mut Self) -> T) -> T {
        self.layer(transform, None, None, f)
    }
    pub fn translated<T>(&mut self, translation: Vector, f: impl FnOnce(&mut Self) -> T) -> T {
        self.transformed(Affine::translate(translation), f)
    }
    pub fn rotated<T>(&mut self, angle: f32, f: impl FnOnce(&mut Self) -> T) -> T {
        self.transformed(Affine::rotate(angle), f)
    }
    pub fn scaled<T>(&mut self, scale: Vector, f: impl FnOnce(&mut Self) -> T) -> T {
        self.transformed(Affine::scale(scale), f)
    }
    pub fn masked<T>(&mut self, mask: Mask, f: impl FnOnce(&mut Self) -> T) -> T {
        self.layer(Affine::IDENTITY, Some(mask), None, f)
    }
    pub fn hoverable<T>(&mut self, view: ViewId, f: impl FnOnce(&mut Self) -> T) -> T {
        self.layer(Affine::IDENTITY, None, Some(view), f)
    }
    pub fn view_at(&self, point: Point) -> Option<ViewId> {
        fn recurse(primitives: &[Primitive], view: Option<ViewId>, point: Point) -> Option<ViewId> {
            for primitive in primitives.iter().rev() {
                match primitive {
                    Primitive::Fill { curve, fill, .. } => {
                        if view.is_none() {
                            continue;
                        }
                        if curve.contains(point, *fill) {
                            return view;
                        }
                    }
                    Primitive::Stroke { curve, stroke, .. } => {
                        if view.is_none() {
                            continue;
                        }
                        if !curve.bounds().expand(stroke.width).contains(point) {
                            continue;
                        }
                        let mut stroked = Curve::new();
                        stroked.stroke_curve(curve, *stroke);
                        if stroked.contains(point, FillRule::NonZero) {
                            return view;
                        }
                    }
                    Primitive::Paragraph { bounds, .. } => {
                        if view.is_none() {
                            continue;
                        }
                        if bounds.contains(point) {
                            return view;
                        }
                    }
                    Primitive::Layer {
                        primitives,
                        transform,
                        mask,
                        view: layer_view,
                    } => {
                        let point = transform.inverse() * point;
                        if let Some(mask) = mask {
                            if !mask.curve.contains(point, mask.fill) {
                                continue;
                            }
                        }
                        let view = match layer_view {
                            Some(view) => recurse(primitives, Some(*view), point),
                            None => recurse(primitives, view, point),
                        };
                        if view.is_some() {
                            return view;
                        }
                    }
                }
            }
            None
        }
        for primitives in self.overlays.values().rev() {
            if let Some(view) = recurse(primitives, None, point) {
                return Some(view);
            }
        }
        recurse(&self.primitives, None, point)
    }
}