import * as Diff from 'diff'
import {
  Node as ASTNode,
  Html,
  Heading,
  Paragraph,
  Parent,
  Root,
  RootContent,
  TableCell,
  TableRow,
  Text,
  List,
  Code,
  FootnoteDefinition,
} from 'mdast'
import { remark } from 'remark'
import { remarkExtendedTable } from 'remark-extended-table'
import remarkGfm from 'remark-gfm'
import remarkParse from 'remark-parse'
import { v4 } from 'uuid'

import { ParagraphTokenizer } from './paragraph-tokenizer'
import * as Types from './types'
import { extractText } from './utils'

/**
 * Represents an operation to be performed on the AST.
 */
type OperationBase = {
  type: 'insert' | 'delete' | 'modify'
  cost: number
  newLinesAfter: number
}

export type Operation =
  | (OperationBase & { type: 'insert'; newNode: ASTNode })
  | (OperationBase & { type: 'delete'; oldNode: ASTNode })
  | (OperationBase & { type: 'modify'; oldNode: ASTNode; newNode: ASTNode })

export type HandlerState = {
  list: {
    ordered: boolean
    orderNumber: number
    nestCount: number
  }
}

type ASTNodeWithId = ASTNode & { id: string }

const PARAGRAPH_LIKE_TYPES = ['paragraph', 'heading']

/**
 * A service that diffs two markdown strings.
 * Uses mdast (https://github.com/syntax-tree/mdast) to parse the markdown
 * into an AST. Uses diff (https://github.com/kpdecker/jsdiff) to diff the
 * text nodes. The diffing is done by walking through the AST and comparing
 * nodes. Looking at the RootContentMap & PhrasingContentMap helps to understand
 * the AST structure.
 *
 * MDAST AST is structured as follows:
 * Root -> RootContent[]
 * RootContent -> PhrasingContent[]
 * PhrasingContent -> Text (in various forms)
 */
export class MarkdownDiff {
  private nodesToBestOperationCache: Map<
    string,
    { cost: number; operations: Operation[] }
  > = new Map()

  /**
   * Runs the diff algorithm on the given markdown strings.
   * @returns A string with the diffed markdown.
   */
  public run(oldText: string, newText: string): string {
    this.clearCache()
    const oldAst = remark()
      .use(remarkParse)
      .use(remarkGfm)
      .use(remarkExtendedTable)
      .parse(oldText)
    const newAst = remark()
      .use(remarkParse)
      .use(remarkGfm)
      .use(remarkExtendedTable)
      .parse(newText)

    const oldAstWithIds = this.addIdsToNodes(oldAst as Root)
    const newAstWithIds = this.addIdsToNodes(newAst as Root)

    return this.dfsDiff(oldAstWithIds, newAstWithIds)
  }

  private dfsDiff(oldNode: ASTNode, newNode: ASTNode): string {
    const { operations } = this.dfsOptimalChildrenPath(oldNode, newNode)
    return this.applyOperations(operations, {
      list: { ordered: false, orderNumber: 0, nestCount: 0 },
    })
  }

  applyOperations(operations: Operation[], state: HandlerState) {
    let result = ''

    for (let i = 0; i < operations.length; i++) {
      const operation = operations[i]
      switch (operation.type) {
        case 'insert':
          result += this.insertNode(operation.newNode, state)
          break
        case 'delete':
          result += this.deleteNode(operation.oldNode, state)
          break
        case 'modify':
          result += this.modifyNode(operation.oldNode, operation.newNode, state)
          break
        default:
          console.warn('Unsupported operation type:', operation)
      }
      if (i + 1 < operations.length)
        result += '\n'.repeat(operation.newLinesAfter)
    }
    return result
  }

