import React, { useState, useRef, useEffect } from "react";
import { Element } from "../Element";
import { DataFrame } from "../../DSL/DataFrame";
import * as d3 from "d3";

function kernelDensityEstimator(kernel, X) {
    return function(V) {
        return X.map(function(x) {
            return [x, d3.mean(V, function(v) { return kernel(x - v); })];
        });
    };
}
function kernelEpanechnikov(k) {
    return function(v) {
        const vk = v / k;
        return Math.abs(vk) <= 1 ? 0.75 * (1 - vk * vk) / k : 0;
    };
}

function addZeroBoundaries(kernel) {
    return [[kernel[0][0], 0]].concat(kernel).concat([[kernel[kernel.length - 1][0], 0]]);
}

const addLegend = (svg, categories, legendSize, categoryToColorFunction, xOffset = 24, yOffset = 12) => {
    function atLeastOneSelected() {
        return (d3.selectAll(".kdeselected").size() !== 0);
    }

    function handleMouseOver(e, d) {
        d3.selectAll(".kdelement").classed("kdeactive", false);
        ["plot", "label", "dot"].forEach(e => d3.select(`[id="${e}_${d}"]`).classed("kdeactive", true));
    }

    function handleMouseOut() {
        d3.selectAll(".kdelement").classed("kdeactive", !atLeastOneSelected());
    }

    function handleClick(e, d) {
        // 2 modes: active and selected. On click, select. On hover, make active.
        // If none are selected and we are not hovering, set all to active. If at least one is selected, and we are not
        // hovering, set all to inactive.
        ["plot", "label", "dot"].forEach(e => d3.select(`[id="${e}_${d}"]`).classed(
            "kdeselected",
            !d3.select(`[id="${e}_${d}"]`).classed("kdeselected")));
        d3.selectAll(".kdelement").classed("kdeactive", !atLeastOneSelected());
    }

    const dots = svg.selectAll("mydots")
        .data(categories)
        .enter()
        .append("rect")
        .attr("class", "kdelement")
        .attr("x", xOffset)
        .attr("y", function(d, i) { return yOffset + i * (legendSize + 5); })
        .attr("width", legendSize)
        .attr("height", legendSize)
        .attr("id", (d) => `dot_${d}`)
        .style("fill", categoryToColorFunction);

    const labels = svg.selectAll("mylabels")
        .data(categories)
        .enter()
        .append("text")
        .attr("class", "kdelement")
        .attr("x", xOffset + legendSize * 1.2)
        .attr("y", function(d, i) { return yOffset + i * (legendSize + 5) + (legendSize / 2); })
        .style("fill", categoryToColorFunction)
        .style("stroke", "black")
        .style("stroke-width", "0.5px")
        .text((d) => d)
        .attr("text-anchor", "left")
        .attr("id", (d) => `label_${d}`)
        .style("alignment-baseline", "middle")
        .style("cursor", "pointer")
        .on("mouseover", handleMouseOver)
        .on("mouseout", handleMouseOut)
        .on("click", handleClick)
    ;
    return [dots, labels];
};

function getAxisX(width) {
    return d3.scaleLinear()
        .domain([0, 100]) // default for now
        .range([0, width]);
}

function getAxisY(height, df, groupByCol, kde) {
    let curves;
    if (groupByCol) {
        curves = [...df.bucketize(groupByCol).values()]
            .map((subDf) => addZeroBoundaries(kde(subDf.getColumn("value"))));
    } else {
        curves = [addZeroBoundaries(kde(df.getColumn("value")))];
    }
    const maxY = Math.max(...curves.map(c0 => c0.map(c1 => c1[1])).flat());
    const y = d3.scaleLinear()
        .range([height, 0])
        .domain([0, maxY * 2.2]);
    return [y, curves];
}

function getAxisCurveData(df, width, height, groupByCol) {
    const xAxis = getAxisX(width);
    const kde = kernelDensityEstimator(kernelEpanechnikov(7), xAxis.ticks(60));
    const [yAxis, curves] = getAxisY(height, df, groupByCol, kde);
    return { xAxis, yAxis, curves };
}

function renderCurves(svg, categories, { xAxis, yAxis, curves }, categoryToColorFunction) {
    categories
        .forEach((categoryName, i) => {
            svg.append("path")
                .attr("class", "kdelement")
                .datum(curves[parseInt(i, 10)])
                .attr("fill", categoryToColorFunction(categoryName)) // get color from somewhere
                .attr("stroke", "#000")
                .attr("stroke-width", 1)
                .attr("stroke-linejoin", "round")
                .attr("d", d3.line()
                    .curve(d3.curveBasis)
                    .x((d) => xAxis(d[0]))
                    .y((d) => yAxis(d[1]))
                )
                .attr("id", `plot_${categoryName}`)
                .append("title")
                .text(categoryName);
        });
}

const ElementChartKDEplot = (props) => {
    const [data, setData] = useState(new DataFrame([]));
    const { "data": dataArgs, colorMap } = props;
    const d3Container = useRef(null);
    const categoryToColorFunction = (categoryName) => (colorMap?.get(categoryName) || "#69b3a2");

    /*
    * This is a kernel density plot (coloured area under curve).
    *
    * props:
    *   colorMap        -   If present, maps category to column.
    *
    * dataArgs:
    *   groupByCol      -   If present, the column by which to group the contents of the curves.
    *
    * */

    useEffect(() => {
        if (!data.isEmpty() && d3Container.current) {
            const df = data.copy();
            const groupByCol = dataArgs.groupByCol;
            const margin = {
                top: 30,
                right: 30,
                bottom: 30,
                left: 50
            };
            const legendSize = 20;
            const width = 640 - margin.left - margin.right;
            const height = 400 - margin.top - margin.bottom;
            const svg = d3.select(d3Container.current)
                .attr("width", width + margin.left + margin.right)
                .attr("height", height + margin.top + margin.bottom)
                .append("g")
                .attr("transform", `translate(${margin.left},${margin.top})`);
            const axisCurveData = getAxisCurveData(df, width, height, groupByCol);

            svg.append("g")
                .attr("transform", `translate(0, ${height})`)
                .call(d3.axisBottom(axisCurveData.xAxis));
            svg.append("g")
                .call(d3.axisLeft(axisCurveData.yAxis));
            const categories = groupByCol ? df.uniques(groupByCol) : ["data"];
            renderCurves(svg, categories, axisCurveData, categoryToColorFunction);
            addLegend(svg, categories, legendSize, categoryToColorFunction);
        }
        // init active state
        d3.selectAll(".kdelement").classed("kdeactive", true);
    }, [data, d3Container.current]);

    return (
        <React.Fragment>
            <Element
                {...props}
                width="100%"
                primary
                setData={setData}>
            </Element>
            <svg ref={d3Container}/>
        </React.Fragment>
    );
};

export { ElementChartKDEplot as KDEplot };
