# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import abc
from collections import defaultdict

import numpy as np
import pandas as pd

from nsys_recipe.lib import data_utils, overlap

DEFAULT_DOMAIN_ID = 0
EVENT_TYPE_NVTX_DOMAIN_CREATE = 75
EVENT_TYPE_NVTX_PUSHPOP_RANGE = 59
EVENT_TYPE_NVTX_STARTEND_RANGE = 60


def filter_by_domain_id(nvtx_df, domain_id):
    """Get push/pop and start/end ranges with the specified 'domain_id'."""
    return nvtx_df[
        (nvtx_df["domainId"] == domain_id)
        & (
            nvtx_df["eventType"].isin(
                [EVENT_TYPE_NVTX_PUSHPOP_RANGE, EVENT_TYPE_NVTX_STARTEND_RANGE]
            )
        )
    ]


def filter_by_domain_name(nvtx_df, domain_name):
    """Get push/pop and start/end ranges with the specified 'domain_name'."""
    domain_df = nvtx_df[
        (nvtx_df["eventType"] == EVENT_TYPE_NVTX_DOMAIN_CREATE)
        & (nvtx_df["text"] == domain_name)
    ]

    if domain_df.empty:
        return domain_df

    domain_id = domain_df["domainId"].iloc[0]

    return filter_by_domain_id(nvtx_df, domain_id)


def combine_text_fields(nvtx_df, str_df):
    """Combine the 'text' and 'textId' fields of the NVTX dataframe.

    This function simplifies the lookup process for accessing the event
    message. The 'text' field corresponds to the NVTX event message passed
    through 'nvtxDomainRegisterString', while the 'textId' field corresponds
    to the other case. By merging these fields, we streamline the process of
    retrieving the message.
    """
    if not nvtx_df["textId"].notnull().any():
        return nvtx_df.copy()

    nvtx_textId_df = data_utils.replace_id_with_value(
        nvtx_df, str_df, "textId", "textStr"
    )
    mask = ~nvtx_textId_df["textStr"].isna()
    nvtx_textId_df.loc[mask, "text"] = nvtx_textId_df.loc[mask, "textStr"]
    return nvtx_textId_df.drop(columns=["textStr"])


def _compute_hierarchy_info(nvtx_df, nvtx_stream_map=None):
    hierarchy_df = nvtx_df.assign(parentId=None, stackLevel=0, rangeStack=None)
    stack = []

    exclude_ranges_w_gpu_ops_on_diff_streams = nvtx_stream_map is not None

    for row in hierarchy_df.itertuples():
        while stack and row.end > stack[-1].end:
            stack.pop()

        stack.append(row)

        if exclude_ranges_w_gpu_ops_on_diff_streams:
            # Exclude ranges from the stack where the GPU operations are run on
            # different streams by checking if the intersection of their streams is
            # non-empty.
            current_stack = [
                r
                for r in stack
                if nvtx_stream_map[r.originalIndex] & nvtx_stream_map[row.originalIndex]
            ]
        else:
            current_stack = stack

        # The current row is the last element of the stack.
        hierarchy_df.at[row.Index, "parentId"] = (
            current_stack[-2].Index if len(current_stack) > 1 else np.nan
        )
        # The stack level starts at 0.
        hierarchy_df.at[row.Index, "stackLevel"] = len(current_stack) - 1
        hierarchy_df.at[row.Index, "rangeStack"] = [r.Index for r in current_stack]

    hierarchy_df = hierarchy_df.reset_index().rename(columns={"index": "rangeId"})
    children_count = hierarchy_df["parentId"].value_counts()
    hierarchy_df["childrenCount"] = (
        hierarchy_df["rangeId"].map(children_count).fillna(0).astype(int)
    )
    # Convert to Int64 to support missing values (pd.NA) while keeping
    # integer type.
    hierarchy_df["parentId"] = hierarchy_df["parentId"].astype("Int64")

    return hierarchy_df


