import { SimulationNodeDatum, SimulationLinkDatum } from "d3";
import * as d3 from "d3";

const highlightedStroke = "#000000";
const highlightedStrokeWidth = 2.5;
const baseStrokeWidth = 1.5;

// https://observablehq.com/@borowski-9ld/d3-force-directed-graph (check license for zoom)
// https://levelup.gitconnected.com/creating-a-force-graph-using-react-and-d3-6579bcd9628c
// // https://observablehq.com/@d3/force-directed-graph
type Shape = "+" | "circle" | "rect" | "triangle" | "diamond";
export interface GraphVizNode<T> {
  id: number;
  label: string;
  size: number;
  color: string;
  shape: Shape;
  data: T;
}
export interface GraphVizEdge<T> {
  id: number;
  source: number;
  target: number;
  data: T;
}

interface Node<N> extends SimulationNodeDatum {
  key: number;
  node: GraphVizNode<N>;
}
interface Link<N, E> extends SimulationLinkDatum<Node<N>> {
  key: number;
  source: Node<N>;
  target: Node<N>;
  edge: GraphVizEdge<E>;
}

function edge2Link<N, E>(nodeMap: Map<number, Node<N>>) {
  return (edge: GraphVizEdge<E>): Link<N, E> => {
    const source = nodeMap.get(edge.source);
    const target = nodeMap.get(edge.target);
    if (!source || !target) throw new Error("Source or target node not found");
    return {
      key: edge.id,
      source: source,
      target: target,
      edge: edge,
    };
  };
}

function chooseSymbol(shape: Shape) {
  switch (shape) {
    case "+":
      return d3.symbolWye;
    case "circle":
      return d3.symbolCircle;
    case "rect":
      return d3.symbolSquare;
    case "triangle":
      return d3.symbolTriangle;
    case "diamond":
      return d3.symbolDiamond;
    default:
      return d3.symbolCircle;
  }
}

const generateGraph = <N, E>(
  nodes: GraphVizNode<N>[],
  edges: GraphVizEdge<E>[]
) => {
  const _nodes: Node<N>[] = nodes
    .slice()
    .map((node) => ({ key: node.id, node: node }));
  const nodeMap = new Map(_nodes.map((n) => [n.node.id, n]));
  const _links: Link<N, E>[] = edges.slice().map(edge2Link(nodeMap));

  const nodeEdgeMap = new Map<Node<N>, Link<N, E>[]>();
  _links.forEach((link) => {
    const source = nodeMap.get(link.source.key);
    const target = nodeMap.get(link.target.key);
    if (!source || !target) throw new Error("Source or target node not found");

    const sourceLinks = [...(nodeEdgeMap.get(source) || []), link];
    const targetLinks = [...(nodeEdgeMap.get(target) || []), link];
    nodeEdgeMap.set(source, sourceLinks);
    nodeEdgeMap.set(target, targetLinks);
  });

  return {
    getLinks: (node: Node<N>): Link<N, E>[] => {
      return nodeEdgeMap.get(node) || [];
    },
    getLinkedNodes: (node: Node<N>): Node<N>[] => {
      const links = nodeEdgeMap.get(node) || [];
      const linkedNodes = links.flatMap((link) => {
        const target = nodeMap.get(link.edge.target);
        const source = nodeMap.get(link.edge.source);
        if (!target || !source)
          throw new Error("Source or target node not found");
        return [source, target];
      });
      return Array.from(new Set(linkedNodes.filter((n) => n.key !== node.key)));
    },
    getAllNodes: () => _nodes,
    getAllLinks: () => _links,
  };
};

