import { Theme, alpha } from "@mui/material";
import { UseQueryResult } from "@tanstack/react-query";
import { AggregationMethod, PivotRowDto, ProcessWhiteSpotMatrixWidgetConfig } from "api-shared";
import { groupBy, mapValues, sum, sumBy } from "lodash";

export function computeMatrixData(
    data: PivotRowDto[] | undefined,
    rowPivotField: string | null,
    columnPivotField: string | null,
): Map<string, Map<string, number>> {
    // Perform null checking here to reduce cognitive load in the component
    if (rowPivotField == null || columnPivotField == null || data === undefined) {
        return new Map();
    }

    const groupedDataFirst = groupBy(data, (d) => d.fields[rowPivotField]);
    const mapEntries = Object.entries(groupedDataFirst).map(([firstPivotFieldValue, firstPivotFieldData]) => {
        const groupedDataSecond = groupBy(firstPivotFieldData, (d) => d.fields[columnPivotField]);
        const d = mapValues(groupedDataSecond, (values) => sumBy(values, (row) => row.value));
        return [firstPivotFieldValue, new Map(Object.entries(d))] as const;
    });
    return new Map(mapEntries);
}

export function aggregateQueryStates(...queries: UseQueryResult[]) {
    // add further flags as needed
    return {
        isAnyQueryError: queries.some((q) => q.isError),
        isAnyQueryLoading: queries.some((q) => q.isLoading),
        areAllQueriesSuccess: queries.every((q) => q.isSuccess),
    };
}

export function computeMatrixSums(matrixData: Map<string, Map<string, number>>) {
    const rowSums = new Map([...matrixData.entries()].map(([key, value]) => [key, sum([...value.values()])]));
    const columnSums = [...matrixData.values()]
        .flatMap((value) => [...value.entries()])
        .reduce((sums, [key, value]) => sums.set(key, (sums.get(key) ?? 0) + value), new Map<string, number>());

    const totalSum = sum([...columnSums.values()]);
    return { rowSums, columnSums, totalSum };
}

export function getCellColor(
    value: number | undefined,
    totalSum: number,
    stepCount: number,
    color: string,
    theme: Theme,
): { backgroundColor?: string; color?: string } {
    if (value == undefined) {
        return {};
    }

    const discreteAlpha = scaleStep(totalSum, 1, value, stepCount);
    const backgroundColor = alpha(color, discreteAlpha);
    const fontColor = value < 0 ? theme.palette.error.dark : undefined;
    return {
        backgroundColor,
        ...(fontColor && { color: fontColor }),
    };
}

/**
 * Scale parameter x from an input interval [0, reference] towards a target interval [0, target] and apply
 * a step function with stepCount steps.
 *
 * @export
 * @param {number} reference
 * @param {number} target
 * @param {number} x
 * @param {number} stepCount
 * @returns {number}
 */
export function scaleStep(reference: number, target: number, x: number, stepCount: number): number {
    if (reference === 0) {
        return 0;
    }

    // 0-step is always the first step
    if (stepCount <= 1) {
        return 0;
    }

    const fraction = x / reference;
    const stepIndex = stepFunction(fraction, stepCount);
    const stepSize = target / (stepCount - 1);
    return stepSize * stepIndex;
}

/**
 * Convert a continuous value ∈ [0, 1] to a discrete step index ∈ [0, stepCount - 1]
 *
 * Use a closed formula instead of iterating over the step intervals for better performance.
 *
 * @export
 * @param {number} fraction
 * @param {number} stepCount
 * @returns step index ∈ [0, stepCount - 1]
 */
export function stepFunction(fraction: number, stepCount: number) {
    // first and last steps should be open intervals
    // x ∈ [-∞, 0] => 0
    if (fraction <= 0) {
        return 0;
    }
    // x ∈ [1, ∞] => 1
    if (fraction >= 1) {
        return stepCount - 1;
    }

    const innerIntervalCount = stepCount - 2;
    // The remaining stepCount -2 steps are divided equally among the remaining interval [0, 1]
    // Only Math.ceil will not work on sharp boundaries,
    // e.g. stepCount = 6 and x = 0.25 would yield index 1 instead of 2
    // -> use Math.floor() + 1 instead
    return Math.floor(fraction * innerIntervalCount) + 1;
}

export function getReferenceValue(
    matrixData: Map<string, Map<string, number>>,
    config: Pick<ProcessWhiteSpotMatrixWidgetConfig, "useManualMaxValues" | "aggregation" | "maxDisplayCount" | "maxDisplayPotential">,
) {
    if (config.useManualMaxValues) {
        return config.aggregation === AggregationMethod.Sum ? config.maxDisplayPotential : config.maxDisplayCount;
    }

    const cellValues = [...matrixData.values()].flatMap((map) => [...map.values()]);
    return cellValues.length > 0 ? Math.max(...cellValues) : 0;
}
