import cloneDeep from 'lodash/cloneDeep'
import { getIncomers, getOutgoers } from 'react-flow-renderer'

import {
  getLevelY,
  getElementById,
  getFragmentY,
  getLevelX,
  getFragmentX,
  getInitialX,
} from 'components/MotionTarget/motionTarget.utils'
import { elementsSeparator } from 'services/Utils/motionHelpers/motionHelpers.utils'

import type { Elements, Node, XYPosition } from 'react-flow-renderer'

import type { SegmentBuilderData } from 'models/motion/motionBuilder.model'
import type { NodePayload } from 'models/motion.model'
import { NodeTypeEnum } from 'models/motion.model'

interface NodeLimit {
  id: string
  minX: number
  maxX: number
  minY: number
  maxY: number
}

interface ChildNodesWithRestrictionProps {
  mergeNode: Node
  parentMainBranch: Node
  elements: Elements
  restrictedNodes: Node[]
}

/**
 * Optimize the positions of the nodes and edges in the motion builder.
 * The use of cloneDeep cannot be replaced with structuredClone because of a circular reference issue.
 * @param elements The elements to optimize.
 * @returns The optimized elements.
 */
export const optimizePositions = (elements: Elements<{ isInitial: boolean; isFinal: boolean }>) => {
  const { edges, nodes, loopEdges } = elementsSeparator<
    { type: string; payload: NodePayload; isInitial: boolean; isFinal: boolean },
    { edgeLabel: string; edgeType: NodeTypeEnum }
  >(cloneDeep(elements))
  const elementsWithoutLoops = [...edges, ...nodes]
  const firstNode = nodes.find((node) => node?.data?.isInitial)
  if (!firstNode) {
    return elements
  }

  const mainBranchEnd = nodes.find((node) => node?.data?.isFinal && node.position.x === firstNode.position.x)

  if (!mainBranchEnd) {
    return elements
  }
  // get from bottom to top
  // update Y position
  updateYPosition(firstNode, elementsWithoutLoops)
  // update X
  updateXPosition(firstNode, elementsWithoutLoops, mainBranchEnd)

  return [...elementsWithoutLoops, ...loopEdges]
}

/**
 * Iterate through all nodes and updates the Y position based on parent node or lowest node (merge).
 * The Traverser algorithm is DFS, the left child is pushed before the right child to make sure that the right subtree is processed first.
 * @param {Node} firstNode
 * @param {Elements} elementsWithoutLoops
 */
const updateYPosition = (firstNode: Node, elementsWithoutLoops: Elements) => {
  const { edges, nodes } = elementsSeparator<
    { type: string; payload: NodePayload },
    { edgeLabel: string; edgeType: NodeTypeEnum }
  >(elementsWithoutLoops)
  const stack = [firstNode]

  while (stack.length) {
    const currentNode = stack.pop() as Node<SegmentBuilderData>
    // Returns all direct child nodes of the passed node.
    const outgoers = getOutgoers(currentNode, elementsWithoutLoops) as Node<SegmentBuilderData>[]
    // visually left is down
    const left = outgoers.find((outgoer) => outgoer.position.x === currentNode.position.x) ?? null
    // visually right is right/ no
    const right = outgoers.find((outgoer) => outgoer.position.x > currentNode.position.x) ?? null

    if (left) stack.push(left)
    if (right) stack.push(right)

    const parentNodeEdge = edges.find(
      (edge) => edge.target === currentNode.id && edge?.data?.edgeType !== NodeTypeEnum.Merge,
    )
    const parentNode = nodes.find((node) => node.id === parentNodeEdge?.source)
    if (!parentNode) {
      // if doesn't have parent it's the initial segment
      continue
    }

    const difference = getLevelY(currentNode.position.y) - getLevelY(parentNode.position.y)
    if (difference !== 1) {
      currentNode.position.y = parentNode.position.y + getFragmentY()
    }

    if (isConfiguredMerge(currentNode) && parentNode.position.x === currentNode.position.x) {
      const mergeParents = edges
        .filter((edge) => edge.target === currentNode.id)
        .map((edge) => getElementById(elementsWithoutLoops, edge.source) as Node)
        .filter((node: Node | undefined) => !!node)

      const mergeTargetMaxX = mergeParents.reduce((prev, curr) => (curr.position.x > prev.position.x ? curr : prev))

      const innerMergeNodes = getInnerMergeNodes(currentNode, mergeTargetMaxX, elementsWithoutLoops)

      const innerMergeMaxY = innerMergeNodes.reduce((prev, curr) => (curr.position.y > prev.position.y ? curr : prev), {
        position: { y: 0, x: 0 },
      })

      const maxY = Math.max(innerMergeMaxY.position.y, mergeTargetMaxX.position.y, parentNode.position.y)

      currentNode.position.y = maxY + getFragmentY()
    }
  }
}