  private dfsOptimalChildrenPath(
    oldParent: ASTNode | null,
    newParent: ASTNode | null,
    state: {
      oldIndex: number
      newIndex: number
    } = { oldIndex: 0, newIndex: 0 }
  ): { cost: number; operations: Operation[] } {
    let deletionResult: { cost: number; operations: Operation[] } | null = null
    let insertionResult: { cost: number; operations: Operation[] } | null = null
    let modificationResult: { cost: number; operations: Operation[] } | null =
      null

    const deletingChild = Types.isParent(oldParent)
      ? oldParent.children[state.oldIndex]
      : null

    const insertingChild = Types.isParent(newParent)
      ? newParent.children[state.newIndex]
      : null

    if (!deletingChild && !insertingChild) {
      return { cost: 0, operations: [] }
    }

    const cacheKey = this.getKey(deletingChild, insertingChild)
    if (this.nodesToBestOperationCache.has(cacheKey)) {
      return this.nodesToBestOperationCache.get(cacheKey)!
    }

    // Delete current old node
    let deletionCost = Infinity
    if (deletingChild && Types.isParent(oldParent)) {
      deletionCost = this.deletionCost(deletingChild)
      deletionResult = this.dfsOptimalChildrenPath(oldParent, newParent, {
        oldIndex: state.oldIndex + 1,
        newIndex: state.newIndex,
      })
      deletionCost += deletionResult.cost
    }

    // Insert current new node
    let insertionCost = Infinity
    if (insertingChild && Types.isParent(newParent)) {
      insertionCost = this.insertionCost(insertingChild)
      insertionResult = this.dfsOptimalChildrenPath(oldParent, newParent, {
        oldIndex: state.oldIndex,
        newIndex: state.newIndex + 1,
      })
      insertionCost += insertionResult.cost
    }

    // Modify current old node to current new node
    let modificationCost = Infinity
    if (
      deletingChild &&
      insertingChild &&
      Types.isParent(oldParent) &&
      Types.isParent(newParent)
    ) {
      modificationCost = this.modificationCost(deletingChild, insertingChild)
      modificationResult = this.dfsOptimalChildrenPath(oldParent, newParent, {
        oldIndex: state.oldIndex + 1,
        newIndex: state.newIndex + 1,
      })
      modificationCost += modificationResult.cost
    }

    if (
      deletionResult &&
      deletionCost <= insertionCost &&
      deletionCost <= modificationCost &&
      deletingChild &&
      Types.isParent(oldParent)
    ) {
      const newLinesAfter = this.getNewlinesBetweenNodes(
        deletingChild,
        oldParent?.children[state.newIndex + 1]
      )
      const deleteOperation: Operation = {
        type: 'delete',
        oldNode: deletingChild,
        cost: deletionCost,
        newLinesAfter,
      }

      const result = {
        cost: deletionCost,
        operations: [deleteOperation as Operation].concat(
          deletionResult.operations
        ),
      }

      this.nodesToBestOperationCache.set(cacheKey, result)

      return result
    }

    if (
      insertionResult &&
      insertionCost <= deletionCost &&
      insertionCost <= modificationCost &&
      insertingChild &&
      Types.isParent(newParent)
    ) {
      const newLinesAfter = this.getNewlinesBetweenNodes(
        insertingChild,
        newParent?.children[state.newIndex + 1]
      )
      const insertOperation: Operation = {
        type: 'insert',
        newNode: insertingChild,
        cost: insertionCost,
        newLinesAfter,
      }

      const result = {
        cost: insertionCost,
        operations: [insertOperation as Operation].concat(
          insertionResult.operations
        ),
      }
      this.nodesToBestOperationCache.set(cacheKey, result)
      return result
    }

    if (
      modificationResult &&
      modificationCost <= deletionCost &&
      modificationCost <= insertionCost &&
      deletingChild &&
      insertingChild &&
      Types.isParent(oldParent) &&
      Types.isParent(newParent)
    ) {
      const modifyOperation: Operation = {
        type: 'modify',
        oldNode: deletingChild,
        newNode: insertingChild,
        cost: modificationCost,
        newLinesAfter: Math.max(
          this.getNewlinesBetweenNodes(
            insertingChild,
            newParent?.children[state.newIndex + 1]
          ),
          this.getNewlinesBetweenNodes(
            deletingChild,
            oldParent?.children[state.oldIndex + 1]
          )
        ),
      }

      const result = {
        cost: modificationCost,
        operations: [modifyOperation as Operation].concat(
          modificationResult.operations
        ),
      }

      this.nodesToBestOperationCache.set(cacheKey, result)

      return result
    }

    throw new Error('Unexpected: no path found')
  }

