o
    i                     @  s<  d Z ddlmZ ddlZddlmZmZ ddlmZ ddl	Z	ddl	m
Z
mZ ddlmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZ ddlmZ ddlmZmZ ddlm Z  ddl!m"Z" ddl#m$Z$ ddl%m&Z&m'Z'm(Z(m)Z) e)ddd\Z*Z+e)ddd\Z,Z+G dd dej-Z.			d*d+d(d)Z/dS ),z{
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
    )annotationsN)CallableSequence)Any)Tensornn)	RetinaNetresnet_fpn_feature_extractor)AnchorGenerator)ATSSMatcher)BoxCoder)BoxSelector)check_training_targetspreprocess_images)HardNegativeSampler)ensure_dict_value_to_list_predict_with_inferer)box_iou)SlidingWindowInferer)resnet)	BlendModePytorchPadModeensure_tuple_repoptional_importz#torchvision.models.detection._utilsBalancedPositiveNegativeSampler)nameMatcherc                
      sF  e Zd ZdZeddddddfd fddZdddZddd Zdd#d$Zdd&d'Z	dd+d,Z
	-ddd2d3Zddd8d9Z	:ddd?d@ZddAdBZddCejdDejdEddddf
ddUdVZ	W	X	C	Y	-ddd_d`Z		dddgdhZdidj ZddodpZddsdtZ	-ddd~dZdddZdddZdddZdddZdddZdddZ  ZS )RetinaNetDetectora  
    Retinanet detector, expandable to other one stage anchor based box detectors in the future.
    An example of construction can found in the source code of
    :func:`~monai.apps.detection.networks.retinanet_detector.retinanet_resnet50_fpn_detector` .

    The input to the model is expected to be a list of tensors, each of shape (C, H, W) or  (C, H, W, D),
    one for each image, and should be in 0-1 range. Different images can have different sizes.
    Or it can also be a Tensor sized (B, C, H, W) or  (B, C, H, W, D). In this case, all images have same size.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:

    - boxes (``FloatTensor[N, 4]`` or ``FloatTensor[N, 6]``): the ground-truth boxes in ``StandardMode``, i.e.,
      ``[xmin, ymin, xmax, ymax]`` or ``[xmin, ymin, zmin, xmax, ymax, zmax]`` format,
      with ``0 <= xmin < xmax <= H``, ``0 <= ymin < ymax <= W``, ``0 <= zmin < zmax <= D``.
    - labels: the class label for each ground-truth box

    The model returns a Dict[str, Tensor] during training, containing the classification and regression
    losses.
    When saving the model, only self.network contains trainable parameters and needs to be saved.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:

    - boxes (``FloatTensor[N, 4]`` or ``FloatTensor[N, 6]``): the predicted boxes in ``StandardMode``, i.e.,
      ``[xmin, ymin, xmax, ymax]`` or ``[xmin, ymin, zmin, xmax, ymax, zmax]`` format,
      with ``0 <= xmin < xmax <= H``, ``0 <= ymin < ymax <= W``, ``0 <= zmin < zmax <= D``.
    - labels (Int64Tensor[N]): the predicted labels for each image
    - labels_scores (Tensor[N]): the scores for each prediction

    Args:
        network: a network that takes an image Tensor sized (B, C, H, W) or (B, C, H, W, D) as input
            and outputs a dictionary Dict[str, List[Tensor]] or Dict[str, Tensor].
        anchor_generator: anchor generator.
        box_overlap_metric: func that compute overlap between two sets of boxes, default is Intersection over Union (IoU).
        debug: whether to print out internal parameters, used for debugging and parameter tuning.

    Notes:

        Input argument ``network`` can be a monai.apps.detection.networks.retinanet_network.RetinaNet(*) object,
        but any network that meets the following rules is a valid input ``network``.

        1. It should have attributes including spatial_dims, num_classes, cls_key, box_reg_key, num_anchors, size_divisible.

            - spatial_dims (int) is the spatial dimension of the network, we support both 2D and 3D.
            - num_classes (int) is the number of classes, excluding the background.
            - size_divisible (int or Sequence[int]) is the expectation on the input image shape.
              The network needs the input spatial_size to be divisible by size_divisible, length should be 2 or 3.
            - cls_key (str) is the key to represent classification in the output dict.
            - box_reg_key (str) is the key to represent box regression in the output dict.
            - num_anchors (int) is the number of anchor shapes at each location. it should equal to
              ``self.anchor_generator.num_anchors_per_location()[0]``.

            If network does not have these attributes, user needs to provide them for the detector.

        2. Its input should be an image Tensor sized (B, C, H, W) or (B, C, H, W, D).

        3. About its output ``head_outputs``, it should be either a list of tensors or a dictionary of str: List[Tensor]:

            - If it is a dictionary, it needs to have at least two keys:
              ``network.cls_key`` and ``network.box_reg_key``, representing predicted classification maps and box regression maps.
              ``head_outputs[network.cls_key]`` should be List[Tensor] or Tensor. Each Tensor represents
              classification logits map at one resolution level,
              sized (B, num_classes*num_anchors, H_i, W_i) or (B, num_classes*num_anchors, H_i, W_i, D_i).
              ``head_outputs[network.box_reg_key]`` should be List[Tensor] or Tensor. Each Tensor represents
              box regression map at one resolution level,
              sized (B, 2*spatial_dims*num_anchors, H_i, W_i)or (B, 2*spatial_dims*num_anchors, H_i, W_i, D_i).
              ``len(head_outputs[network.cls_key]) == len(head_outputs[network.box_reg_key])``.
            - If it is a list of 2N tensors, the first N tensors should be the predicted classification maps,
              and the second N tensors should be the predicted box regression maps.

    Example:

        .. code-block:: python

            # define a naive network
            import torch
            class NaiveNet(torch.nn.Module):
                def __init__(self, spatial_dims: int, num_classes: int):
                    super().__init__()
                    self.spatial_dims = spatial_dims
                    self.num_classes = num_classes
                    self.size_divisible = 2
                    self.cls_key = "cls"
                    self.box_reg_key = "box_reg"
                    self.num_anchors = 1
                def forward(self, images: torch.Tensor):
                    spatial_size = images.shape[-self.spatial_dims:]
                    out_spatial_size = tuple(s//self.size_divisible for s in spatial_size)  # half size of input
                    out_cls_shape = (images.shape[0],self.num_classes*self.num_anchors) + out_spatial_size
                    out_box_reg_shape = (images.shape[0],2*self.spatial_dims*self.num_anchors) + out_spatial_size
                    return {self.cls_key: [torch.randn(out_cls_shape)], self.box_reg_key: [torch.randn(out_box_reg_shape)]}

            # create a RetinaNetDetector detector
            spatial_dims = 3
            num_classes = 5
            anchor_generator = monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape(
                feature_map_scales=(1, ), base_anchor_shapes=((8,) * spatial_dims)
            )
            net = NaiveNet(spatial_dims, num_classes)
            detector = RetinaNetDetector(net, anchor_generator)

            # only detector.network may contain trainable parameters.
            optimizer = torch.optim.SGD(
                detector.network.parameters(),
                1e-3,
                momentum=0.9,
                weight_decay=3e-5,
                nesterov=True,
            )
            torch.save(detector.network.state_dict(), 'model.pt')  # save model
            detector.network.load_state_dict(torch.load('model.pt', weights_only=True))  # load model
    N   classificationbox_regressionFnetwork	nn.Moduleanchor_generatorr
   box_overlap_metricr   spatial_dims
int | Nonenum_classessize_divisibleSequence[int] | intcls_keystrbox_reg_keydebugboolc
                   sR  t    || _| jd|d| _| jd|d| _| jd|d| _t| j| j| _| jd|d| _| jd|d| _	|| _
| j
 d | _| jd| jd}
| j|
kr\td	|
 d
| j dd | _d | _|| _|	| _d | _| tjjdd | jtjjdddddd td| j d| _d| _d| _| jd | _d | _t| jdddddd| _ d S )Nr%   )default_valuer'   r(   r*   r,   r   num_anchorsz Number of feature map channels (z8) should match with number of anchors at each location (z).mean)	reductiongqq?)betar2   TF)	encode_gtdecode_pred)      ?r6   weightsboxeslabels_scores皙?        ?,  )r$   score_threshtopk_candidates_per_level
nms_threshdetections_per_imgapply_sigmoid)!super__init__r!   get_attribute_from_networkr%   r'   r(   r   r*   r,   r#   num_anchors_per_locationnum_anchors_per_loc
ValueErroranchorsprevious_image_shaper$   r-   fg_bg_samplerset_cls_losstorchr   BCEWithLogitsLossset_box_regression_lossSmoothL1Lossr   	box_codertarget_box_keytarget_label_keypred_score_keyinfererr   box_selector)selfr!   r#   r$   r%   r'   r(   r*   r,   r-   Znetwork_num_anchors	__class__ r/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/networks/retinanet_detector.pyrF      sN   

zRetinaNetDetector.__init__c                 C  s4   t | j|rt| j|S |d ur|S td| d)Nz network does not have attribute z$, please provide it in the detector.)hasattrr!   getattrrJ   )rY   	attr_namer/   r\   r\   r]   rG     s
   z,RetinaNetDetector.get_attribute_from_networkr8   tuple[float]returnNonec                 C  s>   t |d| j krtdd| j  d| dt|d| _dS )z
        Set the weights for box coder.

        Args:
            weights: a list/tuple with length of 2*self.spatial_dims

           zlen(weights) should be z, got weights=.r7   N)lenr%   rJ   r   rS   )rY   r8   r\   r\   r]   set_box_coder_weights
  s   z'RetinaNetDetector.set_box_coder_weightsbox_key	label_keyc                 C  s   || _ || _|d | _dS )aB  
        Set keys for the training targets and inference outputs.
        During training, both box_key and label_key should be keys in the targets
        when performing ``self.forward(input_images, targets)``.
        During inference, they will be the keys in the output dict of `self.forward(input_images)``.
        r;   N)rT   rU   rV   )rY   rh   ri   r\   r\   r]   set_target_keys  s   z!RetinaNetDetector.set_target_keyscls_lossc                 C  s
   || _ dS )a  
        Using for training. Set loss for classification that takes logits as inputs, make sure sigmoid/softmax is built in.

        Args:
            cls_loss: loss module for classification

        Example:
            .. code-block:: python

                detector.set_cls_loss(torch.nn.BCEWithLogitsLoss(reduction="mean"))
                detector.set_cls_loss(FocalLoss(reduction="mean", gamma=2.0))
        N)cls_loss_func)rY   rk   r\   r\   r]   rN   !  s   
