# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# =========================================================================
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
# which has the following license...
# https://github.com/pytorch/vision/blob/main/LICENSE

# BSD 3-Clause License

# Copyright (c) Soumith Chintala 2016,
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.

# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.

# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
"""
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
"""

from __future__ import annotations

from collections.abc import Callable

import torch
from torch import Tensor

from monai.data.box_utils import batched_nms, box_iou, clip_boxes_to_image
from monai.transforms.utils_pytorch_numpy_unification import floor_divide


class BoxSelector:
    """
    Box selector which selects the predicted boxes.
    The box selection is performed with the following steps:

    #. For each level, discard boxes with scores less than self.score_thresh.
    #. For each level, keep boxes with top self.topk_candidates_per_level scores.
    #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
    #. For the whole image, keep boxes with top self.detections_per_img scores.

    Args:
        apply_sigmoid: whether to apply sigmoid to get scores from classification logits
        score_thresh: no box with scores less than score_thresh will be kept
        topk_candidates_per_level: max number of boxes to keep for each level
        nms_thresh: box overlapping threshold for NMS
        detections_per_img: max number of boxes to keep for each image

    Example:

        .. code-block:: python

            input_param = {
                "apply_sigmoid": True,
                "score_thresh": 0.1,
                "topk_candidates_per_level": 2,
                "nms_thresh": 0.1,
                "detections_per_img": 5,
            }
            box_selector = BoxSelector(**input_param)
            boxes = [torch.randn([3,6]), torch.randn([7,6])]
            logits = [torch.randn([3,3]), torch.randn([7,3])]
            spatial_size = (8,8,8)
            selected_boxes, selected_scores, selected_labels = box_selector.select_boxes_per_image(
                boxes, logits, spatial_size
            )
    """

    def __init__(
        self,
        box_overlap_metric: Callable = box_iou,
        apply_sigmoid: bool = True,
        score_thresh: float = 0.05,
        topk_candidates_per_level: int = 1000,
        nms_thresh: float = 0.5,
        detections_per_img: int = 300,
    ):
        self.box_overlap_metric = box_overlap_metric

        self.apply_sigmoid = apply_sigmoid
        self.score_thresh = score_thresh
        self.topk_candidates_per_level = topk_candidates_per_level
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img

    def select_top_score_idx_per_level(self, logits: Tensor) -> tuple[Tensor, Tensor, Tensor]:
        """
        Select indices with highest scores.

        The indices selection is performed with the following steps:

        #. If self.apply_sigmoid, get scores by applying sigmoid to logits. Otherwise, use logits as scores.
        #. Discard indices with scores less than self.score_thresh
        #. Keep indices with top self.topk_candidates_per_level scores

        Args:
            logits: predicted classification logits, Tensor sized (N, num_classes)

        Return:
            - topk_idxs: selected M indices, Tensor sized (M, )
            - selected_scores: selected M scores, Tensor sized (M, )
            - selected_labels: selected M labels, Tensor sized (M, )
        """
        num_classes = logits.shape[-1]

        # apply sigmoid to classification logits if asked
        if self.apply_sigmoid:
            scores = torch.sigmoid(logits.to(torch.float32)).flatten()
        else:
            scores = logits.flatten()

        # remove low scoring boxes
        keep_idxs = scores > self.score_thresh
        scores = scores[keep_idxs]
        flatten_topk_idxs = torch.where(keep_idxs)[0]

        # keep only topk scoring predictions
        num_topk = min(self.topk_candidates_per_level, flatten_topk_idxs.size(0))
        selected_scores, idxs = scores.to(torch.float32).topk(
            num_topk
        )  # half precision not implemented for cpu float16
        flatten_topk_idxs = flatten_topk_idxs[idxs]

        selected_labels = flatten_topk_idxs % num_classes

        topk_idxs = floor_divide(flatten_topk_idxs, num_classes)
        return topk_idxs, selected_scores, selected_labels  # type: ignore

    def select_boxes_per_image(
        self, boxes_list: list[Tensor], logits_list: list[Tensor], spatial_size: list[int] | tuple[int]
    ) -> tuple[Tensor, Tensor, Tensor]:
        """
        Postprocessing to generate detection result from classification logits and boxes.

        The box selection is performed with the following steps:

        #. For each level, discard boxes with scores less than self.score_thresh.
        #. For each level, keep boxes with top self.topk_candidates_per_level scores.
        #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
        #. For the whole image, keep boxes with top self.detections_per_img scores.

        Args:
            boxes_list: list of predicted boxes from a single image,
                each element i is a Tensor sized (N_i, 2*spatial_dims)
            logits_list: list of predicted classification logits from a single image,
                each element i is a Tensor sized (N_i, num_classes)
            spatial_size: spatial size of the image

        Return:
            - selected boxes, Tensor sized (P, 2*spatial_dims)
            - selected_scores, Tensor sized (P, )
            - selected_labels, Tensor sized (P, )
        """

        if len(boxes_list) != len(logits_list):
            raise ValueError(
                "len(boxes_list) should equal to len(logits_list). "
                f"Got len(boxes_list)={len(boxes_list)}, len(logits_list)={len(logits_list)}"
            )

        image_boxes = []
        image_scores = []
        image_labels = []

        boxes_dtype = boxes_list[0].dtype
        logits_dtype = logits_list[0].dtype

        for boxes_per_level, logits_per_level in zip(boxes_list, logits_list):
            # select topk boxes for each level
            topk_idxs: Tensor
            topk_idxs, scores_per_level, labels_per_level = self.select_top_score_idx_per_level(logits_per_level)
            boxes_per_level = boxes_per_level[topk_idxs]

            keep: Tensor
            boxes_per_level, keep = clip_boxes_to_image(  # type: ignore
                boxes_per_level, spatial_size, remove_empty=True
            )
            image_boxes.append(boxes_per_level)
            image_scores.append(scores_per_level[keep])
            image_labels.append(labels_per_level[keep])

        image_boxes_t: Tensor = torch.cat(image_boxes, dim=0)
        image_scores_t: Tensor = torch.cat(image_scores, dim=0)
        image_labels_t: Tensor = torch.cat(image_labels, dim=0)

        # non-maximum suppression on detected boxes from all levels
        keep_t: Tensor = batched_nms(  # type: ignore
            image_boxes_t,
            image_scores_t,
            image_labels_t,
            self.nms_thresh,
            box_overlap_metric=self.box_overlap_metric,
            max_proposals=self.detections_per_img,
        )

        selected_boxes = image_boxes_t[keep_t].to(boxes_dtype)
        selected_scores = image_scores_t[keep_t].to(logits_dtype)
        selected_labels = image_labels_t[keep_t]

        return selected_boxes, selected_scores, selected_labels
