import type { QueryDocumentSnapshot } from 'firebase/firestore'
import type { ChatMessage } from './types'
import { maxBy, sortBy } from 'lodash'

export interface MessageTree {
  parent?: MessageTree;
  message: QueryDocumentSnapshot<ChatMessage>;
  children: MessageTree[];
}

export type MessageThread = MessageTree[];

export function collectThread(
  node: MessageTree | null,
  acc: MessageTree[] = []
): MessageTree[] {
  if (node == null) {
    return acc
  }
  return collectThread(node.parent ?? null, [node, ...acc])
}

export function findLatestLeaf(node: MessageTree): MessageTree {
  if (node.children.length === 0) {
    return node
  }

  const leafs = node.children.map((leaf) => findLatestLeaf(leaf))
  const latest = maxBy(leafs, (leaf) => leaf.message.data().created_at)

  if (latest == null) {
    throw new Error('No latest leaf found')
  }

  return latest
}

export function buildTree(messages: QueryDocumentSnapshot<ChatMessage>[]) {
  const roots: MessageTree[] = []

  const messageMap = generateMessageMap(messages)

  messages.forEach((message) => {
    const data = message.data()
    const node = messageMap.get(message.id)

    if (node == null) {
      throw new Error(`Node not found for message id ${message.id}`)
    }

    const parentId = data.parent
    if (parentId == null) {
      roots.push(node)
    } else {
      const parentNode = messageMap.get(parentId)
      if (parentNode == null) {
        throw new Error(`Parent node not found for message id ${parentId}`)
      }
      node.parent = parentNode
      parentNode.children.push(node)
    }
  })

  if (roots.length > 1) {
    throw new Error('Multiple roots found')
  }

  if (roots.length === 0) {
    throw new Error('No roots found')
  }

  return sortChildren(roots[0])
}

function sortChildren(tree: MessageTree): MessageTree {
  return {
    ...tree,
    children: sortBy(tree.children, (child) => child.message.data().created_at),
  }
}

function generateMessageMap(messages: QueryDocumentSnapshot<ChatMessage>[]) {
  const messageMap = new Map<string, MessageTree>()

  messages.forEach((message) => {
    messageMap.set(message.id, {
      message,
      children: [],
    })
  })

  return messageMap
}