zRetinaNetDetector.set_cls_lossbox_lossr4   r5   c                 C  s   || _ || _|| _dS )a  
        Using for training. Set loss for box regression.

        Args:
            box_loss: loss module for box regression
            encode_gt: if True, will encode ground truth boxes to target box regression
                before computing the losses. Should be True for L1 loss and False for GIoU loss.
            decode_pred: if True, will decode predicted box regression into predicted boxes
                before computing losses. Should be False for L1 loss and True for GIoU loss.

        Example:
            .. code-block:: python

                detector.set_box_regression_loss(
                    torch.nn.SmoothL1Loss(beta=1.0 / 9, reduction="mean"),
                    encode_gt = True, decode_pred = False
                )
                detector.set_box_regression_loss(
                    monai.losses.giou_loss.BoxGIoULoss(reduction="mean"),
                    encode_gt = False, decode_pred = True
                )
        N)box_loss_funcr4   r5   )rY   rm   r4   r5   r\   r\   r]   rQ   0  s   
z)RetinaNetDetector.set_box_regression_lossTfg_iou_threshfloatbg_iou_threshallow_low_quality_matchesc                 C  s2   ||k rt d| d| dt|||d| _dS )a  
        Using for training. Set torchvision matcher that matches anchors with ground truth boxes.

        Args:
            fg_iou_thresh: foreground IoU threshold for Matcher, considered as matched if IoU > fg_iou_thresh
            bg_iou_thresh: background IoU threshold for Matcher, considered as not matched if IoU < bg_iou_thresh
            allow_low_quality_matches: if True, produce additional matches
                for predictions that have only low-quality match candidates.
        z:Require fg_iou_thresh >= bg_iou_thresh. Got fg_iou_thresh=z, bg_iou_thresh=re   )rr   N)rJ   r   proposal_matcher)rY   ro   rq   rr   r\   r\   r]   set_regular_matcherK  s   z%RetinaNetDetector.set_regular_matcher   num_candidatesintcenter_in_gtc                 C  s   t || j|| jd| _dS )a&  
        Using for training. Set ATSS matcher that matches anchors with ground truth boxes

        Args:
            num_candidates: number of positions to select candidates from.
                Smaller value will result in a higher matcher threshold and less matched candidates.
            center_in_gt: If False (default), matched anchor center points do not need
                to lie within the ground truth box. Recommend False for small objects.
                If True, will result in a strict matcher and less matched candidates.
        )r-   N)r   r$   r-   rs   )rY   rv   rx   r\   r\   r]   set_atss_matcher`  s   z"RetinaNetDetector.set_atss_matcher
   batch_size_per_imagepositive_fractionmin_neg	pool_sizec                 C  s   t ||||d| _dS )a  
        Using for training. Set hard negative sampler that samples part of the anchors for training.

        HardNegativeSampler is used to suppress false positive rate in classification tasks.
        During training, it select negative samples with high prediction scores.

        Args:
            batch_size_per_image: number of elements to be selected per image
            positive_fraction: percentage of positive elements in the selected samples
            min_neg: minimum number of negative samples to select if possible.
            pool_size: when we need ``num_neg`` hard negative samples, they will be randomly selected from
                ``num_neg * pool_size`` negative samples with the highest prediction scores.
                Larger ``pool_size`` gives more randomness, yet selects negative samples that are less 'hard',
                i.e., negative samples with lower prediction scores.
        )r{   r|   r}   r~   N)r   rM   )rY   r{   r|   r}   r~   r\   r\   r]   set_hard_negative_samplerm  s   z+RetinaNetDetector.set_hard_negative_samplerc                 C  s   t ||d| _dS )a  
        Using for training. Set torchvision balanced sampler that samples part of the anchors for training.

        Args:
            batch_size_per_image: number of elements to be selected per image
            positive_fraction: percentage of positive elements per batch

        )r{   r|   N)r   rM   )rY   r{   r|   r\   r\   r]   set_balanced_sampler  s   	z&RetinaNetDetector.set_balanced_samplerr>   g      ?        roi_sizesw_batch_sizeoverlapmodeBlendMode | strsigma_scaleSequence[float] | floatpadding_modePytorchPadMode | strcval	sw_devicetorch.device | str | Nonedeviceprogresscache_roi_weight_mapc                 C  s"   t |||||||||	|
|| _dS )zM
        Define sliding window inferer and store it to self.inferer.
        N)r   rW   )rY   r   r   r   r   r   r   r   r   r   r   r   r\   r\   r]   set_sliding_window_inferer  s   
z,RetinaNetDetector.set_sliding_window_infererr<   r=   r?   r@   rA   rB   rC   rD   c                 C  s   t | j|||||d| _dS )aW  
        Using for inference. Set the parameters that are used for box selection during inference.
        The box selection is performed with the following steps:

        #. For each level, discard boxes with scores less than self.score_thresh.
        #. For each level, keep boxes with top self.topk_candidates_per_level scores.
        #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
        #. For the whole image, keep boxes with top self.detections_per_img scores.

        Args:
            score_thresh: no box with scores less than score_thresh will be kept
            topk_candidates_per_level: max number of boxes to keep for each level
            nms_thresh: box overlapping threshold for NMS
            detections_per_img: max number of boxes to keep for each image
        )r$   rD   r@   rA   rB   rC   N)r   r$   rX   )rY   r@   rA   rB   rC   rD   r\   r\   r]   set_box_selector_parameters  s   z-RetinaNetDetector.set_box_selector_parametersinput_imageslist[Tensor] | Tensortargetslist[dict[str, Tensor]] | Noneuse_inferer+dict[str, Tensor] | list[dict[str, Tensor]]c                 C  sF  | j rt||| j| j| j}|   t|| j| j\}}| j s!|sQ| |}t	|t
tfrLi }|dt|d  || j< |t|d d || j< |}nt| n| jdu rZtdt|| j| j| jg| jd}| || dd || j D }| j| jfD ]}	| ||	 ||	< q~| j r| ||| j|}
|
S | || j||}|S )a  
        Returns a dict of losses during training, or a list predicted dict of boxes and labels during inference.

        Args:
            input_images: The input to the model is expected to be a list of tensors, each of shape (C, H, W) or  (C, H, W, D),
                one for each image, and should be in 0-1 range. Different images can have different sizes.
                Or it can also be a Tensor sized (B, C, H, W) or  (B, C, H, W, D). In this case, all images have same size.
            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,
                ground-truth boxes present in the image (optional).
            use_inferer: whether to use self.inferer, a sliding window inferer, to do the inference.
                If False, will simply forward the network.
                If True, will use self.inferer, and requires
                ``self.set_sliding_window_inferer(*args)`` to have been called before.

        Return:
            If training mode, will return a dict with at least two keys,
            including self.cls_key and self.box_reg_key, representing classification loss and box regression loss.

            If evaluation mode, will return a list of detection results.
            Each element corresponds to an images in ``input_images``, is a dict with at least three keys,
            including self.target_box_key, self.target_label_key, self.pred_score_key,
            representing predicted boxes, classification labels, and classification scores.

        Nrd   zZ`self.inferer` is not defined.Please refer to function self.set_sliding_window_inferer(*).)keysrW   c                 S  s   g | ]}|j d d  qS )rd   N)shapenumel).0xr\   r\   r]   
<listcomp>  s    z-RetinaNetDetector.forward.<locals>.<listcomp>)trainingr   r%   rU   rT   #_check_detector_training_componentsr   r(   r!   
isinstancetuplelistrf   r*   r,   r   rW   rJ   r   generate_anchors_reshape_mapscompute_lossrK   postprocess_detections)rY   r   r   r   imagesimage_sizeshead_outputsZtmp_dictnum_anchor_locs_per_levelkeylosses
detectionsr\   r\   r]   forward  s@   




zRetinaNetDetector.forwardc                 C  s8   t | ds	td| jdu r| jrtd dS dS dS )zc
        Check if self.proposal_matcher and self.fg_bg_sampler have been set for training.
        rs   z\Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*).Na  No balanced sampler is used. Negative samples are likely to be much more than positive samples. Please set balanced samplers with self.set_balanced_sampler(*) or self.set_hard_negative_sampler(*), or set classification loss function as Focal loss with self.set_cls_loss(*))r^   AttributeErrorrM   r-   warningswarnrY   r\   r\   r]   r   (  s   
z5RetinaNetDetector._check_detector_training_componentsr   r   r   dict[str, list[Tensor]]c                 C  s:   | j du s| j|jkr| ||| j | _ |j| _dS dS )aA  
        Generate anchors and store it in self.anchors: List[Tensor].
        We generate anchors only when there is no stored anchors,
        or the new coming images has different shape with self.previous_image_shape

        Args:
            images: input images, a (B, C, H, W) or (B, C, H, W, D) Tensor.
            head_outputs: head_outputs. ``head_output_reshape[self.cls_key]`` is a Tensor
              sized (B, sum(HW(D)A), self.num_classes). ``head_output_reshape[self.box_reg_key]`` is a Tensor
              sized (B, sum(HW(D)A), 2*self.spatial_dims)
        N)rK   rL   r   r#   r*   )rY   r   r   r\   r\   r]   r   8  s   z"RetinaNetDetector.generate_anchorsresult_mapslist[Tensor]c           	   	   C  s   g }|D ]n}|j d }|j d | j }|j | j d }|d|f| }||}| jdkr7|ddddd}n| jdkrG|dddddd}ntd	||d|}t|	 s`t
