o
    iG#                     @  sb   d Z ddlmZ ddlmZ ddlZddlmZ ddlmZm	Z	m
Z
 ddlmZ G dd	 d	ZdS )
z{
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
    )annotations)CallableN)Tensor)batched_nmsbox_iouclip_boxes_to_image)floor_dividec                   @  s<   e Zd ZdZedddddfd!ddZd"ddZd#ddZd S )$BoxSelectora  
    Box selector which selects the predicted boxes.
    The box selection is performed with the following steps:

    #. For each level, discard boxes with scores less than self.score_thresh.
    #. For each level, keep boxes with top self.topk_candidates_per_level scores.
    #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
    #. For the whole image, keep boxes with top self.detections_per_img scores.

    Args:
        apply_sigmoid: whether to apply sigmoid to get scores from classification logits
        score_thresh: no box with scores less than score_thresh will be kept
        topk_candidates_per_level: max number of boxes to keep for each level
        nms_thresh: box overlapping threshold for NMS
        detections_per_img: max number of boxes to keep for each image

    Example:

        .. code-block:: python

            input_param = {
                "apply_sigmoid": True,
                "score_thresh": 0.1,
                "topk_candidates_per_level": 2,
                "nms_thresh": 0.1,
                "detections_per_img": 5,
            }
            box_selector = BoxSelector(**input_param)
            boxes = [torch.randn([3,6]), torch.randn([7,6])]
            logits = [torch.randn([3,3]), torch.randn([7,3])]
            spatial_size = (8,8,8)
            selected_boxes, selected_scores, selected_labels = box_selector.select_boxes_per_image(
                boxes, logits, spatial_size
            )
    Tg?i  g      ?i,  box_overlap_metricr   apply_sigmoidboolscore_threshfloattopk_candidates_per_levelint
nms_threshdetections_per_imgc                 C  s(   || _ || _|| _|| _|| _|| _d S )N)r
   r   r   r   r   r   )selfr
   r   r   r   r   r    r   i/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/utils/box_selector.py__init__X   s   	
zBoxSelector.__init__logitsr   returntuple[Tensor, Tensor, Tensor]c                 C  s   |j d }| jrt|tj }n| }|| jk}|| }t|d }t	| j
|d}|tj|\}}|| }|| }	t||}
|
||	fS )a  
        Select indices with highest scores.

        The indices selection is performed with the following steps:

        #. If self.apply_sigmoid, get scores by applying sigmoid to logits. Otherwise, use logits as scores.
        #. Discard indices with scores less than self.score_thresh
        #. Keep indices with top self.topk_candidates_per_level scores

        Args:
            logits: predicted classification logits, Tensor sized (N, num_classes)

        Return:
            - topk_idxs: selected M indices, Tensor sized (M, )
            - selected_scores: selected M scores, Tensor sized (M, )
            - selected_labels: selected M labels, Tensor sized (M, )
        r   )shaper   torchsigmoidtofloat32flattenr   whereminr   sizetopkr   )r   r   num_classesscores	keep_idxsZflatten_topk_idxsnum_topkselected_scoresidxsselected_labels	topk_idxsr   r   r   select_top_score_idx_per_leveli   s   



z*BoxSelector.select_top_score_idx_per_level
boxes_listlist[Tensor]logits_listspatial_sizelist[int] | tuple[int]c                 C  s"  t |t |krtdt | dt | g }g }g }|d j}|d j}t||D ],\}	}
| |
\}}}|	| }	t|	|dd\}	}||	 |||  |||  q+tj|dd}tj|dd}tj|dd}t	|||| j
| j| jd}|| |}|| |}|| }|||fS )a6  
        Postprocessing to generate detection result from classification logits and boxes.

        The box selection is performed with the following steps:

        #. For each level, discard boxes with scores less than self.score_thresh.
        #. For each level, keep boxes with top self.topk_candidates_per_level scores.
        #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
        #. For the whole image, keep boxes with top self.detections_per_img scores.

        Args:
            boxes_list: list of predicted boxes from a single image,
                each element i is a Tensor sized (N_i, 2*spatial_dims)
            logits_list: list of predicted classification logits from a single image,
                each element i is a Tensor sized (N_i, num_classes)
            spatial_size: spatial size of the image

        Return:
            - selected boxes, Tensor sized (P, 2*spatial_dims)
            - selected_scores, Tensor sized (P, )
            - selected_labels, Tensor sized (P, )
        zFlen(boxes_list) should equal to len(logits_list). Got len(boxes_list)=z, len(logits_list)=r   T)remove_empty)dim)r
   max_proposals)len
ValueErrordtypezipr-   r   appendr   catr   r   r
   r   r   )r   r.   r0   r1   image_boxesimage_scoresimage_labelsZboxes_dtypeZlogits_dtypeboxes_per_levellogits_per_levelr,   scores_per_levellabels_per_levelkeepZimage_boxes_tZimage_scores_tZimage_labels_tkeep_tselected_boxesr)   r+   r   r   r   select_boxes_per_image   sJ   



	
z"BoxSelector.select_boxes_per_imageN)r
   r   r   r   r   r   r   r   r   r   r   r   )r   r   r   r   )r.   r/   r0   r/   r1   r2   r   r   )__name__
__module____qualname____doc__r   r   r-   rF   r   r   r   r   r	   3   s    &
+r	   )rJ   
__future__r   collections.abcr   r   r   monai.data.box_utilsr   r   r   0monai.transforms.utils_pytorch_numpy_unificationr   r	   r   r   r   r   <module>   s   "