export function runForceGraph<N, E>(
  container: HTMLDivElement,
  nodes: GraphVizNode<N>[],
  edges: GraphVizEdge<E>[],
  onMouseOver?: (node: GraphVizNode<N>) => void,
  onMouseOut?: (node: GraphVizNode<N>) => void,
  onSelected?: (connected: GraphVizNode<N>[], edges: GraphVizEdge<E>[]) => void
) {
  const containerRect = container.getBoundingClientRect();
  const height = containerRect.height;
  const width = containerRect.width;

  const graph = generateGraph(nodes, edges);
  let _nodes: Node<N>[] = graph.getAllNodes();
  let _links: Link<N, E>[] = graph.getAllLinks();

  const simulation = d3
    .forceSimulation(_nodes)
    .force("charge", d3.forceManyBody())
    .force("link", d3.forceLink(_links))
    .force("x", d3.forceX(width / 2))
    .force("y", d3.forceY(height / 2))
    .on("tick", ticked);

  const svg = d3
    .select(container)
    .selectAll("svg")
    .data([true])
    .join("svg")
    .attr("viewBox", [-width / 2, -height / 2, width, height]);

  const graphDiv = svg.selectAll("g").data([true]).join("g");

  // clear all highlights on click
  svg.on("click", () => {
    selectedNodeKeys.clear();
    selectedLinkKeys.clear();
    node.attr("stroke", null);
    link.attr("stroke-width", baseStrokeWidth);
    node.attr("visibility", "visible");
    link.attr("visibility", "visible");
    if (onSelected) onSelected([], []);
  });

  const link = graphDiv
    .selectAll("line")
    .data(_links, (d: any) => d.key)
    .join("line")
    .attr("stroke", "#000000")
    .attr("stroke-opacity", 1)
    .attr("stroke-width", baseStrokeWidth)
    .attr("stroke-linecap", "round");

  let selectedNodeKeys = new Set<number>();
  let selectedLinkKeys = new Set<number>();
  const node = graphDiv
    .selectAll<SVGElement, Node<N>>("path")
    .data(_nodes, (d: Node<N>) => d.key)
    .join("path")
    .attr(
      "d",
      d3
        .symbol()
        .size((d) => d.node.size * 100)
        .type((d) => chooseSymbol(d.node.shape))
    )
    .attr("fill", (d) => d.node.color)
    .attr("stroke-opacity", 1)
    .attr("stroke-width", baseStrokeWidth)
    .attr("transform", (d) => `translate(${d.x},${d.y})`)
    .attr("stroke", null)
    .on("mouseover", function (event, d) {
      d3.select(this).attr("stroke", highlightedStroke);
      onMouseOver && onMouseOver(d.node);
    })
    .on("mouseout", function (event, d) {
      if (!selectedNodeKeys.has(d.key)) {
        d3.select(this).attr("stroke", null);
      }
      onMouseOut && onMouseOut(d.node);
    })
    .on("click", function (event, d) {
      selectedNodeKeys = new Set([
        d.key,
        ...graph.getLinkedNodes(d).map((n) => n.key),
      ]);
      selectedLinkKeys = new Set(graph.getLinks(d).map((e) => e.key));

      // highlight selected nodes
      node.attr("stroke", (d) =>
        selectedNodeKeys.has(d.key) ? highlightedStroke : null
      );
      link.attr("stroke-width", (d) => {
        return selectedLinkKeys.has(d.key)
          ? highlightedStrokeWidth
          : baseStrokeWidth;
      });

      // dim unselected nodes
      node.attr("visibility", (d) =>
        selectedNodeKeys.has(d.key) ? "visible" : "hidden"
      );
      link.attr("visibility", (d) =>
        selectedLinkKeys.has(d.key) ? "visible" : "hidden"
      );
      onSelected &&
        onSelected(
          [...graph.getLinkedNodes(d).map((n) => n.node), d.node],
          graph.getLinks(d).map((e) => e.edge)
        );

      event.stopPropagation();
    })
    // @ts-ignore
    .call(drag(simulation));

  const zoom = d3
    .zoom<SVGSVGElement, unknown>()
    .scaleExtent([1 / 2, 64])
    .on("zoom", function (event) {
      graphDiv.attr("transform", event.transform);
    });

  // @ts-ignore
  svg.call(zoom).call(zoom.translateTo, width / 2, height / 2);

  function drag(simulation: any) {
    function dragstarted(event: any) {
      if (!event.active) simulation.alphaTarget(0.3).restart();
      event.subject.fx = event.subject.x;
      event.subject.fy = event.subject.y;
    }

    function dragged(event: any) {
      event.subject.fx = event.x;
      event.subject.fy = event.y;
    }

    function dragended(event: any) {
      if (!event.active) simulation.alphaTarget(0);
      event.subject.fx = null;
      event.subject.fy = null;
    }

    return d3
      .drag()
      .on("start", dragstarted)
      .on("drag", dragged)
      .on("end", dragended);
  }

  function ticked() {
    link
      .attr("x1", (d) => (d.source.x ? d.source.x : 0))
      .attr("y1", (d) => (d.source.y ? d.source.y : 0))
      .attr("x2", (d) => (d.target.x ? d.target.x : 0))
      .attr("y2", (d) => (d.target.y ? d.target.y : 0));

    node.attr("transform", function (d) {
      return "translate(" + d.x + "," + d.y + ")";
    });
  }

  return {
    destroy: () => {
      simulation.stop();
    },
    nodes: () => {
      return svg.node();
    },
  };
}
