# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import pandas as pd


def format_columns(df):
    """Format a predefined set of statistical columns.

    - The first letter of each column name is capitalized.
    - The statistical columns are ordered according to a predefined sequence.
    - Any remaining columns that are not part of the predefined set are
      appended at the end.
    """
    formatted_df = df.rename(
        {
            "25%": "Q1",
            "50%": "Median",
            "75%": "Q3",
            0.25: "Q1",
            0.5: "Median",
            0.75: "Q3",
        },
        axis="columns",
    )

    formatted_df.columns = formatted_df.columns.str.title()
    formatted_df = formatted_df.rename_axis(index=str.title)

    stats_cols = ["Count", "Mean", "Std", "Min", "Q1", "Median", "Q3", "Max", "Sum"]
    other_columns = [col for col in formatted_df.columns if col not in stats_cols]

    return formatted_df[stats_cols + other_columns]


def aggregate_stats_df(df, index_col=None, column_agg_dict=None):
    """Aggregate statistical rows based on their index values.

    The input dataframe should be formatted using the 'format_columns'
    function. They should include all the statistical columns present in the
    default pandas describe() function, along with the sum column.

    Parameters
    ----------
    df : dataframe
        DataFrame to calculate statistics for.
    index_col : str or list, optional
        Name of one or more columns to group by. The default is the indices of
        the input dataframe.
    column_agg_dict: dict, optional
        Dictionary with additional columns to aggregate. The keys represent
        the column names and the values indicate the aggregation operations
        to apply to these columns.
    """
    if index_col is None:
        index_col = df.index.names

    if not isinstance(index_col, list):
        index_col = [index_col]

    stats_gdf = df.groupby(index_col)

    sum_total = stats_gdf["Sum"].sum()
    count_total = stats_gdf["Count"].sum()
    weighted_mean = sum_total / count_total

    aggregated_df = pd.DataFrame(
        {
            "Count": count_total,
            "Mean": weighted_mean,
            "Std (approx)": stats_gdf["Std"].mean(),
            "Min": stats_gdf["Min"].min(),
            "Q1 (approx)": stats_gdf["Q1"].min(),
            "Median (approx)": stats_gdf["Median"].median(),
            "Q3 (approx)": stats_gdf["Q3"].max(),
            "Max": stats_gdf["Max"].max(),
            "Sum": sum_total,
        },
        index=stats_gdf.groups.keys(),
    )

    aggregated_df.index.names = index_col

    if column_agg_dict is not None:
        for key, value in column_agg_dict.items():
            aggregated_df[key] = stats_gdf[key].agg(value)

    return aggregated_df.round(1)


def describe_column(series_groupby):
    """Generate summary statistics for a single column of a grouped series.

    This function extends the pandas describe() function by including the
    'sum' column for the given column of a grouped dataframe.

    Parameters
    ----------
    series_groupby : SeriesGroupBy
        Grouped series to describe. This should contain only a single numerical
        column.

    Returns
    -------
    stats_df : dataframe
        Dataframe containing summary statistics, with the grouping keys as the
        index.
    """
    if not series_groupby.ngroups:
        return None

    agg_df = series_groupby.agg(["min", "max", "count", "std", "mean", "sum"])
    quantile_df = series_groupby.quantile([0.25, 0.5, 0.75])

    quantile_df = quantile_df.unstack()
    quantile_df.columns = ["25%", "50%", "75%"]

    stats_df = pd.merge(agg_df, quantile_df, left_index=True, right_index=True)
    return format_columns(stats_df).round(1)


def describe_columns(df_groupby, index_name):
    """Generate summary statistics for a grouped dataframe.

    This function extends the pandas describe() function by including the
    'sum' column for the given dataframe grouped by a specific key.

    Parameters
    ----------
    df_groupby : DataFrameGroupBy
        Grouped dataframe to describe. It should contain only numeric columns.
    index_name : str
        Name of the index in the output dataframe, with values corresponding
        to the original column names.

    Returns
    -------
    stats_df : DataFrame
        DataFrame containing summary statistics, with the grouping keys and the
        'index_name' as the indices.
    """
    if not df_groupby.ngroups:
        return None

    # This gives a multi-level column with metric names and aggregators.
    agg_df = df_groupby.agg(["min", "max", "count", "std", "mean", "sum"])
    # Convert the metric names into index rows.
    agg_df = agg_df.stack(level=0)

    # This gives a multi-level index with the existing grouping keys and the
    # quantiles.
    quantile_df = df_groupby.quantile([0.25, 0.5, 0.75])
    # Convert the quantiles into columns.
    quantile_df = quantile_df.unstack(level=-1)
    # Convert the metric names into index rows.
    quantile_df = quantile_df.stack(level=0)

    stats_df = pd.concat([agg_df, quantile_df], axis=1)

    stats_df.index.names = stats_df.index.names[:-1] + [index_name]
    return format_columns(stats_df).round(1)


def describe_df(df, index_name):
    """Generate summary statistics for a dataframe.

    This function extends the pandas describe() function by including the
    'sum' column. The output dataframe will have the original column names
    as indices and the aggregators as columns.

    Parameters
    ----------
    df : dataframe
        DataFrame with columns to describe. It should contain only numeric
        columns.
    index_name : str
        Name of the index in the output dataframe, with values corresponding
        to the original column names.

    Returns
    -------
    stats_df : dataframe
        Dataframe containing summary statistics.
    """
    agg_df = df.agg(["min", "max", "count", "std", "mean", "sum"])
    quantile_df = df.quantile([0.25, 0.5, 0.75])
    stats_df = pd.concat([agg_df, quantile_df]).T

    stats_df[index_name] = stats_df.index
    stats_df = stats_df.set_index(index_name)

    stats_df["count"] = stats_df["count"].astype(int)
    return format_columns(stats_df).round(1)