def _aggregate_cuda_ranges(
    cuda_df, cuda_nvtx_index_map, innermost_nvtx_indices, row_offset, compute_hierarchy
):
    # Each NVTX index will be associated with the minimum start time and the
    # maximum end time of the CUDA operations that the corresponding NVTX range
    # encloses.
    nvtx_gpu_start_dict = {}
    nvtx_gpu_end_dict = {}

    indices = []
    starts = []
    ends = []

    nvtx_stream_map = defaultdict(set)

    for cuda_row in cuda_df.itertuples():
        if cuda_row.Index not in cuda_nvtx_index_map:
            continue

        nvtx_indices = cuda_nvtx_index_map[cuda_row.Index]

        for nvtx_index in nvtx_indices:
            nvtx_stream_map[nvtx_index].add(cuda_row.streamId)

            start = cuda_row.gpu_start
            end = cuda_row.gpu_end

            # Handle cases where the innermost NVTX range encloses CUDA events
            # that result in multiple GPU ranges (e.g. CUDA graphs). In this
            # case, we don't group them and keep them as separate NVTX ranges.
            if (
                hasattr(cuda_row, "groupId")
                and not pd.isna(cuda_row.groupId)
                and nvtx_index in innermost_nvtx_indices
            ):
                indices.append(nvtx_index)
                starts.append(start)
                ends.append(end)
                continue

            if nvtx_index not in nvtx_gpu_start_dict:
                nvtx_gpu_start_dict[nvtx_index] = start
                nvtx_gpu_end_dict[nvtx_index] = end
                continue

            if start < nvtx_gpu_start_dict[nvtx_index]:
                nvtx_gpu_start_dict[nvtx_index] = start
            if end > nvtx_gpu_end_dict[nvtx_index]:
                nvtx_gpu_end_dict[nvtx_index] = end

    indices += list(nvtx_gpu_start_dict.keys())
    starts += list(nvtx_gpu_start_dict.values())
    ends += list(nvtx_gpu_end_dict.values())

    df = (
        pd.DataFrame({"originalIndex": indices, "start": starts, "end": ends})
        # Preserve original order for rows with identical "start" and "end"
        # values using the index.
        .sort_values(
            by=["start", "end", "originalIndex"], ascending=[True, False, True]
        ).reset_index(drop=True)
    )

    if not compute_hierarchy:
        return df

    df.index = range(row_offset, row_offset + len(df))
    return _compute_hierarchy_info(df, nvtx_stream_map)


def _compute_gpu_projection_df(
    cuda_df,
    group_columns,
    cuda_nvtx_index_map,
    innermost_nvtx_indices,
    row_offset,
    compute_hierarchy,
):
    if group_columns:
        cuda_gdf = cuda_df.groupby(group_columns)
    else:
        cuda_gdf = [(None, cuda_df)]

    dfs = []
    for group_keys, cuda_group_df in cuda_gdf:
        df = _aggregate_cuda_ranges(
            cuda_group_df,
            cuda_nvtx_index_map,
            innermost_nvtx_indices,
            row_offset,
            compute_hierarchy,
        )
        if df.empty:
            continue

        row_offset += len(df)

        for key in group_columns:
            df[key] = group_keys[group_columns.index(key)]

        dfs.append(df)

    if not dfs:
        return pd.DataFrame()

    return pd.concat(dfs, ignore_index=True)


def _validate_group_columns(df, group_columns):
    if isinstance(group_columns, str):
        group_columns = [group_columns]
    elif group_columns is None:
        group_columns = []

    for col in group_columns:
        if col not in df.columns:
            raise ValueError(f"Column '{col}' not found in the DataFrame.")

    return group_columns


def _get_innermost_nvtx_indices(nvtx_df):
    parent_nvtx_df = pd.Series(np.nan, index=nvtx_df.index)
    stack = []

    for row in nvtx_df.itertuples():
        while stack and row.end > stack[-1].end:
            stack.pop()

        if stack:
            parent_nvtx_df[row.Index] = stack[-1].Index

        stack.append(row)

    parent_ids = parent_nvtx_df.dropna().unique()
    return set(parent_nvtx_df[~parent_nvtx_df.index.isin(parent_ids)].index)


