import {
  getDefaultTableCellWidth,
  isCellSelection,
  reduceAttributes,
  tableAligns,
} from '@blissbook/lib/document'
import { updateNodeAttributes } from '@blissbook/lib/document/commands'
import {
  Node,
  type NodeWithPos,
  type ParentConfig,
  callOrReturn,
  getExtensionField,
  isTextSelection,
} from '@tiptap/core'
import type { Attrs } from 'prosemirror-model'
import { TextSelection, type Transaction } from 'prosemirror-state'
import {
  TableMap,
  deleteColumn,
  deleteRow,
  deleteTable,
  goToNextCell,
  mergeCells,
  splitCell,
  tableEditing,
} from 'prosemirror-tables'
import { getContainerNodeFromPos } from '../../content'

declare module '@tiptap/core' {
  interface Commands<ReturnType> {
    table: {
      insertTable: (options?: { rows?: number; cols?: number }) => ReturnType
      insertColumnAt: (col: number) => ReturnType
      insertRowAt: (row: number) => ReturnType
      addColumn: () => ReturnType
      addRow: () => ReturnType
      deleteColumn: () => ReturnType
      deleteRow: () => ReturnType
      deleteTable: () => ReturnType
      tidyUpTable: (size: { width: number }) => ReturnType
      mergeCells: () => ReturnType
      splitCell: () => ReturnType
      setCellAttributes: (attrs: Attrs) => ReturnType
      goToNextCell: () => ReturnType
      goToPreviousCell: () => ReturnType
    }
  }

  interface NodeConfig<Options, Storage> {
    /**
     * Table Role
     */
    tableRole?:
      | string
      | ((this: {
          name: string
          options: Options
          storage: Storage
          parent: ParentConfig<NodeConfig<Options>>['tableRole']
        }) => string)
  }
}

function insertColumnAt({
  col,
  map,
  table,
  tr,
}: {
  col: number
  map: TableMap
  table: NodeWithPos
  tr: Transaction
}) {
  const refColumn = col > 0 ? -1 : 0

  let row = 0
  while (row < map.height) {
    const index = row * map.width + col

    // If this position falls inside a col-spanning cell
    if (col > 0 && col < map.width && map.map[index - 1] === map.map[index]) {
      const cellPos = map.map[index]
      const cell = table.node.nodeAt(cellPos)
      const { colspan, rowspan } = cell.attrs

      updateNodeAttributes(tr, tr.mapping.map(table.pos + cellPos + 1), {
        colspan: colspan + 1,
      })

      row += rowspan
    } else {
      const cell = table.node.nodeAt(map.map[index + refColumn])
      const cellPos = map.positionAt(row, col, table.node)
      const attrs = reduceAttributes({ ...cell.attrs })
      // biome-ignore lint/performance/noDelete: app relies on this behavior
      delete attrs.colspan
      tr.insert(
        tr.mapping.map(table.pos + cellPos + 1),
        cell.type.createAndFill(attrs),
      )

      const { rowspan } = attrs
      row += rowspan
    }
  }
}

function insertRowAt({
  map,
  row,
  table,
  tr,
}: {
  map: TableMap
  row: number
  table: NodeWithPos
  tr: Transaction
}) {
  let rowPos = table.pos + 1
  for (let i = 0; i < row; i++) {
    rowPos += table.node.child(i).nodeSize
  }

  const cells = []
  const refRow = row > 0 ? -1 : 0

  let index = map.width * row
  let col = 0
  while (col < map.width) {
    // Covered by a rowspan cell
    if (
      row > 0 &&
      row < map.height &&
      map.map[index] === map.map[index - map.width]
    ) {
      const cellPos = map.map[index]
      const cell = table.node.nodeAt(cellPos)
      const { colspan, rowspan } = cell.attrs

      updateNodeAttributes(tr, tr.mapping.map(table.pos + cellPos + 1), {
        rowspan: rowspan + 1,
      })

      col += colspan
      index += 1
    } else {
      const cell = table.node.nodeAt(map.map[index + refRow * map.width])
      const attrs = reduceAttributes({ ...cell.attrs })
      // biome-ignore lint/performance/noDelete: app relies on this behavior
      delete attrs.rowspan
      cells.push(cell.type.createAndFill(attrs))

      const { colspan } = attrs
      col += colspan
      index += colspan
    }
  }

  const rowType = table.node.type.schema.nodes.tableRow
  tr.insert(rowPos, rowType.create(null, cells))
}