  /**
   * Returns a key for the cache based on the old and new node.
   */
  private getKey(oldNode: ASTNode | null, newNode: ASTNode | null): string {
    return `${(oldNode as ASTNodeWithId)?.id ?? 'null'}-${
      (newNode as ASTNodeWithId)?.id ?? 'null'
    }`
  }

  private insertionCost(node: ASTNode): number {
    return this.getNodeComplexity(node) + 1
  }

  private deletionCost(node: ASTNode): number {
    return this.getNodeComplexity(node) + 1
  }

  public modificationCost(oldNode: ASTNode, newNode: ASTNode): number {
    if (!this.nodePairIsModifiable(oldNode, newNode)) {
      return Infinity
    }

    const type = oldNode.type
    if (type === 'text') {
      return this.textModificationCost(
        (oldNode as Text).value,
        (newNode as Text).value
      )
    } else if (Types.isParent(oldNode) && Types.isParent(newNode)) {
      return this.parentTextModificationCost(oldNode, newNode)
    }

    return 0
  }

  private textModificationCost(oldText: string, newText: string): number {
    const changes = this.fastDiff(oldText, newText)
    return changes.reduce(
      (cost, change) =>
        cost + (change.added || change.removed ? change.value.length : 0),
      0
    )
  }

  /**
   * Extracts text content from parent node and calculates the modification cost
   **/
  private parentTextModificationCost(oldNode: Parent, newNode: Parent): number {
    const oldText = extractText(oldNode)
    const newText = extractText(newNode)
    return this.textModificationCost(oldText, newText)
  }

  private getNodeComplexity(node: ASTNode): number {
    if (Types.isParent(node)) {
      return (
        1 +
        (node.children as ASTNode[]).reduce(
          (sum, child) => sum + this.getNodeComplexity(child),
          0
        )
      )
    }

    if (Types.isText(node)) {
      return node.value?.length || 1
    }
    return 1
  }

  private nodeToHandler(node: ASTNode) {
    if (Types.isTable(node)) {
      return this.tableHandler.bind(this)
    } else if (Types.isHtml(node)) {
      return this.htmlHandler.bind(this)
    } else if (Types.isCode(node)) {
      return this.codeHandler.bind(this)
    } else if (Types.isThematicBreak(node)) {
      return this.thematicBreakHandler.bind(this)
    }
    return null
  }

  private insertNode(node: ASTNode, state: HandlerState): string {
    if (Types.isText(node)) {
      return `<ins>${(node as Text).value}</ins>`
    }
    const newState = this.updateState(null, node, state)

    if (PARAGRAPH_LIKE_TYPES.includes(node.type)) {
      let result = ''
      result += this.openTag(node, newState)
      result += this.paragraphHandler(null, node as Paragraph)
      result += this.closeTag(node)
      return result
    }

    const handler = this.nodeToHandler(node)
    if (handler) return handler(null, node as any)

    let result = ''
    if (Types.isParent(node)) {
      const { operations } = this.dfsOptimalChildrenPath(null, node)
      result += this.openTag(node, newState)
      result += this.applyOperations(operations, newState)
      result += this.closeTag(node)
    }
    return result
  }

  private deleteNode(node: ASTNode, state: HandlerState): string {
    if (Types.isText(node)) {
      return `<del>${(node as Text).value}</del>`
    }
    const newState = this.updateState(node, null, state)

    if (PARAGRAPH_LIKE_TYPES.includes(node.type)) {
      let result = ''
      result += this.openTag(node, newState)
      result += this.paragraphHandler(node as Paragraph, null)
      result += this.closeTag(node)
      return result
    }

    const handler = this.nodeToHandler(node)
    if (handler) return handler(node as any, null)

    let result = ''
    if (Types.isParent(node)) {
      const { operations } = this.dfsOptimalChildrenPath(node, null)
      result += this.openTag(node, newState)
      result += this.applyOperations(operations, newState)
      result += this.closeTag(node)
    }
    return result
  }