def project_nvtx_onto_gpu(
    nvtx_df, cuda_df, group_columns=None, compute_hierarchy=False
):
    """Project the NVTX ranges from the CPU onto the GPU.

    The projected range will have the start timestamp of the first enclosed GPU
    operation and the end timestamp of the last enclosed GPU operation.

    Parameters
    ----------
    nvtx_df : pd.DataFrame
        DataFrame containing NVTX ranges.
    cuda_df : pd.DataFrame
        DataFrame containing CUDA events. It must contain both the runtime and
        GPU operations.
    group_columns : str or list of str, optional
        Column names in the CUDA table by which events should be grouped when
        projecting the associated NVTX range onto the GPU.
    compute_hierarchy : bool, optional
        Whether to compute the hierarchy information for the projected NVTX
        ranges, which includes:
        - stackLevel: Level of the range in the stack.
        - parentId: ID of the parent range.
        - rangeStack: IDs of the ranges that make up the stack.
        - childrenCount: Number of child ranges.
        - rangeId: Arbitrary ID assigned to the range.

    Returns
    -------
    proj_nvtx_df : pd.DataFrame
        DataFrame containing the projected NVTX ranges.
    """
    group_columns = _validate_group_columns(cuda_df, group_columns)

    # Filter ranges that are incomplete or end on a different thread.
    filtered_nvtx_df = nvtx_df[
        nvtx_df["start"].notnull()
        & nvtx_df["end"].notnull()
        & nvtx_df["endGlobalTid"].isnull()
    ]

    nvtx_gdf = filtered_nvtx_df.groupby("globalTid")
    cuda_gdf = cuda_df.groupby("globalTid")

    dfs = []
    total_rows = 0

    for global_tid, nvtx_tid_df in nvtx_gdf:
        if global_tid not in cuda_gdf.groups:
            continue

        cuda_tid_df = cuda_gdf.get_group(global_tid)
        cuda_nvtx_index_map = overlap.map_overlapping_ranges(
            nvtx_tid_df, cuda_tid_df, fully_contained=True
        )

        innermost_nvtx_indices = _get_innermost_nvtx_indices(nvtx_tid_df)
        df = _compute_gpu_projection_df(
            cuda_tid_df,
            group_columns,
            cuda_nvtx_index_map,
            innermost_nvtx_indices,
            total_rows,
            compute_hierarchy,
        )
        if df.empty:
            continue

        total_rows += len(df)

        df["text"] = df["originalIndex"].map(nvtx_tid_df["text"])
        # The values of pid and tid are the same within each group of globalTid.
        for col in ["pid", "tid"]:
            df[col] = nvtx_tid_df[col].iat[0]

        df = df.drop(columns=["originalIndex"])

        dfs.append(df)

    if not dfs:
        return pd.DataFrame(
            columns=[
                "text",
                "start",
                "end",
                "pid",
                "tid",
                "stackLevel",
                "parentId",
                "rangeStack",
                "childrenCount",
                "rangeId",
            ]
            + group_columns
        )

    return pd.concat(dfs, ignore_index=True)


class GroupingStrategy(abc.ABC):
    @property
    @abc.abstractmethod
    def key_column(self):
        """
        The key column by which NVTX ranges are grouped.
        """
        pass

    @property
    @abc.abstractmethod
    def all_columns(self):
        """
        All columns that are involved in grouping one way or another (they can safely
        be aggregated using "first" function):
            - key column,
            - columns involved in the key column calculation,
            - and other columns.
        """
        pass

    @abc.abstractmethod
    def fill_columns(self, nvtx_df):
        """
        Fill key columns to be used for grouping and any additional columns if required.
        """
        pass