|	 rmt rhtd
td
 || qtj|ddS )a  
        Concat network output map list to a single Tensor.
        This function is used in both training and inference.

        Args:
            result_maps: a list of Tensor, each Tensor is a (B, num_channel*A, H, W) or (B, num_channel*A, H, W, D) map.
                A = self.num_anchors_per_loc

        Return:
            reshaped and concatenated result, sized (B, sum(HWA), num_channel) or (B, sum(HWDA), num_channel)
        r   r   Nrd      ru      zImages can only be 2D or 3D.z"Concatenated result is NaN or Inf.dim)r   rI   r%   viewpermuterJ   reshaperO   isnananyisinfis_grad_enabledr   r   appendcat)	rY   r   Zall_reshaped_result_map
result_map
batch_sizeZnum_channelspatial_size
view_shapeZreshaped_result_mapr\   r\   r]   r   H  s&   




zRetinaNetDetector._reshape_mapshead_outputs_reshapedict[str, Tensor]rK   r   list[list[int]]r   Sequence[int]need_sigmoidlist[dict[str, Tensor]]c              	     s
  fdd|D i }|D ]}t || jdd||< qfdd|D }|j }	|j }
|	d j t|}g }t|D ]Cfdd|
D }fdd|	D }| | }} fd	dt||D }j	|||\}}}|
j|j|j|i q?|S )
a  
        Postprocessing to generate detection result from classification logits and box regression.
        Use self.box_selector to select the final output boxes for each image.

        Args:
            head_outputs_reshape: reshaped head_outputs. ``head_output_reshape[self.cls_key]`` is a Tensor
              sized (B, sum(HW(D)A), self.num_classes). ``head_output_reshape[self.box_reg_key]`` is a Tensor
              sized (B, sum(HW(D)A), 2*self.spatial_dims)
            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,
                ground-truth boxes present in the image.
            anchors: a list of Tensor. Each Tensor represents anchors for each image,
                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).
                A = self.num_anchors_per_loc.

        Return:
            a list of dict, each dict corresponds to detection result on image.
        c                   s   g | ]}| j  qS r\   )rI   )r   Znum_anchor_locsr   r\   r]   r     s    
