import { computeDirectSelections } from "./multiSelectUtils";

import {
  TreePickerActionType,
  TreePickerNode,
  TreePickerReducerAction,
  TreePickerReducerState,
} from "../../types";
import { getParentIds } from "../../utils";

export const treePickerReducer = (
  state: TreePickerReducerState,
  action: TreePickerReducerAction,
): TreePickerReducerState => {
  // Make a shallow copy of the nodeStates to avoid mutation.
  const updatedNodeStates = { ...state.nodeStates };
  const { node, nodeMap } = action.payload;

  // Returns the computed state based on the updated node states.
  const getComputedState = () => {
    const newDirectSelections = computeDirectSelections(
      Object.values(nodeMap)[0], // assumes the first node is the root
      updatedNodeStates,
    );

    return {
      selectedNodeIds: newDirectSelections,
      nodeStates: updatedNodeStates,
    };
  };

  switch (action.type) {
    case TreePickerActionType.SelectNode: {
      if (updatedNodeStates[node.id]?.selected) return state;

      // Recursively mark the subtree as selected.
      const setSubtreeSelection = (subTreeNode: TreePickerNode) => {
        updatedNodeStates[subTreeNode.id] = {
          selected: true,
          hasSelectedDescendants: false,
        };
        subTreeNode.children.forEach(setSubtreeSelection);
      };
      setSubtreeSelection(node);

      // If the node is not a root, update all its ancestors.
      if (node.parentId !== null) {
        const parentIds = getParentIds(node, nodeMap);
        parentIds.forEach((parentNodeId) => {
          updatedNodeStates[parentNodeId] = {
            selected: false,
            hasSelectedDescendants: true,
          };
        });
      }

      return getComputedState();
    }

    case TreePickerActionType.DeselectNode: {
      // Recursively mark the subtree as deselected.
      const unsetSubtreeSelection = (subTreeNode: TreePickerNode) => {
        updatedNodeStates[subTreeNode.id] = {
          selected: false,
          hasSelectedDescendants: false,
        };
        subTreeNode.children.forEach(unsetSubtreeSelection);
      };
      unsetSubtreeSelection(node);

      // Cascade upward: update each ancestor based solely on its children.
      const parentIds = getParentIds(node, nodeMap);
      parentIds.forEach((parentId) => {
        const parent = nodeMap[parentId];
        let anyChildSelected = false;
        for (const child of parent.children) {
          const childState = updatedNodeStates[child.id] || {
            selected: false,
            hasSelectedDescendants: false,
          };
          if (childState.selected || childState.hasSelectedDescendants) {
            anyChildSelected = true;
            break;
          }
        }
        updatedNodeStates[parentId] = anyChildSelected
          ? { selected: false, hasSelectedDescendants: true }
          : { selected: false, hasSelectedDescendants: false };
      });

      return getComputedState();
    }

    default:
      return state;
  }
};
