U
    Ph:)                     @  s   d Z ddlmZ ddlmZ ddlmZ ddlm  m	Z
 ddlmZmZ ddlmZmZ dd	d
dgZG dd dejZG dd	 d	eZG dd
 d
eZG dd dejZdS )z
This script is modified from from torchvision to support N-D images,
by overriding the definition of convolutional layers and pooling layers.

https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py
    )annotations)OrderedDict)CallableN)Tensornn)ConvPoolExtraFPNBlockLastLevelMaxPoolLastLevelP6P7FeaturePyramidNetworkc                   @  s"   e Zd ZdZddddddZdS )r	   z
    Base class for the extra block in the FPN.

    Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py
    list[Tensor]	list[str])resultsxnamesc                 C  s   dS )av  
        Compute extended set of results of the FPN and their names.

        Args:
            results: the result of the FPN
            x: the original feature maps
            names: the names for each one of the original feature maps

        Returns:
            - the extended set of results of the FPN
            - the extended set of names for the results
        N selfr   r   r   r   r   b/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/feature_pyramid_network.pyforwardI   s    zExtraFPNBlock.forwardN)__name__
__module____qualname____doc__r   r   r   r   r   r	   B   s   c                      s:   e Zd ZdZdd fddZddddd	d
dZ  ZS )r
   z
    Applies a max_pool2d or max_pool3d on top of the last feature map. Serves as an ``extra_blocks``
    in :class:`~monai.networks.blocks.feature_pyramid_network.FeaturePyramidNetwork` .
    int)spatial_dimsc                   s,   t    ttj|f }|dddd| _d S )N      r   kernel_sizestridepadding)super__init__r   MAXmaxpool)r   r   	pool_type	__class__r   r   r$   _   s    
zLastLevelMaxPool.__init__r   r   tuple[list[Tensor], list[str]]r   r   r   returnc                 C  s&   | d | | |d  ||fS )Npool)appendr&   r   r   r   r   r   d   s    
zLastLevelMaxPool.forwardr   r   r   r   r$   r   __classcell__r   r   r(   r   r
   Y   s   c                      s>   e Zd ZdZdddd fddZddddd	d
dZ  ZS )r   z
    This module is used in RetinaNet to generate extra layers, P6 and P7.
    Serves as an ``extra_blocks``
    in :class:`~monai.networks.blocks.feature_pyramid_network.FeaturePyramidNetwork` .
    r   )r   in_channelsout_channelsc                   s   t    ttj|f }|||dddd| _|||dddd| _| j| jfD ]&}tjj|j	dd tj
|jd qL||k| _d S )N   r   r   r   ar   )r#   r$   r   CONVp6p7r   initkaiming_uniform_weight	constant_biasuse_P5)r   r   r2   r3   	conv_typemoduler(   r   r   r$   q   s    
zLastLevelP6P7.__init__r   r   r*   r+   c           	      C  s^   |d |d  }}| j r|n|}| |}| t|}|||g |ddg ||fS )Nr.   r8   r9   )r?   r8   r9   Freluextend)	r   r   r   r   Zp5c5x5r8   r9   r   r   r   r   {   s    
zLastLevelP6P7.forwardr0   r   r   r(   r   r   j   s   
c                      sb   e Zd ZdZdddddd fddZd	dd	d
ddZd	dd	d
ddZdddddZ  ZS )r   a  
    Module that adds a FPN from on top of a set of feature maps. This is based on
    `"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.

    The feature maps are currently supposed to be in increasing depth
    order.

    The input to the model is expected to be an OrderedDict[Tensor], containing
    the feature maps on top of which the FPN will be added.

    Args:
        spatial_dims: 2D or 3D images
        in_channels_list: number of channels for each feature map that
            is passed to the module
        out_channels: number of channels of the FPN representation
        extra_blocks: if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names

    Examples::

        >>> m = FeaturePyramidNetwork(2, [10, 20, 30], 5)
        >>> # get some dummy data
        >>> x = OrderedDict()
        >>> x['feat0'] = torch.rand(1, 10, 64, 64)
        >>> x['feat2'] = torch.rand(1, 20, 16, 16)
        >>> x['feat3'] = torch.rand(1, 30, 8, 8)
        >>> # compute the FPN on top of x
        >>> output = m(x)
        >>> print([(k, v.shape) for k, v in output.items()])
        >>> # returns
        >>>   [('feat0', torch.Size([1, 5, 64, 64])),
        >>>    ('feat2', torch.Size([1, 5, 16, 16])),
        >>>    ('feat3', torch.Size([1, 5, 8, 8]))]

    Nr   z	list[int]zExtraFPNBlock | None)r   in_channels_listr3   extra_blocksc                   s   t    ttj|f }t | _t | _|D ]H}|dkrDtd|||d}|||ddd}| j	| | j	| q0ttj|f }	| 
 D ]0}
t|
|	rtjj|
jdd tj|
jd q|d k	rt|tst|| _d S )Nr   z(in_channels=0 is currently not supportedr   r4   )r"   r5   g        )r#   r$   r   r7   r   
ModuleListinner_blockslayer_blocks
ValueErrorr/   modules
isinstancer:   r;   r<   r=   r>   r	   AssertionErrorrH   )r   r   rG   r3   rH   r@   r2   Zinner_block_moduleZlayer_block_moduleZ
conv_type_mr(   r   r   r$      s(    




zFeaturePyramidNetwork.__init__r   )r   idxr,   c                 C  sF   t | j}|dk r||7 }|}t| jD ]\}}||kr(||}q(|S )zs
        This is equivalent to self.inner_blocks[idx](x),
        but torchscript doesn't support this yet
        r   )lenrJ   	enumerater   r   rQ   Z
num_blocksoutirA   r   r   r   get_result_from_inner_blocks   s    

z2FeaturePyramidNetwork.get_result_from_inner_blocksc                 C  sF   t | j}|dk r||7 }|}t| jD ]\}}||kr(||}q(|S )zs
        This is equivalent to self.layer_blocks[idx](x),
        but torchscript doesn't support this yet
        r   )rR   rK   rS   rT   r   r   r   get_result_from_layer_blocks   s    

z2FeaturePyramidNetwork.get_result_from_layer_blockszdict[str, Tensor])r   r,   c                 C  s   t | }t | }| |d d}g }|| |d tt|d ddD ]N}| || |}|jdd }t	j
||dd}	||	 }|d| || qR| jdk	r| |||\}}tt t||}
|
S )z
        Computes the FPN for a set of feature maps.

        Args:
            x: feature maps for each feature level.

        Returns:
            feature maps after FPN layers. They are ordered from highest resolution first.
        r.   r   Nnearest)sizemoder   )listkeysvaluesrW   r/   rX   rangerR   shaperB   interpolateinsertrH   r   zip)r   r   r   x_valuesZ
last_innerr   rQ   Zinner_lateralZ
feat_shapeZinner_top_downrU   r   r   r   r      s    
zFeaturePyramidNetwork.forward)N)	r   r   r   r   r$   rW   rX   r   r1   r   r   r(   r   r      s   + !)r   
__future__r   collectionsr   collections.abcr   torch.nn.functionalr   
functionalrB   torchr   monai.networks.layers.factoriesr   r   __all__Moduler	   r
   r   r   r   r   r   r   <module>.   s   