U
    Ph4                     @  s   d Z ddlmZ ddlZddlmZmZ ddlmZm	Z	 ddl
mZ ddlZddlmZ ddlmZmZmZmZ dd	lmZ ed
ZG dd deZG dd deZededZdS )a  
The functions in this script are adapted from nnDetection,
https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py
which is adapted from torchvision.

These are the changes compared with nndetection:
1) comments and docstrings;
2) reformat;
3) add a debug option to ATSSMatcher to help the users to tune parameters;
4) add a corner case return in ATSSMatcher.compute_matches;
5) add support for float16 cpu
    )annotationsN)ABCabstractmethod)CallableSequence)TypeVar)Tensor)COMPUTE_DTYPEbox_iouboxes_center_distancecenters_in_boxes)convert_to_tensorinfc                   @  sl   e Zd ZU dZdZded< dZded< efddd	d
ZddddddddZ	e
ddddddddZdS )Matcherz
    Base class of Matcher, which matches boxes and anchors to each other

    Args:
        similarity_fn: function for similarity computation between
            boxes and anchors
    intBELOW_LOW_THRESHOLDBETWEEN_THRESHOLDS"Callable[[Tensor, Tensor], Tensor]similarity_fnc                 C  s
   || _ d S )Nr   )selfr    r   \/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/detection/utils/ATSS_matcher.py__init__h   s    zMatcher.__init__torch.TensorSequence[int]!tuple[torch.Tensor, torch.Tensor]boxesanchorsnum_anchors_per_levelnum_anchors_per_locreturnc                 C  sX   |  dkrF|jd }tg |}tj|tjd| j}||fS | j	||||dS )a  
        Compute matches for a single image

        Args:
            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            - matrix which contains the similarity from each boxes to each anchor [N, M]
            - vector which contains the matched box index for all
                anchors (if background `BELOW_LOW_THRESHOLD` is used
                and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]

        Note:
            ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,
            also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D
            and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D.
        r   )dtype)r    r!   r"   r#   )
numelshapetorchtensortoemptyint64fill_r   compute_matches)r   r    r!   r"   r#   num_anchorsmatch_quality_matrixmatchesr   r   r   __call__k   s    
zMatcher.__call__c                 C  s   t dS )a  
        Compute matches

        Args:
            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            - matrix which contains the similarity from each boxes to each anchor [N, M]
            - vector which contains the matched box index for all
              anchors (if background `BELOW_LOW_THRESHOLD` is used
              and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]
        N)NotImplementedError)r   r    r!   r"   r#   r   r   r   r.      s    zMatcher.compute_matchesN)__name__
__module____qualname____doc__r   __annotations__r   r
   r   r2   r   r.   r   r   r   r   r   \   s   
%r   c                      sH   e Zd Zdeddfddddd fdd	Zd
d
ddddddZ  ZS )ATSSMatcher   TFr   r   bool)num_candidatesr   center_in_gtdebugc                   sF   t  j|d || _d| _|| _|| _td| j d| j d dS )an  
        Compute matching based on ATSS https://arxiv.org/abs/1912.02424
        `Bridging the Gap Between Anchor-based and Anchor-free Detection
        via Adaptive Training Sample Selection`

        Args:
            num_candidates: number of positions to select candidates from.
                Smaller value will result in a higher matcher threshold and less matched candidates.
            similarity_fn: function for similarity computation between boxes and anchors
            center_in_gt: If False (default), matched anchor center points do not need
                to lie withing the ground truth box. Recommend False for small objects.
                If True, will result in a strict matcher and less matched candidates.
            debug: if True, will print the matcher threshold in order to
                tune ``num_candidates`` and ``center_in_gt``.
        r   g{Gz?z*Running ATSS Matching with num_candidates=z and center_in_gt .N)superr   r<   min_distr=   r>   logginginfo)r   r<   r   r=   r>   	__class__r   r   r      s    zATSSMatcher.__init__r   r   r   r   c                  C  sP  |j d }|j d }t||\}}}	t|}
g }d}t|D ]`\}}|||  }t| j| |}|
dd||f tj|ddd\}}|	||  |}q<t
j|dd}| ||}|d|}|j d dkrdt
j|ft
j|jd }d||< ||fS |jdd}|jdd}|| }||dddf k}| jrFtd	|  | jrt
j||jt
jd
dddf | }t|	|d ||d | jd}t|}|||@ }t|D ]"}||ddf  || 7  < qt
|t d}|d|d }|d| ||< ||}|tj dd\}}| j!||t k< ||fS )aa  
        Compute matches according to ATTS for a single image
        Adapted from
        (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss/loss.py#L180-L184)

        Args:
            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            - matrix which contains the similarity from each boxes to each anchor [N, M]
            - vector which contains the matched box index for all
              anchors (if background `BELOW_LOW_THRESHOLD` is used
              and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]

        Note:
            ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,
            also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D
            and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D.
        r   N   F)dimlargest)rG   r   )r%   devicezAnchor matcher threshold: )rI   r%   )eps)"r'   r   r   	enumerateminr<   r*   r	   topkappendr(   catr   gatheroneslongrI   meanstdr>   printr=   arange	expand_as
contiguousr   viewrA   view_asrange	full_likeINFmaxr   ) r   r    r!   r"   r#   num_gtr/   
distances__Zanchors_center	distancesZcandidate_idx_list	start_idxZaplend_idxrM   idxZcandidate_idxr0   Zcandidate_iousr1   Ziou_mean_per_gtZiou_std_per_gtZiou_thresh_per_gtis_posZ	boxes_idxZ	is_in_gt_Zis_in_gtngZious_infindexmatched_valsr   r   r   r.      s\    

*    
zATSSMatcher.compute_matches)r4   r5   r6   r
   r   r.   __classcell__r   r   rD   r   r9      s   r9   MatcherType)bound)r7   
__future__r   rB   abcr   r   collections.abcr   r   typingr   r(   r   monai.data.box_utilsr	   r
   r   r   monai.utils.type_conversionr   floatr]   r   r9   rk   r   r   r   r   <module>>   s   J}