/**
 * Update x position of the elements to optimize the space between nodes
 * @param root start node
 * @param elementsWithoutLoops array of nodes and edges
 * @param endNode last node on the main branch
 */
const updateXPosition = (root: Node, elementsWithoutLoops: Elements, endNode: Node) => {
  const stack = [root]
  const limits: NodeLimit[] = [
    {
      id: root.id,
      minX: root.position.x,
      maxX: root.position.x,
      minY: root.position.y,
      maxY: endNode.position.y,
    },
    // initial limit
  ]

  while (stack.length) {
    const curr = stack.pop() as Node
    // Returns all direct child nodes of the passed node.
    const outgoers = getOutgoers(curr, elementsWithoutLoops)
    // Returns all direct incoming nodes of the passed node.
    const incommers = getIncomers(curr, elementsWithoutLoops) as Node<SegmentBuilderData>[]

    if (incommers.length === 1) {
      const parentNode = incommers[0]
      // is current node first branch
      const isBranchFirstNoChildNode =
        parentNode.data?.type === NodeTypeEnum.Branch && parentNode.position.x < curr.position.x

      if (isBranchFirstNoChildNode) {
        const lastXNode = getLastNode(curr, elementsWithoutLoops)
        const lastNodeOutgoers = getOutgoers(lastXNode, elementsWithoutLoops) as Node<SegmentBuilderData>[]
        const isMergeTarget = lastNodeOutgoers.find((node) => node.data?.type === NodeTypeEnum.Merge)
        const lastXNodePosition = lastXNode.position
        const isOverlapping = !!getOverlapLimit(limits, lastXNodePosition)

        if (isOverlapping) {
          // if is overlapping we have to move nodes to the right
          overlapUpdateNodes(limits, lastXNodePosition, parentNode, elementsWithoutLoops)
        } else {
          // if doesn't overlap check if we can move node to the left
          const overlapPosition = getOverlapPosition(lastXNodePosition, parentNode.position.x, limits)
          // we manage to overlap change back to version that doesnt

          const nodesToUpdate = getBranchNodesToUpdate(parentNode, elementsWithoutLoops)
          const difference = getXLvlToUpdate(overlapPosition.x, lastXNodePosition.x)

          if (difference) {
            nodesToUpdate.forEach((node: Node) => {
              node.position.x = node.position.x + getFragmentX() * difference
            })
          }
        }

        if (isMergeTarget) {
          const innerNodes = getInnerMergeNodes(isMergeTarget, lastXNode, elementsWithoutLoops)
          if (innerNodes.length) {
            const nodeLimit = innerNodes.reduce((prev, curr) => (curr.position.x > prev.position.x ? curr : prev))

            const mergeLimit = {
              id: nodeLimit.id,
              minX: parentNode.position.x,
              maxX: nodeLimit.position.x,
              minY: nodeLimit.position.y + getFragmentY(),
              maxY: isMergeTarget.position.y,
            }
            const imaginaryLastPosition = {
              x: curr.position.x,
              y: isMergeTarget.position.y,
            }

            overlapUpdateNodes([mergeLimit], imaginaryLastPosition, parentNode, elementsWithoutLoops)
          }
        }

        limits.push({
          id: curr.id,
          minX: parentNode.position.x,
          maxX: curr.position.x,
          minY: curr.position.y - getFragmentY(),
          maxY: lastXNode.position.y,
        })
      }
    }

    // visually left is down
    const left = outgoers.find((outgoer) => outgoer.position.x === curr.position.x) ?? null
    // visually right is right/ no
    const right = outgoers.find((outgoer) => outgoer.position.x > curr.position.x) ?? null

    if (right) {
      stack.push(right)
    }
    if (left) {
      stack.push(left)
    }
  }
}
/**
 * Get the nodes that are inside a merge node and are not part of any other branch.
 * @param {Node} mergeNode - The merge node to check.
 * @param {Node} target - The target node of the merge.
 * @param {Elements} elements - An array of elements.
 * @returns {Node[]} An array of nodes that are inside the merge node and are not part of any other branch.
 */