z<RetinaNetDetector.postprocess_detections.<locals>.<listcomp>r   r   c                   s   g | ]	}t | qS r\   )r   split)r   a)num_anchors_per_levelr\   r]   r     s    r   c                      g | ]}|  qS r\   r\   )r   brindexr\   r]   r     s    c                   r   r\   r\   )r   clr   r\   r]   r     s    c                   s,   g | ]\}}j |tj| qS r\   )rS   decode_singletorO   float32)r   br   )compute_dtyperY   r\   r]   r     s    )r   r   r*   r,   dtyperf   rangeziprX   Zselect_boxes_per_imager   rT   rV   rU   )rY   r   rK   r   r   r   split_head_outputsksplit_anchorsclass_logitsr    
num_imagesr   box_regression_per_imagelogits_per_imageanchors_per_imageZimg_spatial_sizeboxes_per_imageZselected_boxesZselected_scoresZselected_labelsr\   )r   r   r   rY   r]   r   u  s>   





z(RetinaNetDetector.postprocess_detectionsc                 C  sH   |  |||}| || j ||}| || j |||}| j|| j|iS )a  
        Compute losses.

        Args:
            head_outputs_reshape: reshaped head_outputs. ``head_output_reshape[self.cls_key]`` is a Tensor
              sized (B, sum(HW(D)A), self.num_classes). ``head_output_reshape[self.box_reg_key]`` is a Tensor
              sized (B, sum(HW(D)A), 2*self.spatial_dims)
            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,
                ground-truth boxes present in the image.
            anchors: a list of Tensor. Each Tensor represents anchors for each image,
                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).
                A = self.num_anchors_per_loc.

        Return:
            a dict of several kinds of losses.
        )compute_anchor_matched_idxscompute_cls_lossr*   compute_box_lossr,   )rY   r   r   rK   r   matched_idxsZ
