import { useTheme } from "@mui/material";
import { nearlyEqual } from "api-shared";
import { get, sum } from "lodash";
import { Label, LabelList, LabelProps } from "recharts";
import { Props } from "recharts/types/component/LabelList";
import { ViewBox } from "recharts/types/util/types";
import { transformFromLabelValue, transformToLabelValue } from "./utils.ts";

// this interface is sadly not exposed by recharts
// see https://github.com/recharts/recharts/blob/master/src/component/LabelList.tsx#L13-L17
interface RechartsLabelListData {
    value?: number | string | Array<number | string>;
    payload?: any;
    parentViewBox?: ViewBox;
}

export type SumLabelListProps<T extends RechartsLabelListData> = Props<T> & {
    chartWidth?: number;
    yAxisPosition?: number;
    isHorizontal?: boolean;
    isRelativeRepresentation?: boolean;
    hideEverySecondLabel?: boolean;
    stackDataKeys: (string | number)[];
};

type PayloadReferenceValue = {
    max: number;
    total: number;
};

/**
 * Calculate sum of positive and negative values for each payload in the data array. This is used to identify the last
 * stack in each bar of a ReCharts chart, for which we want to show the label.
 */
function getPayloadReferenceValues(data: any[], stackDataKeys: (string | number)[]): PayloadReferenceValue[] {
    return data.map(({ payload }) => {
        const stackValues = stackDataKeys.map((key) => get(payload, key)).filter((values) => typeof values === "number");

        const sumOfPositiveValues = sum(stackValues.filter((value) => value > 0));
        const sumOfNegativeValues = sum(stackValues.filter((value) => value < 0));

        return {
            // The max calculation is important because of the way ReCharts handles labels. By default, ReCharts renders
            // a label for each stack in a bar. However, for our use case, we want to display the label on the stack
            // with the most positive value.
            max:
                sumOfPositiveValues === 0
                    ? // If the sum of positive values equals to 0, it takes the first negative number from stackValues.
                      stackValues.filter((value) => value < 0)[0]
                    : // Otherwise, it takes the larger value between sumOfPositiveValues and sumOfNegativeValues.
                      Math.max(sumOfPositiveValues, sumOfNegativeValues),
            total: sumOfPositiveValues + sumOfNegativeValues,
        };
    });
}

function renderLabel(payloadReferenceValues: PayloadReferenceValue[], isRelativeRepresentation: boolean, hideEverySecondLabel: boolean) {
    return ({ ...props }: LabelProps) => {
        const { value, index } = props;

        if (index === undefined || index < 0 || index > payloadReferenceValues.length - 1) {
            return null;
        }

        // This is a very simple way to show the sums for every second label.
        // For now, we only use it in the WeeklySavingsRunUpChart. In the future we should  replace this with a better solution.
        // See https://valued.atlassian.net/browse/DEV-4185
        if (hideEverySecondLabel && (index + 1) % 2) {
            return false;
        }

        const values = transformFromLabelValue(value);

        let labelValue: number;
        if (Array.isArray(values)) {
            // We should not render bars that don't affect the total sum of all the bars. This is to ensure we render
            // the sum label only once.
            //
            // Within the values array:
            // - values[0] is the start value of the particular bar segment
            // - values[1] is the end value of the same bar segment
            //
            // If the start value equals the end value, it implies that this bar segment doesn't contribute to the
            // overall sum.
            if (values[0] === values[1]) {
                return null;
            }
            labelValue = Number(values[1]);
        } else {
            labelValue = Number(values);
        }

        const referenceValues = payloadReferenceValues[index];
        const compareValue = isRelativeRepresentation ? 1 : referenceValues.max;
        if (!nearlyEqual(labelValue, compareValue)) {
            return null;
        }

        let position = props.position;

        // Override the position when we only have negative bars
        if (labelValue < 0 && props.position === "right") {
            position = "left";
        } else if (labelValue < 0) {
            position = "bottom";
        }

        return (
            <Label
                {...props}
                // remove this function from the props given to the label, otherwise there will be an endless loop of `renderLabel`
                content={undefined}
                value={payloadReferenceValues[index].total}
                position={position}
            />
        );
    };
}

const SumLabelList = <T extends RechartsLabelListData>({
    chartWidth,
    isHorizontal,
    data,
    stackDataKeys,
    yAxisPosition,
    isRelativeRepresentation = false,
    hideEverySecondLabel = false,
    ...otherProps
}: SumLabelListProps<T>) => {
    const theme = useTheme();

    if (data === undefined) {
        return null;
    }

    const defaultPosition = isHorizontal ? "right" : "top";
    const payloadReferenceValues = getPayloadReferenceValues(data, stackDataKeys);

    // default value accessor of `LabelList` is using data.value or last(data.value) if value is a array
    return (
        <LabelList
            position={defaultPosition}
            fill={theme.palette.text.primary}
            style={{
                ...theme.typography.body2,
            }}
            content={renderLabel(payloadReferenceValues, isRelativeRepresentation, hideEverySecondLabel)}
            data={data}
            {...otherProps}
            valueAccessor={(entry: RechartsLabelListData) => transformToLabelValue(entry.value)}
        />
    );
};

// This is required by ReCharts because otherwise this component is filtered out by the ReCharts `findAllByType` function
// that is collecting label list components for the label list layer.
SumLabelList.displayName = "LabelList";

export default SumLabelList;