class FlatGroupingStrategy(abc.ABC):
    """
    A strategy for grouping NVTX ranges with the same text, domain ID and PID index.
    """

    def __init__(self):
        # Because PID differs from run to run, we use the PID index
        # rather than the PID itself.
        self._pin_idx_column = "pinIdx"
        self._input_columns = ["text", "domainId", self._pin_idx_column]

    @property
    def key_column(self):
        return "flatId"

    @property
    def all_columns(self):
        return [self.key_column] + self._input_columns

    def fill_columns(self, nvtx_df):
        if self._pin_idx_column not in nvtx_df.columns:
            nvtx_df[self._pin_idx_column] = nvtx_df.groupby("pid", sort=False).ngroup()
        if self.key_column not in nvtx_df.columns:
            nvtx_df[self.key_column] = nvtx_df.apply(
                lambda x: hash(tuple(x[col] for col in self._input_columns)), axis=1
            )


class TopDownGroupingStrategy(abc.ABC):
    """
    A strategy for grouping NVTX ranges with the same text,
    domain ID, PID index and NVTX call stack.
    To use this strategy, NVTX ranges should either contain stack information
    or have precomputed key columns.
    """

    def __init__(self):
        self._flat_strategy = FlatGroupingStrategy()

    @property
    def key_column(self):
        return "topDownId"

    @property
    def par_key_column(self):
        return "parTopDownId"

    @property
    def all_columns(self):
        return [self.key_column, self.par_key_column] + self._flat_strategy.all_columns

    def fill_columns(self, nvtx_df):
        self._flat_strategy.fill_columns(nvtx_df)

        if (
            self.key_column in nvtx_df.columns
            and self.par_key_column in nvtx_df.columns
        ):
            return

        range_to_flat_id_map = pd.Series(
            nvtx_df[self._flat_strategy.key_column].values, index=nvtx_df.rangeId
        ).to_dict()

        def get_topdown_id(range_stack):
            return hash(tuple(range_to_flat_id_map.get(id, None) for id in range_stack))

        nvtx_df[self.key_column] = nvtx_df["rangeStack"].map(get_topdown_id)

        range_to_topdown_id_map = pd.Series(
            nvtx_df[self.key_column].values, index=nvtx_df.rangeId
        ).to_dict()
        nvtx_df[self.par_key_column] = nvtx_df["parentId"].map(
            lambda x: range_to_topdown_id_map.get(x, -1)
        )