export const getInnerMergeNodes = (mergeNode: Node, target: Node, elements: Elements): Node[] => {
  const parentMainBranch = getMergeMainParentBranch(mergeNode, target, elements)

  if (!parentMainBranch) {
    return []
  }

  const noBranchNodes = getNoBranchesNodes(mergeNode, elements, parentMainBranch, target)
  const innerNodes = getChildNodesWithRestriction({
    mergeNode,
    parentMainBranch,
    elements,
    restrictedNodes: noBranchNodes,
  })

  return innerNodes
}

/**
 * Finds the main parent branch of a given merge node and target node.
 * @param {Node} mergeNode - The merge node to search for.
 * @param {Node} target - The target node of the merge.
 * @param {Elements} elements - An array of elements.
 * @returns {Node|null} The main parent branch of the merge node, or null if it was not found.
 */
const getMergeMainParentBranch = (mergeNode: Node, target: Node, elements: Elements): Node | null => {
  let parentMainBranch: Node | null = null

  const traverseParent = (nodes: Node<SegmentBuilderData>[]) => {
    const parentNode = nodes.reduce((prev, curr) => (prev.position.x < curr.position.x ? prev : curr))
    const isBranch = parentNode.data?.type === NodeTypeEnum.Branch
    if ((isBranch && parentNode?.position.x <= mergeNode.position.x) || parentNode?.position.x === getInitialX()) {
      parentMainBranch = parentNode
      return
    }

    traverseParent(getIncomers(parentNode, elements) as Node<SegmentBuilderData>[])
  }
  traverseParent(getIncomers(target, elements) as Node<SegmentBuilderData>[])

  return parentMainBranch
}
// dfs starting from top to bottom
// used to get child nodes of a branch with restrictions
const getChildNodesWithRestriction = ({
  mergeNode,
  parentMainBranch,
  elements,
  restrictedNodes,
}: ChildNodesWithRestrictionProps) => {
  const stack = [parentMainBranch],
    traversed = []
  let curr: Node

  while (stack.length) {
    curr = stack.pop() as Node
    // Returns all direct child nodes of the passed node.
    const outgoers = getOutgoers(curr, elements)
    // Returns all direct incoming nodes of the passed node.
    const incommers = getIncomers(curr, elements)
    const parentNode = incommers[0]

    const currentNodeX = curr.position.x
    // visually left is down
    const left = outgoers.find((outgoer) => outgoer.position.x === currentNodeX) ?? null
    // visually right is right/ no
    const right = outgoers.find((outgoer) => outgoer.position.x > currentNodeX) ?? null
    const isReachedMerge = curr.id === mergeNode.id
    const skipBranch = isNodeInArray(restrictedNodes, parentNode.id) && parentNode.position.x !== curr.position.x

    const isInvalidInner = isNodeInArray([...restrictedNodes, mergeNode], curr.id)
    const isMainParentBranch = parentMainBranch.id === curr.id

    if (!skipBranch && !isInvalidInner && !isMainParentBranch) {
      traversed.push(curr)
    }

    if (isReachedMerge) {
      return traversed
    }
    if (left && !isReachedMerge) {
      stack.push(left)
    }

    if (right && !skipBranch && !isInvalidInner) {
      // skip going to no path if the node is on list of restricted nodes
      stack.push(right)
    }
  }

  return traversed
}