  private modifyNode(
    oldNode: ASTNode | null,
    newNode: ASTNode | null,
    state: HandlerState
  ): string {
    if (oldNode && !newNode) return this.deleteNode(oldNode, state)
    if (newNode && !oldNode) return this.insertNode(newNode, state)
    if (!this.nodePairIsModifiable(oldNode, newNode)) {
      if (oldNode && newNode)
        return this.deleteNode(oldNode, state) + this.insertNode(newNode, state)
      else if (oldNode) return this.deleteNode(oldNode, state)
      else if (newNode) return this.insertNode(newNode, state)
    }
    const newState = this.updateState(oldNode, newNode, state)

    // XXX: These should be the same (this.nodePairIsModifiable enforced)
    const type = oldNode?.type || newNode?.type

    // Just for type checking
    if (!oldNode || !newNode)
      throw new Error('Unexpected: oldNode or newNode is null')

    if (type === 'text') {
      return this.diffText(
        (oldNode as Text | null)?.value || '',
        (newNode as Text | null)?.value || ''
      )
    } else if (PARAGRAPH_LIKE_TYPES.includes(type!)) {
      let result = ''
      result += this.openTag(oldNode, newState)
      result += this.paragraphHandler(
        oldNode as Paragraph,
        newNode as Paragraph
      )
      result += this.closeTag(oldNode)
      return result
    }

    const handler = this.nodeToHandler(oldNode!)
    if (handler) return handler(oldNode as any, newNode as any)

    if (Types.isParent(oldNode) && Types.isParent(newNode)) {
      let result = ''
      const { operations } = this.dfsOptimalChildrenPath(oldNode, newNode)
      result += this.openTag(oldNode, newState)
      result += this.applyOperations(operations, newState)
      result += this.closeTag(oldNode)
      return result
    } else {
      console.warn(
        `Unhandled node type in modifyNode: ${oldNode.type} ${newNode.type}`
      )
      return this.deleteNode(oldNode, state) + this.insertNode(newNode, state)
    }
  }

  private updateState(
    oldNode: ASTNode | null,
    newNode: ASTNode | null,
    state: HandlerState
  ): HandlerState {
    if (!oldNode && !newNode) return state

    const node = (oldNode ?? newNode) as ASTNode

    if (Types.isList(node)) {
      return {
        list: {
          ordered: node.ordered || false,
          // Nest count is 1-indexed
          nestCount: (state.list?.nestCount ?? 0) + 1,
          // Order number is 0-indexed
          orderNumber: (node?.start ?? 1) - 1,
        },
      }
    } else if (Types.isListItem(node)) {
      state.list.orderNumber++
      return state
    }

    return state
  }

  private nodePairIsModifiable(
    oldNode: ASTNode | null,
    newNode: ASTNode | null
  ): boolean {
    if (!oldNode || !newNode) return false
    if (oldNode.type !== newNode.type) return false

    const type = oldNode.type

    if (type === 'heading') {
      if ((oldNode as Heading).depth !== (newNode as Heading).depth)
        return false
    } else if (type === 'list') {
      if ((oldNode as List).ordered !== (newNode as List).ordered) return false
      // TODO: other list properties
    }
    return true
  }

  private openTag(node: ASTNode, state: HandlerState): string {
    switch (node.type) {
      case 'heading':
        return `${'#'.repeat((node as Heading).depth)} `
      case 'listItem':
        return (
          '  '.repeat(state.list.nestCount - 1) +
          (state.list.ordered ? `${state.list.orderNumber}. ` : '- ')
        )
      case 'emphasis':
        return '*'
      case 'strong':
        return '**'
      case 'delete':
        return '~~'
      case 'footnoteDefinition':
        return `[^${(node as FootnoteDefinition).identifier}]: `
      default:
        return ''
    }
  }

  private closeTag(node: ASTNode): string {
    switch (node.type) {
      case 'emphasis':
        return '*'
      case 'strong':
        return '**'
      case 'delete':
        return '~~'
      default:
        return ''
    }
  }

  private diffText(oldText: string, newText: string): string {
    const changes = Diff.diffWords(oldText, newText)
    return changes
      .map((change) => {
        if (change.added) {
          return `<ins>${change.value}</ins>`
        }
        if (change.removed) {
          return `<del>${change.value}</del>`
        }
        return change.value
      })
      .join('')
  }

