o
    iJ                     @  s   d Z ddlmZ ddlZddlZddlmZmZ ddlm	Z	 ddl
Z
ddl
mZmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZmZmZ eddd\ZZG dd dejZG dd dejZG dd dejZ			d$d%d"d#ZdS )&z{
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
    )annotationsN)CallableSequence)Any)Tensornn)BackboneWithFPN_resnet_fpn_extractor)Conv)resnet)ensure_tuple_replook_up_optionoptional_importz+torchvision.models.detection.backbone_utils_validate_trainable_layers)namec                      s0   e Zd ZdZ	dd fd
dZdddZ  ZS )RetinaNetClassificationHeada  
    A classification head for use in RetinaNet.

    This head takes a list of feature maps as inputs, and outputs a list of classification maps.
    Each output map has same spatial size with the corresponding input feature map,
    and the number of output channel is num_anchors * num_classes.

    Args:
        in_channels: number of channels of the input feature
        num_anchors: number of anchors to be predicted
        num_classes: number of classes to be predicted
        spatial_dims: spatial dimension of the network, should be 2 or 3.
        prior_probability: prior probability to initialize classification convolutional layers.
    {Gz?in_channelsintnum_anchorsnum_classesspatial_dimsprior_probabilityfloatc           
   
     s  t    ttj|f }g }tdD ]}||||dddd |tjd|d |t  qtj	| | _
| j
 D ]}	t|	|rWtjjj|	jdd tjj|	jd	 q=|||| dddd| _tjjj| jjdd tjj| jjtd| |   || _|| _d S )
N         kernel_sizestridepadding   
num_groupsnum_channelsr   stdr   )super__init__r
   CONVrangeappendr   	GroupNormReLU
Sequentialconvchildren
isinstancetorchinitnormal_weight	constant_bias
cls_logitsmathlogr   r   )
selfr   r   r   r   r   	conv_typer/   _layer	__class__ q/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/networks/retinanet_network.pyr(   L   s$   

$
z$RetinaNetClassificationHead.__init__xlist[Tensor]returnc                 C  |   g }t |tr|g}n|}|D ],}| |}| |}|| t| s.t| r;t	 r6t
dtd q|S )ai  
        It takes a list of feature maps as inputs, and outputs a list of classification maps.
        Each output classification map has same spatial size with the corresponding input feature map,
        and the number of output channel is num_anchors * num_classes.

        Args:
            x: list of feature map, x[i] is a (B, in_channels, H_i, W_i) or (B, in_channels, H_i, W_i, D_i) Tensor.

        Return:
            cls_logits_maps, list of classification map. cls_logits_maps[i] is a
            (B, num_anchors * num_classes, H_i, W_i) or (B, num_anchors * num_classes, H_i, W_i, D_i) Tensor.

        zcls_logits is NaN or Inf.)r1   r   r/   r8   r+   r2   isnananyisinfis_grad_enabled
ValueErrorwarningswarn)r;   rC   Zcls_logits_mapsfeature_mapsfeaturesr8   rA   rA   rB   forwarde      




z#RetinaNetClassificationHead.forward)r   )
r   r   r   r   r   r   r   r   r   r   rC   rD   rE   rD   __name__
__module____qualname____doc__r(   rP   __classcell__rA   rA   r?   rB   r   <   s
    r   c                      s,   e Zd ZdZd fddZdddZ  ZS )RetinaNetRegressionHeada  
    A regression head for use in RetinaNet.

    This head takes a list of feature maps as inputs, and outputs a list of box regression maps.
    Each output box regression map has same spatial size with the corresponding input feature map,
    and the number of output channel is num_anchors * 2 * spatial_dims.

    Args:
        in_channels: number of channels of the input feature
        num_anchors: number of anchors to be predicted
        spatial_dims: spatial dimension of the network, should be 2 or 3.
    r   r   r   r   c              
     s   t    ttj|f }g }tdD ]}||||dddd |tjd|d |t  qtj	| | _
|||d | dddd| _tjjj| jjdd	 tjj| jj | j
 D ]}t||rxtjjj|jdd	 tjj|j q_d S )