losses_clsZlosses_box_regressionr\   r\   r]   r     s   zRetinaNetDetector.compute_lossc           	   	   C  s  g }t ||D ]\}}|| j  dkr'|tj|dfdtj|jd qt	| j
tr@| || j |j|}| 
|}nt	| j
trY| 
|| j |j||| j\}}ntd| jrotdtj|ddd  d t|dk rtd	|| j  d || q|S )
a  
        Compute the matched indices between anchors and ground truth (gt) boxes in targets.
        output[k][i] represents the matched gt index for anchor[i] in image k.
        Suppose there are M gt boxes for image k. The range of it output[k][i] value is [-2, -1, 0, ..., M-1].
        [0, M - 1] indicates this anchor is matched with a gt box,
        while a negative value indicating that it is not matched.

        Args:
            anchors: a list of Tensor. Each Tensor represents anchors for each image,
                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).
                A = self.num_anchors_per_loc.
            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,
                ground-truth boxes present in the image.
            num_anchor_locs_per_level: each element represents HW or HWD at this level.


        Return:
            a list of matched index `matched_idxs_per_image` (Tensor[int64]), Tensor sized (sum(HWA),) or (sum(HWDA),).
            Suppose there are M gt boxes. `matched_idxs_per_image[i]` is a matched gt index in [0, M - 1]
            or a negative value indicating that anchor i could not be matched.
            BELOW_LOW_THRESHOLD = -1, BETWEEN_THRESHOLDS = -2
        r   r   )r   r   zCurrently support torchvision Matcher and monai ATSS matcher. Other types of matcher not supported. Please override self.compute_anchor_matched_idxs(*) for your own matcher.z.Max box overlap between anchors and gt boxes: r   r   re   zNo anchor is matched with GT boxes. Please adjust matcher setting, anchor setting, or the network setting to change zoom scale between network output and input images.GT boxes are )r   rT   r   r   rO   fullsizeint64r   r   rs   r   r$   r   r   rI   NotImplementedErrorr-   printmaxr   r   )	rY   rK   r   r   r   r   targets_per_imagematch_quality_matrixmatched_idxs_per_imager\   r\   r]   r     s@   
