U
    PhdJ                     @  s   d Z ddlmZ ddlZddlZddlmZmZ ddlm	Z	m
Z
 ddlZddlmZmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZmZmZ eddd\ZZG dd dejZG dd dejZG dd dejZddddddddddZdS ) z{
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
    )annotationsN)CallableSequence)AnyDict)Tensornn)BackboneWithFPN_resnet_fpn_extractor)Conv)resnet)ensure_tuple_replook_up_optionoptional_importz+torchvision.models.detection.backbone_utils_validate_trainable_layers)namec                      s@   e Zd ZdZddddddd fddZddd	d
dZ  ZS )RetinaNetClassificationHeada  
    A classification head for use in RetinaNet.

    This head takes a list of feature maps as inputs, and outputs a list of classification maps.
    Each output map has same spatial size with the corresponding input feature map,
    and the number of output channel is num_anchors * num_classes.

    Args:
        in_channels: number of channels of the input feature
        num_anchors: number of anchors to be predicted
        num_classes: number of classes to be predicted
        spatial_dims: spatial dimension of the network, should be 2 or 3.
        prior_probability: prior probability to initialize classification convolutional layers.
    {Gz?intfloat)in_channelsnum_anchorsnum_classesspatial_dimsprior_probabilityc           
   
     s  t    ttj|f }g }tdD ]>}||||dddd |tjd|d |t  q$tj	| | _
| j
 D ]4}	t|	|rztjjj|	jdd tjj|	jd	 qz|||| dddd| _tjjj| jjdd tjj| jjtd| |   || _|| _d S )
N         kernel_sizestridepadding   
num_groupsnum_channelsr   stdr   )super__init__r   CONVrangeappendr   	GroupNormReLU
Sequentialconvchildren
isinstancetorchinitnormal_weight	constant_bias
cls_logitsmathlogr   r   )
selfr   r   r   r   r   	conv_typer0   _layer	__class__ d/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/detection/networks/retinanet_network.pyr)   L   s"    

$z$RetinaNetClassificationHead.__init__list[Tensor]xreturnc                 C  s~   g }t |tr|g}n|}|D ]Z}| |}| |}|| t| s\t| rt	 rnt
dqtd q|S )ai  
        It takes a list of feature maps as inputs, and outputs a list of classification maps.
        Each output classification map has same spatial size with the corresponding input feature map,
        and the number of output channel is num_anchors * num_classes.

        Args:
            x: list of feature map, x[i] is a (B, in_channels, H_i, W_i) or (B, in_channels, H_i, W_i, D_i) Tensor.

        Return:
            cls_logits_maps, list of classification map. cls_logits_maps[i] is a
            (B, num_anchors * num_classes, H_i, W_i) or (B, num_anchors * num_classes, H_i, W_i, D_i) Tensor.

        zcls_logits is NaN or Inf.)r2   r   r0   r9   r,   r3   isnananyisinfis_grad_enabled
ValueErrorwarningswarn)r<   rF   Zcls_logits_mapsfeature_mapsfeaturesr9   rB   rB   rC   forwarde   s    




z#RetinaNetClassificationHead.forward)r   __name__
__module____qualname____doc__r)   rQ   __classcell__rB   rB   r@   rC   r   <   s    r   c                      s:   e Zd ZdZdddd fddZddddd	Z  ZS )
RetinaNetRegressionHeada  
    A regression head for use in RetinaNet.

    This head takes a list of feature maps as inputs, and outputs a list of box regression maps.
    Each output box regression map has same spatial size with the corresponding input feature map,
    and the number of output channel is num_anchors * 2 * spatial_dims.

    Args:
        in_channels: number of channels of the input feature
        num_anchors: number of anchors to be predicted
        spatial_dims: spatial dimension of the network, should be 2 or 3.
    r   )r   r   r   c              
     s   t    ttj|f }g }tdD ]>}||||dddd |tjd|d |t  q$tj	| | _
|||d | dddd| _tjjj| jjdd	 tjj| jj | j
 D ]2}t||rtjjj|jdd	 tjj|j qd S )
Nr   r   r   r   r"   r#      r   r&   )r(   r)   r   r*   r+   r,   r   r-   r.   r/   r0   bbox_regr3   r4   r5   r6   zeros_r8   r1   r2   )r<   r   r   r   r=   r0   r>   r?   r@   rB   rC   r)      s    