  /**
   * Diffs paragraph content, walking through children and comparing them.
   */
  // TODO: Rename if this is going to be used more generically
  private paragraphHandler(
    oldNode: Paragraph | TableCell | null,
    newNode: Paragraph | TableCell | null
  ): string {
    const tokenize = (node: Paragraph | TableCell | null) =>
      new ParagraphTokenizer().tokenizeParagraph(node)

    const oldTokenizedText = tokenize(oldNode)
    const newTokenizedText = tokenize(newNode)

    const diff = Diff.diffArrays(oldTokenizedText, newTokenizedText, {
      comparator: (a, b) =>
        JSON.stringify({ token: a.token, attributes: a.attributes }) ===
        JSON.stringify({ token: b.token, attributes: b.attributes }),
    })
    const diffObjects = diff.flatMap((diffEntry) =>
      diffEntry.value.map((v) => ({
        added: !!diffEntry.added,
        removed: !!diffEntry.removed,
        value: v,
      }))
    )

    const result = diffObjects.reduce(
      // eslint-disable-next-line max-params
      (acc, diffEntry, i, arr) => {
        const { value: parsed, added, removed } = diffEntry
        const nextParsed = arr[i + 1]?.value
        let resultToken = ''

        const handleOpenAttribute = (attr: string, tag: string) => {
          if (
            parsed.attributes.includes(attr) &&
            !acc.attributes.includes(attr)
          ) {
            resultToken += tag
            acc.attributes.push(attr)
          }
        }

        const handleCloseAttribute = (attr: string, tag: string) => {
          if (
            !nextParsed?.attributes.includes(attr) &&
            acc.attributes.includes(attr)
          ) {
            resultToken += tag
            acc.attributes = acc.attributes.filter((a) => a !== attr)
          }
        }

        const handleOpenSpecialAttribute = (
          condition: boolean,
          attr: 'ins' | 'del',
          openTag: string
        ) => {
          if (condition && !acc.attributes.includes(attr)) {
            resultToken += openTag
            acc.attributes.push(attr)
          }
        }

        const handleCloseSpecialAttribute = (
          attr: 'ins' | 'del',
          closeTag: string
        ) => {
          if (
            !arr[i + 1]?.[attr === 'ins' ? 'added' : 'removed'] &&
            acc.attributes.includes(attr)
          ) {
            resultToken += closeTag
            acc.attributes = acc.attributes.filter((a) => a !== attr)
          }
        }

        handleOpenSpecialAttribute(added, 'ins', '<ins>')
        handleOpenSpecialAttribute(removed, 'del', '<del>')
        handleOpenAttribute('strong', '**')
        handleOpenAttribute('emphasis', '*')
        handleOpenAttribute('delete', '~~')

        resultToken += parsed.token

        handleCloseAttribute('delete', '~~')
        handleCloseAttribute('emphasis', '*')
        handleCloseAttribute('strong', '**')
        handleCloseSpecialAttribute('ins', '</ins>')
        handleCloseSpecialAttribute('del', '</del>')

        if (
          !parsed.noSpace?.after &&
          !nextParsed?.noSpace?.before &&
          i !== arr.length - 1
        ) {
          resultToken += ' '
        }

        acc.result.push(resultToken)
        return acc
      },
      { result: [] as string[], attributes: [] as string[] }
    )

    const applyFinalCleanup = (result: string[]) => {
      const removeSpaceBetweenInsAndDel = (result: string) => {
        return result
          .replace(/<\/ins> <del>/g, '</ins><del>')
          .replace(/<\/del> <ins>/g, '</del><ins>')
      }
      return removeSpaceBetweenInsAndDel(result.join(''))
    }

    return applyFinalCleanup(result.result)
  }

  private tableHandler(
    oldNode: RootContent | null,
    newNode: RootContent | null
  ): string {
    const result: string[] = []

    if (
      (oldNode && !Types.isTable(oldNode)) ||
      (newNode && !Types.isTable(newNode))
    ) {
      throw new Error(
        `Unexpected node type in tableHandler: ${oldNode?.type} ${newNode?.type}`
      )
    }

    const oldNodeRows = oldNode?.children || []
    const newNodeRows = newNode?.children || []

    let isFirstRow = true

    const maxCols = Math.max(
      oldNodeRows[0]?.children.length || 0,
      newNodeRows[0]?.children.length || 0
    )

    const { operations: optimalTableRowPath } = this.dfsOptimalChildrenPath(
      oldNode,
      newNode
    )

    for (let i = 0; i < optimalTableRowPath.length; i++) {
      result.push(this.applyTableRowOperation(optimalTableRowPath[i]))
      if (isFirstRow) result.push(`|${'-|'.repeat(maxCols)}`)

      isFirstRow = false
    }

    return result.join('\n')
  }