z-RetinaNetDetector.compute_anchor_matched_idxs
cls_logitsr   c                 C  sz   g }g }t |||D ]\}}}| |||\}	}
||	 ||
 q
tj|dd}tj|dd}| |||j}|S )a  
        Compute classification losses.

        Args:
            cls_logits: classification logits, sized (B, sum(HW(D)A), self.num_classes)
            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,
                ground-truth boxes present in the image.
            matched_idxs: a list of matched index. each element is sized (sum(HWA),) or  (sum(HWDA),)

        Return:
            classification losses.
        r   r   )r   get_cls_train_sample_per_imager   rO   r   rl   r   r   )rY   r   r   r   Ztotal_cls_logits_listZtotal_gt_classes_target_listr   cls_logits_per_imager   Zsampled_cls_logits_per_imageZsampled_gt_classes_targetZtotal_cls_logitsZtotal_gt_classes_targetr   r\   r\   r]   r   "  s   
z"RetinaNetDetector.compute_cls_lossc                 C  s   g }g }t ||||D ]\}}}	}
| |||	|
\}}|| || qtj|dd}tj|dd}|jd dkrBtd}|S | |||j	}|S )a  
        Compute box regression losses.

        Args:
            box_regression: box regression results, sized (B, sum(HWA), 2*self.spatial_dims)
            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,
                ground-truth boxes present in the image.
            anchors: a list of Tensor. Each Tensor represents anchors for each image,
                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).
                A = self.num_anchors_per_loc.
            matched_idxs: a list of matched index. each element is sized (sum(HWA),) or  (sum(HWDA),)

        Return:
            box regression losses.
        r   r   r   )