const isNodeInArray = (nodes: Node[], id: string) => {
  return nodes.some((node) => node.id === id)
}

/**
 * Returns the nodes of No branches between a merge
 * Goes from bottom to top
 * @param mergeNode - The merge node
 * @param elements - The array of all elements
 * @param stopNodeId - The node id to stop
 * @returns an array of nodes that are from the no branch
 */
export const getNoBranchesNodes = (mergeNode: Node, elements: Elements, stopNode: Node, target: Node): Node[] => {
  const chainOfChildrens: Node[] = []

  const traverse = (incommers: Node[], parentNode?: Node) => {
    const isTargetInIncommers = incommers.find((node) => node.id === target.id)

    if (!incommers.length) {
      return
    }

    let currentNode =
      incommers.length > 1 && isTargetInIncommers
        ? isTargetInIncommers
        : incommers.reduce((prev, curr) => (prev.position.x > curr.position.x ? curr : prev))

    if (isTargetInIncommers && isConfiguredMerge(isTargetInIncommers as Node<SegmentBuilderData>)) {
      // traverse on merge target
      traverse(getIncomers(currentNode, elements), currentNode)
    }
    if (parentNode) {
      currentNode = incommers.reduce((prev, curr) => (prev.position.x < curr.position.x ? prev : curr))
    }

    if (currentNode && currentNode.id !== stopNode.id && currentNode.position.x !== mergeNode.position.x) {
      // traverse on any other nodes
      chainOfChildrens.push(currentNode)
      traverse(getIncomers(currentNode, elements))
    }
  }

  traverse(getIncomers(mergeNode, elements))

  return chainOfChildrens
}

/**
 * Updates the x position of a node and its child branch nodes to prevent overlapping
 * @param {NodeLimit[]} limits - An array of limits that define the x-axis boundaries for each level of nodes
 * @param {XYPosition} lastNodePosition - The position of the previous node in the same branch
 * @param {Node} parentNode - The parent node of the branch to update
 * @param {Elements} allElements - The array of all elements in the diagram
 */
const overlapUpdateNodes = (
  limits: NodeLimit[],
  lastNodePosition: XYPosition,
  parentNode: Node,
  allElements: Elements,
) => {
  const newPosition = { ...lastNodePosition }
  let isOverlapping = true
  const nodesToUpdate = getBranchNodesToUpdate(parentNode, allElements)

  while (isOverlapping) {
    const limit = getOverlapLimit(limits, newPosition)

    if (!limit) {
      return
    }
    const difference = getXLvlToUpdate(limit.maxX, newPosition.x)
    newPosition.x = newPosition.x + getFragmentX() * difference
    isOverlapping = !!getOverlapLimit(limits, newPosition)

    if (nodesToUpdate.length) {
      // update directly on reference
      nodesToUpdate.forEach((node: Node) => {
        node.position.x = node.position.x + getFragmentX() * difference
      })
    }
  }
}

const getOverlapPosition = (startPosition: XYPosition, currentNodeX: number, limits: NodeLimit[]) => {
  let isOverlap = false
  const newPosition = { ...startPosition }

  while (!isOverlap) {
    newPosition.x = newPosition.x - getFragmentX()
    if (newPosition.x <= currentNodeX) {
      break
    }

    isOverlap = !!getOverlapLimit(limits, newPosition)
  }
  return newPosition
}