Nr   r   r   r   r!   r"      r   r%   )r'   r(   r
   r)   r*   r+   r   r,   r-   r.   r/   bbox_regr2   r3   r4   r5   zeros_r7   r0   r1   )r;   r   r   r   r<   r/   r=   r>   r?   rA   rB   r(      s"   

z RetinaNetRegressionHead.__init__rC   rD   rE   c                 C  rF   )a|  
        It takes a list of feature maps as inputs, and outputs a list of box regression maps.
        Each output box regression map has same spatial size with the corresponding input feature map,
        and the number of output channel is num_anchors * 2 * spatial_dims.

        Args:
            x: list of feature map, x[i] is a (B, in_channels, H_i, W_i) or (B, in_channels, H_i, W_i, D_i) Tensor.

        Return:
            box_regression_maps, list of box regression map. cls_logits_maps[i] is a
            (B, num_anchors * 2 * spatial_dims, H_i, W_i) or (B, num_anchors * 2 * spatial_dims, H_i, W_i, D_i) Tensor.

        zbox_regression is NaN or Inf.)r1   r   r/   r[   r+   r2   rG   rH   rI   rJ   rK   rL   rM   )r;   rC   Zbox_regression_mapsrN   rO   box_regressionrA   rA   rB   rP      rQ   zRetinaNetRegressionHead.forward)r   r   r   r   r   r   rR   rS   rA   rA   r?   rB   rY      s    rY   c                      s2   e Zd ZdZ		dd fddZdddZ  ZS )	RetinaNeta  
    The network used in RetinaNet.

    It takes an image tensor as inputs, and outputs either 1) a dictionary ``head_outputs``.
    ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
    ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
    or 2) a list of 2N tensors ``head_outputs``, with first N tensors being the predicted
    classification maps and second N tensors being the predicted box regression maps.

    Args:
        spatial_dims: number of spatial dimensions of the images. We support both 2D and 3D images.
        num_classes: number of output classes of the model (excluding the background).
        num_anchors: number of anchors at each location.
        feature_extractor: a network that outputs feature maps from the input images,
            each feature map corresponds to a different resolution.
            Its output can have a format of Tensor, Dict[Any, Tensor], or Sequence[Tensor].
            It can be the output of ``resnet_fpn_feature_extractor(*args, **kwargs)``.
        size_divisible: the spatial size of the network input should be divisible by size_divisible,
            decided by the feature_extractor.
        use_list_output: default False. If False, the network outputs a dictionary ``head_outputs``,
            ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
            ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
            If True, the network outputs a list of 2N tensors ``head_outputs``, with first N tensors being
            the predicted classification maps and second N tensors being the predicted box regression maps.

    Example:

        .. code-block:: python

            from monai.networks.nets import resnet
            spatial_dims = 3  # 3D network
            conv1_t_stride = (2,2,1)  # stride of first convolutional layer in backbone
            backbone = resnet.ResNet(
                spatial_dims = spatial_dims,
                block = resnet.ResNetBottleneck,
                layers = [3, 4, 6, 3],
                block_inplanes = resnet.get_inplanes(),
                n_input_channels= 1,
                conv1_t_stride = conv1_t_stride,
                conv1_t_size = (7,7,7),
            )
            # This feature_extractor outputs 4-level feature maps.
            # number of output feature maps is len(returned_layers)+1
            returned_layers = [1,2,3]  # returned layer from feature pyramid network
            feature_extractor = resnet_fpn_feature_extractor(
                backbone = backbone,
                spatial_dims = spatial_dims,
                pretrained_backbone = False,
                trainable_backbone_layers = None,
                returned_layers = returned_layers,
            )
            # This feature_extractor requires input image spatial size
            # to be divisible by (32, 32, 16).
            size_divisible = tuple(2*s*2**max(returned_layers) for s in conv1_t_stride)
            model = RetinaNet(
                spatial_dims = spatial_dims,
                num_classes = 5,
                num_anchors = 6,
                feature_extractor=feature_extractor,
                size_divisible = size_divisible,
            ).to(device)
            result = model(torch.rand(2, 1, 128,128,128))
            cls_logits_maps = result["classification"]  # a list of len(returned_layers)+1 Tensor
            box_regression_maps = result["box_regression"]  # a list of len(returned_layers)+1 Tensor
    r   Fr   r   r   r   feature_extractor	nn.Modulesize_divisibleSequence[int] | intuse_list_outputboolc                   s   t    t|g dd| _|| _t|| j| _|| _t|ds$t	d|| _
| j
j| _|| _t| j| j| j| jd| _t| j| j| jd| _d| _d| _d S )Nr   rZ   r   )	supportedout_channelszfeature_extractor should contain an attribute out_channels specifying the number of output channels (assumed to be the same for all the levels))r   classificationr]   )r'   r(   r   r   r   r   ra   rc   hasattrrK   r_   rg   Zfeature_map_channelsr   r   classification_headrY   regression_headcls_keybox_reg_key)r;   r   r   r   r_   ra   rc   r?   rA   rB   r(     s(   
	


