o
    i4                     @  s   d Z ddlmZ ddlZddlmZmZ ddlmZm	Z	 ddl
mZ ddlZddlmZ ddlmZmZmZmZ dd	lmZ ed
ZG dd deZG dd deZededZdS )a  
The functions in this script are adapted from nnDetection,
https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py
which is adapted from torchvision.

These are the changes compared with nndetection:
1) comments and docstrings;
2) reformat;
3) add a debug option to ATSSMatcher to help the users to tune parameters;
4) add a corner case return in ATSSMatcher.compute_matches;
5) add support for float16 cpu
    )annotationsN)ABCabstractmethod)CallableSequence)TypeVar)Tensor)COMPUTE_DTYPEbox_iouboxes_center_distancecenters_in_boxes)convert_to_tensorinfc                   @  sP   e Zd ZU dZdZded< dZded< efdd	d
ZdddZ	e
dddZdS )Matcherz
    Base class of Matcher, which matches boxes and anchors to each other

    Args:
        similarity_fn: function for similarity computation between
            boxes and anchors
    intBELOW_LOW_THRESHOLDBETWEEN_THRESHOLDSsimilarity_fn"Callable[[Tensor, Tensor], Tensor]c                 C  s
   || _ d S )Nr   )selfr    r   i/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/utils/ATSS_matcher.py__init__h   s   
zMatcher.__init__boxestorch.Tensoranchorsnum_anchors_per_levelSequence[int]num_anchors_per_locreturn!tuple[torch.Tensor, torch.Tensor]c                 C  sX   |  dkr#|jd }tg |}tj|tjd| j}||fS | j	||||dS )a  
        Compute matches for a single image

        Args:
            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            - matrix which contains the similarity from each boxes to each anchor [N, M]
            - vector which contains the matched box index for all
                anchors (if background `BELOW_LOW_THRESHOLD` is used
                and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]

        Note:
            ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,
            also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D
            and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D.
        r   )dtype)r   r   r   r!   )
numelshapetorchtensortoemptyint64fill_r   compute_matches)r   r   r   r   r!   num_anchorsmatch_quality_matrixmatchesr   r   r   __call__k   s   
zMatcher.__call__c                 C  s   t )a  
        Compute matches

        Args:
            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            - matrix which contains the similarity from each boxes to each anchor [N, M]
            - vector which contains the matched box index for all
              anchors (if background `BELOW_LOW_THRESHOLD` is used
              and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]
        )NotImplementedError)r   r   r   r   r!   r   r   r   r-      s   zMatcher.compute_matchesN)r   r   
r   r   r   r   r   r    r!   r   r"   r#   )__name__
__module____qualname____doc__r   __annotations__r   r
   r   r1   r   r-   r   r   r   r   r   \   s   
 
%r   c                      s2   e Zd Zdeddfd fddZdddZ  ZS )ATSSMatcher   TFnum_candidatesr   r   r   center_in_gtbooldebugc                   sF   t  j|d || _d| _|| _|| _td| j d| j d dS )an  
        Compute matching based on ATSS https://arxiv.org/abs/1912.02424
        `Bridging the Gap Between Anchor-based and Anchor-free Detection
        via Adaptive Training Sample Selection`

        Args:
            num_candidates: number of positions to select candidates from.
                Smaller value will result in a higher matcher threshold and less matched candidates.
            similarity_fn: function for similarity computation between boxes and anchors
            center_in_gt: If False (default), matched anchor center points do not need
                to lie withing the ground truth box. Recommend False for small objects.
                If True, will result in a strict matcher and less matched candidates.
            debug: if True, will print the matcher threshold in order to
                tune ``num_candidates`` and ``center_in_gt``.
        r   g{Gz?z*Running ATSS Matching with num_candidates=z and center_in_gt .N)superr   r;   min_distr<   r>   logginginfo)r   r;   r   r<   r>   	__class__r   r   r      s   zATSSMatcher.__init__r   r   r   r   r    r!   r"   r#   c                  C  sJ  |j d }|j d }t||\}}}	t|}
g }d}t|D ]0\}}|||  }t| j| |}|
dd||f tj|ddd\}}|	||  |}qt
j|dd}| ||}|d|}|j d dkr~dt
j|ft
j|jd }d||< ||fS |jdd}|jdd}|| }||dddf k}| jrtd	|  | jrt
j||jt
jd
dddf | }t|	|d ||d | jd}t|}|||@ }t|D ]}||ddf  || 7  < qt
|t d}|d|d }|d| ||< ||}|tj dd\}}| j!||t k< ||fS )aa  
        Compute matches according to ATTS for a single image
        Adapted from
        (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss/loss.py#L180-L184)

        Args:
            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            - matrix which contains the similarity from each boxes to each anchor [N, M]
            - vector which contains the matched box index for all
              anchors (if background `BELOW_LOW_THRESHOLD` is used
              and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]

        Note:
            ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,
            also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D
            and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D.
        r   N   F)dimlargest)rG   r   )r$   devicezAnchor matcher threshold: )rI   r$   )eps)"r&   r   r   	enumerateminr;   r)   r	   topkappendr'   catr   gatheroneslongrI   meanstdr>   printr<   arange	expand_as
contiguousr   viewrA   view_asrange	full_likeINFmaxr   ) r   r   r   r   r!   num_gtr.   
distances__Zanchors_center	distancesZcandidate_idx_list	start_idxZaplend_idxrM   idxZcandidate_idxr/   Zcandidate_iousr0   Ziou_mean_per_gtZiou_std_per_gtZiou_thresh_per_gtis_posZ	boxes_idxZ	is_in_gt_Zis_in_gtngZious_infindexmatched_valsr   r   r   r-      sX   

*
zATSSMatcher.compute_matches)r;   r   r   r   r<   r=   r>   r=   r3   )r4   r5   r6   r
   r   r-   __classcell__r   r   rD   r   r9      s    r9   MatcherType)bound)r7   
__future__r   rB   abcr   r   collections.abcr   r   typingr   r'   r   monai.data.box_utilsr	   r
   r   r   monai.utils.type_conversionr   floatr]   r   r9   rk   r   r   r   r   <module>   s   =J}