import classNames from "classnames";
import { FC, ReactNode, useMemo } from "react";
import {
  CartesianGrid,
  ReferenceLine,
  ResponsiveContainer,
  Scatter,
  ScatterChart,
  Tooltip,
  XAxis,
  YAxis,
  ZAxis,
} from "recharts";
import { ClassNames } from "../classes";
import { FILL_COLORS } from "./common";
import { formatToIndianCurrency } from "../../utils/functions";

type DataPoint = {
  x: number;
  y: number;
  label?: string;
  cluster?: number;
  value?: number;
};

const kMeansClustering = (data: DataPoint[], k: number) => {
  const centroids = data.slice(0, k).map((point) => point.x);
  const clusters: number[] = Array(data.length).fill(-1);

  let hasChanged = true;

  while (hasChanged) {
    hasChanged = false;

    for (let i = 0; i < data.length; i++) {
      const distances = centroids.map((c) => Math.abs(data[i].x - c));
      const closestCentroid = distances.indexOf(Math.min(...distances));
      if (clusters[i] !== closestCentroid) {
        clusters[i] = closestCentroid;
        hasChanged = true;
      }
    }

    for (let i = 0; i < k; i++) {
      const clusterPoints = data.filter((_, idx) => clusters[idx] === i);
      if (clusterPoints.length > 0) {
        centroids[i] =
          clusterPoints.reduce((sum, p) => sum + p.x, 0) / clusterPoints.length;
      }
    }
  }

  return { clusters, centroids };
};

export const ScatterPlotWithBuckets: FC<{ data: DataPoint[]; sort?: ("x"|"y")[], xLabel?: string, yLabel?: string, tooltip?: (point: DataPoint) => ReactNode }> = ({
  data: inputData,
  sort = [],
  xLabel,
  yLabel,
  tooltip,
}) => {
  const { boundaries, clusteredData } = useMemo(() => {
    let sortedData = inputData;
    if (sort.length > 0) {
      sortedData = [...inputData].sort((a, b) => a[sort[0]] - b[sort[0]]);
    }
    const { clusters, centroids } = kMeansClustering(sortedData, 3);
    const clusteredData = sortedData.map((point, idx) => ({
      ...point,
      cluster: clusters[idx],
    }));

    const boundaries = centroids
      .sort((a, b) => a - b)
      .slice(1);

    return { boundaries, clusteredData };
  }, [inputData, sort]);

  const { minVal, maxVal } = useMemo(() => {
    const minVal = Math.min(...clusteredData.map(datum => datum.value ?? 0));
    const maxVal = Math.max(...clusteredData.map(datum => datum.value ?? 0));
    return { minVal, maxVal };
  }, [clusteredData]);
  

  return (
    <div className="h-full w-full">
      <ResponsiveContainer className="h-full w-full">
        <ScatterChart>
          <CartesianGrid className="stroke-white/10" />
          <XAxis
            dataKey="x"
            type="number"
            domain={["dataMin", "dataMax"]}
            label={{ value: xLabel, position: "insideBottom", offset: -5 }}
          />
          <YAxis
            dataKey="y"
            type="number"
            domain={["dataMin", "dataMax"]}
            label={{ value: yLabel, angle: -90, position: "insideLeft" }}
          />
          <Tooltip
            content={({ active, payload }) => {
              if (active && payload && payload.length) {
                const { x, y, label } = payload[0].payload;
                if (tooltip != null) {
                  return <div className={ClassNames.ChartTooltip}>
                    {tooltip(payload[0].payload)}
                  </div>
                }
                return <div className={ClassNames.ChartTooltip}>
                    {label && <strong>{label}</strong>}
                    <div>X: {x}</div>
                    <div>Y: {y}</div>
                  </div>
              }
              return null;
            }}
          />
          <ZAxis type="number" dataKey="value" domain={[minVal, maxVal]} range={[15, 255]} />
          <Scatter
            className={FILL_COLORS[0]}
            name="Data Points"
            data={clusteredData}
          />
          {boundaries.map((boundary, idx) => (
            <ReferenceLine
              key={`boundaries-${idx}`}
              x={boundary}
              stroke="rgba(0,255,0,0.3)"
              strokeWidth={2}
              strokeDasharray="5 5"
            />
          ))}
        </ScatterChart>
      </ResponsiveContainer>
      <div className="flex gap-4 items-center mb-4 justify-center">
        {boundaries.map((boundary, i) => (
          <div className={classNames(ClassNames.Text, "text-sm")}><span className="font-bold">Classification Point {i+1}:</span> {formatToIndianCurrency(boundary ?? 0)}</div>
        ))}
      </div>
    </div>
  );
};
