import { useTheme } from "@mui/material";
import { nearlyEqual } from "api-shared";
import { get, sum, sumBy } from "lodash";
import { LabelList, LabelProps } from "recharts";
import { Props } from "recharts/types/component/LabelList";
import { ViewBox } from "recharts/types/util/types";
import { findNode, NamedTreeNode, traverseTree } from "../../../lib/tree";
import DeltaChip from "./DeltaChip";
import { transformFromLabelValue } from "./utils.ts";

// this interface is sadly not exposed by recharts
// see https://github.com/recharts/recharts/blob/master/src/component/LabelList.tsx#L13-L17
interface RechartsLabelListData {
    value?: number | string | Array<number | string>;
    payload?: any;
    parentViewBox?: ViewBox;
}

/**
 * Sums up the values of all the child nodes recursively in addition to node itself.
 */
function aggregateReferenceWithTreeChildren(
    payloadKey: string | number,
    treeData: NamedTreeNode[],
    referenceValues: Record<string, number>,
    selectedPath?: NamedTreeNode[],
): number {
    const node = traverseTree(treeData, findNode(Number(payloadKey)));

    let result = referenceValues[node?.id ?? "null"] ?? 0;

    if (node?.children !== undefined) {
        const childrenSum = sumBy(node.children, (child) =>
            aggregateReferenceWithTreeChildren(child.id, treeData, referenceValues, selectedPath),
        );
        result += childrenSum;
    }

    return result;
}

/**
 * Calculates the reference value:
 * If we have no tree structure, return the original reference value.
 * If we have a tree structure and the node is selected, then also return the original reference value
 * If we have a tree structure and the node is NOT selected, then calculate the sum of the tree node and all children
 */
function calculateReference(
    payload: any,
    axisDataKey: string,
    referenceValuesData: Record<string, number>,
    treeData?: NamedTreeNode[],
    selectedPath?: NamedTreeNode[],
): number {
    const payloadKey = get(payload, axisDataKey);
    const originalReferenceValue = referenceValuesData[payloadKey];

    if (!treeData) {
        return originalReferenceValue;
    }

    const isNodeSelected = selectedPath?.some((selectedNode) => selectedNode.id === Number(payloadKey));
    if (isNodeSelected) {
        return originalReferenceValue;
    }

    return aggregateReferenceWithTreeChildren(payloadKey, treeData, referenceValuesData, selectedPath);
}

type PayloadReferenceValue = {
    max: number;
    total: number;
    key: string;
    actual: number;
    reference: number;
};

/**
 * Calculates and returns the actual (sum of all bars) and reference values (for calculation details see other
 * functions).
 */
function getPayloadReferenceValues(
    data: any[],
    stackDataKeys: (string | number)[],
    axisDataKey: string,
    referenceValuesData: Record<string, number>,
    treeData?: NamedTreeNode[],
    selectedPath?: NamedTreeNode[],
): PayloadReferenceValue[] {
    return data.map(({ payload }) => {
        const stackValues = stackDataKeys.map((key) => get(payload, key)).filter((values) => typeof values === "number");

        const sumOfPositiveValues = sum(stackValues.filter((value) => value > 0));
        const sumOfNegativeValues = sum(stackValues.filter((value) => value < 0));

        return {
            // The max calculation is important because of the way ReCharts handles labels. By default, ReCharts renders
            // a label for each stack in a bar. However, for our use case, we want to display the label on the stack
            // with the most positive value.
            max:
                sumOfPositiveValues === 0
                    ? // If the sum of positive values equals to 0, it takes the first negative number from stackValues.
                      stackValues.filter((value) => value < 0)[0]
                    : // Otherwise, it takes the larger value between sumOfPositiveValues and sumOfNegativeValues.
                      Math.max(sumOfPositiveValues, sumOfNegativeValues),
            total: sumOfPositiveValues + sumOfNegativeValues,
            key: get(payload, axisDataKey),
            actual: sum(stackDataKeys.map((key) => get(payload, key)).filter((value): value is number => typeof value === "number")),
            reference: calculateReference(payload, axisDataKey, referenceValuesData, treeData, selectedPath),
        };
    });
}

function renderLabel(
    payloadReferenceValues: PayloadReferenceValue[],
    isRelativeRepresentation = false,
    isVerticalLayout = false,
    chartWidth?: number,
) {
    return ({ ...props }: LabelProps) => {
        const { value, index } = props;

        if (index === undefined || index < 0 || index > payloadReferenceValues.length - 1) {
            return null;
        }

        const values = transformFromLabelValue(value);

        let labelValue: number;
        if (Array.isArray(values)) {
            // We should not render bars that don't affect the total sum of all the bars. This is to ensure we render
            // the sum label only once.
            //
            // Within the values array:
            // - values[0] is the start value of the particular bar segment
            // - values[1] is the end value of the same bar segment
            //
            // If the start value equals the end value, it implies that this bar segment doesn't contribute to the
            // overall sum.
            if (values[0] === values[1]) {
                return null;
            }
            labelValue = Number(values[1]);
        } else {
            labelValue = Number(values);
        }

        const referenceValues = payloadReferenceValues[index];
        const compareValue = isRelativeRepresentation ? 1 : referenceValues.max;
        if (!nearlyEqual(labelValue, compareValue, 1e-4)) {
            return null;
        }

        let position = props.position;

        // Override the position when we only have negative bars
        if (labelValue < 0 && props.position === "right") {
            position = "left";
        } else if (labelValue < 0) {
            position = "bottom";
        }

        const actual = payloadReferenceValues[index].actual;
        const reference = payloadReferenceValues[index].reference ?? 0;
        const difference = nearlyEqual(actual, reference, 1e-4) ? 0 : actual - reference;

        return (
            <DeltaChip
                {...props}
                // remove this function from the props given to the label, otherwise there will be an endless loop of `renderLabel`
                content={undefined}
                chartWidth={chartWidth}
                isVerticalLayout={isVerticalLayout}
                value={difference}
                position={position}
            />
        );
    };
}

type DeltaLabelListProps<T extends RechartsLabelListData> = Props<T> & {
    stackDataKeys: (string | number)[];
    axisDataKey: string;
    referenceValues: { data: Record<string, number> } | null;
    treeData?: NamedTreeNode[];
    selectedPath?: NamedTreeNode[];
    chartWidth?: number;
    isRelativeRepresentation?: boolean;
    isVerticalLayout?: boolean;
};

export const DeltaLabelList = <T extends RechartsLabelListData>({
    stackDataKeys,
    data,
    treeData,
    selectedPath,
    referenceValues,
    chartWidth,
    isVerticalLayout,
    isRelativeRepresentation,
    axisDataKey,
    ...otherProps
}: DeltaLabelListProps<T>) => {
    const theme = useTheme();

    if (referenceValues === null || data === undefined) {
        return null;
    }

    const payloadReferenceValues = getPayloadReferenceValues(
        data,
        stackDataKeys,
        axisDataKey,
        referenceValues.data,
        treeData,
        selectedPath,
    );

    const position = isVerticalLayout ? "top" : "right";

    return (
        <LabelList
            position={position}
            fill={theme.palette.text.primary}
            style={{
                ...theme.typography.body2,
                fontSize: 12,
            }}
            content={renderLabel(payloadReferenceValues, isRelativeRepresentation, isVerticalLayout, chartWidth)}
            data={data}
            {...otherProps}
        />
    );
};

// This is required by recharts because otherwise this component is filtered out by the recharts `findAllByType` function
// that is collecting label list components for the label list layer.
DeltaLabelList.displayName = "LabelList";