z RetinaNetRegressionHead.__init__rD   rE   c                 C  s~   g }t |tr|g}n|}|D ]Z}| |}| |}|| t| s\t| rt	 rnt
dqtd q|S )a|  
        It takes a list of feature maps as inputs, and outputs a list of box regression maps.
        Each output box regression map has same spatial size with the corresponding input feature map,
        and the number of output channel is num_anchors * 2 * spatial_dims.

        Args:
            x: list of feature map, x[i] is a (B, in_channels, H_i, W_i) or (B, in_channels, H_i, W_i, D_i) Tensor.

        Return:
            box_regression_maps, list of box regression map. cls_logits_maps[i] is a
            (B, num_anchors * 2 * spatial_dims, H_i, W_i) or (B, num_anchors * 2 * spatial_dims, H_i, W_i, D_i) Tensor.

        zbox_regression is NaN or Inf.)r2   r   r0   rZ   r,   r3   rH   rI   rJ   rK   rL   rM   rN   )r<   rF   Zbox_regression_mapsrO   rP   box_regressionrB   rB   rC   rQ      s    




zRetinaNetRegressionHead.forwardrR   rB   rB   r@   rC   rX      s   rX   c                      sB   e Zd ZdZdddddddd fd	d
ZdddddZ  ZS )	RetinaNeta  
    The network used in RetinaNet.

    It takes an image tensor as inputs, and outputs either 1) a dictionary ``head_outputs``.
    ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
    ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
    or 2) a list of 2N tensors ``head_outputs``, with first N tensors being the predicted
    classification maps and second N tensors being the predicted box regression maps.

    Args:
        spatial_dims: number of spatial dimensions of the images. We support both 2D and 3D images.
        num_classes: number of output classes of the model (excluding the background).
        num_anchors: number of anchors at each location.
        feature_extractor: a network that outputs feature maps from the input images,
            each feature map corresponds to a different resolution.
            Its output can have a format of Tensor, Dict[Any, Tensor], or Sequence[Tensor].
            It can be the output of ``resnet_fpn_feature_extractor(*args, **kwargs)``.
        size_divisible: the spatial size of the network input should be divisible by size_divisible,
            decided by the feature_extractor.
        use_list_output: default False. If False, the network outputs a dictionary ``head_outputs``,
            ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
            ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
            If True, the network outputs a list of 2N tensors ``head_outputs``, with first N tensors being
            the predicted classification maps and second N tensors being the predicted box regression maps.

    Example:

        .. code-block:: python

            from monai.networks.nets import resnet
            spatial_dims = 3  # 3D network
            conv1_t_stride = (2,2,1)  # stride of first convolutional layer in backbone
            backbone = resnet.ResNet(
                spatial_dims = spatial_dims,
                block = resnet.ResNetBottleneck,
                layers = [3, 4, 6, 3],
                block_inplanes = resnet.get_inplanes(),
                n_input_channels= 1,
                conv1_t_stride = conv1_t_stride,
                conv1_t_size = (7,7,7),
            )
            # This feature_extractor outputs 4-level feature maps.
            # number of output feature maps is len(returned_layers)+1
            returned_layers = [1,2,3]  # returned layer from feature pyramid network
            feature_extractor = resnet_fpn_feature_extractor(
                backbone = backbone,
                spatial_dims = spatial_dims,
                pretrained_backbone = False,
                trainable_backbone_layers = None,
                returned_layers = returned_layers,
            )
            # This feature_extractor requires input image spatial size
            # to be divisible by (32, 32, 16).
            size_divisible = tuple(2*s*2**max(returned_layers) for s in conv1_t_stride)
            model = RetinaNet(
                spatial_dims = spatial_dims,
                num_classes = 5,
                num_anchors = 6,
                feature_extractor=feature_extractor,
                size_divisible = size_divisible,
            ).to(device)
            result = model(torch.rand(2, 1, 128,128,128))
            cls_logits_maps = result["classification"]  # a list of len(returned_layers)+1 Tensor
            box_regression_maps = result["box_regression"]  # a list of len(returned_layers)+1 Tensor
    r   Fr   z	nn.ModulezSequence[int] | intbool)r   r   r   feature_extractorsize_divisibleuse_list_outputc                   s   t    t|dddgd| _|| _t|| j| _|| _t|dsJt	d|| _
| j
j| _|| _t| j| j| j| jd| _t| j| j| jd| _d| _d	| _d S )
Nr   rY   r   )	supportedout_channelszfeature_extractor should contain an attribute out_channels specifying the number of output channels (assumed to be the same for all the levels))r   classificationr\   )r(   r)   r   r   r   r   r`   ra   hasattrrL   r_   rc   Zfeature_map_channelsr   r   classification_headrX   regression_headcls_keybox_reg_key)r<   r   r   r   r_   r`   ra   r@   rB   rC   r)     s2    	


     zRetinaNet.__init__r   r   )imagesrG   c                 C  s   |  |}t|tr|g}n,tj|tttf r@t| }nt|}t|d ts^t	d| j
