import plotComponentFactory from "react-plotly.js/factory";
import Plotly from "plotly.js/dist/plotly";

import { useCallback, useEffect, useState } from "react";
import { QueryParser, get_pareto_front } from "../rust/componet/componet";
import { Components, Component } from "../proto/ts/componet.graph";
import { Affix } from "../proto/ts/componet";
import { COLUMNS } from "../utils/octopart";
import Alert from "./Alert";
import {
  Point,
  RequiredFilter,
  RequiredTrace,
  RequiredAxis,
  areTracesEqual,
} from "../utils/types";
import { YEARS } from "../utils/consts";
import {
  HoverConstants,
  LegendConstants,
  MarkerConstants,
  PlotConstants,
  SelectionConstants,
} from "./plot/consts";
import LeftSidebar from "./sidebar/left/LeftSidebar";
import RightSidebar from "./sidebar/right/RightSidebar";
import PlotInfo from "./PlotInfo";
import { OutlierTooltip } from "./OutlierTooltip";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faSpinner } from "@fortawesome/free-solid-svg-icons";
import log from "loglevel";
import { useAtom, useAtomValue } from "jotai";
import {
  plottedAxesAtom,
  plottedFiltersAtom,
  plottedParetoQuadrantsAtom,
  plottedPointsWithXAtom,
  plottedTracesAtom,
  selectedAxesAtom,
  selectedFiltersAtom,
  selectedTracesAtom,
} from "../common/store";
import { getAxisRange } from "./plot/utils";
import ReactGA from "react-ga4";