const getXLvlToUpdate = (limitX: number, nodeX: number) => {
  const X_DISTANCE_BETWEEN_NODES = 1
  const lastNodeLvl = getLevelX(nodeX)
  const limitLevel = getLevelX(limitX)
  const difference = limitLevel - lastNodeLvl

  return difference + X_DISTANCE_BETWEEN_NODES
}

const getOnlyLimitsThatAffectsNode = (limits: NodeLimit[], nodePosition: XYPosition) => {
  return limits.filter((limit) => nodePosition.y >= limit.minY)
}

/**
 * Finds the first limit in the given array that overlaps with the given node position.
 * @param {NodeLimit[]} limits - The array of node limits to check.
 * @param {XYPosition} nodePosition - The position of the node to check for overlap.
 * @returns {NodeLimit|null} - The first limit that overlaps with the given node position, or null if none is found.
 */
export const getOverlapLimit = (limits: NodeLimit[], nodePosition: XYPosition): NodeLimit | null => {
  const collisionLimits = getOnlyLimitsThatAffectsNode(limits, nodePosition)
  const limit = collisionLimits.find(
    (limit) =>
      (nodePosition.y >= limit.minY &&
        nodePosition.y <= limit.maxY &&
        nodePosition.x >= limit.minX &&
        nodePosition.x <= limit.maxX) ||
      (nodePosition.x <= limit.maxX && collisionLimits.some((limit) => nodePosition.y > limit.maxY)),
  )

  return limit ?? null
}

const isConfiguredMerge = (node: Node<SegmentBuilderData>): boolean => {
  return !!(node.data?.type === NodeTypeEnum.Merge && node.data?.payload?.targets?.length)
}

/**
 * Returns an array of all branch nodes that need to be updated given a starting branch node and an array of elements.
 * A branch node is defined as any node that is downstream (to the right) of the starting node, but not part of a merge.
 * @param {Node} node - The starting node.
 * @param {Elements} elements - An array of elements.
 * @returns {Node[]} An array of branch nodes that need to be updated.
 */
export const getBranchNodesToUpdate = (node: Node, elements: Elements): Node[] => {
  const childBranchNodes: Node[] = []

  const traverse = (outgoingNodes: Node<SegmentBuilderData>[], previousNode?: Node) => {
    for (const currentNode of outgoingNodes) {
      const isNodeOnRight = currentNode.position.x > node.position.x
      const isConnectedMerge =
        previousNode &&
        isConfiguredMerge(currentNode) &&
        currentNode.data?.payload?.targets?.some((id: string) => id === previousNode.id)

      if (isNodeOnRight && !isConnectedMerge) {
        if (!childBranchNodes.some((node) => node.id === currentNode.id)) {
          childBranchNodes.push(currentNode)
        }
        traverse(getOutgoers(currentNode, elements) as Node<SegmentBuilderData>[], currentNode)
      }
    }
  }

  traverse(getOutgoers(node, elements) as Node<SegmentBuilderData>[])

  return childBranchNodes
}

/**
 * Returns the last node of same x axis.
 * @param {Node} node - The starting node of the branch.
 * @param {Elements} elements - The elements to search through.
 * @returns {Node } The last node or provided node if was not found any other.
 */
export const getLastNode = (node: Node, elements: Elements) => {
  let lastNode

  const traverse = (nodes: Node[]) => {
    if (nodes.length) {
      for (const currentNode of nodes) {
        if (node.position.x === currentNode.position.x) {
          const childNodes = getOutgoers(currentNode, elements).filter(
            (childNode) => childNode.position.x === node.position.x,
          )

          if (!childNodes.length) {
            lastNode = currentNode
          } else {
            traverse([...getOutgoers(currentNode, elements)])
          }
        }
      }
    }
  }

  traverse(getOutgoers(node, elements))

  return lastNode ?? node
}