s| j| |i}| ||| j< |S | || | }|S dS )ag  
        It takes an image tensor as inputs, and outputs predicted classification maps
        and predicted box regression maps in ``head_outputs``.

        Args:
            images: input images, sized (B, img_channels, H, W) or (B, img_channels, H, W, D).

        Return:
            1) If self.use_list_output is False, output a dictionary ``head_outputs`` with
            keys including self.cls_key and self.box_reg_key.
            ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
            ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
            2) if self.use_list_output is True, outputs a list of 2N tensors ``head_outputs``, with first N tensors being
            the predicted classification maps and second N tensors being the predicted box regression maps.

        r   zWfeature_extractor output format must be Tensor, Dict[str, Tensor], or Sequence[Tensor].N)r_   r2   r   r3   jitr   strlistvaluesrL   ra   rh   rf   rg   ri   )r<   rj   rP   rO   head_outputsZhead_outputs_sequencerB   rB   rC   rQ   8  s    

zRetinaNet.forward)r   FrR   rB   rB   r@   rC   r]      s
   H  $r]   Fr   rY   r   zresnet.ResNetr   r^   zSequence[int]z
int | Noner	   )backboner   pretrained_backbonereturned_layerstrainable_backbone_layersrG   c                 C  s*   t ||ddd}t| ||t|dd}|S )ah
  
    Constructs a feature extractor network with a ResNet-FPN backbone, used as feature_extractor in RetinaNet.

    Reference: `"Focal Loss for Dense Object Detection" <https://arxiv.org/abs/1708.02002>`_.

    The returned feature_extractor network takes an image tensor as inputs,
    and outputs a dictionary that maps string to the extracted feature maps (Tensor).

    The input to the returned feature_extractor is expected to be a list of tensors,
    each of shape ``[C, H, W]`` or ``[C, H, W, D]``,
    one for each image. Different images can have different sizes.


    Args:
        backbone: a ResNet model, used as backbone.
        spatial_dims: number of spatial dimensions of the images. We support both 2D and 3D images.
        pretrained_backbone: whether the backbone has been pre-trained.
        returned_layers: returned layers to extract feature maps. Each returned layer should be in the range [1,4].
            len(returned_layers)+1 will be the number of extracted feature maps.
            There is an extra maxpooling layer LastLevelMaxPool() appended.
        trainable_backbone_layers: number of trainable (not frozen) resnet layers starting from final block.
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
            When pretrained_backbone is False, this value is set to be 5.
            When pretrained_backbone is True, if ``None`` is passed (the default) this value is set to 3.

    Example:

        .. code-block:: python

            from monai.networks.nets import resnet
            spatial_dims = 3 # 3D network
            backbone = resnet.ResNet(
                spatial_dims = spatial_dims,
                block = resnet.ResNetBottleneck,
                layers = [3, 4, 6, 3],
                block_inplanes = resnet.get_inplanes(),
                n_input_channels= 1,
                conv1_t_stride = (2,2,1),
                conv1_t_size = (7,7,7),
            )
            # This feature_extractor outputs 4-level feature maps.
            # number of output feature maps is len(returned_layers)+1
            feature_extractor = resnet_fpn_feature_extractor(
                backbone = backbone,
                spatial_dims = spatial_dims,
                pretrained_backbone = False,
                trainable_backbone_layers = None,
                returned_layers = [1,2,3],
            )
            model = RetinaNet(
                spatial_dims = spatial_dims,
                num_classes = 5,
                num_anchors = 6,
                feature_extractor=feature_extractor,
                size_divisible = 32,
            ).to(device)
       r   )	max_valuedefault_valueN)rs   extra_blocks)r   r
   rm   )rq   r   rr   rs   rt   Zvalid_trainable_backbone_layersr_   rB   rB   rC   resnet_fpn_feature_extractorc  s    B   ry   )Frp   N) rV   
__future__r   r:   rM   collections.abcr   r   typingr   r   r3   r   r   Z(monai.networks.blocks.backbone_fpn_utilsr	   r
   monai.networks.layers.factoriesr   monai.networks.netsr   monai.utilsr   r   r   r   r>   Moduler   rX   r]   ry   rB   rB   rB   rC   <module>#   s.    
MH    