  private applyTableRowOperation(operation: Operation) {
    const result: string[] = []

    if (operation.type === 'insert') {
      const newNode = operation.newNode as TableRow
      for (let i = 0; i < newNode.children.length; i++) {
        const cell = newNode.children[i]
        result.push(this.tableCellHandler(null, cell))
      }
    } else if (operation.type === 'delete') {
      const oldNode = operation.oldNode as TableRow

      for (let i = 0; i < oldNode.children.length; i++) {
        const cell = oldNode.children[i]
        result.push(this.tableCellHandler(cell, null))
      }
    } else {
      const oldNode = operation.oldNode as TableRow
      const newNode = operation.newNode as TableRow

      const maxLength = Math.max(
        oldNode?.children.length || 0,
        newNode?.children.length || 0
      )
      for (let i = 0; i < maxLength; i++) {
        const oldCell = oldNode?.children[i] || null
        const newCell = newNode?.children[i] || null

        result.push(this.tableCellHandler(oldCell, newCell))
      }
    }
    return `|${result.join('|')}|`
  }

  private tableCellHandler(
    oldNode: TableCell | null,
    newNode: TableCell | null
  ) {
    return this.paragraphHandler(oldNode, newNode)
  }

  private htmlHandler(oldNode: Html | null, newNode: Html | null): string {
    const res = []

    if (oldNode && newNode && oldNode.value !== newNode.value) {
      if (oldNode) res.push(oldNode.value)
      if (newNode) res.push(newNode.value)
    } else if (oldNode || newNode) {
      res.push((newNode || oldNode)!.value)
    }
    return res.join('')
  }

  /**
   * Doesn't really diff them, just returns the old code if it exists, and the
   * new code if it exists.
   */
  private codeHandler(oldNode: Code | null, newNode: Code | null): string {
    const res = []

    if (
      oldNode &&
      newNode &&
      (oldNode.value !== newNode.value || oldNode.lang !== newNode.lang)
    ) {
      if (oldNode) res.push(`\`\`\`${oldNode.lang}\n${oldNode.value}\n\`\`\``)
      if (newNode) res.push(`\`\`\`${newNode.lang}\n${newNode.value}\n\`\`\``)
    } else {
      res.push(`\`\`\`${newNode?.lang}\n${newNode?.value}\n\`\`\``)
    }
    return res.join('\n\n')
  }

  /**
   * There isn't a nice way to diff thematic breaks, so return it if it exists
   * in the new AST, otherwise return an empty string if it was deleted.
   */
  private thematicBreakHandler(
    _oldNode: RootContent | null,
    newNode: RootContent | null
  ): string {
    if (newNode) return '---'
    return ''
  }

  /**
   * Each node is separated by newlines. This function calculates
   * the number of newlines needed to separate the nodes.
   */
  public getNewlinesBetweenNodes(
    currentNode: ASTNode,
    nextNode: ASTNode | undefined
  ): number {
    if (!currentNode.position || !nextNode?.position) {
      if (Types.isBlockContent(currentNode)) return 2
      if (Types.isParent(currentNode)) return 1
      return 0
    }
    return Math.max(
      Math.abs(nextNode.position.start.line - currentNode.position.end.line),
      Types.isParent(currentNode) ? 1 : 0
    )
  }

  private clearCache() {
    this.nodesToBestOperationCache.clear()
  }

  private addIdsToNodes(node: ASTNode) {
    ;(node as ASTNodeWithId).id = v4()

    if (Types.isParent(node as ASTNode)) {
      ;(node as Parent).children.forEach((child) => this.addIdsToNodes(child))
    }
    return node as ASTNode
  }

  private fastDiff = (oldText: string, newText: string): Diff.Change[] => {
    return Diff.diffSentences(oldText, newText)
  }
}