r   get_box_train_sample_per_imager   rO   r   r   tensorrn   r   r   )rY   r    r   rK   r   Ztotal_box_regression_listZtotal_target_regression_listr   r   r   r   Zdecode_box_regression_per_imagematched_gt_boxes_per_imageZtotal_box_regressionZtotal_target_regressionr   r\   r\   r]   r   @  s"   

z"RetinaNetDetector.compute_box_lossr   r   r   tuple[Tensor, Tensor]c                 C  s  t | st | rt  rtdtd |dk}t|	 }|| j
 jd }| jrPtd| d| d |dkrP|d| k rPtd| d| d	 t |}d
|||| j ||  f< | jdu rn|| jjk}nUt| jtrt j|t jddd }	| |d g|	\}
}nt| jtr| |d g\}
}ntdt t j|
ddd }t t j|ddd }t j||gdd}||ddf ||ddf fS )a;  
        Get samples from one image for classification losses computation.

        Args:
            cls_logits_per_image: classification logits for one image, (sum(HWA), self.num_classes)
            targets_per_image: a dict with at least two keys: self.target_box_key and self.target_label_key,
                ground-truth boxes present in the image.
            matched_idxs_per_image: matched index, Tensor sized (sum(HWA),) or (sum(HWDA),)
                Suppose there are M gt boxes. matched_idxs_per_image[i] is a matched gt index in [0, M - 1]
                or a negative value indicating that anchor i could not be matched.
                BELOW_LOW_THRESHOLD = -1, BETWEEN_THRESHOLDS = -2

        Return:
            paired predicted and GT samples from one image for classification losses computation
        z.NaN or Inf in predicted classification logits.r   z&Number of positive (matched) anchors: z; Number of GT box: re   rd   zOnly z anchors are matched with z GT boxes. Please consider adjusting matcher setting, anchor setting, or the network setting to change zoom scale between network output and input images.r6   Nr   r   zCurrently support torchvision BalancedPositiveNegativeSampler and monai HardNegativeSampler matcher. Other types of sampler not supported. Please override self.get_cls_train_sample_per_image(*) for your own sampler.)rO   r   r   r   r   rJ   r   r   rw   sumrT   r   r-   r   
zeros_likerU   rM   rs   BETWEEN_THRESHOLDSr   r   r   r   r   r   r   wherer   )rY   r   r   r   foreground_idxs_per_imagenum_foreground
num_gt_boxgt_classes_targetvalid_idxs_per_imageZmax_cls_logits_per_imageZsampled_pos_inds_listZsampled_neg_inds_listsampled_pos_indssampled_neg_indsr\   r\   r]   r   o  sL   


	

 z0RetinaNetDetector.get_cls_train_sample_per_imager   r   c           
      C  s   t | st | rt  rtdtd t |dkd }|| j	 j
d }|dkrD|ddddf |ddddf fS || j	 ||  |j}||ddf }||ddf }|}|}	| jro| j||}| jry| j|	|}	|	|fS )a  
        Get samples from one image for box regression losses computation.

        Args:
            box_regression_per_image: box regression result for one image, (sum(HWA), 2*self.spatial_dims)
            targets_per_image: a dict with at least two keys: self.target_box_key and self.target_label_key,
                ground-truth boxes present in the image.
            anchors_per_image: anchors of one image,
                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).
                A = self.num_anchors_per_loc.
            matched_idxs_per_image: matched index, sized (sum(HWA),) or  (sum(HWDA),)

        Return:
            paired predicted and GT samples from one image for box regression losses computation
        z'NaN or Inf in predicted box regression.r   N)rO   r   r   r   r   rJ   r   r   r   rT   r   r   r   r4   rS   encode_singler5   r   )