export default function GraphForm() {
  // Selection Parameters
  const [selectedAxes, setSelectedAxes] = useAtom(selectedAxesAtom);
  const [selectedTraces, setSelectedTraces] = useAtom(selectedTracesAtom);
  const [selectedFilters, setSelectedFilters] = useAtom(selectedFiltersAtom);

  // Data that corresponds to user selections for the plot
  const [plottedPointsWithX, setPlottedPointsWithX] = useAtom(
    plottedPointsWithXAtom
  );
  const plottedParetoQuadrants = useAtomValue(plottedParetoQuadrantsAtom);
  const [plottedAxes, setPlottedAxes] = useAtom(plottedAxesAtom);
  const [plottedTraces, setPlottedTraces] = useAtom(plottedTracesAtom);
  const [plottedFilters, setPlottedFilters] = useAtom(plottedFiltersAtom);

  // Actual plot data
  const [plotComponentData, setPlotComponentData] = useState<any[]>([]);
  const [plotParetoFrontData, setPlotParetoFrontData] = useState<any[]>([]);
  const [plotLayout, setPlotLayout] = useState<{ [key: string]: any }>(
    PlotConstants.defaultLayout
  );

  const Plot = plotComponentFactory(Plotly);

  // Loading data from the database
  const [loading, setLoading] = useState<boolean>(false);

  // Error messages
  const [alertMessage, setAlertMessage] = useState<string>();

  // If the plot data changes, update the selected points if they exist.
  useEffect(() => {
    if (plotComponentData.length === 0) return;
    plotComponentData.forEach((component, fullDataIdx) => {
      component.metadata.mpns.forEach((mpn: any, ptIdx: number) => {
        let matchingPoint = plottedPointsWithX.find(
          (markedPoint) => markedPoint.mpn === mpn
        );
        if (matchingPoint) {
          let index = fullDataIdx;
          // Update color, size, and symbol for this point
          plotComponentData[index].marker.color[ptIdx] =
            MarkerConstants.selectColor;
          plotComponentData[index].marker.size[ptIdx] =
            MarkerConstants.selectSize;
          plotComponentData[index].marker.symbol[ptIdx] =
            MarkerConstants.selectSymbol;
        }
      });
    });
  }, [plotComponentData, plottedPointsWithX]);

  const onPlotSelectPoint = (e?: any) => {
    console.log(e);
    if (!e || plotComponentData.length === 0) return;

    const fullDataIdx = e.points[0].data.metadata.componentIndex;
    const ptIdx = e.points[0].pointIndex;
    const mpn = plotComponentData[fullDataIdx].metadata.mpns[ptIdx];
    const manufacturer =
      plotComponentData[fullDataIdx].metadata.manufacturers[ptIdx];
    const point: Point = {
      ptIdx,
      fullDataIdx,
      mpn: mpn ?? "",
      color: plotComponentData[fullDataIdx].marker.color[ptIdx],
      manufacturer: manufacturer ?? "",
      link: `https://octopart.com/search?q=${mpn}&view=list`,
      xAxis: plotComponentData[fullDataIdx].x[ptIdx],
      yAxis: plotComponentData[fullDataIdx].y[ptIdx],
      xUnits: plotComponentData[fullDataIdx].metadata.unitX,
      yUnits: plotComponentData[fullDataIdx].metadata.unitY,
    };
    ReactGA.event("select_point", point);

    // Skip duplicates
    if (plottedPointsWithX.find((p) => p.mpn === point.mpn)) return;

    setPlottedPointsWithX((prev) => [...prev, point]);

    plotComponentData[fullDataIdx].marker.color[ptIdx] =
      MarkerConstants.selectColor;
    plotComponentData[fullDataIdx].marker.size[ptIdx] =
      MarkerConstants.selectSize;
    plotComponentData[fullDataIdx].marker.symbol[ptIdx] =
      MarkerConstants.selectSymbol;
  };

  const onPlotDeselectPoint = (point: Point, i: number) => {
    ReactGA.event("deselect_point", point);
    setPlottedPointsWithX((prev) => {
      plotComponentData[point.fullDataIdx].marker.color[point.ptIdx] =
        point.color;
      plotComponentData[point.fullDataIdx].marker.size[point.ptIdx] =
        MarkerConstants.normalSize;
      plotComponentData[point.fullDataIdx].marker.symbol[point.ptIdx] =
        MarkerConstants.normalSymbol;
      return prev.filter((_, index) => index !== i);
    });
  };

  const addToPlot = () => {
    // Create Axes from the different selections
    let axes: RequiredAxis[] = [];
    if (selectedAxes.length > 0) {
      for (const axis of selectedAxes) {
        if (!axis.name || !axis.kind || !axis.abbreviation) {
          log.info("Axis is missing a required field", axis);
          continue;
        }
        axes.push(axis as RequiredAxis);
      }
    }

    // Skip empty traces.
    let traces: RequiredTrace[] = [];
    if (selectedTraces.length > 0) {
      let alreadyAddedTraces: string[] = [];
      for (const trace of selectedTraces) {
        if (
          !trace.name ||
          !trace.visible ||
          !trace.abbreviation ||
          !trace.year ||
          !trace.visible
        ) {
          log.info("Trace is missing a required field", trace);
          continue;
        }
        if (
          plottedTraces.find(
            (t) => t.name === trace.name && t.year === trace.year
          )
        ) {
          log.info("Skipping duplicate trace", trace);
          continue;
        }
        let temporaryHash = `${trace.name}_${trace.year}`;
        if (!alreadyAddedTraces.includes(temporaryHash)) {
          alreadyAddedTraces.push(temporaryHash);
          traces.push(trace as RequiredTrace);
        }
      }
    }
    traces = [...plottedTraces, ...traces];

    // Since we don't know before merging how to color the traces, we recolor them here.
    traces.forEach(
      (trace, i) => (trace.color = MarkerConstants.traceColors[i])
    );

    // Skip empty filters.
    let filters: RequiredFilter[] = [];
    let currentlyPlottedFilters = [...plottedFilters];
    if (selectedFilters.length > 0) {
      let alreadyAddedFilters: string[] = [];
      for (const filter of selectedFilters) {
        // Explicitly check for undefined for `min` and `max`, since 0 is a valid value but evaluates to false.
        if (
          !filter.kind ||
          !filter.abbreviation ||
          !filter.name ||
          !filter.id ||
          filter.min === undefined ||
          filter.max === undefined
        ) {
          log.info("Filter is missing a required field", filter);
          continue;
        }
        if (!alreadyAddedFilters.includes(filter.name)) {
          alreadyAddedFilters.push(filter.name);
          currentlyPlottedFilters = currentlyPlottedFilters.filter(
            (f) => f.name !== filter.name
          );
          filters.push(filter as RequiredFilter);
        }
      }
    }
    filters = [...currentlyPlottedFilters, ...filters];

    if (axes.length !== 2 && plottedAxes.length !== 2) {
      setAlertMessage(`Please select both x and y axes to plot and try again.`);
      return;
    }

    if (traces.length < 1 && plottedTraces.length < 1) {
      setAlertMessage(
        `Please select at least one trace to plot and try again.`
      );
      return;
    }

    setPlottedAxes(axes);
    setPlottedTraces(traces);
    setPlottedFilters(filters);
    setAlertMessage("");

    setLoading(true);

    // Clear current selections.
    setSelectedFilters(SelectionConstants.selectedFilters);
    setSelectedTraces(SelectionConstants.selectedTraces);
  };

  const removeFromPlot = (trace: RequiredTrace) => {
    setLoading(true);
    setPlottedTraces((prev) => {
      const p = prev.filter((t) => !areTracesEqual(t, trace));
      p.forEach((t, i) => (t.color = MarkerConstants.traceColors[i]));

      // Handle the case for zero traces
      if (p.length === 0) {
        resetPlot();
      }

      return p;
    });
  };

  const resetPlot = () => {
    ReactGA.event("reset_plot");
    setSelectedAxes([]);
    setSelectedTraces(SelectionConstants.selectedTraces);
    setSelectedFilters(SelectionConstants.selectedFilters);

    setPlotComponentData([]);
    setPlotParetoFrontData([]);
    setPlotLayout(PlotConstants.defaultLayout);
    setPlottedPointsWithX([]);
    setLoading(false);
  };

  // Check if the return data is valid
  const isEmpty = (data: any) => {
    return (
      data &&
      Object.keys(data).length === 0 &&
      Object.getPrototypeOf(data) === Object.prototype
    );
  };

  const graphData = useCallback(
    (components: Component[]) => {
      let componentData: any[] = [];
      let paretoFrontData: any[] = [];
      const activeQuadrants = Object.entries(plottedParetoQuadrants)
        .filter(([, isActive]) => isActive)
        .map(([quadrantNumber]) => Number(quadrantNumber));

      components.forEach((component, i) => {
        const hoverText: string[] = component.mpns.map(
          (
            _,
            idx
          ) => `<span style='font-size: 12px; font-family: "CMU Serif", "Times New Roman", serif; color: black;'>&nbsp;&nbsp;MPN: <b style="color: #3B7EA1;">${component.mpns[idx]}</b>&nbsp;&nbsp;</span><br>
<span style='font-size: 12px; font-family: "CMU Serif", "Times New Roman", serif; color: black;'>&nbsp;&nbsp;Manufacturer: <b style="color: #3B7EA1;">${component.manufacturers[idx]}</b>&nbsp;&nbsp;</span><br>
<span style='font-size: 12px; font-family: "CMU Serif", "Times New Roman", serif; color: black;'>&nbsp;&nbsp;Year: <b style="color: #3B7EA1;">${component.year}</b>&nbsp;&nbsp;</span>
    `
        );

        const trace = plottedTraces.find(
          (t) => t.name === component.name && t.year === component.year
        );

        let abbreviation =
          COLUMNS.find((c) => c.name === component.name)?.abbreviation ??
          component.name;
        const plotSettings: { [key: string]: any } = {
          x: component.axes?.[0]?.data,
          y: component.axes?.[1]?.data,
          hoverlabel: {
            bgcolor: HoverConstants.backgroundColor,
            bordercolor: trace?.color,
          },
          text: hoverText,
          hovertemplate: `
<br>
<span style='font-family: "CMU Serif", "Times New Roman", serif; color: black; font-size: 16px;'>&nbsp;&nbsp;${abbreviation} &nbsp;&nbsp;</span><br><br>
%{text}<br><br>
<span style='font-family: "CMU Serif", "Times New Roman", serif; color: black; font-size: 12px;'>&nbsp;&nbsp;%{yaxis.title.text}: <b style="color: #3B7EA1;">%{y}</b>&nbsp;&nbsp;</span> <br>
<span style='font-family: "CMU Serif", "Times New Roman", serif; color: black; font-size: 12px;'>&nbsp;&nbsp;%{xaxis.title.text}: <b style="color: #3B7EA1;">%{x}</b>&nbsp;&nbsp;</span> <br>
<extra></extra>
  		`,
          name: `${abbreviation} [${component.year}]`
            .replace(" Capacitors", "")
            .replace(" Inductors", ""),
          type: "scattergl",
          mode: "markers",
          marker: {
            color: new Array(component.mpns.length).fill(
              trace?.color ?? MarkerConstants.defaultTraceColor
            ),
            size: new Array(component.mpns.length).fill(
              MarkerConstants.normalSize
            ),
            symbol: new Array(component.mpns.length).fill(
              MarkerConstants.normalSymbol
            ),
            line: {
              width: MarkerConstants.borderSize,
            },
            opacity: MarkerConstants.opacity,
          },
          visible: trace?.visible ?? "legendonly",
          zorder: i * 2,
          metadata: {
            type: "component",
            abbreviation,
            year: component.year,
            componentIndex: i,
            mpns: component.mpns,
            manufacturers: component.manufacturers,
            unitX: component.axes?.[0]?.unit,
            unitY: component.axes?.[1]?.unit,
          },
        };

        if (component.axes?.length > 2) {
          plotSettings["z"] = component.axes?.[2]?.data;
          plotSettings["type"] = "scatter3d";
        }
        componentData.push(plotSettings);

        // If no quadrants are active, we might want to skip adding the Pareto front for this component
        if (activeQuadrants.length !== 0) {
          const result = get_pareto_front(
            JSON.stringify(component),
            activeQuadrants
          );
          const x = result.get("x");
          const y = result.get("y");

          const sampleLineTrace = {
            x: x,
            y: y,
            type: "scattergl",
            mode: "lines",
            name: "Pareto Front",
            showlegend: false,
            hoverinfo: "skip",
            line: {
              color: trace?.color ?? MarkerConstants.defaultTraceColor,
              width: 2,
            },
            visible: trace?.visible ?? "legendonly",
            zorder: i * 2 + 1,
            metadata: {
              type: "pareto",
              abbreviation,
              year: component.year,
              componentIndex: i,
              mpns: component.mpns,
              manufacturers: component.manufacturers,
              unitX: component.axes?.[0]?.unit,
              unitY: component.axes?.[1]?.unit,
            },
          };
          paretoFrontData.push(sampleLineTrace);
        }
      });

      setPlotComponentData(componentData);
      setPlotParetoFrontData(paretoFrontData);
    },
    [plottedParetoQuadrants, plottedTraces]
  );

  const graphLayout = useCallback(
    (
      components: Component[],
      xmin: number,
      xmax: number,
      ymin: number,
      ymax: number
    ) => {
      const layout: { [key: string]: any } = {
        autosize: true,
        xaxis: {
          title: plottedAxes?.[0]?.name + ` [${components[0].axes[0]?.unit}]`,
          type: "log",
          autorange: false,
          range: [
            xmin - PlotConstants.logRangeBuffer,
            xmax + PlotConstants.logRangeBuffer,
          ],
          automargin: true,
          exponentformat: components[0].axes[0]?.computed ? "power" : "B",
          tickfont: {
            size: components[0].axes[0]?.computed
              ? PlotConstants.computedPropertyTickFontSize
              : PlotConstants.fontSize,
            color: PlotConstants.fontColor,
          },
          ticksuffix:
            components[0].axes[0]?.affix === Affix.SUFFIX &&
            !components[0].axes[0]?.computed
              ? components[0].axes[0]?.unit
              : "",
          tickprefix:
            components[0].axes[0]?.affix === Affix.PREFIX &&
            !components[0].axes[0]?.computed
              ? components[0].axes[0]?.unit
              : "",
          mirror: true,
          ticks: "inside",
          showline: true,
        },
        yaxis: {
          title: plottedAxes?.[1]?.name + ` [${components[0].axes[1]?.unit}]`,
          type: "log",
          autorange: false,
          range: [
            ymin - PlotConstants.logRangeBuffer,
            ymax + PlotConstants.logRangeBuffer,
          ],
          automargin: true,
          exponentformat: components[0].axes[1]?.computed ? "power" : "B",
          tickfont: {
            size: components[0].axes[1]?.computed
              ? PlotConstants.computedPropertyTickFontSize
              : PlotConstants.fontSize,
            color: PlotConstants.fontColor,
          },
          ticksuffix:
            components[0].axes[1]?.affix === Affix.SUFFIX &&
            !components[0].axes[1]?.computed
              ? components[0].axes[1]?.unit
              : "",
          tickprefix:
            components[0].axes[1]?.affix === Affix.PREFIX &&
            !components[0].axes[1]?.computed
              ? components[0].axes[1]?.unit
              : "",
          mirror: true,
          ticks: "inside",
          showline: true,
        },
        font: {
          family: PlotConstants.fontFamily,
          size: PlotConstants.fontSize,
          color: PlotConstants.fontColor,
        },
        showlegend: true,
        legend: {
          x: LegendConstants.location.x,
          y: LegendConstants.location.y,
          bordercolor: LegendConstants.borderColor,
          borderwidth: LegendConstants.borderWidth,
          itemsizing: "constant",
          marker: {
            size: LegendConstants.markerSize,
          },
        },
      };

      setPlotLayout(layout);
    },
    [plottedAxes]
  );

  const computePlot = () => {
    if (!plotComponentData || !plotLayout || plottedTraces.length === 0) {
      console.log("Returning early from computePlot");
      return;
    }

    const searchParams = new URLSearchParams();

    plottedTraces.forEach((trace) => {
      console.warn("Found octopart id for component: ", trace);
      const id = COLUMNS.find((t) => t.name === trace.name)?.id;
      if (id && trace.year) {
        searchParams.append("categories", id as unknown as string);
        searchParams.append("years", trace.year);
      } else {
        console.warn("Could not find octopart id for component: ", trace);
      }
    });

    if (plottedAxes.length > 0) {
      for (const axis of plottedAxes) {
        searchParams.append(
          "attributes",
          COLUMNS.find((c) => c.name === axis.name)?.column as unknown as string
        );
      }
    }

    if (plottedFilters.length > 0) {
      for (const filter of plottedFilters) {
        searchParams.append("filters", filter.id);
        searchParams.append("filtersMin", filter.min.toString());
        searchParams.append("filtersMax", filter.max.toString());
      }
    }

    console.log(searchParams.toString());
    fetch("/api?" + searchParams.toString())
      .then((res) => {
        console.log(res);
        return res.json();
      })
      .then((data) => {
        const componentString = QueryParser.parse(
          JSON.stringify(data),
          YEARS
        ) as unknown as string;

        // Convert the string to a Component object.
        if (!isEmpty(JSON.parse(componentString))) {
          const components = Components.fromJSON(
            JSON.parse(componentString)
          ).components;

          // Grab and min and max values from the data returned.
          let { min: xmin, max: xmax } = getAxisRange(components, 0);
          let { min: ymin, max: ymax } = getAxisRange(components, 1);

          // Rearrage components to match the axis order, since the API
          // is not guaranteed to return the components in the same order.
          const xAxisIndex = components[0].axes.findIndex(
            (axis) => axis.name === plottedAxes?.[0].name
          );
          const yAxisIndex = components[0].axes.findIndex(
            (axis) => axis.name === plottedAxes?.[1].name
          );
          if (xAxisIndex !== 0 && yAxisIndex !== 1) {
            components.forEach((component) => {
              const xAxis = component.axes[xAxisIndex];
              const yAxis = component.axes[yAxisIndex];
              component.axes[xAxisIndex] = yAxis;
              component.axes[yAxisIndex] = xAxis;
            });
            [xmin, xmax, ymin, ymax] = [ymin, ymax, xmin, xmax];
          }

          // Also rearrange the components to match the order of the plottedTraces
          components.sort((a, b) => {
            // We want the lowest index to appear on top, so we reverse the order
            // of the plottedTraces.
            const reversedTraces = [...plottedTraces].reverse();

            let aIdx = reversedTraces.findIndex((trace) => {
              return trace.name === a.name && trace.year === a.year;
            });

            let bIdx = reversedTraces.findIndex((trace) => {
              return trace.name === b.name && trace.year === b.year;
            });

            return aIdx - bIdx;
          });

          graphData(components);
          graphLayout(components, xmin, xmax, ymin, ymax);
        }
        setLoading(false);
      });
  };

  useEffect(() => {
    setLoading(true);
    if (plottedAxes.length === 0) {
      resetPlot();
      return;
    }
    ReactGA.event("update_plot", {
      plottedAxes,
      plottedTraces,
      plottedFilters,
      plottedParetoQuadrants,
    });
    computePlot();
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [plottedAxes, plottedTraces, plottedFilters, plottedParetoQuadrants]);

  return (
    <div
      className="grid grid-cols-12 grid-flow-col"
      style={{ height: "calc(100vh - 6rem)" }}
    >
      {/* 6rem is `h-24`, i.e. the height of the header */}
      <div className="col-span-3 border-r-2 border-black overflow-y-auto">
        <LeftSidebar handleRefreshPlot={addToPlot} />
      </div>
      <div className="col-span-6 overflow-y-auto">
        {alertMessage ? (
          <div className="col-span-3 px-8 py-4">
            <Alert message={alertMessage} onClose={() => setAlertMessage("")} />
          </div>
        ) : null}
        <div className="flex justify-center items-center w-full">
          <div
            className="w-full h-auto px-8"
            style={{
              maxWidth: "calc(100vh - 6rem)",
              height: "calc((100vh - 6rem) / 1.2)",
            }}
          >
            <Plot
              div="graph"
              data={Array.from(
                {
                  length: Math.max(
                    plotComponentData.length,
                    plotParetoFrontData.length
                  ),
                },
                (_, i) => [plotComponentData[i], plotParetoFrontData[i]]
              )
                .flat()
                .filter((x) => x !== undefined)}
              layout={plotLayout}
              onClick={onPlotSelectPoint}
              useResizeHandler="true"
              config={{
                displaylogo: false,
                responsive: true,
              }}
              className="h-full w-full"
            />
            {loading ? (
              <div className="absolute inset-0 flex items-center justify-center bg-gray-500 bg-opacity-50 z-50">
                <FontAwesomeIcon
                  icon={faSpinner}
                  className="text-white text-3xl animate-spin"
                />
              </div>
            ) : null}
          </div>
        </div>
        <div className="col-span-3 px-8 py-4">
          {plotComponentData && plotComponentData?.length > 0 ? (
            <PlotInfo components={plotComponentData} />
          ) : null}
        </div>
        <div className="col-span-3 px-8 pb-4">
          <OutlierTooltip />
        </div>
      </div>
      <div className="col-span-3 border-l-2 border-black overflow-y-auto">
        <RightSidebar
          handleRemoveTrace={removeFromPlot}
          handleRemovePoint={onPlotDeselectPoint}
        />
      </div>
    </div>
  );
}