zRetinaNet.__init__imagesr   rE   r   c                 C  s   |  |}t|tr|g}ntj|tttf r t| }nt|}t|d ts/t	d| j
sD| j| |i}| ||| j< |S | || | }|S )ag  
        It takes an image tensor as inputs, and outputs predicted classification maps
        and predicted box regression maps in ``head_outputs``.

        Args:
            images: input images, sized (B, img_channels, H, W) or (B, img_channels, H, W, D).

        Return:
            1) If self.use_list_output is False, output a dictionary ``head_outputs`` with
            keys including self.cls_key and self.box_reg_key.
            ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.
            ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.
            2) if self.use_list_output is True, outputs a list of 2N tensors ``head_outputs``, with first N tensors being
            the predicted classification maps and second N tensors being the predicted box regression maps.

        r   zWfeature_extractor output format must be Tensor, Dict[str, Tensor], or Sequence[Tensor].)r_   r1   r   r2   jitdictstrlistvaluesrK   rc   rl   rj   rk   rm   )r;   rn   rO   rN   head_outputsZhead_outputs_sequencerA   rA   rB   rP   8  s   

zRetinaNet.forward)r   F)r   r   r   r   r   r   r_   r`   ra   rb   rc   rd   )rn   r   rE   r   rS   rA   rA   r?   rB   r^      s    H$r^   Fre   backboneresnet.ResNetr   r   pretrained_backbonerd   returned_layersSequence[int]trainable_backbone_layers
int | NonerE   r   c                 C  s*   t ||ddd}t| ||t|dd}|S )ah
  
    Constructs a feature extractor network with a ResNet-FPN backbone, used as feature_extractor in RetinaNet.

    Reference: `"Focal Loss for Dense Object Detection" <https://arxiv.org/abs/1708.02002>`_.

    The returned feature_extractor network takes an image tensor as inputs,
    and outputs a dictionary that maps string to the extracted feature maps (Tensor).

    The input to the returned feature_extractor is expected to be a list of tensors,
    each of shape ``[C, H, W]`` or ``[C, H, W, D]``,
    one for each image. Different images can have different sizes.


    Args:
        backbone: a ResNet model, used as backbone.
        spatial_dims: number of spatial dimensions of the images. We support both 2D and 3D images.
        pretrained_backbone: whether the backbone has been pre-trained.
        returned_layers: returned layers to extract feature maps. Each returned layer should be in the range [1,4].
            len(returned_layers)+1 will be the number of extracted feature maps.
            There is an extra maxpooling layer LastLevelMaxPool() appended.
        trainable_backbone_layers: number of trainable (not frozen) resnet layers starting from final block.
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
            When pretrained_backbone is False, this value is set to be 5.
            When pretrained_backbone is True, if ``None`` is passed (the default) this value is set to 3.

    Example:

        .. code-block:: python

            from monai.networks.nets import resnet
            spatial_dims = 3 # 3D network
            backbone = resnet.ResNet(
                spatial_dims = spatial_dims,
                block = resnet.ResNetBottleneck,
                layers = [3, 4, 6, 3],
                block_inplanes = resnet.get_inplanes(),
                n_input_channels= 1,
                conv1_t_stride = (2,2,1),
                conv1_t_size = (7,7,7),
            )
            # This feature_extractor outputs 4-level feature maps.
            # number of output feature maps is len(returned_layers)+1
            feature_extractor = resnet_fpn_feature_extractor(
                backbone = backbone,
                spatial_dims = spatial_dims,
                pretrained_backbone = False,
                trainable_backbone_layers = None,
                returned_layers = [1,2,3],
            )
            model = RetinaNet(
                spatial_dims = spatial_dims,
                num_classes = 5,
                num_anchors = 6,
                feature_extractor=feature_extractor,
                size_divisible = 32,
            ).to(device)
       r   )	max_valuedefault_valueN)rx   extra_blocks)r   r	   rr   )ru   r   rw   rx   rz   Zvalid_trainable_backbone_layersr_   rA   rA   rB   resnet_fpn_feature_extractorc  s   Br   )Fre   N)ru   rv   r   r   rw   rd   rx   ry   rz   r{   rE   r   )rW   
__future__r   r9   rL   collections.abcr   r   typingr   r2   r   r   Z(monai.networks.blocks.backbone_fpn_utilsr   r	   monai.networks.layers.factoriesr
   monai.networks.netsr   monai.utilsr   r   r   r   r=   Moduler   rY   r^   r   rA   rA   rA   rB   <module>   s.   "
MH 