export const TableNode = Node.create({
  name: 'table',

  content: 'tableRow+',

  tableRole: 'table',

  isolating: true,

  addAttributes() {
    return {
      align: {
        default: 'center',
        keepOnSplit: true,
        parseHTML: (element: HTMLElement) => {
          const align = tableAligns.find((align) =>
            element.classList.contains(align.className),
          )
          if (!align) return
          return align.value
        },
        renderHTML: ({ align }) => {
          if (!align) return {}

          const type = tableAligns.find((type) => type.value === align)
          if (!type) return
          return { class: type.className }
        },
      },
      borderColor: {
        default: undefined,
        parseHTML: (element: HTMLElement) => {
          return element.style.borderColor || undefined
        },
        renderHTML: ({ borderColor }) => {
          if (!borderColor) return {}
          return { style: `border-color: ${borderColor}` }
        },
      },
    }
  },

  parseHTML() {
    return [{ tag: 'table' }]
  },

  renderHTML({ HTMLAttributes }) {
    return ['table', HTMLAttributes, ['tbody', 0]]
  },

  addCommands() {
    return {
      insertTable:
        ({ rows = 3, cols = 3 } = {}) =>
        ({ dispatch, editor, tr }) => {
          const { schema } = editor

          // Only allow inserts on empty text selections
          const { selection } = tr
          if (!isTextSelection(selection) || !selection.empty) return false

          // Do not allow nested tables
          const listItem = getContainerNodeFromPos(
            tr.selection.$from,
            'listItem',
          )
          const table = getContainerNodeFromPos(tr.selection.$from, 'table')
          if (listItem || table) return false

          const cellWidth = getDefaultTableCellWidth(cols)
          const trNodes = []
          for (let index = 0; index < rows; index++) {
            const tdNodes = []
            for (let index = 0; index < cols; index++) {
              const tdNode = schema.nodes.tableCell.createAndFill({
                width: cellWidth,
              })
              tdNodes.push(tdNode)
            }
            const trNode = schema.nodes.tableRow.createChecked(null, tdNodes)
            trNodes.push(trNode)
          }
          const tableNode = schema.nodes.table.createChecked(null, trNodes)

          if (dispatch) {
            const offset = tr.selection.anchor + 1

            tr.replaceSelectionWith(tableNode)
              .scrollIntoView()
              .setSelection(TextSelection.near(tr.doc.resolve(offset)))
          }

          return true
        },
      insertColumnAt:
        (col) =>
        ({ state, dispatch, tr }) => {
          const { selection } = state
          const table = getContainerNodeFromPos(selection.$from, 'table')
          if (!table) return false

          if (dispatch) {
            const map = TableMap.get(table.node)
            insertColumnAt({ col, map, table, tr })
          }

          return true
        },
      insertRowAt:
        (row) =>
        ({ state, dispatch, tr }) => {
          const { selection } = state
          const table = getContainerNodeFromPos(selection.$from, 'table')
          if (!table) return false

          if (dispatch) {
            const map = TableMap.get(table.node)
            insertRowAt({ map, row, table, tr })
          }

          return true
        },
      addColumn:
        () =>
        ({ state, dispatch, tr }) => {
          const { selection } = state
          const table = getContainerNodeFromPos(selection.$from, 'table')
          if (!table) return false

          if (dispatch) {
            const map = TableMap.get(table.node)
            const col = map.width
            insertColumnAt({ col, map, table, tr })
          }

          return true
        },
      addRow:
        () =>
        ({ state, dispatch, tr }) => {
          const { selection } = state
          const table = getContainerNodeFromPos(selection.$from, 'table')
          if (!table) return false

          if (dispatch) {
            const map = TableMap.get(table.node)
            const row = map.height
            insertRowAt({ map, row, table, tr })
          }

          return true
        },
      deleteColumn:
        () =>
        ({ state, dispatch }) => {
          return deleteColumn(state, dispatch)
        },
      deleteRow:
        () =>
        ({ state, dispatch }) => {
          return deleteRow(state, dispatch)
        },
      deleteTable:
        () =>
        ({ state, dispatch }) => {
          return deleteTable(state, dispatch)
        },
      tidyUpTable:
        (size) =>
        ({ commands, state, dispatch }) => {
          const { selection } = state
          const table = getContainerNodeFromPos(selection.$from, 'table')
          if (!table) return false

          if (dispatch) {
            const map = TableMap.get(table.node)
            commands.setCellAttributes({
              height: undefined,
              width: Math.round(size.width / map.width),
            })
          }

          return true
        },
      mergeCells:
        () =>
        ({ state, dispatch }) => {
          return mergeCells(state, dispatch)
        },
      splitCell:
        () =>
        ({ state, dispatch }) => {
          return splitCell(state, dispatch)
        },
      setCellAttributes:
        (attrs) =>
        ({ state, dispatch, tr }) => {
          const { selection } = state
          if (!isCellSelection(selection)) return false

          if (dispatch) {
            selection.forEachCell((_node, pos) => {
              updateNodeAttributes(tr, pos, attrs)
            })
          }

          return true
        },
      goToNextCell:
        () =>
        ({ state, dispatch }) => {
          return goToNextCell(1)(state, dispatch)
        },
      goToPreviousCell:
        () =>
        ({ state, dispatch }) => {
          return goToNextCell(-1)(state, dispatch)
        },
    }
  },

  addKeyboardShortcuts() {
    return {
      Tab: () => {
        if (this.editor.commands.goToNextCell()) {
          return true
        }

        if (!this.editor.can().addRow()) {
          return false
        }

        return this.editor.chain().addRow().goToNextCell().run()
      },
      'Shift-Tab': () => this.editor.commands.goToPreviousCell(),
    }
  },

  addProseMirrorPlugins() {
    return [
      tableEditing({
        allowTableNodeSelection: true,
      }),
    ]
  },

  extendNodeSchema(extension) {
    const context = {
      name: extension.name,
      options: extension.options,
      storage: extension.storage,
    }

    return {
      tableRole: callOrReturn(
        getExtensionField(extension, 'tableRole', context),
      ),
    }
  },
})
