import {
  type Attrs,
  Fragment,
  type Node as ProseMirrorNode,
  type Schema as ProseMirrorSchema,
  Slice,
} from 'prosemirror-model'
import { type Mapping, Step, StepResult } from 'prosemirror-transform'

function applyAttrs(prev: Attrs, changes: Attrs) {
  const attrs = { ...prev }

  for (const [key, value] of Object.entries(changes)) {
    if (value === undefined) {
      delete attrs[key]
    } else {
      attrs[key] = value
    }
  }

  return attrs
}

export class UpdateNodeAttributesStep extends Step {
  pos: number
  attrs: Attrs

  constructor(pos: number, attrs: Attrs) {
    super()
    this.pos = pos
    this.attrs = attrs
  }

  apply(doc: ProseMirrorNode) {
    const { pos } = this
    const prevNode = doc.nodeAt(pos)
    if (!prevNode) throw new Error(`Invalid position: ${pos}`)

    const attrs = applyAttrs(prevNode.attrs, this.attrs)
    const node = prevNode.type.create(attrs, prevNode.content, prevNode.marks)

    const from = pos
    const to = from + node.nodeSize
    const fragment = Fragment.fromArray([node])
    const slice = new Slice(fragment, 0, 0)
    return StepResult.fromReplace(doc, from, to, slice)
  }

  invert(doc: ProseMirrorNode) {
    const { pos } = this

    const node = doc.nodeAt(pos)
    if (!node) throw new Error(`Invalid position: ${pos}`)

    const attrs = { ...this.attrs }
    for (const key of Object.keys(attrs)) {
      attrs[key] = node.attrs[key]
    }
    return new UpdateNodeAttributesStep(pos, attrs)
  }

  map(mapping: Mapping) {
    const mapped = mapping.mapResult(this.pos)
    if (mapped.deleted) return null
    return new UpdateNodeAttributesStep(mapped.pos, this.attrs)
  }

  merge(other: Step) {
    const { pos } = this
    if (other instanceof UpdateNodeAttributesStep && other.pos === pos) {
      const attrs = { ...this.attrs, ...other.attrs }
      return new UpdateNodeAttributesStep(pos, attrs)
    }
    return null
  }

  toJSON() {
    return {
      stepType: 'updateNodeAttributes',
      pos: this.pos,
      attrs: this.attrs,
    }
  }

  static fromJSON(_schema: ProseMirrorSchema, json: Record<string, any>) {
    const { attrs, pos } = json
    if (typeof attrs !== 'object' || typeof pos !== 'number')
      throw new RangeError(
        'Invalid input for UpdateNodeAttributesStep.fromJSON',
      )
    return new UpdateNodeAttributesStep(pos, attrs)
  }
}

Step.jsonID('updateNodeAttributes', UpdateNodeAttributesStep)