rY   r   r   r   r   r   r  r   Zmatched_gt_boxes_per_image_Zbox_regression_per_image_r\   r\   r]   r     s,   
(z0RetinaNetDetector.get_box_train_sample_per_image)r!   r"   r#   r
   r$   r   r%   r&   r'   r&   r(   r)   r*   r+   r,   r+   r-   r.   )N)r8   ra   rb   rc   )rh   r+   ri   r+   rb   rc   )rk   r"   rb   rc   )rm   r"   r4   r.   r5   r.   rb   rc   )T)ro   rp   rq   rp   rr   r.   rb   rc   )ru   F)rv   rw   rx   r.   rb   rc   )r   rz   )
r{   rw   r|   rp   r}   rw   r~   rp   rb   rc   )r{   rw   r|   rp   rb   rc   )r   r)   r   rw   r   rp   r   r   r   r   r   r   r   rp   r   r   r   r   r   r.   r   r.   rb   rc   )r<   r=   r>   r?   T)r@   rp   rA   rw   rB   rp   rC   rw   rD   r.   rb   rc   )NF)r   r   r   r   r   r.   rb   r   )r   r   r   r   rb   rc   )r   r   rb   r   )r   r   rK   r   r   r   r   r   r   r.   rb   r   )
r   r   r   r   rK   r   r   r   rb   r   )rK   r   r   r   r   r   rb   r   )r   r   r   r   r   r   rb   r   )
r    r   r   r   rK   r   r   r   rb   r   )r   r   r   r   r   r   rb   r   )
r   r   r   r   r   r   r   r   rb   r   ) __name__
__module____qualname____doc__r   rF   rG   rg   rj   rN   rQ   rt   ry   r   r   r   CONSTANTr   r   r   r   r   r   r   r   r   r   r   r   r   r   __classcell__r\   r\   rZ   r]   r   D   sh    y
H




!$U

3
G

H

/Or   r   rd   r   FTr'   rw   r#   r
   returned_layersr   
pretrainedr.   r   kwargsr   rb   c                   sr   t j||fi |}t|jj}t|||d d}| d }	 fdd|jjjD }
t|| |	||
d}t	||S )aX  
    Returns a RetinaNet detector using a ResNet-50 as backbone, which can be pretrained
    from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`
    _.

    Args:
        num_classes: number of output classes of the model (excluding the background).
        anchor_generator: AnchorGenerator,
        returned_layers: returned layers to extract feature maps. Each returned layer should be in the range [1,4].
            len(returned_layers)+1 will be the number of extracted feature maps.
            There is an extra maxpooling layer LastLevelMaxPool() appended.
        pretrained: If True, returns a backbone pre-trained on 23 medical datasets
        progress: If True, displays a progress bar of the download to stderr

    Return:
        A RetinaNetDetector object with resnet50 as backbone

    Example:

        .. code-block:: python

            # define a naive network
            resnet_param = {
                "pretrained": False,
                "spatial_dims": 3,
                "n_input_channels": 2,
                "num_classes": 3,
                "conv1_t_size": 7,
                "conv1_t_stride": (2, 2, 2)
            }
            returned_layers = [1]
            anchor_generator = monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape(
                feature_map_scales=(1, 2), base_anchor_shapes=((8,) * resnet_param["spatial_dims"])
            )
            detector = retinanet_resnet50_fpn_detector(
                **resnet_param, anchor_generator=anchor_generator, returned_layers=returned_layers
            )
    N)backboner%   pretrained_backbonetrainable_backbone_layersr  r   c                   s    g | ]}|d  d t    qS )rd   )r   )r   sr  r\   r]   r   1  s     z3retinanet_resnet50_fpn_detector.<locals>.<listcomp>)r%   r'   r0   feature_extractorr(   )
r   resnet50rf   conv1strider	   rH   bodyr   r   )r'   r#   r  r  r   r  r  r%   r  r0   r(   r!   r\   r  r]   retinanet_resnet50_fpn_detector  s&   /
r  )r  FT)r'   rw   r#   r
   r  r   r  r.   r   r.   r  r   rb   r   )0r  
__future__r   r   collections.abcr   r   typingr   rO   r   r   Z/monai.apps.detection.networks.retinanet_networkr   r	   Z'monai.apps.detection.utils.anchor_utilsr
   Z'monai.apps.detection.utils.ATSS_matcherr   Z$monai.apps.detection.utils.box_coderr   Z'monai.apps.detection.utils.box_selectorr   Z)monai.apps.detection.utils.detector_utilsr   r   Z0monai.apps.detection.utils.hard_negative_samplerr   Z(monai.apps.detection.utils.predict_utilsr   r   monai.data.box_utilsr   monai.inferersr   monai.networks.netsr   monai.utilsr   r   r   r   r   _r   Moduler   r  r\   r\   r\   r]   <module>   sF   "
       =