class NvtxGrouper:
    """
    A class for grouping NVTX ranges according to a given grouping strategy.
    It allows to aggregate and obtain middle NVTX ranges for grouped data.
    This grouper can be reused to aggregate already grouped data in different ways.
    """

    def __init__(self, nvtx_df, nvtx_grouping_strategy):
        self._df = nvtx_df
        self._grouping_strategy = nvtx_grouping_strategy
        self._df_grouped = self._group(self._df, self._grouping_strategy)

    @property
    def df(self):
        return self._df

    @property
    def df_grouped(self):
        return self._df_grouped

    def _group(self, df, grouping_strategy):
        """
        Group NVTX ranges by the `grouping_strategy` provided.
        """
        grouping_strategy.fill_columns(df)
        return df.groupby(grouping_strategy.key_column)

    def aggregate(self, col_to_agg_func_map, rest_col_agg_funcs=None):
        """
        Aggregate NVTX ranges for `self.df_grouped` grouped data.
        Data is aggregated using provided functions from the `col_to_agg_func_map`
        dictionary, which maps a specific column to an aggregation function,
        and from the `rest_col_agg_funcs` list (optional),
        which represents common aggregation functions for the remaining columns
        uncovered in `col_to_agg_func_map`.
        """
        for column in self._grouping_strategy.all_columns:
            if column not in col_to_agg_func_map:
                col_to_agg_func_map[column] = "first"

        if rest_col_agg_funcs is not None:
            col_to_agg_func_map = {
                col: (
                    col_to_agg_func_map[col]
                    if col in col_to_agg_func_map.keys()
                    else rest_col_agg_funcs
                )
                for col in self._df.columns
            }

        agg_df = self.df_grouped.agg(col_to_agg_func_map).reset_index(drop=True)

        if isinstance(agg_df.columns, pd.MultiIndex):
            flattened_columns = []
            for col in agg_df.columns:
                name, agg_func = col
                flattened_columns.append(
                    name
                    if type(col_to_agg_func_map[name]) is str
                    else f"{name}_{agg_func}"
                )
            agg_df.columns = flattened_columns

        return agg_df

    def mid_ranges(self):
        """
        Get middle NVTX ranges for `self.df_grouped` grouped data.
        """
        res = self.df_grouped.apply(lambda x: x.iloc[len(x) // 2])
        return res.reset_index(drop=True)


def compute_callstack(nvtx_df):
    """
    Accompany NVTX ranges with the NVTX call stack information.

    """
    dfs = [
        _compute_hierarchy_info(nvtx_gtid_df)
        for _, nvtx_gtid_df in nvtx_df.groupby("globalTid")
    ]
    return pd.concat(dfs).sort_values(by=["start"]).reset_index(drop=True)


def add_original_indices(nvtx_df):
    """
    Add `originalIndices` column to the NVTX ranges.

    The `originalIndices` column is a list that contains
    the original index of the NVTX range.
    This column is further used to reference the original
    NVTX ranges when aggregating them.
    """
    nvtx_df["originalIndices"] = nvtx_df.index.map(lambda i: [i])
    return nvtx_df


def consolidate_parallel_ranges(nvtx_df):
    """
    Get NVTX ranges aggregated from parallel threads.
    The input DataFrame `nvtx_df` represents the SQL table `NVTX_EVENTS`
    with `tid`, `pid` columns and call stack information.
    It is assumed to be sorted by start time.

    The resulting DataFrame will contain agg. NVTX ranges
    with reference to the original ranges being aggregated (NVTX indices).
    """

    # TopDownGroupingStrategy presents data for grouping NVTX ranges with
    # the same text, domain ID, PID index and NVTX call stack.
    td_strategy = TopDownGroupingStrategy()

    if "originalIndices" not in nvtx_df.columns:
        nvtx_df = add_original_indices(nvtx_df)

    col_to_agg_func_map = {
        "originalIndices": "sum",
        "tid": list,
        "globalTid": list,
        "pid": "first",
        "start": "min",
        "end": "max",
    }

    # Columns: text, domainId, pidIdx, flatId, topDownId, parTopDownId and etc.
    for col in td_strategy.all_columns:
        col_to_agg_func_map[col] = "first"

    hierarchy_info_columns = [
        "rangeId",
        "parentId",
        "stackLevel",
        "rangeStack",
        "childrenCount",
    ]
    for col in hierarchy_info_columns:
        col_to_agg_func_map[col] = "first"

    all_columns = list(col_to_agg_func_map.keys()) + list(nvtx_df.columns)
    col_to_agg_func_map = {
        col: col_to_agg_func_map[col] if col in col_to_agg_func_map.keys() else "sum"
        for col in all_columns
    }

    # Check if all NVTX ranges were happened on the same thread and process.
    # In this case we don't need to check for overlapping/concurrent ranges.
    if nvtx_df["globalTid"].nunique() == 1:
        return nvtx_df

    dfs = []
    for _, nvtx_group_df in NvtxGrouper(nvtx_df, td_strategy).df_grouped:

        # Check if group of NVTX ranges was happened on the same thread.
        # In this case we don't need to check for overlapping/parallel ranges.
        if nvtx_group_df["tid"].nunique() == 1:
            dfs.append(nvtx_group_df)
            continue

        nvtx_group_df["overlap"] = overlap.group_overlapping_ranges(nvtx_group_df)
        # We sort by (`is_main_tid`, `start`) to get the data for the
        # `hierarchy_info_columns` from the range in the main thread or/and from the
        #  range with minimum start time.
        nvtx_group_df["is_main_tid"] = nvtx_group_df["tid"] == nvtx_group_df["pid"]
        nvtx_group_df.sort_values(
            by=["is_main_tid", "start"], ascending=[False, True], inplace=True
        )
        consolidated_range_df = nvtx_group_df.groupby("overlap").agg(
            col_to_agg_func_map
        )
        dfs.append(consolidated_range_df)

    return pd.concat(dfs).sort_values(by=["start"]).reset_index(drop=True)
